In [1]:
from transformer_nb2 import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm
#from torchsummary import summary

In [2]:
folder = 'data/IMDB/'
data_name = folder+'data.json'
# validation_name = folder+'valid_seq.json'
# testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [3]:
num_epochs = 500
save_rate = 1 #how many epochs per modelsave
#continue_from = "trained/Model1" # if none, put None
continue_from = None
epsilon = 1e-8
validation_size = 10000
device = torch.device('cuda')
!mkdir -p trained

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 400
SUMM_MAX = 20
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 1

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, SUMM_MAX, vocab[PAD], batch_size, cutoff=None, shuffle=True, num_workers=4)

# validation_set, validation_generator = make_data_generator(\
# validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src, label, tgt in training_generator:
        src = src.to(device)
        label = (label).long().to(device)
        tgt = tgt.to(device)
        b = Batch(src, tgt, vocab[PAD])
        b.label = label
        yield b

loading json
load json done.


HBox(children=(IntProgress(value=0, max=25000), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
# total_valid = int(math.ceil(validation_set.size / batch_size))
# print(total_train, total_valid)

In [7]:
def make_translator(src_vocab, tgt_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=False):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    src_emb = nn.Sequential(Embeddings(d_model, src_vocab), c(position))
    tgt_emb = src_emb if emb_share else nn.Sequential(Embeddings(d_model, tgt_vocab), c(position))
    
    model = Translator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        src_emb,
        tgt_emb,
        Generator(d_model, tgt_vocab))
    
    return model

In [8]:
def make_classifier(src_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    bert = BERT(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        vocab[PAD]
    )
    
    model = Classifier(
        bert
        # criterion = CE
    )

    return model

In [9]:
def make_discriminator(src_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    bert = BERT(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        vocab[PAD]
    )
    
    model = Discriminator(
        bert
    )
    
    return model

In [10]:
def init_param(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

def make_big_bird(vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=False, bert_share=False):
    "Helper: Construct a model from hyperparameters."
    
    vocab_sz = len(vocab)
    
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    src_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    if emb_share:        
        tgt_emb = src_emb
        bert_class_emb = src_emb
        bert_discr_emb = src_emb
    else:
        tgt_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_class_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_discr_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    
    
#     bert_class = BERT(
#         Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
#         bert_class_emb,
#         vocab[PAD]
#     )
    
#     if bert_share:
#         bert_discr = bert_class
#     else:
#         bert_discr = BERT(
#             Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
#             bert_discr_emb,
#             vocab[PAD]
#         )
        
    bert_class = LSTMEncoder(
        vocab_sz, 
        d_model,
        vocab[PAD]
    )
    
    if bert_share:
        bert_discr = bert_class
    else:
        bert_discr = LSTMEncoder(
            vocab_sz, 
            d_model,
            vocab[PAD]
        )
    
    translator = Translator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        src_emb,
        tgt_emb,
        Generator(d_model, vocab_sz))

#     translator = PointerGenerator(
#             hidden_dim=d_model, 
#             emb_dim=d_model, 
#             input_len=INPUT_MAX, 
#             output_len=SUMM_MAX, 
#             voc_size=vocab_sz, 
#             eps=1e-8
#         )
    
#     classifier = Classifier(
#         bert_class,
#         2
#         # criterion = BCE
#     )
        
    reconstructor = Reconstructor(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        src_emb,
        tgt_emb,
        Generator(d_model, vocab_sz),
        vocab[PAD]
    )
    
    discriminator = Discriminator(
        bert_discr
    )
        
    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for m in [translator, reconstructor, discriminator]:
        init_param(m)
            
    # creation of big bird
    model = BigBird(
        translator, discriminator, reconstructor, 
        vocab, gamma=0.99, clip_value=0.1, #for WGAN, if WGAN-GP is used this is useless 
        lr_G = 0.,
        lr_D = 0.,
        lr_R = 1e-4,
        LAMBDA = 10, # Gradient penalty lambda hyperparameter
        RL_scale = 1000,
        device = device
    )

    return model


In [11]:
model = make_big_bird(vocab, N=1, d_model=32, d_ff=32, h=8, dropout=0.1, emb_share=True, bert_share=True)
#model.load("Nest/NewbornBird_PG")

In [12]:
vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [13]:
start = 1 if continue_from == None else (int(continue_from.split("Model")[-1])+1)
history = []

from tensorboardX import SummaryWriter
writer = SummaryWriter('LSTMgraph')

for epoch in range(start, num_epochs+1):
    print("Epoch", epoch)
    
    # training
    stats = Stats()
    model.train()
    trange = tqdm(enumerate(data_gen_train()), total=total_train)
    for i, batch in trange:
        loss, score  = model.run_iter(batch.src, batch.src_mask, SUMM_MAX, batch.trg, batch.label, writer, D_iters=1)
        trange.set_postfix(
            **{'RL_sample_loss': '{:.3f}'.format(loss[0])},
            **{'RL_argmax_loss': '{:.3f}'.format(loss[1])},
            **{'G_loss': '{:.3f}'.format(loss[2])},
            **{'D_loss': '{:.3f}'.format(loss[3])},
            **{'real_score': '{:.3f}'.format(score[0])},
            **{'fake_score': '{:.3f}'.format(score[1])},
            **{'sample_acc': '{:.3f}'.format(score[2])},
            **{'argmax_acc': '{:.3f}'.format(score[3])}
        )
        stats.update(sum(loss), 1, log=0)
        
    t_h = stats.history
    history.append(t_h)
    
    print("[info] epoch train loss:", np.mean(t_h))
    
#     try:
#         torch.save({'model':model.state_dict(), 'training_history':t_h, 'validation_loss':np.mean(v_h)}, 
#                    "trained/Model"+str(epoch))
#     except:
#         continue

Epoch 1


HBox(children=(IntProgress(value=0, max=25000), HTML(value='')))

origin:
['[CLS]', 'this', 'is', 'the', 'one', 'movie', 'to', 'see', 'if', 'you', 'are', 'to', 'wed', 'or', 'are', 'a', 'married', 'couple', '.', 'the', 'movie', 'port', '##rai', '##s', 'a', 'couple', 'in', 'italy', 'and', 'deals', 'with', 'such', 'difficult', 'topics', 'as', 'abortion', ',', 'in', '##fide', '##lity', ',', 'jug', '##gling', 'work', 'and', 'family', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'the', 'so', 'called', '"', 'culture', 'of', 'death', '"', 'that', 'we', 'are', 'experiencing', 'nowadays', 'in', 'the', 'world', 'is', 'terrible', 'and', 'this', 'movie', 'will', 'surely', 'make', 'you', 'think', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'a', 'must', 'see', '.', 'i', 'hope', 'it', 'gets', 'distributed', 'as', 'it', 'should', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'congratulations', 'on', 'the', 'cast', 'and', 'director', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'two', 'thumbs', 'up', 'and', 'a', '10', 'star', 'evaluation', 'from', 'me', '!', '

origin:
['[CLS]', 'i', 'had', 'the', 'privilege', 'to', 'see', 'this', 'movie', 'at', 'the', 'int', '##ena', '##tion', '##al', 'film', 'festival', 'of', 'rotterdam', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', "'", 'xi', '##zh', '##ao', "'", 'or', "'", 'shower', "'", 'is', 'a', '$', '200', '.', '000', 'low', '##bu', '##dget', 'movie', 'about', 'a', 'father', 'and', 'his', '2', 'sons', '.', 'the', 'father', 'has', 'a', 'traditional', 'bath', '##house', 'somewhere', 'in', 'a', 'traditional', 'chinese', 'village', 'where', 'local', ',', 'mostly', 'aged', 'men', ',', 'come', 'to', 'relax', 'and', 'to', 'go', 'bathing', '.', 'the', 'father', 'has', 'to', 'sons', ':', 'a', "'", 're', '##tar', '##ded', "'", 'son', 'who', 'lives', 'with', 'him', 'and', 'a', 'son', 'who', 'lives', 'in', 'a', 'big', 'modern', 'city', 'and', 'who', 'comes', 'to', 'visit', 'him', '.', 'to', 'this', 'son', 'the', 'traditional', 'village', ',', 'the', 'bath', '##house', 'and', 'his', "'", 're', '##tar', '##ded', 

origin:
['[CLS]', 'an', 'excellent', 'movie', 'and', 'great', 'example', 'of', 'how', 'scary', 'a', 'movie', 'can', 'be', 'without', 'really', 'showing', 'the', 'viewer', 'anything', '.', 'it', "'", 's', 'a', 'set', 'of', 'four', 'stories', 'all', 'revolving', 'around', 'the', 'tenants', 'of', 'a', 'charming', '##ly', 'old', '-', 'fashioned', 'house', 'and', 'their', 'various', 'gr', '##ues', '##ome', 'and', 'horrific', 'fates', ',', 'all', 'tied', 'together', 'by', 'a', 'wrap', '-', 'around', 'story', 'about', 'a', 'scotland', 'yard', 'inspector', 'searching', 'for', 'a', 'missing', 'horror', 'film', 'star', '.', 'it', 'starts', 'out', 'with', 'a', 'story', 'about', 'a', 'mystery', 'writer', 'whose', 'main', 'character', 'becomes', 'a', 'little', 'too', 'realistic', ',', 'followed', 'by', 'a', 'story', 'about', 'two', 'old', 'romantic', 'rivals', 'who', 'become', 'obsessed', 'over', 'a', 'wax', 'figure', 'in', 'a', 'museum', ',', 'then', 'a', 'story', 'about', 'a', 'sweetly', 'angel',

origin:
['[CLS]', 'spoil', '##ers', '?', 'maybe', 'a', 'few', 'details', ',', 'but', 'nothing', 'too', 'plot', 'related', '.', 'not', 'like', 'it', 'would', 'matter', 'with', 'this', 'movie', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'air', 'rage', 'b', '##lat', '##antly', 'rip', '##s', 'off', 'the', 'mid', '-', 'air', 'in', '##filtration', 'premise', 'of', 'executive', 'decision', '.', 'ice', '-', 't', 'leads', 'a', 'team', 'of', 'four', '"', 'elite', '"', 'commandos', 'who', 'wear', 'bag', '##gy', 'black', 'shirts', 'that', 'we', 'can', 'only', 'imagine', 'must', 'conceal', 'invisible', 'body', 'armor', 'as', 'their', 'idiot', '##ic', 'tactics', '(', 'similar', 'to', 'what', '3rd', 'graders', 'use', 'when', 'playing', 'star', 'wars', 'on', 'the', 'playground', ')', 'lead', 'them', 'to', 'absorb', 'a', 'hail', 'of', 'gunfire', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'what', 'entertained', 'me', 'the', 'most', 'about', 'this', 'flick', 'was', 'the', 'use', 'of', 'look', '-'

origin:
['[CLS]', 'i', 'received', 'this', 'movie', 'as', 'a', 'gift', ',', 'i', 'knew', 'from', 'the', 'dvd', 'cover', ',', 'this', 'movie', 'are', 'going', 'to', 'be', 'bad', '.', 'after', 'not', 'watching', 'it', 'for', 'more', 'than', 'a', 'year', 'i', 'finally', 'watched', 'it', '.', 'what', 'a', 'pathetic', 'movie', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'i', 'almost', 'didn', "'", 't', 'finish', 'watching', 'this', 'bad', 'movie', ',', 'but', 'it', 'will', 'be', 'unfair', 'of', 'me', 'to', 'write', 'a', 'review', 'without', 'watching', 'the', 'complete', 'movie', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'trust', 'me', 'when', 'i', 'say', '"', 'this', 'movie', 'sucks', '"', 'i', 'am', 'truly', 'shocked', 'that', 'some', 'bad', 'filmmaker', 'wan', '##e', 'bee', 'got', 'even', 'financed', 'to', 'make', 'this', 'pathetic', 'movie', ',', 'but', 'it', 'couldn', "'", 't', 'have', 'cost', 'more', 'than', '$', '20', '000', 'to', 'produce', 'this', 'movie', '.', 'all', 'you'

origin:
['[CLS]', 'oh', ',', 'brother', '.', 'the', 'only', 'reason', 'this', 'very', 'irritating', 'film', 'avoids', 'getting', 'the', 'total', '"', 'bomb', '"', 'from', 'me', 'is', 'because', 'it', "'", 's', 'at', 'least', 'historically', 'noteworthy', 'as', 'the', 'first', 'three', 'st', '##oo', '##ges', 'film', '(', 'when', 'they', 'weren', "'", 't', 'yet', 'on', 'their', 'own', 'and', 'were', 'still', 'saddle', '##d', 'with', 'that', 'painfully', 'un', '##fu', '##nn', '##y', 'ted', 'healy', ')', '.', 'but', 'even', 'as', 'a', 'longtime', 'st', '##oo', '##ges', 'fan', 'i', "'", 'd', 'have', 'to', 'say', 'that', 'young', 'moe', ',', 'larry', 'and', 'curly', 'are', 'badly', 'used', 'here', 'as', 'three', 'za', '##ny', 'assistant', 'jan', '##itors', 'to', 'mr', '.', 'healy', "'", 's', 'taller', 'boss', 'jan', '##itor', '.', 'they', "'", 're', 'not', 'featured', 'steadily', 'through', 'the', 'movie', 'and', 'their', 'silly', 'on', '-', 'and', '-', 'off', '-', 'again', 'stint', '##s', '

origin:
['[CLS]', 'so', 'this', 'made', 'for', 'tv', 'film', 'scores', 'only', 'a', '7', '.', '6', 'on', 'this', 'site', '?', 'ba', '##h', '!', 'hum', '##bu', '##g', '!', 'without', 'question', 'this', '1984', 'version', 'of', 'dickens', "'", 'classic', 'tale', 'is', 'the', 'best', 'ever', 'made', '.', 'and', 'yes', ',', 'the', 'hound', 'has', 'seen', 'the', '1951', 'version', 'which', 'was', 'also', 'good', ',', 'but', 'not', 'good', 'enough', '.', 'the', 'lack', 'of', 'color', 'is', 'perhaps', 'the', 'biggest', 'short', '##coming', 'of', 'that', 'version', ',', 'although', 'the', 'acting', 'was', 'wonderful', '.', '<', 'br', '/', '>', '<', 'br', '/', '>', 'george', 'c', '.', 'scott', 'is', 'simply', 'incredible', 'as', 'e', '##ben', '##ezer', 'sc', '##ro', '##oge', '.', 'we', 'all', 'know', 'the', 'story', 'of', 'this', 'sting', '##y', 'businessman', 'who', 'is', 'haunted', 'by', 'the', 'ghost', 'of', 'his', 'dead', 'partner', ',', 'then', 'by', 'three', 'other', 'spirits', 'later', 

origin:
['[CLS]', '*', '*', 'spoil', '##ers', '*', '*', 'k', '##ham', '##osh', 'is', 'totally', 'un', '##real', '##istic', ',', 'lacks', 'a', 'plot', ',', 'and', 'was', 'basically', 'only', 'made', 'to', 'see', 'stars', 'portray', 'themselves', '.', 'the', 'most', 'suspense', '##ful', 'scene', 'in', 'the', 'movie', 'was', 'when', 'sha', '##bana', 'az', '##mi', 'is', 'in', 'the', 'shower', 'and', 'then', 'we', 'see', 'her', 'tv', 'playing', 'the', 'shower', 'scene', 'from', 'psycho', '.', 'this', 'movie', 'actually', 'expected', 'users', 'to', 'believe', 'that', 'nas', '##eer', '##uddin', 'shah', "'", 's', 'character', 'has', 'a', 'good', 'enough', 'memory', 'to', 'remember', 'where', 'certain', 'shots', 'were', 'fired', 'and', 'how', 'many', '!', '<', 'br', '/', '>', '<', 'br', '/', '>', '*', '*', '*', 'spoil', '##er', 'begins', '*', '*', '*', '<', 'br', '/', '>', '<', 'br', '/', '>', 'at', 'the', 'end', ',', 'the', 'killer', 'spill', '##s', 'his', 'guts', 'to', 'sha', '##bana', 'az', 

KeyboardInterrupt: 