In [95]:
import pandas as pd

train = pd.read_json('data/trainmodel.json')
validate = pd.read_json('data/val.json')

train['answers'] = train['answers'].apply(lambda x: x[0])
validate['answers'] = validate['answers'].apply(lambda x: x[0])

In [96]:
train.head(5)

Unnamed: 0,qId,answers,qText
0,wqr000001,Padmé Amidala,what character did natalie portman play in sta...
1,wqr000002,New York City,what state does selena gomez?
2,wqr000003,Bahamas,what country is the grand bahama island in?
3,wqr000005,Denethor II,what character did john noble play in lord of ...
4,wqr000006,Chicago Bulls,who does joakim noah play for?


In [97]:
questions = train['qText'].values
answers = train['answers'].values
questions_val = validate['qText'].values
answers_val = validate['answers'].values

questions_text = " ".join(list(questions))
answers_text = " ".join(list(answers))
questions_val_text = " ".join(list(questions_val))
answers_val_text = " ".join(list(answers_val))

full_text = "$" + questions_text + " " + answers_text + " " + questions_val_text + " " + answers_val_text + "@"

In [98]:
vocab_size = len(set(full_text))
encoder_map = {'$': 0, '@': 1}
decoder_map = {0: '$', 1: '@'}
unique_chars = set(full_text)

for i, c in enumerate(unique_chars.difference({'$', '@'}), start=2):
    encoder_map[c] = i
    decoder_map[i] = c

encode = lambda x: [encoder_map[c] for c in x]
decode = lambda x: ''.join([decoder_map[i] for i in x])


In [99]:
encode("$$$Test string@"), decode(encode("$$$Test string@"))

([0, 0, 0, 94, 92, 23, 47, 26, 23, 47, 50, 63, 51, 16, 1], '$$$Test string@')

In [100]:
import torch

torch.manual_seed(2115)
batch_size = 16
seq_len = 48
n_embed = 32
num_heads = 8


def pad_sequences(sequences, seq_len):
    padded_sequences = [[0] * (seq_len - len(seq)) + seq for seq in sequences]
    return padded_sequences


def truncate_sequences(sequences, max_len):
    truncated_sequences = [seq[:max_len] for seq in sequences]
    return truncated_sequences


def get_batches(questions, answers):
    idx = torch.randint(0, len(questions), (batch_size,))
    batch_questions = questions[idx].tolist()
    batch_answers = answers[idx].tolist()

    encoded_questions = [encode(q)[:-1] for q in batch_questions]
    encoded_answers = [encode(a)[:-1] for a in batch_answers]

    y_questions = [encode(q)[1:] for q in batch_questions]
    y_answers = [encode(a)[1:] for a in batch_answers]

    encoded_questions = truncate_sequences(encoded_questions, seq_len)
    encoded_answers = truncate_sequences(encoded_answers, seq_len)
    y_questions = truncate_sequences(y_questions, seq_len)
    y_answers = truncate_sequences(y_answers, seq_len)

    x = pad_sequences(encoded_questions, seq_len)
    y = pad_sequences(y_questions, seq_len)
    x_ans = pad_sequences(encoded_answers, seq_len)
    y_ans = pad_sequences(y_answers, seq_len)

    x, y, x_ans, y_ans = map(lambda seqs: torch.tensor(seqs), [x, y, x_ans, y_ans])

    return x, y, x_ans, y_ans


x, y, x_ans, y_ans = get_batches(questions, answers)

In [101]:
%reload_ext autoreload
%autoreload 2
from EncoderDecoder.layers.Decoder.Decoder import Decoder
from EncoderDecoder.layers.Encoder.Encoder import Encoder
import torch.nn as nn
import torch.nn.functional as F


def generate_mask(src, tgt):
    src_mask = (src != 0).unsqueeze(1)
    tgt_mask = (tgt != 0).unsqueeze(1)
    seq_length = tgt.size(1)
    nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
    tgt_mask = tgt_mask & nopeak_mask
    return src_mask, tgt_mask


class EncoderDecoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = Encoder(vocab_size, n_embed, num_heads, seq_len)
        self.decoder = Decoder(vocab_size, n_embed, num_heads, seq_len)

    def forward(self, prompt, response, targets=None):
        src_mask, tgt_mask = generate_mask(prompt, response)
        k, v = self.encoder(prompt, src_mask)
        x, loss = self.decoder(response, k, v, targets, tgt_mask)
        return x, loss

    @staticmethod
    def generate_mask(src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def generate(self, prompt, idx, n):
        for _ in range(n):
            idx_crop = idx[:, -seq_len:]
            logits, _ = self(idx_crop, prompt)
            logits = logits[:, -1, :]
            p = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(p, num_samples=1)
            if next_token == 1 or next_token == [1]:
                return idx
            idx = torch.cat((idx, next_token), dim=1)
        return idx

In [102]:
@torch.no_grad()
def estimate_loss(model, eval_iters=200):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y, x_ans, y_ans = get_batches(questions, answers) if split == 'train' else get_batches(questions_val,
                                                                                                      answers_val)
            logits, loss = model(x, y, x_ans)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
import os
def notify_end_of_cell(message="Cell execution completed!"):
    os.system(f'notify-send "Jupyter Cell Notification" "{message}"')


m = EncoderDecoder()
optimizer = torch.optim.Adam(m.parameters(), lr=0.0003)
history = []
eval_interval = 10
max_iter = 50
for iter in range(max_iter):
    if iter % eval_interval == 0:
        losses = estimate_loss(m)
        history.append((iter, losses))
        print(f'Iter {iter}, train loss: {losses["train"]:.3f}, val loss: {losses["val"]:.3f}')

    x, y, x_ans, y_ans = get_batches(questions, answers)
    logits, loss = m(x, y, x_ans)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
notify_end_of_cell()

Iter 0, train loss: 4.797, val loss: 4.793
Iter 10, train loss: 3.881, val loss: 3.905


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(25, 5))

x = [h[0] for h in history]
y = [h[1]['train'] for h in history]
y_val = [h[1]['val'] for h in history]
plt.plot(x, y, label='train')
plt.plot(x, y_val, label='val')

In [None]:
prompt = 'where is Perpignan located?'
prompt = torch.tensor(truncate_sequences([encode(prompt)], seq_len))
idx = torch.ones((1, 1), dtype=torch.long)
print(decode(m.generate(prompt, idx, 100)[0].tolist()))

In [None]:
torch.save(m.state_dict(), './models/parallel_checkpoints_64_2.pth')