In [1]:
from transformer_nb2 import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm
#from torchsummary import summary

In [2]:
folder = 'data/'
data_name = folder+'data.json'
# validation_name = folder+'valid_seq.json'
# testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [3]:
num_epochs = 10
save_rate = 1 #how many epochs per modelsave
#continue_from = "trained/Model1" # if none, put None
continue_from = None
epsilon = 1e-8
validation_size = 10000
device = torch.device('cuda')
!mkdir -p trained

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 100
SUMM_MAX = 20
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 90

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, SUMM_MAX, vocab[PAD], batch_size, cutoff=None, shuffle=True, num_workers=4)

# validation_set, validation_generator = make_data_generator(\
# validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src, label, tgt in training_generator:
        src = src.to(device)
        #CrossEntropy 1~5 -> 0~4
        label = (label - 1).long().to(device)
        tgt = tgt.to(device)
        b = Batch(src, tgt, vocab[PAD])
        b.label = label
        yield b

loading json
load json done.


HBox(children=(IntProgress(value=0, max=568454), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
# total_valid = int(math.ceil(validation_set.size / batch_size))
# print(total_train, total_valid)

In [7]:
def make_translator(src_vocab, tgt_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=False):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    src_emb = nn.Sequential(Embeddings(d_model, src_vocab), c(position))
    tgt_emb = src_emb if emb_share else nn.Sequential(Embeddings(d_model, tgt_vocab), c(position))
    
    model = Translator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        src_emb,
        tgt_emb,
        Generator(d_model, tgt_vocab))
    
    return model

In [8]:
def make_classifier(src_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    bert = BERT(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        vocab[PAD]
    )
    
    model = Classifier(
        bert
        # criterion = CE
    )

    return model

In [9]:
def make_discriminator(src_vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    bert = BERT(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        vocab[PAD]
    )
    
    model = Discriminator(
        bert
    )
    
    return model

In [10]:
def init_param(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

def make_big_bird(vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=False, bert_share=False):
    "Helper: Construct a model from hyperparameters."
    
    vocab_sz = len(vocab)
    
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    src_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    if emb_share:        
        tgt_emb = src_emb
        bert_class_emb = src_emb
        bert_discr_emb = src_emb
    else:
        tgt_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_class_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_discr_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    
    
    bert_class = BERT(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        bert_class_emb,
        vocab[PAD]
    )
    
    if bert_share:
        bert_discr = bert_class
    else:
        bert_discr = BERT(
            Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
            bert_discr_emb,
            vocab[PAD]
        )
    
    translator = Translator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        src_emb,
        tgt_emb,
        Generator(d_model, vocab_sz))
    
    classifier = Classifier(
        bert_class
        # criterion = BCE
    )
        
    discriminator = Discriminator(
        bert_discr
    )
        
    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for m in [translator, classifier, discriminator]:
        init_param(m)
            
    # creation of big bird
    model = BigBird(
        translator, discriminator, classifier, 
        vocab, gamma=0.99, clip_value=0.1, #for WGAN, if WGAN-GP is used this is useless 
        lr_G = 5e-5,
        lr_D = 5e-5,
        lr_C = 1e-4,
        LAMBDA = 1, # Gradient penalty lambda hyperparameter
        RL_scale = 100
    )

    return model


In [11]:
model = make_big_bird(vocab, N=4, d_model=64, d_ff=512, h=8, dropout=0.1, emb_share=True, bert_share=True)
model.load("Nest/NewbornBird")

load Bird from Nest/NewbornBird


In [12]:
vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [13]:
start = 1 if continue_from == None else (int(continue_from.split("Model")[-1])+1)
history = []
for epoch in range(start, num_epochs+1):
    print("Epoch", epoch)
    
    # training
    stats = Stats()
    model.train()
    trange = tqdm(enumerate(data_gen_train()), total=total_train)
    for i, batch in trange:
        
        loss, score  = model.run_iter(batch.src, batch.src_mask, SUMM_MAX, batch.trg, batch.label, D_iters=1)
        trange.set_postfix(
            **{'RL_sample_loss': '{:.3f}'.format(loss[0])},
            **{'RL_argmax_loss': '{:.3f}'.format(loss[1])},
            **{'G_loss': '{:.3f}'.format(loss[2])},
            **{'D_loss': '{:.3f}'.format(loss[3])},
            **{'real_score': '{:.3f}'.format(score[0])},
            **{'fake_score': '{:.3f}'.format(score[1])},
            **{'sample_acc': '{:.3f}'.format(score[2])},
            **{'argmax_acc': '{:.3f}'.format(score[3])}
        )
        stats.update(sum(loss), 1, log=0)
        
    t_h = stats.history
    history.append(t_h)
    
    print("[info] epoch train loss:", np.mean(t_h))
    
#     try:
#         torch.save({'model':model.state_dict(), 'training_history':t_h, 'validation_loss':np.mean(v_h)}, 
#                    "trained/Model"+str(epoch))
#     except:
#         continue

Epoch 1


HBox(children=(IntProgress(value=0, max=6317), HTML(value='')))

origin:
['[CLS]', 'granddaughter', 'in', 'military', 'and', 'extremely', 'fond', 'of', 'purple', 'ski', '##ttle', '##s', '.', 'she', 'was', 'extremely', 'happy', 'with', 'gift', '.', 'would', 'purchase', 'this', 'again', 'as', 'a', 'christmas', 'gift', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
summary:
['[CLS]', 'causes', 'causes', 'causes', 'causes', 'causes', 'donation', '##arty', 'causes', 'c

origin:
['[CLS]', 'this', 'one', 'will', 'not', 'di', '##sa', '##pp', '##oint', '!', 'i', 'got', 'this', 'yesterday', 'and', 'my', 'husband', 'and', 'i', 'tried', 'it', 'this', 'morning', 'with', 'breakfast', 'and', 'we', 'both', 'loved', 'it', '.', 'i', 'have', 'gotten', 'many', 'different', 'kinds', 'of', 'the', 'k', '-', 'cups', 'to', 'try', 'but', 'after', 'trying', 'this', 'one', ',', 'i', 'think', 'i', 'could', 'be', 'happy', 'with', 'just', 'this', 'one', 'and', 'black', 'tiger', 'for', 'those', 'mornings', 'when', 'i', 'want', 'a', 'really', 'strong', 'cup', 'of', 'coffee', '.', 'the', 'rodeo', 'drive', 'blend', 'is', 'delicious', 'with', 'a', 'nice', 'round', 'feel', 'on', 'the', 'pal', '##ate', 'and', 'an', 'interesting', 'flavor', 'on', 'the', 'finish', 'with', 'just']
summary:
['[CLS]', 'constructing', 'briefcase', 'healthcare', '##ucible', 'shuddering', 'tombstone', '##lau', '##lau', '##ucible', 'oceania', 'indie', '##cht', '##lau', '##ucible', 'madame', '##ucible', 'healt

origin:
['[CLS]', 'interesting', 'concept', '!', 'very', 'ta', '##sty', 'too', ',', 'i', 'used', 'the', 'chocolate', 'p', '##b', '##2', 'in', 'my', 'vanilla', 'ice', 'cream', 'and', 'it', 'was', 'a', 'lean', 'alternative', 'to', 'flavor', '##ed', 'syrup', '##s', '.', 'i', 'purchased', 'these', 'products', 'to', 'cut', 'down', 'on', 'the', 'fat', 'from', 'regular', 'peanut', 'butter', '.', 'i', 'would', 'order', 'it', 'again', 'although', 'it', 'is', 'a', 'bit', 'price', '##y', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
summary:
['[CLS]', '↑', '↑', '##cht', 'oceania', 'oceania', 'pt', 'pt', 'pt', 'oceania', 'oceania', 'pt', 'goodwill', '↑', 'oceania', 'oceania', '##lova', 'pt', '

origin:
['[CLS]', 'the', 'wonderful', 'flavor', 'of', 'ja', '##la', '##pen', '##o', "'", 's', 'in', 'this', 'zip', 'locked', 'bag', 'fills', 'the', 'air', 'once', 'it', 'is', 'opened', '.', 'i', 'originally', 'purchased', 'this', 'package', 'of', 'jack', 'link', 'for', 'snack', '##ing', 'but', 'found', 'myself', 'mu', '##nch', '##ing', 'until', 'the', 'entire', 'bag', 'is', 'empty', '.', 'i', 'first', 'across', 'them', 'on', 'store', 'shelf', '##s', 'in', 'toledo', ',', 'ohio', 'and', 'have', 'not', 'been', 'able', 'to', 'find', 'them', 'in', 'any', 'stores', 'in', 'the', 'fort', 'worth', ',', 'texas', 'area', '(', 'until', 'my', 'only', 'other', 'source', 'amazon', '.', 'com', ')', '.', 'i', 'love', 'the', 'flavor', 'of', 'the', 'ja', '##la', '##pen', '##o', 'and']
summary:
['[CLS]', 'marcelo', 'each', 'bears', 'accidents', 'virtues', '[unused709]', 'patio', '[unused801]', '##ɑ', '##acio', 'marshall', 'training', '##ɑ', 'nintendo', '##vac', '[unused801]', 'welcoming', 'robbery', '##va

origin:
['[CLS]', 'i', 'add', 'this', 'to', 'my', 'morning', 'cereal', 'with', 'fresh', 'fruit', 'for', 'the', 'omega', '3', '.', 'it', 'does', 'not', 'really', 'effect', 'the', 'flavor', 'one', 'way', 'or', 'another', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
summary:
['[CLS]', '##un', 'liquid', 'bobby', 'freddy', 'australian', 'announcements', 'sunlight', 'bit', 'usc', 'replacement', 

origin:
['[CLS]', 'very', 'good', 'product', 'to', 'have', 'on', 'hand', 'if', 'you', 'need', 'a', 'quick', 'on', '-', 'the', '-', 'go', 'snack', '.', 'good', 'tasting', 'and', 'easy', 'to', 'eat', '.', 'i', 'keep', 'a', 'couple', 'in', 'my', 'car', 'to', 'snack', 'on', 'if', 'i', 'can', "'", 't', 'stop', 'for', 'a', 'meal', 'right', 'away', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
summary:
['[CLS]', '##ᵉ', 'dive', 'programmed', 'automobiles', '##¾', '##η', '##η', '##¾', 'fascist', 'phyllis', '##ones', '##η', '##¾', 'automobiles', '##¾', '##¾', '##ones', 'coffin', '##gil

origin:
['[CLS]', 'this', 'is', 'one', 'of', 'the', 'best', 'chew', '##ies', 'i', 'have', 'ever', 'given', 'my', 'dogs', '.', 'they', 'love', 'them', 'and', 'are', 'busy', 'for', 'days', 'to', 'weeks', 'because', 'of', 'the', 'knot', 'in', 'the', 'center', '.', 'i', 'have', 'never', 'had', 'one', 'go', 'to', 'waste', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
summary:
['[CLS]', 'terrible', 'bean', 'botswana', 'stationed', '[unused611]', 'happened', '##ヒ', 'explained', 'isabel', '##ished', 'flickered', 'inferno', 'infern

KeyboardInterrupt: 