In [None]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data
import math
import torch.nn.functional as F

In [None]:
corpus_movie_conv = './cornell movie-dialogs corpus/movie_conversations.txt'
corpus_movie_lines = './cornell movie-dialogs corpus/movie_lines.txt'
max_len = 25

In [None]:
with open(corpus_movie_conv, 'r') as c:
    conv = c.readlines()
with open(corpus_movie_lines, 'r', encoding='iso-8859-1') as l:
    lines = l.readlines()

In [None]:
conv[:5]

In [None]:
lines[:5]

In [None]:
line_dict = {}
for line in lines:
    objects = line.split(' +++$+++ ')
    line_dict[objects[0]] = objects[-1]

In [None]:
list(line_dict.items())[:5]

In [None]:
import string
def remove_punc(s):
    # table[, delete chars]
    return s.translate(str.maketrans('', '', string.punctuation)).lower()

In [None]:
for k, v in line_dict.items():
    line_dict[k] = remove_punc(v)

In [None]:
list(line_dict.items())[:5]

In [None]:
pairs = []
for con in conv:
    ids = eval(con.split(' +++$+++ ')[-1])
    for i in range(len(ids) - 1):
        qa_pairs = []
        first = remove_punc(line_dict[ids[i]].strip())
        second = remove_punc(line_dict[ids[i + 1]].strip())
        qa_pairs.append(first.split()[:max_len])
        qa_pairs.append(second.split()[:max_len])
        pairs.append(qa_pairs)

In [None]:
len(pairs)

In [None]:
word_freq = Counter()
for pair in pairs:
    word_freq.update(pair[0])
    word_freq.update(pair[1])

In [None]:
min_word_freq = 5
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word2idx = {word: idx + 1 for idx, word in enumerate(words)}

In [None]:
word2idx['<unk>'] = len(word2idx) + 1
word2idx['<start>'] = len(word2idx) + 1
word2idx['<end>'] = len(word2idx) + 1
word2idx['<pad>'] = 0

In [None]:
print(f"Total no. of words are {len(word2idx)}")

In [None]:
with open('word2idx.json', 'w') as j:
    json.dump(word2idx, j)

In [None]:
def encode_question(words, word2idx):
    return [word2idx.get(word, word2idx['<unk>']) for word in words] + [word2idx['<pad>']] * (max_len - len(words))

def encode_reply(words, word2idx):
    return [word2idx['<start>']] + [word2idx.get(word, word2idx['<unk>']) for word in words] + [word2idx['<start>']] + [word2idx['<pad>']] * (max_len - len(words))

In [None]:
pairs_encoded = [
    [encode_question(pair[0], word2idx), encode_reply(pair[1], word2idx)] 
    for pair in pairs
]

In [None]:
with open('pairs_encoded.json', 'w') as j:
    json.dump(pairs_encoded, j)

In [None]:
class DialogDataset(Dataset):
    def __init__(self, pairs_encoded_path):
        super().__init__()
        self.pairs = json.load(open(pairs_encoded_path))
        self.dataset_size = len(self.pairs)
    
    def __getitem__(self, index):
        question = torch.LongTensor(self.pairs[index][0])
        reply = torch.LongTensor(self.pairs[index][1])
        return question, reply
    
    def __len__(self):
        return self.dataset_size

In [None]:
train_loader = torch.utils.data.DataLoader(
    DialogDataset('./pairs_encoded.json'),
    batch_size=100,
    shuffle=True,
    pin_memory=True
)

In [None]:
question, reply = next(iter(train_loader))

In [None]:
question.size(), reply.size()

In [None]:
# TODO: Understand this
def create_masks(question, reply_input, reply_target):
    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    question_mask = (question != 0).to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1) # batch_size, 1, 1, max_words
    
    reply_input_mask = reply_input != 0
    reply_input_mask = reply_input_mask.unsqueeze(1)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask) # batch_size, max_words, max_words
    reply_input_mask = reply_input_mask.unsqueeze(1)
    reply_target_mask = reply_target != 0
    
    return question_mask, reply_input_mask, reply_target_mask

In [None]:
torch.triu(torch.ones(5, 5)).transpose(0, 1).unsqueeze(0).size()

In [None]:
# TODO: Understand this 2
class Embeddings(nn.Module):
    def __init__(self, vocab_size, d_model, max_len = 50):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(0.1)
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pe = self.create_positional_encoding(max_len, d_model)
    
    def create_positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model).to(device)
        for pos in range(max_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** (i / d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** (i / d_model)))
        return pe.unsqueeze_(0) # (1, max_len, d_model)

    def forward(self, encoded_words):
        embeddings = self.embed(encoded_words) * math.sqrt(self.d_model) # batch_size, max_words, d_model
        embeddings += self.pe[:, embeddings.size(1)]
        embeddings = self.dropout(embeddings)
        return embeddings

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model):
        super().__init__()
        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = nn.Dropout(0.1)
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.concat = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask):
        """
        query, key, value = (batch_size, max_words, d_model->512)
        mask: (batch_size, 1, 1, max_words)
        """
        query = self.query(query) # (batch_size, max_words, d_model->512)
        key = self.key(key) # (batch_size, max_words, d_model->512)
        value = self.value(value) # (batch_size, max_words, d_model->512)
        
        # (batch_size, max_len, 512) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        
        # (batch_size, h, max_len, d_k) dot (batch_size, h, d_k, max_words)
        scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(self.d_k)
        scores = scores.masked_fill(mask == 0, -1e9) # (batch_size, h, max_len, max_len)
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)
        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)
        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k)
        context = context.permute(0,2,1,3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
        # concat
        interacted = self.concat(context)
        return interacted

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, middle_dim = 2048):
        super().__init__()
        self.fc1 = nn.Linear(d_model, middle_dim)
        self.fc2 = nn.Linear(middle_dim, d_model)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads):
        super().__init__()
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.layernorm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, embeddings, mask):
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        interacted = self.layernorm(interacted + embeddings)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, heads):
        super().__init__()
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.src_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.layernorm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, embeddings, encoded, src_mask, target_mask):
        query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask))
        query = self.layernorm(query + embeddings)
        interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
        interacted = self.layernorm(interacted + query)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        decoded = self.layernorm(feed_forward_out + interacted)
        return decoded

In [None]:
class Transformer(nn.Module):
    def __init__(self, d_model, heads, num_layers, word2idx):
        super().__init__()
        self.d_model = d_model
        self.vocab = len(word2idx)
        self.embed = Embeddings(self.vocab, d_model)
        self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)])
        self.logit = nn.Linear(d_model, self.vocab)
    
    def encode(self, src_words, src_mask):
        src_embedding = self.embed(src_words)
        for layer in self.encoder:
            src_embeddings = layer(src_embedding, src_mask)
        return src_embeddings

    def decode(self, target_words, target_mask, src_embedding, src_mask):
        target_embedding = self.embed(target_words)
        for layer in self.decoder:
            target_embedding = layer(target_embedding, src_embedding, src_mask, target_mask)
        return target_embedding

    def forward(self, src_words, src_mask, target_words, target_mask):
        encoded = self.encode(src_words, src_mask)
        decoded = self.decode(target_words, target_mask, encoded, src_mask)
        out = F.log_softmax(self.logit(decoded))
        # F.log_softmax needed for KL Divergence Loss
        # Not needed if using cross entropy loss
        return out

In [None]:
class AdamWarmup:
    def __init__(self, model_size, warmup_steps, optimizer):
        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0
    
    def get_lr(self):
        return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
    
    def step(self):
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.lr = lr
        # Update Weights
        self.optimizer.step()

In [None]:
class LossWithLabelSmoothing(nn.Module):
    def __init__(self, size, smooth):
        super().__init__()
        self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
        self.confidence = 1 - smooth
        self.smooth = smooth
        self.size = size
    
    def forward(self, prediction, target, mask):
        """
        prediction: (batch_size, max_words, vocab_size)
        target, mask: (batch_size, max_words)
        """
        prediction = prediction.view(-1, prediction.size(-1))   # (batch_size * max_words, vocab_size)
        target = target.contiguous().view(-1)   # (batch_size * max_words)
        mask = mask.float()
        mask = mask.view(-1)       # (batch_size * max_words)
        labels = prediction.data.clone()
        labels.fill_(self.smooth / (self.size - 1))
        labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
        loss = self.criterion(prediction, labels)    # (batch_size * max_words, vocab_size)
        loss = (loss.sum(1) * mask).sum() / mask.sum()
        return loss

In [None]:
d_model = 512
heads = 8
layers = 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 1

with open('word2idx.json', 'r') as j:
    word2idx = json.load(j)

transformer = Transformer(d_model=d_model, heads=heads, num_layers=layers, word2idx=word2idx)
transformer = transformer.to(device)
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size=d_model, warmup_steps=4000, optimizer=adam_optimizer)
criterion = LossWithLabelSmoothing(size=len(word2idx), smooth = 0.3)

In [None]:
def train(train_loader, transformer, criterion, epoch):
    transformer.train()
    sum_loss = 0
    count = 0
    
    for i, (question, reply) in enumerate(train_loader):
        samples = question.shape[0]
        question = question.to(device)
        reply = reply.to(device)
        
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]
        
        question_mask, reply_input_mask, reply_target_mask = create_masks(question, reply_input, reply_target)
        
        # Forward Propagation
        out = transformer(question, question_mask, reply_input, reply_input_mask)
        loss = criterion(out, reply_target, reply_target_mask)
        
        # Backward Propagation
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()
        
        sum_loss += loss.item() * samples
        
        count += samples
        
        if i % 100 == 0:
            print(f"Epoch: [{epoch}][{i}/{len(train_loader)}]\tLoss: {sum_loss / count:.3f}")

In [None]:
def evaluate(transformer, question, question_mask, max_len, word2idx):
    idx2word = {v: k for k, v in word2idx.items()}
    transformer.eval()
    start_token = word2idx['<start>']
    encoded = transformer.encode(question, question_mask)
    words = torch.LongTensor([[start_token]]).to(device) # (1, 1)
    
    for step in range(max_len - 1):
        size = words.shape[0]
        target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(torch.uint8)
        target_mask = target_mask.to(device).unsqueeze(0)
        decoded = transformer.decode(words, target_mask, encoded, question_mask)
        # decoded is of shape (1, 1, word_size)
        predictions = transformer.logit(decoded[:, -1])
        # predictions is of shape (1, vocab_size)
        _, next_word = torch.max(predictions, dim=1)
        next_word = next_word.item()
        if next_word == word2idx['<end>']:
            break
        words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1) # (1, step + 2)
    
    words = words.squeeze(0).tolist()
    sen_idx = [w for w in words if w not in {word2idx['<start>']}]
    sentence = ' '.join([idx2word[sen_idx[k]] for k in range(len(sen_idx))])
    return sentence

In [None]:
for epoch in range(epochs):
    train(train_loader, transformer, criterion, epoch)
    state = {
        'epoch': epoch, 
        'transformer': transformer, 
        'transformer_optimizer': transformer_optimizer
    }
    torch.save(state, 'checkpoint_' + str(epoch) + '.tar')

In [None]:
checkpoint = torch.load('./checkpoint_0.tar')
transformer = checkpoint['transformer']

In [None]:
while(1):
    question = input("Question: ")
    if not question:
        break
    max_len = input("Enter max words to be generated: ")
    enc_qus = [word2idx.get(word, word2idx['<unk>']) for word in question.split()]
    question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
    question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)  
    sentence = evaluate(transformer, question, question_mask, int(max_len), word2idx)
    print(sentence)