In [1]:
from transformer_nb2 import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm
#from torchsummary import summary

In [2]:
folder = 'data/IMDB/'
data_name = folder+'data.json'
# validation_name = folder+'valid_seq.json'
# testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [3]:
num_epochs = 20
save_rate = 1 #how many epochs per modelsave
#continue_from = "trained/Model1" # if none, put None
continue_from = None
epsilon = 1e-8
validation_size = 10000
device = torch.device('cuda')
!mkdir -p trained

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 300
SUMM_MAX = 20
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 512

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, SUMM_MAX, vocab[PAD], batch_size, cutoff=None, shuffle=True, num_workers=4)

# validation_set, validation_generator = make_data_generator(\
# validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src, label, tgt in training_generator:
        src = src.to(device)
        label = (label).long().to(device)
        tgt = tgt.to(device)
        b = Batch(src, tgt, vocab[PAD])
        b.label = label
        yield b

loading json
load json done.


HBox(children=(IntProgress(value=0, max=25000), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
# total_valid = int(math.ceil(validation_set.size / batch_size))
# print(total_train, total_valid)

In [7]:
def init_param(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

def make_classifier(vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    vocab_sz = len(vocab)
    classifier = Classifier(
        BERT(
            Encoder(EncoderLayer(
                d_model, 
                MultiHeadedAttention(h, d_model), 
                PositionwiseFeedForward(d_model, d_ff, dropout), 
                dropout
            ), N),
            nn.Sequential(Embeddings(d_model, vocab_sz), PositionalEncoding(d_model, dropout)),
            vocab[PAD]
        ),
        2
        # criterion = BCE
    )
    
    init_param(classifier)
    
    return classifier

classifier = make_classifier(vocab, N=4, d_model=256, d_ff=512, h=8, dropout=0.1).to(device)

model_opt = torch.optim.Adam(classifier.parameters(), lr=1e-4, betas=(0.9, 0.998), eps=1e-8)


In [8]:
start = 1 if continue_from == None else (int(continue_from.split("Model")[-1])+1)
history = []


for epoch in range(start, num_epochs+1):
    print("Epoch", epoch)
    
    # training
    stats = Stats()
    classifier.train()
    trange = tqdm(enumerate(data_gen_train()), total=total_train)
    
    for i, batch in trange:
        loss, acc, pred = classifier.forward(batch.trg_y, batch.label, vocab[EOS])
        
        loss = loss.mean()
        
        model_opt.zero_grad()
        loss.backward()
        model_opt.step()
        
        stats.update(loss.item(), batch.ntokens, log=1)
                
        trange.set_postfix(
            **{'loss': '{:.3f}'.format(loss.item())},
            **{'accu': '{:.3f}'.format(acc)}
        )
        stats.update(loss.item(), 1, log=0)
        
    t_h = stats.history
    history.append(t_h)
    
    print("[info] epoch train loss:", np.mean(t_h))
    
    try:
        torch.save({'model':classifier.state_dict(), 'training_history':t_h}, 
                   "pretrained/Classifier"+str(epoch))
    except:
        continue

Epoch 1


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000088 Tokens / Sec: 122784.949450
[info] epoch train loss: 0.40046300363933957
Epoch 2


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000088 Tokens / Sec: 122562.708201
[info] epoch train loss: 0.34901340235480655
Epoch 3


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000076 Tokens / Sec: 123501.208078
[info] epoch train loss: 0.3292401435023354
Epoch 4


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000068 Tokens / Sec: 126338.806303
[info] epoch train loss: 0.28285868070063735
Epoch 5


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000063 Tokens / Sec: 122645.008345
[info] epoch train loss: 0.2406986687794884
Epoch 6


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000046 Tokens / Sec: 122834.048967
[info] epoch train loss: 0.2022142922649659
Epoch 7


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000045 Tokens / Sec: 123506.173694
[info] epoch train loss: 0.17285134122779433
Epoch 8


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000039 Tokens / Sec: 125221.200295
[info] epoch train loss: 0.143329030459577
Epoch 9


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000028 Tokens / Sec: 122620.084206
[info] epoch train loss: 0.11600209580445354
Epoch 10


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000026 Tokens / Sec: 125239.301344
[info] epoch train loss: 0.09504024710004841
Epoch 11


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000021 Tokens / Sec: 120798.210420
[info] epoch train loss: 0.07943662759309862
Epoch 12


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000017 Tokens / Sec: 122326.654010
[info] epoch train loss: 0.06859733350032993
Epoch 13


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000022 Tokens / Sec: 122042.999538
[info] epoch train loss: 0.05824756159978175
Epoch 14


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000017 Tokens / Sec: 123288.063634
[info] epoch train loss: 0.05250542892326564
Epoch 15


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000009 Tokens / Sec: 124622.650227
[info] epoch train loss: 0.044035506813450906
Epoch 16


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000012 Tokens / Sec: 122942.653578
[info] epoch train loss: 0.039933913414766084
Epoch 17


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000007 Tokens / Sec: 127165.653752
[info] epoch train loss: 0.03709358504667414
Epoch 18


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000012 Tokens / Sec: 120248.805224
[info] epoch train loss: 0.03151663463147315
Epoch 19


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000008 Tokens / Sec: 122732.322677
[info] epoch train loss: 0.027465347978438994
Epoch 20


HBox(children=(IntProgress(value=0, max=49), HTML(value='')))

Step: 97 Loss: 0.000008 Tokens / Sec: 123198.160281
[info] epoch train loss: 0.025284427823248655
