In [1]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data
import math
import torch.nn.functional as F

torch.seed()

14691081174242961837

In [2]:
movie_conv = "data/movie_conversations.txt"
movie_lines = "data/movie_lines.txt"

max_len = 20
device = "cuda"

In [3]:
with open(movie_conv, 'r') as c:
    conv = c.readlines()

with open(movie_lines, 'r', encoding="ISO-8859-1") as l:
    lines = l.readlines()

In [4]:
lines_dict = {}

for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dict[objects[0]] = objects[-1]

In [5]:
def remove_punc(string):
    punctuations = '''-[];'"\,<>./:?@#{}$%^&!*()_~'''
    # punctuations = ""
    no_punc = ""
    for char in string:
        if char not in punctuations:
            no_punc = no_punc + char
    return no_punc.lower()

In [6]:
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1])
    for i in range(len(ids)):
        qa_pairs = []
        
        if i==len(ids)-1:
            break
        
        first = remove_punc(lines_dict[ids[i]].strip())      
        second = remove_punc(lines_dict[ids[i+1]].strip())
        qa_pairs.append(first.split()[:max_len])
        qa_pairs.append(second.split()[:max_len])
        pairs.append(qa_pairs)

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f7555d757d0>>
Traceback (most recent call last):
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


In [None]:
word_freq = Counter()
for pair in pairs:
    word_freq.update(pair[0])
    word_freq.update(pair[1])

In [None]:
min_word_freq = 20
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word_map = {k: v + 1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

In [None]:
print("Word count: {}".format(len(word_map)))

Word count: 7322


In [None]:
with open('json_preprocessed_data/WORDMAP_corpus.json', 'w') as j:
    json.dump(word_map, j)

In [None]:
def encode_q(words, word_map):
    enc_c = [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

In [None]:
def encode_r(words, word_map):
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

In [None]:
pairs_encoded = []
for pair in pairs:
    qus = encode_q(pair[0], word_map)
    ans = encode_r(pair[1], word_map)
    pairs_encoded.append([qus, ans])

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f0ca90ca350>>
Traceback (most recent call last):
  File "/home/maxim/anaconda3/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


In [None]:
with open('json_preprocessed_data/pairs_encoded.json', 'w') as p:
    json.dump(pairs_encoded, p)

In [None]:
class Dataset(Dataset):
    def __init__(self):
        self.pairs = json.load(open('json_preprocessed_data/pairs_encoded.json'))
        self.dataset_size = len(self.pairs)

    def __getitem__(self, i):
        question = torch.LongTensor(self.pairs[i][0])
        reply = torch.LongTensor(self.pairs[i][1])
            
        return question, reply

    def __len__(self):
        return self.dataset_size

In [None]:
train_loader = torch.utils.data.DataLoader(Dataset(),
                                           batch_size = 64, 
                                           shuffle=True, 
                                           pin_memory=True,
                                           num_workers=20)

In [None]:
def create_masks(question, reply_input, reply_target):
    def next_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    question_mask = question!=0
    question_mask = question_mask.to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)
     
    reply_input_mask = reply_input!=0
    reply_input_mask = reply_input_mask.unsqueeze(1)
    reply_input_mask = reply_input_mask & next_mask(reply_input.size(-1)).type_as(reply_input_mask.data) 
    reply_input_mask = reply_input_mask.unsqueeze(1)
    reply_target_mask = reply_target!=0
    
    return question_mask, reply_input_mask, reply_target_mask


In [None]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size, d_model, max_len = 35, num_layers = 6):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(0.1)
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pe = self.create_positinal_encoding(max_len, self.d_model)
        self.te = self.create_positinal_encoding(num_layers, self.d_model)
        self.dropout = nn.Dropout(0.1)
        
    def create_positinal_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model).to(device)
        for pos in range(max_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
        pe = pe.unsqueeze(0)
        return pe
        
    def forward(self, embedding, layer_idx):
        if layer_idx == 0:
            embedding = self.embed(embedding) * math.sqrt(self.d_model)
        embedding += self.pe[:, :embedding.size(1)]
        embedding += self.te[:, layer_idx, :].unsqueeze(1).repeat(1, embedding.size(1), 1)
        embedding = self.dropout(embedding)
        return embedding

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model):
        super().__init__()

        assert d_model % heads == 0

        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = nn.Dropout(0.1)

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.concat = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask):
        query = self.query(query)
        key = self.key(key)        
        value = self.value(value)   
        
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)   
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        
        scores = torch.matmul(query, key.permute(0,1,3,2)) / math.sqrt(query.size(-1))
        scores = scores.masked_fill(mask == 0, -1e9)

        weights = F.softmax(scores, dim = -1)
        weights = self.dropout(weights)

        context = torch.matmul(weights, value)
        context = context.permute(0,2,1,3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
        
        interact = self.concat(context)
        return interact

In [None]:
class FF(nn.Module):
    def __init__(self, d_model, middle_dim = 2048):
        super().__init__()
        
        self.fc1 = nn.Linear(d_model, middle_dim)
        self.fc2 = nn.Linear(middle_dim, d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

In [None]:
class EncodeLayer(nn.Module):
    def __init__(self, d_model, heads):
        super().__init__()
        
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FF(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, embeddings, mask):
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        interacted = self.layernorm(interacted + embeddings)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

In [None]:
class DecodeLayer(nn.Module):
    def __init__(self, d_model, heads):
        super().__init__()
        
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.src_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FF(d_model)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, embeddings, encoded, src_mask, target_mask):
        query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask))
        query = self.layernorm(query + embeddings)
        interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
        interacted = self.layernorm(interacted + query)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        decoded = self.layernorm(feed_forward_out + interacted)
        return decoded

In [None]:
class Model(nn.Module):
    def __init__(self, d_model, heads, num_layers, word_map):
        super().__init__()
        
        self.d_model = d_model
        self.num_layers = num_layers
        self.vocab_size = len(word_map)
        self.embed = Embeddings(self.vocab_size, d_model, num_layers = num_layers)
        self.encoder = EncodeLayer(d_model, heads) 
        self.decoder = DecodeLayer(d_model, heads)
        self.logit = nn.Linear(d_model, self.vocab_size)
        
    def encode(self, src_embeddings, src_mask):
        for i in range(self.num_layers):
            src_embeddings = self.embed(src_embeddings, i)
            src_embeddings = self.encoder(src_embeddings, src_mask)
        return src_embeddings
    
    def decode(self, tgt_embeddings, target_mask, src_embeddings, src_mask):
        for i in range(self.num_layers):
            tgt_embeddings = self.embed(tgt_embeddings, i)
            tgt_embeddings = self.decoder(tgt_embeddings, src_embeddings, src_mask, target_mask)
        return tgt_embeddings
        
    def forward(self, src_words, src_mask, target_words, target_mask):
        encoded = self.encode(src_words, src_mask)
        decoded = self.decode(target_words, target_mask, encoded, src_mask)
        out = F.log_softmax(self.logit(decoded), dim = 2)
        return out

In [None]:
class AdamOptimize:
    def __init__(self, model_size, warmup_steps, optimizer):
        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0
        
    def get_lr(self):
        return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
        
    def step(self):
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.lr = lr
        self.optimizer.step()       

In [None]:
class LossWithLS(nn.Module):
    def __init__(self, size, smooth):
        super().__init__()
        self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
        self.confidence = 1.0 - smooth
        self.smooth = smooth
        self.size = size
        
    def forward(self, prediction, target, mask):
        prediction = prediction.view(-1, prediction.size(-1))
        target = target.contiguous().view(-1)
        mask = mask.float()
        mask = mask.view(-1)
        labels = prediction.data.clone()
        labels.fill_(self.smooth / (self.size - 1))
        labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
        loss = self.criterion(prediction, labels)
        loss = (loss.sum(1) * mask).sum() / mask.sum()
        return loss

In [None]:
d_model = 128 #maybe 128
heads = 4
num_layers = 2
device = "cuda"
epochs = 1000

with open('json_preprocessed_data/WORDMAP_corpus.json', 'r') as j:
    word_map = json.load(j)

# transformer = Model(d_model, heads, num_layers, word_map)
transformer = torch.load("checkpoint_999|time: 04|02|2024 20:36:55.pt") 
transformer = transformer.to(device)
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamOptimize(d_model, 4000, adam_optimizer)
criterion = LossWithLS(len(word_map), 0.2)



In [None]:
def train(train_loader, transformer, criterion, epoch):
    transformer.train()
    sum_loss = 0
    count = 0

    for i, (question, reply) in enumerate(train_loader):
        samples = question.shape[0]

        question = question.to(device)
        reply = reply.to(device)

        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]

        question_mask, reply_input_mask, reply_target_mask = create_masks(question, reply_input, reply_target)

        out = transformer(question, question_mask, reply_input, reply_input_mask)

        loss = criterion(out, reply_target, reply_target_mask)
        
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()
        
        sum_loss += loss.item() * samples
        count += samples
        
        if i % 100 == 0:
            print("Epoch [{}][{}/{}]\tLoss: {:.3f}".format(epoch, i, len(train_loader), sum_loss/count))

In [None]:
def evaluate(transformer, question, question_mask, max_len, word_map):
    rev_word_map = {v: k for k, v in word_map.items()}
    transformer.eval()
    start_token = word_map['<start>']
    encoded = transformer.encode(question, question_mask)
    words = torch.LongTensor([[start_token]]).to(device)
    
    for step in range(max_len - 1):
        size = words.shape[1]
        target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
        decoded = transformer.decode(words, target_mask, encoded, question_mask)
        predictions = transformer.logit(decoded[:, -1])
        _, next_word = torch.max(predictions, dim = 1)
        next_word = next_word.item()
        if next_word == word_map['<end>']:
            break
        words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1)
        
    if words.dim() == 2:
        words = words.squeeze(0)
        words = words.tolist()
        
    sen_idx = [w for w in words if w not in {word_map['<start>']}]
    sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
    
    return sentence

In [None]:
from datetime import datetime

In [None]:
for epoch in range(epochs):
    train(train_loader, transformer, criterion, epoch)

    now = datetime.now()
    dt_string = now.strftime("%d|%m|%Y %H:%M:%S")
    
    torch.save(transformer, "checkpoint" + "|time: " + dt_string + "___"  + str(epoch) +  ".pt")

Epoch [0][0/3463]	Loss: 7.499
Epoch [0][100/3463]	Loss: 7.207
Epoch [0][200/3463]	Loss: 6.457
Epoch [0][300/3463]	Loss: 5.853
Epoch [0][400/3463]	Loss: 5.535
Epoch [0][500/3463]	Loss: 5.330
Epoch [0][600/3463]	Loss: 5.181
Epoch [0][700/3463]	Loss: 5.065
Epoch [0][800/3463]	Loss: 4.974
Epoch [0][900/3463]	Loss: 4.898
Epoch [0][1000/3463]	Loss: 4.834
Epoch [0][1100/3463]	Loss: 4.779
Epoch [0][1200/3463]	Loss: 4.730
Epoch [0][1300/3463]	Loss: 4.686
Epoch [0][1400/3463]	Loss: 4.649
Epoch [0][1500/3463]	Loss: 4.616
Epoch [0][1600/3463]	Loss: 4.585
Epoch [0][1700/3463]	Loss: 4.556
Epoch [0][1800/3463]	Loss: 4.531
Epoch [0][1900/3463]	Loss: 4.506
Epoch [0][2000/3463]	Loss: 4.484
Epoch [0][2100/3463]	Loss: 4.463
Epoch [0][2200/3463]	Loss: 4.445
Epoch [0][2300/3463]	Loss: 4.428
Epoch [0][2400/3463]	Loss: 4.412
Epoch [0][2500/3463]	Loss: 4.397
Epoch [0][2600/3463]	Loss: 4.382
Epoch [0][2700/3463]	Loss: 4.369
Epoch [0][2800/3463]	Loss: 4.357
Epoch [0][2900/3463]	Loss: 4.345
Epoch [0][3000/3463]	L

KeyboardInterrupt: 

3.602

In [None]:
while(1):
    question = input("Question: ") 
    
    if question == 'quit':
        break

    max_len = 100
    enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
    question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
    question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)  
    sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
    print(sentence.__str__() + "\n")