In [1]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt

In [2]:
from Architectures.Transformer import Transformer
from Dataset.squad_dataset import Dataset, DataLoader

Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? 
Answer: Saint Bernadette Soubirous


In [3]:
ds = Dataset()

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
class GPUDL():
    def __init__(self, dl, device=device):
        self.dl = dl
        self.device = device
    def __iter__(self):
        for xb, xi, yb, yi, t in self.dl:
            yield (xb.to(self.device), xi.to(self.device),
                  yb.to(self.device), yi.to(self.device),
                  t.to(self.device))
    def __len__(self):
        return len(self.dl)

In [6]:
def fit(model, lr, batch_size, epochs, ds):
    loss_fn = nn.CrossEntropyLoss(reduction='sum')
    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    dl = GPUDL(DataLoader(ds, batch_size=batch_size, shuffle=False))
    
    model.to(device)
    
    losses = []
    predicted_time = ""
    for epoch in range(epochs):
        epoch_loss = []
        epoch_start = time.time()
        for xb, x_pad_idx, yb, y_pad_idx, targets in dl:
            preds = model(xb, x_pad_idx, yb, y_pad_idx)
            loss = loss_fn(preds.permute(0,2,1), targets) / batch_size
            
            optim.zero_grad()
            loss.backward()
            optim.step()
            losses.append(loss.item())
            #print(loss.item())
            epoch_loss.append(loss.item())
            del(xb); del(x_pad_idx); del(yb); 
            del(y_pad_idx); del(targets); del(loss); del(preds)
            torch.cuda.empty_cache()
        print('Epoch', epoch + 1, 'loss', np.mean(epoch_loss), "next_epoch_drop:", time.ctime(time.time()
                                                                                             + (time.time() - epoch_start)))
        if epoch == 0:
            print('Will be done at approx: ', time.ctime(time.time() + ((epochs-1)*(time.time() - epoch_start))))
    return losses

In [7]:
transformer = Transformer(ds.vocab, ds.vocab_hashtable, device=device)
transformer.load_state_dict(torch.load('Models/50_epochs_0.0001lr'))

<All keys matched successfully>

In [None]:
#hyperparams
lr = 0.0001
batch_size = 256
epochs = 55

In [None]:
##time is 8 hours ahead the real time

In [None]:
loss = fit(transformer, lr, batch_size, epochs, ds)

In [None]:
plt.plot(loss)

In [None]:
ds.questions[20]

In [None]:
ds.answer[20]

In [14]:
transformer.to(device)
transformer.eval()
pass

In [20]:
transformer.make_inference("who hits the most home runs ?", 30)

'10 music group'

In [None]:
asdf = ds[3]
question = asdf[0]
question

In [None]:
with open('../working/Squad_1_loss.pt', 'wb') as f:
    pass
torch.save(transformer.state_dict(), '../working/Squad_1_loss.pt')

In [None]:
with open('../working/Squad_1_loss_backup', 'wb') as f:
    pass
torch.save(transformer.state_dict(), '../working/Squad_1_loss_backup')