In [1]:
# RNN which will (hopefully) generate a random chat between 2 people

In [2]:
# get and unzip dataset - only need to run once
# dataset = 'projjal1/human-conversation-training-data'
# kaggle.api.dataset_download_files(dataset, path=".", unzip=True)

In [11]:
# # for kaggle
# # import kaggle
# for work
import torch
# import nltk
# nltk.download('punkt')
from nltk.tokenize import word_tokenize
from collections import Counter
import numpy as np

In [12]:
# building vocabulary in token form
# words = [line.lower().split() for line in lines] # using nltk tokens for better splitting (fine! vs fine)
words = word_tokenize(text.lower())
print(words[:100])
print("Total words: ",len(words))
print("Unique words: ", len(set((words))))

['human', '1', ':', 'hi', '!', 'human', '2', ':', 'what', 'is', 'your', 'favorite', 'holiday', '?', 'human', '1', ':', 'one', 'where', 'i', 'get', 'to', 'meet', 'lots', 'of', 'different', 'people', '.', 'human', '2', ':', 'what', 'was', 'the', 'most', 'number', 'of', 'people', 'you', 'have', 'ever', 'met', 'during', 'a', 'holiday', '?', 'human', '1', ':', 'hard', 'to', 'keep', 'a', 'count', '.', 'maybe', '25.', 'human', '2', ':', 'which', 'holiday', 'was', 'that', '?', 'human', '1', ':', 'i', 'think', 'it', 'was', 'australia', 'human', '2', ':', 'do', 'you', 'still', 'talk', 'to', 'the', 'people', 'you', 'met', '?', 'human', '1', ':', 'not', 'really', '.', 'the', 'interactions', 'are', 'usually', 'short-lived', 'but', 'it', "'s"]
Total words:  27943
Unique words:  2813


In [13]:
# reading the data line by line into list
text = None
with open('human_chat.txt','r',encoding='utf-8') as dataset_file:
    text = dataset_file.read()
print("Total characters in text: ",len(text))
# print(text)

Total characters in text:  115782


In [14]:
vocabulary = sorted(set(words))
print(len(vocabulary))

2813


In [15]:
# create lookup table for word to in and reverse
word2int_mapping = {word:i for i,word in enumerate(vocabulary)}
word_array = np.array(vocabulary)
# reverse_vocabulary = {i+1: word for i, (word, _) in enumerate(word_counts.items())} # reverse as well
# print(len(vocabulary) + 1)
# print(vocabulary)
# # print(reverse_vocabulary)

In [16]:
# encoding the sentences according to the vocabulary
encoded_lines = np.array([word2int_mapping[word] for word in words], dtype=np.int32)
print(encoded_lines)
print("Encoded lines shape:",encoded_lines.shape)
print("Words", words[:20])
print("Encoding,", encoded_lines[:20])
print("Reverse conversion: ", ' '.join(word_array[encoded_lines[:20]]))

[1187   25   56 ... 1574 1281   20]
Encoded lines shape: (27943,)
Words ['human', '1', ':', 'hi', '!', 'human', '2', ':', 'what', 'is', 'your', 'favorite', 'holiday', '?', 'human', '1', ':', 'one', 'where', 'i']
Encoding, [1187   25   56 1134    0 1187   33   56 2708 1278 2783  876 1159   60
 1187   25   56 1664 2713 1203]
Reverse conversion:  human 1 : hi ! human 2 : what is your favorite holiday ? human 1 : one where i


In [17]:
# make sequences and chunks
sequence_size = 50
chunk_size = sequence_size + 1

text_chunks = [encoded_lines[i:i+chunk_size] for i in range(len(encoded_lines)-chunk_size+1)]
print(text_chunks[:1],end="\n\n")

for seq in text_chunks[:1]:
    input_seq = seq[:sequence_size]
    target = seq[sequence_size]
    print(input_seq,'->',target)
    print(repr(' '.join(word_array[input_seq])),'->',repr(''.join(word_array[target])))

[array([1187,   25,   56, 1134,    0, 1187,   33,   56, 2708, 1278, 2783,
        876, 1159,   60, 1187,   25,   56, 1664, 2713, 1203, 1009, 2501,
       1491, 1427, 1642,  688, 1762,   20, 1187,   33,   56, 2708, 2681,
       2461, 1556, 1628, 1642, 1762, 2781, 1103,  812, 1512,  749,   63,
       1159,   60, 1187,   25,   56, 1095, 2501], dtype=int32)]

[1187   25   56 1134    0 1187   33   56 2708 1278 2783  876 1159   60
 1187   25   56 1664 2713 1203 1009 2501 1491 1427 1642  688 1762   20
 1187   33   56 2708 2681 2461 1556 1628 1642 1762 2781 1103  812 1512
  749   63 1159   60 1187   25   56 1095] -> 2501
'human 1 : hi ! human 2 : what is your favorite holiday ? human 1 : one where i get to meet lots of different people . human 2 : what was the most number of people you have ever met during a holiday ? human 1 : hard' -> 'to'


In [18]:
# convert into proper pytorch dataset form

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, text_chunks):
        self.text_chunks = text_chunks
    def __len__(self):
        return len(self.text_chunks)
    def __getitem__(self,idx):
        text_chunk = self.text_chunks[idx]
        return text_chunk[:-1].long(), text_chunk[1:].long()

seq_dataset = TextDataset(torch.tensor((text_chunks)))

  seq_dataset = TextDataset(torch.tensor((text_chunks)))


In [19]:
# checking out our get item chunks functionality
for i,(seq,target) in enumerate(seq_dataset):
    print("Input:",repr(' '.join(word_array[seq])))
    print("Target:",repr(' '.join(word_array[target])))
    print()
    if i == 2:
        break

Input: 'human 1 : hi ! human 2 : what is your favorite holiday ? human 1 : one where i get to meet lots of different people . human 2 : what was the most number of people you have ever met during a holiday ? human 1 : hard'
Target: '1 : hi ! human 2 : what is your favorite holiday ? human 1 : one where i get to meet lots of different people . human 2 : what was the most number of people you have ever met during a holiday ? human 1 : hard to'

Input: '1 : hi ! human 2 : what is your favorite holiday ? human 1 : one where i get to meet lots of different people . human 2 : what was the most number of people you have ever met during a holiday ? human 1 : hard to'
Target: ': hi ! human 2 : what is your favorite holiday ? human 1 : one where i get to meet lots of different people . human 2 : what was the most number of people you have ever met during a holiday ? human 1 : hard to keep'

Input: ': hi ! human 2 : what is your favorite holiday ? human 1 : one where i get to meet lots of differe

In [20]:
# create a dataloader
batch_size = 32
dataloader = torch.utils.data.DataLoader(seq_dataset,batch_size=batch_size,shuffle=True,drop_last=True)

In [21]:
# this code tells what to use, not neccessary if u dont have a gpu to run on
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda:0


In [22]:
# my rnn
class ChatGeneratingRNN(torch.nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embed_size)
        self.rnn_hidden_size = hidden_size
        self.rnn = torch.nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, vocab_size)

    def forward(self, inputs, hidden, cell):
        layer_output = self.embedding(inputs).unsqueeze(1)
        layer_output, (hidden,cell) = self.rnn(layer_output,(hidden,cell))
        layer_output = self.fc(layer_output).reshape(layer_output.size(0),-1)
        return layer_output, hidden, cell

    def init_hidden(self, batch_size):
        hidden = torch.zeros(1,batch_size,self.rnn_hidden_size)
        cell = torch.zeros(1,batch_size,self.rnn_hidden_size)
        return hidden.to(DEVICE), cell.to(DEVICE)

In [23]:
# Hyperparameters
vocab_size = len(word_array)
embed_size = 256
hidden_size = 512

In [24]:
# creating model, setting output activation and loss
model = ChatGeneratingRNN(vocab_size, embed_size, hidden_size)
model = model.to(DEVICE)
model

ChatGeneratingRNN(
  (embedding): Embedding(2813, 256)
  (rnn): LSTM(256, 512, batch_first=True)
  (fc): Linear(in_features=512, out_features=2813, bias=True)
)

In [25]:
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [36]:
# training the model
training_epochs = 5001
model.to(DEVICE)
model.train()
for epoch in range(training_epochs):
    hidden,cell = model.init_hidden(batch_size)
    seq_batch, target_batch = next(iter(dataloader))
    seq_batch = seq_batch.to(DEVICE)
    target_batch = target_batch.to(DEVICE)
    optimizer.zero_grad()
    loss = 0
    for w in range(sequence_size):
        pred, hidden, cell = model(seq_batch[:,w],hidden,cell)
        loss += loss_function(pred, target_batch[:,w])
    loss.backward()
    optimizer.step()
    loss = loss.item()/sequence_size
    if epoch % 500 == 0:
      print(f'Epoch [{epoch}/{training_epochs}], Loss: {loss:.4f}')

Epoch [0/5001], Loss: 0.1805
Epoch [500/5001], Loss: 0.1706
Epoch [1000/5001], Loss: 0.1361
Epoch [1500/5001], Loss: 0.1916
Epoch [2000/5001], Loss: 0.1714
Epoch [2500/5001], Loss: 0.1692
Epoch [3000/5001], Loss: 0.1872
Epoch [3500/5001], Loss: 0.1980
Epoch [4000/5001], Loss: 0.1551
Epoch [4500/5001], Loss: 0.1749
Epoch [5000/5001], Loss: 0.1878


In [57]:
def top_p_sampling(logits, temperature=1.0, top_p=0.9):
    # Ensure logits are a PyTorch tensor and move to DEVICE

    # Apply temperature scaling
    scaled_logits = logits / temperature

    # Convert logits to probabilities using softmax
    probabilities = torch.softmax(scaled_logits, dim=-1)

    # Sort probabilities and compute cumulative sum
    sorted_indices = torch.argsort(probabilities, descending=True)
    sorted_probabilities = probabilities[sorted_indices]
    cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1)

    # Apply top-p filtering
    indices_to_keep = cumulative_probabilities <= top_p
    truncated_probabilities = sorted_probabilities[indices_to_keep]

    # Rescale the probabilities
    truncated_probabilities /= torch.sum(truncated_probabilities)

    # Convert to numpy arrays for random choice
    truncated_probabilities = truncated_probabilities.cpu().numpy()
    sorted_indices = sorted_indices.cpu().numpy()
    indices_to_keep = indices_to_keep.cpu().numpy()

    # Sample from the truncated distribution
    if not indices_to_keep.any():
        # Handle the empty case - for example, using regular sampling without top-p
        probabilities = torch.softmax(logits / temperature, dim=-1)
        next_word_index = torch.multinomial(probabilities, 1).item()
    else:
        # Existing sampling process
        next_word_index = np.random.choice(sorted_indices[indices_to_keep], p=truncated_probabilities)

    return torch.tensor(next_word_index).to(DEVICE)

In [58]:
def generate(model, seed_string, len_generated_text=50, temperature=1.0, top_p=0.95):
    seed_tokens = word_tokenize(seed_string.lower())
    encoded_input = torch.tensor([word2int_mapping[t] for t in seed_tokens])
    encoded_input = torch.reshape(encoded_input, (1, -1)).to(DEVICE)

    generated_str = seed_string

    model.eval()
    with torch.inference_mode():
        hidden, cell = model.init_hidden(1)
        hidden = hidden.to(DEVICE)
        cell = cell.to(DEVICE)

        for w in range(len(seed_tokens) - 1):
            _, hidden, cell = model(encoded_input[:, w].view(1), hidden, cell)

        last_word = encoded_input[:, -1]
        for i in range(len_generated_text):
            logits, hidden, cell = model(last_word.view(1), hidden, cell)
            logits = torch.squeeze(logits, 0)
            last_word = top_p_sampling(logits.cpu(), temperature, top_p)  # Ensure logits is on CPU
            generated_str += " " + str(word_array[last_word])

    return generated_str.replace(" . ", ". ")

In [59]:
model.to(DEVICE)
print(generate(model, seed_string='Hello how'))

Hello how are you doing ? human 1 : i 'm great , thanks. i 'm getting ready for a skydiving lesson. human 2 : ooh , nice. that sounds adventurous. where is it ? human 1 : right near my home town : seville , spain .
