In [1]:
from transformer_BigBirdA2C import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm
#from torchsummary import summary

In [2]:
folder = 'data/IMDB/'
data_name = folder+'data.json'
# validation_name = folder+'valid_seq.json'
# testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [3]:
num_epochs = 1000
save_rate = 1 #how many epochs per modelsave
#continue_from = "trained/Model1" # if none, put None
continue_from = None
epsilon = 1e-8
validation_size = 10000
device = torch.device('cuda')
!mkdir -p trained

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 100
SUMM_MAX = 20
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 16

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, SUMM_MAX, vocab[PAD], batch_size, cutoff=None, shuffle=True, num_workers=4)

# validation_set, validation_generator = make_data_generator(\
# validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src, label, tgt in training_generator:
        src = src.to(device)
        label = (label).long().to(device)
        tgt = tgt.to(device)
        b = Batch(src, tgt, vocab[PAD])
        b.label = label
        yield b

loading json
load json done.


HBox(children=(IntProgress(value=0, max=25000), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
# total_valid = int(math.ceil(validation_set.size / batch_size))
# print(total_train, total_valid)

In [7]:
def init_param(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

def make_big_bird(vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=False, bert_share=False):
    "Helper: Construct a model from hyperparameters."
    
    vocab_sz = len(vocab)
    
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    src_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    if emb_share:        
        tgt_emb = src_emb
        bert_class_emb = src_emb
        bert_discr_emb = src_emb
    else:
        tgt_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_class_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_discr_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    
    
    bert_class = BERT(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        bert_class_emb,
        vocab[PAD]
    )
    
    if bert_share:
        bert_discr = bert_class
    else:
        bert_discr = BERT(
            Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
            bert_discr_emb,
            vocab[PAD]
        )
    
    translator = Translator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        src_emb,
        tgt_emb,
        Generator(d_model, vocab_sz),
        CriticNet(d_model)
        )
    
    classifier = Classifier(
        bert_class,
        2
        # criterion = BCE
    )
        
    discriminator = Discriminator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        d_model,
        len(vocab),
        vocab[PAD]
    )

    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for m in [translator, classifier, discriminator]:
        init_param(m)
            
    # creation of big bird
    model = BigBird(
        translator, discriminator, classifier, 
        vocab, gamma=0.99, clip_value=0.1, #for WGAN, if WGAN-GP is used this is useless 
        lr_G = 2e-4,
        lr_D = 5e-5,
        lr_C = 2e-4,
        LAMBDA = 10, # Gradient penalty lambda hyperparameter
        RL_scale = 1,
        device = device
    )

    return model


In [8]:
model = make_big_bird(vocab, N=4, d_model=256, d_ff=512, h=8, dropout=0.1, emb_share=True, bert_share=True)
#model.load("Nest/NewbornBirdA2C")

In [9]:
vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [10]:
start = 1 if continue_from == None else (int(continue_from.split("Model")[-1])+1)
history = []


from tensorboardX import SummaryWriter
writer = SummaryWriter('mygraph')

for epoch in range(start, num_epochs+1):
    print("Epoch", epoch)
    
    # training
    stats = Stats()
    model.train()
    
    rewards = []
    
    trange = tqdm(enumerate(data_gen_train()), total=total_train)
    for i, batch in trange:
        loss, score  = model.run_iter(batch.src, batch.src_mask, SUMM_MAX, batch.trg, batch.label, writer, D_iters=1, verbose = 1)
        trange.set_postfix(
            **{'RL_loss': '{:.3f}'.format(loss[0])},
            **{'G_loss': '{:.3f}'.format(loss[1])},
            **{'D_loss': '{:.3f}'.format(loss[2])},
            **{'real_score': '{:.3f}'.format(score[0])},
            **{'fake_score': '{:.3f}'.format(score[1])},
            **{'acc': '{:.3f}'.format(score[2])},
            **{'reward':'{:.3f}'.format(score[3])}
        )
        stats.update(sum(loss), 1, log=0)
        rewards.append(score[3])
        
    t_h = stats.history
    history.append(t_h)
    writer.add_scalar('reward', np.mean(t_h), epoch)
    print("[info] epoch train loss:", np.mean(t_h))
    print("[info] epoch train reward:", sum(rewards)/len(rewards))
writer.close()  
#     try:
#         torch.save({'model':model.state_dict(), 'training_history':t_h, 'validation_loss':np.mean(v_h)}, 
#                    "trained/Model"+str(epoch))
#     except:
#         continue

Epoch 1


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

origin:
['[CLS]', 'of', 'all', 'the', 'movies', 'of', 'the', 'seventies', ',', 'none', 'captured', 'to', 'true', '##st', 'essence', 'of', 'the', 'good', 'versus', 'evil', 'battle', 'as', 'did', 'the', 'sentinel', '.', 'i', 'mean', ',', 'yes', ',', 'there', 'were', 'movies', 'like', 'the', 'ex', '##or', '##cis', '##t', ',', 'and', 'other', 'ones', ';', 'but', 'none', 'of', 'them', 'captured', 'the', 'human', 'element', 'of', 'the', 'protagonist', 'like', 'this', 'one', '.', 'if', 'you', 'have', 'time', ',', 'check', 'this', 'one', 'out', '.', 'you', 'may', 'not', 'be', 'able', 'to', 'get', 'past', 'the', 'dated', 'devices', 'as', 'such', ',', 'but', 'this', 'is', 'a', 'story', 'worth', 'getting', 'into', '.', 'then', 'there', 'are', 'all', 'the', 'stars', 'and']
summary:
['[CLS]', 'consumer', 'pol', 'orientation', 'detailing', 'maneuver', '##quel', '##lk', 'amendments', 'trigger', 'predatory', 'junk', '##kot', 'plaintiff', 'reopened', '[unused34]', 'gen', 'fights', 'amusing', 'person']


origin:
['[CLS]', 'i', 'really', 'think', 'i', 'should', 'make', 'my', 'case', 'and', 'have', 'every', '(', 'horror', 'and', 'or', 'cult', ')', 'movie', '-', 'buff', 'go', 'and', 'see', 'this', 'movie', '.', '.', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'i', 'did', '!', '<', 'br', '/', '>', '<', 'br', '/', '>', 'it', '-', 'is', '-', 'excellent', ':', 'very', 'atmospheric', 'and', 'un', '##sett', '##ling', 'and', 'scary', '.', '.', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'inc', '##rid', '##ible', 'how', 'they', 'could', 'make', 'such', 'a', 'gem', 'of', 'a', 'film', 'with', 'the', 'very', 'low', '(', 'read', '-', '"', 'no', '"', '!', ')', '-', 'budget']
summary:
['[CLS]', 'penang', 'plucked', 'life', 'grand', 'radha', 'qualification', 'b', 'serious', '##onale', '##bham', 'zurich', '##pd', 'unanimous', 'swedish', 'generations', 'talent', 'therapeutic', 'five', 'plague']
real summary:
['[CLS]', 'i', 'really', 'think', 'i', 'should', 'make', 'my', 'case', 'and', 'have', 'every


[info] epoch train loss: -3.3625205616412774
[info] epoch train reward: -0.7024112158834515
Epoch 2


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

origin:
['[CLS]', 'i', 'usually', 'steer', 'clear', 'of', 'tv', 'movies', 'because', 'of', 'the', 'many', 'ways', 'you', 'know', 'that', 'it', "'", 's', 'tv', 'movies', 'five', 'seconds', 'into', 'the', 'picture', '.', 'this', 'one', 'got', 'my', 'attention', 'because', 'of', 'the', 'unusual', 'title', 'and', 'its', 'gloom', '##y', ',', 'well', '-', 'crafted', 'mood', 'that', 'is', 'established', 'from', 'the', 'very', 'start', '.', 'while', 'the', 'ever', 'present', 'rain', 'confirmed', 'my', 'suspicions', 'of', 'a', 'mis', '##placed', 'story', '(', 'even', 'if', 'claiming', 'to', 'be', 'set', 'in', 'california', 'the', 'movie', 'was', 'largely', 'shot', 'around', 'a', 'stormy', 'vancouver', ',', 'b', '.', 'c', '.', ')', ',', 'the', 'dark', 'and', 'oppressive', 'outdoors', 'beautifully', 'complement', 'ol']
summary:
['[CLS]', 'shivered', 'luca', 'ب', 'king', 'starred', 'decorate', 'expanse', 'dragged', 'twilight', '##kali', 'lara', 'graveyard', '##stadt', 'ac', 'blowing', 'elements', 

origin:
['[CLS]', 'calling', 'this', 'a', 'romantic', 'comedy', 'is', 'accurate', 'but', 'nowadays', 'misleading', '.', 'the', 'genre', 'has', 'sadly', 'deteriorated', 'into', 'cl', '##iche', '##s', ',', 'too', 'focused', 'on', 'making', 'the', 'main', 'couple', 'get', 'together', 'and', 'with', 'very', 'little', 'room', 'for', 'am', '##bie', '##nce', 'and', 'other', 'stories', ',', 'making', 'it', 'formula', '##ic', 'and', 'overly', 'predictable', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'the', 'shop', 'around', 'the', 'corner', 'does', 'not', 'suffer', 'from', 'these', 'illnesses', ':', 'it', 'manages', 'to', 'create', 'a', 'rec', '##og', '##nis', '##ably', 'middle', '/', 'eastern', '-', 'european', 'atmosphere', 'and', 'has', 'a', 'strong', 'cast', 'besides', 'the', '(', 'also', 'strong', ')', 'nominal', 'leads']
summary:
['[CLS]', 'rough', 'attention', 'word', 'part', 'who', 'today', 'z', 'mixed', 'series', 'each', 'specialized', 'any', 'main', 'happened', 'london', 'adults',

lay egg to ./Nest ... save as ./Nest/NewbornBirdA2C
origin:
['[CLS]', 'i', 'can', 'not', 'quite', 'understand', 'why', 'any', 'of', 'the', '"', 'reviewers', '"', 'gave', 'this', 'documentary', '"', '0', '"', 'other', 'than', 'for', 'political', 'reasons', '.', 'no', ',', 'the', 'film', 'did', 'not', 'investigate', 'both', '"', 'sides', '"', 'of', 'the', 'story', ',', 'but', 'then', 'surely', 'one', 'film', 'in', 'favour', 'of', 'chavez', 'against', 'the', 'tides', 'of', 'propaganda', 'against', 'him', 'should', 'be', 'seen', 'as', 'an', 'attempt', 'to', 'balance', 'out', 'the', 'narrative', 'overall', '(', 'especially', 'given', 'a', '.', 'the', 'history', 'of', 'cia', 'involvement', 'in', 'latin', 'america', 'in', 'fe', '##rm', '##enting', 'civil', 'unrest', '-', 'google', 'national', 'security', 'archive', 'and', 'b', '.', 'the', 'coverage', 'in', 'that', 'country']
summary:
['[CLS]', 'see', 'andrews', 'fictional', 'psychiatrist', 'ob', 're', 'austin', 'crude', 'should', 'along', 'ba

HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

origin:
['[CLS]', 'no', '##ll', "'", 's', 'comfortable', 'way', 'of', 'rolling', 'out', 'blunt', 'comments', ',', 'often', 'with', 'ex', '##ple', '##tive', '##s', ',', 'to', 'describe', 'things', 'that', 'he', 'is', 'more', 'knowledge', '##able', 'about', 'than', 'most', 'is', 'quite', 'refreshing', '.', 'there', 'is', 'one', 'other', 'character', 'in', 'the', 'film', 'that', 'constantly', 'tries', 'to', 'verbal', '##ize', 'complicated', 'issues', ',', 'using', 'more', 'language', 'than', 'necessary', '.', 'this', 'guy', 'should', 'never', 'have', 'been', 'given', 'a', 'the', '##saurus', '.', 'cut', 'to', 'no', '##ll', 'and', 'you', 'know', 'you', "'", 're', 'in', 'for', 'a', 'treat', '!', '<', 'br', '/', '>', '<', 'br', '/', '>', 'the', 'way', 'the', 'pioneers', 'of', 'big', 'wave']
summary:
['[CLS]', 'long', 'preceded', 'loved', 'time', 'trip', '##mp', 'audience', '##ie', 'opinions', '##y', '##y', 'posted', 'don', 'they', 'wore', 'be', 'but', 'certainly', '##eth']
real summary:
['[CL

lay egg to ./Nest ... save as ./Nest/NewbornBirdA2C
origin:
['[CLS]', 'this', 'is', 'a', 'better', 'adaptation', 'of', 'the', 'book', 'than', 'the', 'one', 'with', 'pal', '##tro', '##w', '(', 'although', 'i', 'liked', 'that', 'one', ',', 'too', ')', '.', 'it', 'isn', "'", 't', 'so', 'much', 'that', 'beck', '##ins', '##ale', 'is', 'better', '-', '-', 'they', 'are', 'both', 'very', 'good', '-', '-', 'but', 'that', 'the', 'screenplay', 'is', 'better', '.', 'davies', 'is', 'a', 'master', 'at', 'adapting', 'austen', 'for', 'filming', ',', 'and', 'the', 'production', 'values', 'here', 'are', 'very', 'good', '.', 'it', "'", 's', 'not', 'quite', 'as', 'glossy', 'as', 'the', 'hollywood', 'treatment', ',', 'but', 'it', "'", 's', 'close', ',', 'and', 'i', 'thought', 'that', 'the', 'locations', 'and', 'the', 'costumes']
summary:
['[CLS]', '##lessness', 'each', 'fund', 'hand', 'entertaining', 'anyone', 'put', 'latter', 'something', 'shin', 'sci', '##cturing', 'finds', 'weird', 'even', 'medical', 'o


[info] epoch train loss: -5.164819041368705
[info] epoch train reward: -0.6937446069122505
Epoch 4


HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

origin:
['[CLS]', 'you', 'spend', 'most', 'of', 'this', 'two', '-', 'hour', 'film', 'wondering', '"', 'what', "'", 's', 'the', 'story', 'regarding', 'the', 'lead', 'character', '?', '"', '<', 'br', '/', '>', '<', 'br', '/', '>', 'will', 'smith', ',', 'as', 'a', 'low', '-', 'key', '"', 'ben', 'thomas', '"', 'will', 'keep', 'you', 'guessing', '.', 'the', 'last', '20', '-', '25', 'minutes', 'is', 'when', 'you', 'find', 'out', ',', 'and', 'it', "'", 's', 'a', 'shock', '##er', '.', '.', '.', '.', 'but', 'you', 'knew', 'something', 'dramatic', 'was', 'going', 'to', 'be', 'revealed', '.', 'until', 'then', ',', 'smith', ',', 'plays', 'it', 'mysterious', ',', 'almost', 'stalking', 'people', '.', 'you', 'know', 'he', 'has', 'a']
summary:
['[CLS]', 'haven', 'challenge', 'limbs', '##re', 'hands', 'melting', 'office', '##erin', 'pathetic', 'focal', 'simulation', 'guns', '##re', 'million', 'dancing', 'painter', 'lies', 'weed', 'friday']
real summary:
['[CLS]', 'you', 'spend', 'most', 'of', 'this', '

origin:
['[CLS]', '"', 'sky', 'captain', '"', 'may', 'be', 'considered', 'an', 'homage', 'to', 'comic', 'books', ',', 'pulp', 'adventures', 'and', 'movie', 'serials', 'but', 'it', 'contains', 'little', 'of', 'the', 'magic', 'of', 'some', 'of', 'the', 'best', 'from', 'those', 'genres', '.', 'one', 'contributor', 'says', 'that', 'enjoyment', 'of', 'the', 'film', 'depends', 'on', 'whether', 'or', 'not', 'one', 'recognizes', 'the', 'films', 'influences', '.', 'i', 'don', "'", 't', 'think', 'this', 'is', 'at', 'all', 'true', '.', 'one', "'", 's', 'expectations', 'of', 'the', 'films', ',', 'fiction', 'and', 'serials', 'that', '"', 'captain', '"', 'pays', 'tribute', 'to', 'were', 'entirely', 'different', '.', 'especially', 'so', 'for', 'those', 'who', 'experienced', 'those', 'entertainment', '##s', 'when', 'they', 'were', 'children']
summary:
['[CLS]', 'broadcasters', '##．', 'something', 'mata', '##eri', '##up', '[unused429]', 'few', '##cula', '[unused728]', 'ore', 'tom', 'garfield', '##icon'

origin:
['[CLS]', 'i', 'can', "'", 't', 'believe', 'this', 'film', 'was', 'allowed', 'to', 'be', 'made', '.', 'these', 'people', 'should', 'be', 'drug', 'out', 'and', 'beat', 'with', 'blunt', 'objects', '.', 'they', 'should', 'be', 'tortured', '.', 'this', 'film', 'is', 'an', 'ab', '##omi', '##nation', '.', 'it', "'", 's', 'nothing', 'but', 'footage', 'from', 'the', 'first', 'film', '.', 'whatever', 'is', 'original', 'is', 'freak', '##y', 'and', 'makes', 'no', 'sense', 'whatsoever', '.', 'it', "'", 's', 'like', 'some', 'sort', 'of', 'drug', 'hall', '##uc', '##ination', '.', 'like', ',', 'what', "'", 's', 'with', 'the', 'laying', 'on', 'a', 'mirror', 'naked', 'therapy', '.', 'also', ',', 'whatever', 'mor', '##on', 'patch', '##ed', 'together', 'this', 'tu', '##rd', 'didn']
summary:
['[CLS]', 'salmon', '##psy', 'junctions', 'sidewalk', '##₅', '##ault', 'original', 'symbolism', 'demons', '32nd', 'exchange', 'balkan', '[unused177]', 'jan', 'donor', 'cy', '##zal', 'kn', '[unused614]']
real s

HBox(children=(IntProgress(value=0, max=1563), HTML(value='')))

origin:
['[CLS]', 'i', 'sought', 'this', 'film', 'out', 'because', 'i', "'", 'm', 'a', 'new', 'fra', '##in', 'fan', 'and', 'wanted', 'to', 'see', 'more', 'of', 'his', 'work', '.', 'first', 'of', 'all', ',', 'his', 'irish', 'accent', 'is', 'great', '.', 'he', "'", 's', 'got', 'a', 'keen', 'ear', 'for', 'dialects', ',', 'it', 'seems', '.', 'his', 'acting', 'was', 'marvelous', ',', 'as', 'usual', '.', 'james', 'fra', '##in', 'aside', ',', 'i', 'thought', 'the', 'film', 'was', 'very', 'well', 'done', '.', 'it', 'showed', 'the', 'conflict', 'in', 'northern', 'ireland', 'as', 'the', '*', 'mess', '*', 'it', 'really', 'is', '.', 'both', 'sides', 'are', 'guilty', 'of', 'grave', 'injustice', '##s', ',', 'and', 'the', 'men', 'drawn', 'into', 'the']
summary:
['[CLS]', '##bari', '##script', 'deposited', 'angrily', 'scowl', 'bench', 'statesman', '##lk', 'bar', 'proceeded', 'belong', 'subdivisions', 'greensboro', '##turn', 'prayed', 'declined', '2a', '##dl', 'macbeth']
real summary:
['[CLS]', 'i', 's

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

In [None]:
#print(model.all_rewards)
#plt.plot(range(len(model.all_rewards)), model.all_rewards)

In [None]:
#plt.plot(range(len(model.all_rewards)-1), [sum(model.all_rewards[:i])/i for i in range(1,len(model.all_rewards))])