In [1]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt

In [5]:
from Architectures.PretrainedBert import Transformer
from Dataset.BertTokensSquad_dataset import Dataset, DataLoader

Downloading (…)okenizer_config.json: 100%|██████████| 28.0/28.0 [00:00<?, ?B/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Downloading (…)lve/main/config.json: 100%|██████████| 570/570 [00:00<?, ?B/s] 
Downloading (…)solve/main/vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 2.86MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 7.06MB/s]


In [6]:
ds = Dataset()

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
class GPUDL():
    def __init__(self, dl, device=device):
        self.dl = dl
        self.device = device
    def __iter__(self):
        for xb, t1, t2 in self.dl:
            yield (xb.to(self.device), t1.to(self.device), t2.to(self.device))
    def __len__(self):
        return len(self.dl)

In [9]:
def fit(model, lr, batch_size, epochs, ds, val_ds):
    loss_fn = nn.CrossEntropyLoss(reduction='sum')
    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    dl = GPUDL(DataLoader(ds, batch_size=batch_size, shuffle=True))
    val_dl = GPUDL(DataLoader(val_ds, batch_size=len(val_ds), shuffle=False))
    model.to(device)
    
    losses = []
    val_losses = []
    predicted_time = ""
    for epoch in range(epochs):
        epoch_loss = []
        epoch_val_loss = []
        epoch_start = time.time()
        for xb, in_targs, out_targs in dl:
            preds = model(xb, in_targs)
            loss = loss_fn(preds.permute(0,2,1), out_targs) / batch_size
            
            optim.zero_grad()
            loss.backward()
            optim.step()
            losses.append(loss.item())
            epoch_loss.append(loss.item())
            del(xb); del(in_targs); 
            del(out_targs); del(loss); del(preds)
            torch.cuda.empty_cache()
        with torch.no_grad():
            for val_xb, val_in_targs, val_out_targs in dl:
                val_preds = model(val_xb, val_in_targs)
                val_loss = loss_fn(val_preds.permute(0,2,1), val_out_targs) / val_preds.shape[0]
                val_losses.append(val_loss.item())
                epoch_val_loss.append(val_loss.item())
                del(val_xb); del(val_in_targs); 
                del(val_out_targs); del(val_loss); del(val_preds)
                torch.cuda.empty_cache()
        print('Epoch', epoch + 1, 'TrainLoss', np.mean(epoch_loss), "next_epoch_drop:", time.ctime(time.time()
                                                                                             + (time.time() - epoch_start)))
        print('ValLoss', np.mean(epoch_val_loss))
        if epoch == 0:
            print('Will be done at approx: ', time.ctime(time.time() + ((epochs-1)*(time.time() - epoch_start))))
    return losses, val_losses

In [11]:
transformer = Transformer(freeze_embeddings=True)
transformer.train()
pass

In [None]:
#hyperparams
lr = 0.0001
batch_size = 128
epochs = 300

In [None]:
train, val = torch.utils.data.random_split(ds, [int(len(ds) * 0.98), int(len(ds) - int(len(ds)*0.98))])

In [None]:
loss, val_loss = fit(transformer, lr, batch_size, epochs, train, val)

In [None]:
plt.plot(val_loss, label='val')
plt.plot(loss, label='train')
plt.legend()

In [34]:
transformer.load_state_dict(torch.load('Models/1.34_loss'))
transformer.to(device)
transformer.eval()
pass

In [54]:
transformer.make_inference('What is a very hefty price for shoes ?', 30)

'$ 17. 5 billion'

In [None]:
ds.questions[53123]

In [None]:
ds.answer[53123]

In [None]:
transformer.make_inference("what is the name of the person who preaches?", 40)

In [None]:
asdf = ds[3]
question = asdf[0]
question

In [None]:
with open('../working/1.34_loss', 'wb') as f:
    pass
torch.save(transformer.state_dict(), '../working/1.34_loss')