In [1]:
import numpy as np
import torch
import torch.nn as nn
import json
import random
from preprocessors import BOS, EOS, PAD, UNK
import math
from tqdm import tqdm_notebook as tqdm
from ELMo import LanguageModel
from dataset import *
    
# class Dataset:
#     def __init__(self, filename, vocab):
#         self.data = json.load(open(filename, "r"))['tokens']
#         self.vocab = vocab
#         self.pad_idx = vocab[PAD]
#         idx = np.argsort([len(x) for x in self.data])[::-1] # descending
        
#         self.data = [ self.data[i] for i in idx]
#         self.size = len(self.data)
    
#     def tokens_to_ids(self, s):
#         return [self.vocab.get(t, self.vocab[UNK]) for t in s]
    
#     def np_jagged(self, array):
#         MAX = max([len(i) for i in array])
#         out = [ a + [self.pad_idx]*(MAX-len(a)) if len(a) < MAX else a[:MAX] for a in array ]
#         return np.asarray(out, dtype=np.int64)
    
#     def at(self, i, batch_size):
#         fr = i*batch_size
#         to = min(fr+batch_size, self.size)
#         s = self.data[fr:to] # batch, jagged, 1
#         word_ids = self.np_jagged([self.tokens_to_ids(i)[1:] for i in s])
#         char_ids = batch_to_ids(s)[:,:-1,:]
#         return char_ids, torch.from_numpy(word_ids)
    
# class Loader(object):
#     def __init__(self, dataset, batch_size, shuffle):
#         self.dataset = dataset
#         self.batch_size = batch_size
#         self.shuffle = shuffle
        
#         # preprocess
#         total = dataset.size // batch_size
#         if total * batch_size < dataset.size:
#             total += 1
        
#         self.total = total
                    
#     def __iter__(self):
#         if self.shuffle:            
#             r = list(range(self.total))
#             random.shuffle(r)
#             self.iters = iter(r)
#         else:
#             self.iters = iter(range(self.total))
#         return self
    
#     def __next__(self):
#         return self.next()
    
#     def next(self):
#         index = next(self.iters)
#         return self.dataset.at(index, self.batch_size)

In [2]:
vocab = json.load(open("data-giga/vocab.json", "r"))
vocab_size = len(vocab)
training_set = PretrainDataset("data-giga/train_seq.json", 50, 50, vocab[PAD]) #train_seq
validation_set = PretrainDataset("data-giga/valid_seq.json", 50, 50, vocab[PAD])

loading json
load json done.


HBox(children=(IntProgress(value=0, max=3796361), HTML(value='')))


loading json
load json done.


HBox(children=(IntProgress(value=0, max=7596), HTML(value='')))




In [3]:
batch_size = 64
batch_size_inf = 64
training_generator = Loader(training_set, batch_size=batch_size, shuffle=True)
validation_generator = Loader(validation_set, batch_size=batch_size_inf, shuffle=False)
total_train = int(math.ceil(training_set.size / batch_size))
total_valid = int(math.ceil(validation_set.size / batch_size_inf))

In [4]:
device = torch.device("cuda")
model = LanguageModel(vocab).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=vocab[PAD]).to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-4)

In [5]:
def validation():
    model.eval()
    total_loss = []
    with torch.no_grad():   
        trange = tqdm(validation_generator, total=total_valid)
        for src, tgt in trange:
            src = src.to(device)
            tgt = tgt.to(device)
            
            logits = model(src)
            loss = criterion(logits.view(-1, vocab_size), tgt.view(-1))            

            total_loss.append(loss.item())
            
    return np.mean(total_loss)

In [10]:
start = 5
epochs = 4

In [11]:
if start != 1:
    smodel = torch.load("trainedELMo/Model"+str(start-1))
    model.load_state_dict(smodel['model'])

In [12]:
for e in range(start, epochs+1):
    model.train()
    print("[epoch]", e)
    loss_history = []
    trange = tqdm(training_generator, total=total_train)
    
    for src, tgt in trange:
        src = src.to(device)
        tgt = tgt.to(device)
        
        logits = model(src)
        loss = criterion(logits.view(-1, vocab_size), tgt.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        loss_history.append(loss.item())
        trange.set_postfix(**{'loss':'{:.5f}'.format(loss.item())})
        
    print("Epoch train loss:", np.mean(loss_history))
    print("Epoch valid loss:", validation())
        
    !mkdir -p trained
    torch.save({"model":model.state_dict(), "loss":loss_history}, "trained/Model"+str(e))

In [13]:
print("Epoch valid loss:", validation())

HBox(children=(IntProgress(value=0, max=119), HTML(value='')))

Epoch valid loss: 2.917541948687129


In [14]:
import torch
import numpy as np
import matplotlib.pyplot as plt

losses = []
for i in range(5):
    s = torch.load("trainedELMo/Model"+str(i+1))
    mean = np.mean(s['loss'])
    print(mean)
    losses.append(mean)
    


4.025017204176389
3.5235224604258963
3.4998621888060946
3.4962652858272785
3.4944170089836244


In [None]:
plt.plot(losses)
plt.show()