In [1]:
from transformer_nb2 import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm
#from torchsummary import summary

In [2]:
folder = 'data/BBC/'
data_name = folder+'data.json'
# validation_name = folder+'valid_seq.json'
# testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [3]:
num_epochs = 500
save_rate = 1 #how many epochs per modelsave
#continue_from = "trained/Model1" # if none, put None
continue_from = None
epsilon = 1e-8
validation_size = 10000
device = torch.device('cuda')
!mkdir -p trained

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 400
SUMM_MAX = 20
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 16

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, SUMM_MAX, vocab[PAD], batch_size, cutoff=None, shuffle=True, num_workers=4)

# validation_set, validation_generator = make_data_generator(\
# validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src, label, tgt in training_generator:
        src = src.to(device)
        label = (label).long().to(device)
        tgt = tgt.to(device)
        b = Batch(src, tgt, vocab[PAD])
        b.label = label
        yield b

loading json
load json done.


HBox(children=(IntProgress(value=0, max=2225), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
# total_valid = int(math.ceil(validation_set.size / batch_size))
# print(total_train, total_valid)

In [7]:
def make_translator(src_vocab, tgt_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=False):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    src_emb = nn.Sequential(Embeddings(d_model, src_vocab), c(position))
    tgt_emb = src_emb if emb_share else nn.Sequential(Embeddings(d_model, tgt_vocab), c(position))
    
    model = Translator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        src_emb,
        tgt_emb,
        Generator(d_model, tgt_vocab))
    
    return model

In [None]:
def make_classifier(src_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    bert = BERT(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        vocab[PAD]
    )
    
    model = Classifier(
        bert
        # criterion = CE
    )

    return model

In [None]:
def make_discriminator(src_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    bert = BERT(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        vocab[PAD]
    )
    
    model = Discriminator(
        bert
    )
    
    return model

In [None]:
def init_param(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

def make_big_bird(vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=False, bert_share=False):
    "Helper: Construct a model from hyperparameters."
    
    vocab_sz = len(vocab)
    
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    src_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    if emb_share:        
        tgt_emb = src_emb
        bert_class_emb = src_emb
        bert_discr_emb = src_emb
    else:
        tgt_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_class_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_discr_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    
    
#     bert_class = BERT(
#         Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
#         bert_class_emb,
#         vocab[PAD]
#     )
    
#     if bert_share:
#         bert_discr = bert_class
#     else:
#         bert_discr = BERT(
#             Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
#             bert_discr_emb,
#             vocab[PAD]
#         )
        
    bert_class = LSTMEncoder(
        vocab_sz, 
        d_model,
        vocab[PAD]
    )
    
    if bert_share:
        bert_discr = bert_class
    else:
        bert_discr = LSTMEncoder(
            vocab_sz, 
            d_model,
            vocab[PAD]
        )
    
    translator = Translator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        src_emb,
        tgt_emb,
        Generator(d_model, vocab_sz))
    
    classifier = Classifier(
        bert_class
        # criterion = BCE
    )
        
    discriminator = Discriminator(
        bert_discr
    )
        
    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for m in [translator, classifier, discriminator]:
        init_param(m)
            
    # creation of big bird
    model = BigBird(
        translator, discriminator, classifier, 
        vocab, gamma=0.99, clip_value=0.1, #for WGAN, if WGAN-GP is used this is useless 
        lr_G = 0.,
        lr_D = 0.,
        lr_C = 1e-4,
        LAMBDA = 10, # Gradient penalty lambda hyperparameter
        RL_scale = 1000,
        device = device
    )

    return model


In [None]:
model = make_big_bird(vocab, N=4, d_model=256, d_ff=512, h=8, dropout=0.1, emb_share=True, bert_share=True)
#model.load("Nest/NewbornBird")

In [None]:
vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [None]:
start = 1 if continue_from == None else (int(continue_from.split("Model")[-1])+1)
history = []
for epoch in range(start, num_epochs+1):
    print("Epoch", epoch)
    
    # training
    stats = Stats()
    model.train()
    trange = tqdm(enumerate(data_gen_train()), total=total_train)
    for i, batch in trange:
        loss, score  = model.run_iter(batch.src, batch.src_mask, SUMM_MAX, batch.trg, batch.label, D_iters=1)
        trange.set_postfix(
            **{'RL_sample_loss': '{:.3f}'.format(loss[0])},
            **{'RL_argmax_loss': '{:.3f}'.format(loss[1])},
            **{'G_loss': '{:.3f}'.format(loss[2])},
            **{'D_loss': '{:.3f}'.format(loss[3])},
            **{'real_score': '{:.3f}'.format(score[0])},
            **{'fake_score': '{:.3f}'.format(score[1])},
            **{'sample_acc': '{:.3f}'.format(score[2])},
            **{'argmax_acc': '{:.3f}'.format(score[3])}
        )
        stats.update(sum(loss), 1, log=0)
        
    t_h = stats.history
    history.append(t_h)
    
    print("[info] epoch train loss:", np.mean(t_h))
    
#     try:
#         torch.save({'model':model.state_dict(), 'training_history':t_h, 'validation_loss':np.mean(v_h)}, 
#                    "trained/Model"+str(epoch))
#     except:
#         continue

Epoch 1


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.9883736379211768
Epoch 2


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.89097191062548
Epoch 3


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.7722771819547883
Epoch 4


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.8200897198469778
Epoch 5


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.47852791998801486
Epoch 6


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.5239164763824582
Epoch 7


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.43511700211425447
Epoch 8


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))

lay egg to ./Nest ... save as ./Nest/NewbornBird
origin:
['[CLS]', 'russia', 'w', '##to', 'talks', 'make', 'progress', 'talks', 'on', 'russia', 's', 'proposed', 'membership', 'of', 'the', 'world', 'trade', 'organisation', '(', 'w', '##to', ')', 'have', 'been', 'making', 'good', 'progress', 'say', 'those', 'behind', 'the', 'negotiations', '.', 'but', 'the', 'chairman', 'of', 'the', 'working', 'party', 'ambassador', 'stefan', 'johannes', '##son', 'of', 'iceland', 'warned', 'that', 'there', 'was', 'still', 'a', 'lot', 'of', 'work', 'has', 'to', 'be', 'done', '.', 'his', 'comments', 'came', 'as', 'president', 'george', 'w', 'bush', 'said', 'the', 'us', 'backed', 'russian', 'entry', '.', 'but', 'he', 'said', 'for', 'russia', 'to', 'make', 'progress', 'the', 'government', 'must', 'renew', 'a', 'commitment', 'to', 'democracy', 'and', 'the', 'rule', 'of', 'law', '.', 'his', 'comments', 'come', 'three', 'days', 'before', 'he', 'is', 'due', 'to', 'meet', 'president', 'vladimir', 'putin', '.', 'r

HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.44479747750530285
Epoch 10


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.40922129466323115
Epoch 11


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.31284679506040575
Epoch 12


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.4485895508659139
Epoch 13


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.6773063407678689
Epoch 14


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.640531074235748
Epoch 15


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))

lay egg to ./Nest ... save as ./Nest/NewbornBird
origin:
['[CLS]', 'blair', 'and', 'brown', 'criticised', 'by', 'mps', 'labour', 'mps', 'have', 'angrily', 'criticised', 'tony', 'blair', 'and', 'gordon', 'brown', 'amid', 'renewed', 'reports', 'of', 'a', 'rift', 'between', 'the', 'two', 'men', '.', 'a', 'meeting', 'of', 'the', 'parliamentary', 'labour', 'party', 'saw', 'a', 'succession', 'of', 'normally', 'loyal', 'members', 'warn', 'that', 'feud', '##ing', 'could', 'je', '##opa', '##rdi', '##se', 'labour', 's', 'election', 'hopes', '.', 'the', 'pm', 'insisted', 'nothing', 'would', 'der', '##ail', 'labour', 's', 'campaign', 'despite', 'a', 'new', 'book', 'saying', 'he', 'has', 'upset', 'his', 'chancellor', 'by', 'backing', 'out', 'of', 'a', 'pledge', 'to', 'stand', 'aside', '.', 'mr', 'brown', 'will', 'again', 'be', 'in', 'the', 'public', 'eye', 'at', 'the', 'party', 's', 'new', 'poster', 'launch', '.', 'in', 'what', 'the', 'party', 'had', 'hoped', 'would', 'be', 'perceived', 'as', 'a', 

HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.7608524268742518
Epoch 17


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.7133681554812938
Epoch 18


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.6947987538163684
Epoch 19


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.7663106510083058
Epoch 20


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.5952939310200496
Epoch 21


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.807897162962971
Epoch 22


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))

lay egg to ./Nest ... save as ./Nest/NewbornBird
origin:
['[CLS]', 'radcliffe', 'eyes', 'hard', 'line', 'on', 'drugs', 'paula', 'radcliffe', 'has', 'called', 'for', 'all', 'athletes', 'found', 'guilty', 'on', 'drugs', 'charges', 'to', 'be', 'treated', 'as', 'criminals', '.', 'the', 'marathon', 'world', 'record', 'holder', 'believes', 'more', 'needs', 'to', 'be', 'done', 'to', 'rid', 'athletics', 'of', 'the', 'suspicions', 'and', 'inn', '##uen', '##do', '##es', 'which', 'greet', 'any', 'fast', 'time', '.', 'doping', 'in', 'sport', 'is', 'a', 'criminal', 'offence', 'and', 'should', 'be', 'treated', 'as', 'such', 'the', '30', '-', 'year', '-', 'old', 'told', 'the', 'sunday', 'times', '.', 'it', 'not', 'only', 'cheat', '##s', 'other', 'athletes', 'but', 'also', 'cheat', '##s', 'promoters', 'sponsors', 'and', 'the', 'general', 'public', '.', 'radcliffe', 's', 'comments', 'come', 'at', 'a', 'time', 'when', 'several', 'american', 'sports', 'stars', 'are', 'under', 'suspicion', 'of', 'ste', '#

HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.5606743120165057
Epoch 24


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.684172047906603
Epoch 25


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.38467955098687007
Epoch 26


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.29863179895681763
Epoch 27


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.4237308389806588
Epoch 28


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.387594613103842
Epoch 29


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))

lay egg to ./Nest ... save as ./Nest/NewbornBird
origin:
['[CLS]', 'aviator', 'wins', 'top', 'globe', '##s', 'accolades', 'the', 'aviator', 'has', 'been', 'named', 'best', 'film', 'at', 'the', 'golden', 'globe', 'awards', 'with', 'its', 'star', 'leonardo', 'di', '##cap', '##rio', 'named', 'best', 'actor', '.', 'hollywood', 'veteran', 'clint', 'eastwood', 'took', 'the', 'best', 'director', 'prize', 'for', 'million', 'dollar', 'baby', 'while', 'its', 'star', 'hilary', 'swan', '##k', 'was', 'best', 'actress', '.', 'qui', '##rky', 'comedy', 'sideways', 'was', 'named', 'best', 'screenplay', 'and', 'best', 'comedy', '.', 'ray', 'star', 'jamie', 'fox', '##x', 'was', 'best', 'actor', 'in', 'a', 'musical', '/', 'comedy', 'while', 'brit', '##on', 'clive', 'owen', 'and', 'natalie', 'port', '##man', 'won', 'prizes', 'for', 'best', 'supporting', 'roles', 'in', 'closer', '.', 'the', 'aviator', 'in', 'which', 'di', '##cap', '##rio', 'plays', 'millionaire', 'howard', 'hughes', 'edged', 'ahead', 'of', 

HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.48924458612217936
Epoch 31


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.41866667142788044
Epoch 32


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.5342530919995625
Epoch 33


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.6013748055955928
Epoch 34


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.5553597445100812
Epoch 35


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.5163336638760354
Epoch 36


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))

lay egg to ./Nest ... save as ./Nest/NewbornBird
origin:
['[CLS]', 'south', 'africa', 'sweep', 'top', 'awards', 'south', 'africa', 's', 'sc', '##hal', '##k', 'burger', 'was', 'named', 'player', 'of', 'the', 'year', 'as', 'the', 'tri', '-', 'nations', 'champions', 'swept', 'the', 'top', 'honours', 'at', 'the', 'international', 'rugby', 'board', 's', 'awards', '.', 'the', 'flank', '##er', 'topped', 'a', 'list', 'which', 'included', 'ireland', 'star', 'gordon', 'd', 'arc', '##y', 'and', 'australian', 'sensation', 'matt', 'gi', '##tea', '##u', '.', 'jake', 'white', 'claimed', 'the', 'coaching', 'award', 'while', 'his', 'side', 'held', 'off', 'grand', 'slam', 'winners', 'france', 'to', 'take', 'the', 'team', 'award', '.', 'england', 'player', 'simon', 'amor', 'beat', 'team', '-', 'mate', 'ben', 'go', '##lling', '##s', 'and', 'argentine', 'luc', '##io', 'lopez', 'fleming', 'to', 'win', 'the', 'sevens', 'award', '.', 'burger', 's', 'award', 'came', 'just', 'a', 'week', 'after', 'he', 'won', '

HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.7395722578470928
Epoch 38


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.6411216468749834
Epoch 39


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.49394358060588794
Epoch 40


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.5180931232175291
Epoch 41


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.6427383688710896
Epoch 42


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.6935137434430154
Epoch 43


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))

lay egg to ./Nest ... save as ./Nest/NewbornBird
origin:
['[CLS]', 'wine', 'comedy', 'up', 'for', 'six', 'film', 'gong', '##s', 'sideways', 'a', 'wine', '-', 'tasting', 'comedy', 'starring', 'paul', 'gia', '##mat', '##ti', 'is', 'up', 'for', 'six', 'independent', 'spirit', 'awards', 'the', 'art', '-', 'house', 'version', 'of', 'the', 'oscar', '##s', '.', 'the', 'awards', 'are', 'held', 'on', '26', 'february', 'the', 'day', 'before', 'the', 'oscar', '##s', '.', 'spanish', 'drama', 'maria', 'full', 'of', 'grace', 'about', 'a', 'colombian', 'woman', 'who', 'becomes', 'a', 'drug', 'courier', 'got', 'five', 'nominations', '.', 'controversial', 'bio', '##pic', 'kin', '##sey', 'starring', 'liam', 'nee', '##son', 'as', 'sex', 'researcher', 'alfred', 'kin', '##sey', 'was', 'one', 'of', 'four', 'films', 'to', 'get', 'four', 'nominations', '.', 'the', 'awards', 'now', 'in', 'their', '20th', 'year', 'honour', 'qui', '##rky', 'low', '-', 'budget', 'films', 'all', 'of', 'which', 'must', 'have', 'a',

HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.5770496840501437
Epoch 45


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.547917638371499
Epoch 46


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))


[info] epoch train loss: 0.5120374041310113
Epoch 47


HBox(children=(IntProgress(value=0, max=140), HTML(value='')))