In [1]:
from transformer_nb2 import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm
#from torchsummary import summary

In [2]:
folder = 'data/IMDB/'
data_name = folder+'data.json'
# validation_name = folder+'valid_seq.json'
# testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [3]:
num_epochs = 10
save_rate = 1 #how many epochs per modelsave
#continue_from = "trained/Model1" # if none, put None
continue_from = None
epsilon = 1e-8
validation_size = 10000
device = torch.device('cuda')
!mkdir -p trained

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 300
SUMM_MAX = 20
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 64

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, SUMM_MAX, vocab[PAD], batch_size, cutoff=None, shuffle=True, num_workers=4)

# validation_set, validation_generator = make_data_generator(\
# validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src, label, tgt in training_generator:
        src = src.to(device)
        label = (label).long().to(device)
        tgt = tgt.to(device)
        b = Batch(src, tgt, vocab[PAD])
        b.label = label
        yield b

loading json
load json done.


HBox(children=(IntProgress(value=0, max=25000), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
# total_valid = int(math.ceil(validation_set.size / batch_size))
# print(total_train, total_valid)

In [7]:
def init_param(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

# translator = PointerGenerator(
#             hidden_dim=256, 
#             emb_dim=256, 
#             input_len=INPUT_MAX, 
#             output_len=SUMM_MAX, 
#             voc_size=VOC_SIZE, 
#             eps=1e-8
#         ).to(device)
translator = make_model(VOC_SIZE, VOC_SIZE, N=4, d_model=256, d_ff=512, h=8, dropout=0.1, emb_share=True).to(device)

criterion = LabelSmoothing(size=VOC_SIZE, padding_idx=vocab[PAD], smoothing=0.1).to(device)
model_opt = torch.optim.Adam(translator.parameters(), lr=1e-4, betas=(0.9, 0.998), eps=1e-8)
loss_compute = SimpleLossCompute(translator.generator, criterion, model_opt)



In [8]:
# import torch.nn.functional as F
# def loss_compute(x, y, norm): 
#     x = F.log_softmax(x)
    
#     print(x.shape)
    
#     loss = criterion(x.contiguous().view(-1, x.size(-1)), 
#                           y.contiguous().view(-1)) / norm
#     loss.backward()
#     if model_opt is not None:
#         model_opt.step()
#         model_opt.zero_grad()
#     return loss.item() * norm

In [9]:
# model = make_big_bird(vocab, N=4, d_model=256, d_ff=512, h=8, dropout=0.1, emb_share=True, bert_share=True)
#model.load("Nest/NewbornBird")

In [10]:
vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [11]:
start = 1 if continue_from == None else (int(continue_from.split("Model")[-1])+1)
history = []


for epoch in range(start, num_epochs+1):
    print("Epoch", epoch)
    
    # training
    stats = Stats()
    translator.train()
    trange = tqdm(enumerate(data_gen_train()), total=total_train)
    
    for i, batch in trange:
        out = translator.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        stats.update(loss, batch.ntokens, log=1)
        
        if( i % 150 == 0):
            probs = translator.generator(out) 
            print("\n")
            next_words = torch.argmax(probs, dim=-1, keepdim=True)            
            print(convert_ids_to_tokens([i.item() for i in next_words[0]]))
        
        trange.set_postfix(
            **{'loss': '{:.3f}'.format(loss)}
        )
        stats.update(loss, 1, log=0)
        
    t_h = stats.history
    history.append(t_h)
    
    print("[info] epoch train loss:", np.mean(t_h))
    
    try:
        torch.save({'model':translator.state_dict(), 'training_history':t_h}, 
                   "pretrained/Translator"+str(epoch))
    except:
        continue

Epoch 1


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

Step: 1 Loss: 8.938125 Tokens / Sec: 1952.920921

['examines', 'suppressed', 'vu', 'vu', 'vu', 'directive', '##tower', 'vu', 'examines', 'vu', 'maple', 'vu', 'directive', 'examines', 'directive', 'examines', 'examines', '##face', 'manipulating']
Step: 301 Loss: 6.060853 Tokens / Sec: 4604.086642

['the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the']
Step: 601 Loss: 5.844194 Tokens / Sec: 4644.664581

['i', 'i', 'the', 'the', 'the', 'the', 'the', 'the', 'the', 'the', ',', '.', ',', '.', '.', '.', '.', 'the', 'the']
Step: 781 Loss: 5.839874 Tokens / Sec: 4381.967525
[info] epoch train loss: 3820.1237792560205
Epoch 2


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

Step: 1 Loss: 5.628731 Tokens / Sec: 3313.561526

['i', 'movie', 'the', 'movie', ',', 'the', 'movie', ',', ',', ',', 'the', ',', 'the', ',', ',', ',', ',', 's', ',']
Step: 301 Loss: 5.134790 Tokens / Sec: 4594.625895

['i', 'movie', 'of', 'a', 'movie', 'movie', 'movie', 'the', 'movie', 'movie', ',', 'the', '.', 'movie', ',', '.', "'", 's', "'"]
Step: 601 Loss: 4.951511 Tokens / Sec: 4595.292388

['i', "'", 'to', 'movie', '.', 'the', '.', 'i', 'have', 'to', 'i', "'", ',', '.', 'the', '.', 'i', "'", "'"]
Step: 781 Loss: 4.974915 Tokens / Sec: 4039.710652
[info] epoch train loss: 3143.7551068290113
Epoch 3


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

Step: 1 Loss: 4.746830 Tokens / Sec: 2922.777856

['i', 'is', 'a', 'movie', '.', 'i', "'", 'have', 'to', 'this', 'this', 'movie', '.', 'a', '.', '.', 'i', 'br', '/']
Step: 301 Loss: 4.487877 Tokens / Sec: 4590.064990

['<', 'to', 'the', 'the', 'movie', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', '<', '/', '<', '/', '/']
Step: 601 Loss: 4.224488 Tokens / Sec: 4570.760233

['"', 'an', "'", 'be', 'the', 'the', 'best', 'of', 'the', 'movie', '.', '.', "'", 's', 'the', '.', '.', '.', 'the']
Step: 781 Loss: 3.971101 Tokens / Sec: 4264.003628
[info] epoch train loss: 2672.324607294234
Epoch 4


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

Step: 1 Loss: 3.925173 Tokens / Sec: 2956.594687

['"', '"', ',', ',', '"', ',', 'the', 'story', ',', 'is', 'a', 'great', 'be', 'be', '-', '-', ',', 'i', "'"]
Step: 301 Loss: 3.717526 Tokens / Sec: 4614.709333

['this', 'film', 'is', 'with', 'with', 'great', 'and', 'and', 'all', '##s', 'to', 'the', 'years', 'it', 'i', 'was', 'it', ',', ',']
Step: 601 Loss: 3.216745 Tokens / Sec: 4552.454100

['"', 'the', 'best', '##o', 'film', '"', 'is', 'not', '.', 'not', 'not', 'in', 'the', 'best', 'films', '"', '"', '"', '"']
Step: 781 Loss: 2.825870 Tokens / Sec: 3887.508006
[info] epoch train loss: 2099.927055443339
Epoch 5


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

Step: 1 Loss: 3.240874 Tokens / Sec: 3211.919021

['this', 'is', 'what', 'not', 'like', 'a', '"', 'two', '-', '##s', '##s', 'man', '-', '"', 'movie', '.', 'it', 'has', 'no']
Step: 301 Loss: 2.871978 Tokens / Sec: 4722.497478

['so', 'it', 'has', 'been', 'to', 'this', '.', 'many', ',', 'made', 'made', 'that', 'have', 'the', 'the', 'old', '?', 'can', 'ever']
Step: 601 Loss: 2.452463 Tokens / Sec: 4599.092192

['if', 'i', 'don', "'", 't', 'been', 'going', 'to', 'watch', 'this', 'for', 'time', 'movies', 'i', 'had', 'have', 'have', 'made', 'it']
Step: 781 Loss: 2.088028 Tokens / Sec: 4556.022188
[info] epoch train loss: 1629.7161681688656
Epoch 6


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

Step: 1 Loss: 2.510001 Tokens / Sec: 3307.201573

['this', 'is', 'a', 'a', 's', "'", '.', 'it', 'it', 'di', '##l', '.', 'the', 'acting', 'is', 'by', 'and', 'its', "'"]
Step: 301 Loss: 1.975900 Tokens / Sec: 4678.853966

['un', '##re', '##re', '##es', '##l', '##l', 'of', 'a', 'story', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'it']
Step: 601 Loss: 2.047327 Tokens / Sec: 4648.101684

['actually', ',', 'i', 'never', 'found', 'into', 'the', 'original', 'was', 'un', 'and', 'and', 'characters', ',', 'but', 'this', 'movie', 'kind', 'of']
Step: 781 Loss: 1.896146 Tokens / Sec: 4046.818755
[info] epoch train loss: 1279.4530738756785
Epoch 7


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

Step: 1 Loss: 1.992176 Tokens / Sec: 3065.788694

['i', 'actually', 'wanted', 'to', 'see', 'this', 'movie', 'in', 'the', 'music', '.', 'it', 'was', 'actually', 'up', 'out', '.', 'i', 'actually']
Step: 301 Loss: 1.610950 Tokens / Sec: 4697.976537

['about', 'years', 'years', 'ago', ',', 'i', 'liked', 'this', 'movie', '.', 'i', 'would', 'watch', 'it', 'over', 'over', 'since', 'and', 'over']
Step: 601 Loss: 1.680174 Tokens / Sec: 4636.460683

['for', 'those', 'who', 'never', 'saw', 'a', 'bit', 'line', 'of', 'and', 'their', 'only', 'wanted', 'to', 'the', 'story', 'was', 'this', 'film']
Step: 781 Loss: 1.437979 Tokens / Sec: 4460.027759
[info] epoch train loss: 1022.1669812295443
Epoch 8


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

Step: 1 Loss: 1.364286 Tokens / Sec: 3296.411411

['ok', ',', 'i', 'have', 'to', 'admit', 'that', 'i', 'have', 'never', 'seen', '"', 'r', 'don', '##en', '"', 'and', 'only', 'one']
Step: 301 Loss: 1.324998 Tokens / Sec: 4695.546827

['this', 'movie', 'is', 'finally', 'out', 'on', 'dvd', 'in', 'course', '(', 'completely', ')', ')', '.', 'i', 'have', 'seen', 'this', 'movie']
Step: 601 Loss: 1.200646 Tokens / Sec: 4736.219422

['this', 'is', 'an', 'ok', 'film', 'but', 'takes', 'any', 'real', 'heart', 'either', 'mr', 'or', 'in', 'each', 'of', 'story', 'showing', '.']
Step: 781 Loss: 1.130839 Tokens / Sec: 4541.314182
[info] epoch train loss: 828.6442471520065
Epoch 9


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

Step: 1 Loss: 1.318777 Tokens / Sec: 3318.846368

['although', '##nic', '##ic', 'film', 'with', 'a', 'spoil', '##re', '##y', 'slow', 'start', '-', 'give', 'it', 'a', 'chance', 'to', 'start', '5']
Step: 301 Loss: 1.199452 Tokens / Sec: 4632.283342

['right', 'at', 'this', 'moment', 'i', 'am', 'watching', 'this', 'movie', 'for', 'the', 'second', 'time', '(', 'on', 'television', ')', 'and', 'for']
Step: 601 Loss: 0.987085 Tokens / Sec: 4730.372709

['during', 'the', 'opening', 'night', 'of', 'the', 'van', '##gal', 'a', 'woman', 'is', 'found', 'dead', 'on', 'the', 'cat', 'against', 'above', 'the']
Step: 781 Loss: 1.139729 Tokens / Sec: 4474.125210
[info] epoch train loss: 680.3291429493129
Epoch 10


HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

Step: 1 Loss: 0.993821 Tokens / Sec: 3246.894394

['i', "'", 've', 'had', 'a', 'thing', 'for', 'this', 'several', 'gem', 'for', 'a', 'while', ',', 'and', 'as', 'far', 'as', 'how']
Step: 301 Loss: 0.967576 Tokens / Sec: 4672.694116

['in', 'atlantis', ',', 'carr', '##g', 'once', 'again', 'tries', 'to', 'take', 'a', 'whole', '##ta', '##ble', 'between', 'reality', 'and', 'fiction', ',']
Step: 601 Loss: 0.850285 Tokens / Sec: 4386.773696

['before', 'i', 'begin', ',', 'i', 'want', 'to', 'give', 'say', 'that', 'this', 'movie', 'in', 'and', 'of', 'itself', 'is', 'very', 'well']
Step: 781 Loss: 1.000378 Tokens / Sec: 4120.631447
[info] epoch train loss: 564.4049405319154
