In [1]:
from LSTM_BigBirdA2C_GumbelSoftmax import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm
#from torchsummary import summary

In [2]:
folder = '/tmp2/Food/'
data_name = folder+'data.json'
# validation_name = folder+'valid_seq.json'
# testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [3]:
num_epochs = 10
save_rate = 1 #how many epochs per modelsave
#continue_from = "trained/Model1" # if none, put None
continue_from = None
epsilon = 1e-8
validation_size = 10000
device = torch.device('cuda')
!mkdir -p trained

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 100
SUMM_MAX = 20
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 32

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, SUMM_MAX, vocab[PAD], batch_size, cutoff=None, shuffle=True, num_workers=4)

# validation_set, validation_generator = make_data_generator(\
# validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src, label, tgt in training_generator:
        src = src.to(device)
        label = (label).long().to(device)
        tgt = tgt.to(device)
        b = Batch(src, tgt, vocab[PAD])
        b.label = label
        yield b

loading json
load json done.


HBox(children=(IntProgress(value=0, max=568454), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
# total_valid = int(math.ceil(validation_set.size / batch_size))
# print(total_train, total_valid)

In [7]:
def init_param(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

def make_big_bird(vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=False, bert_share=False):
    "Helper: Construct a model from hyperparameters."
    
    vocab_sz = len(vocab)
    
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    src_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    if emb_share:        
        tgt_emb = src_emb
        bert_class_emb = src_emb
        bert_discr_emb = src_emb
    else:
        tgt_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_class_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_discr_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    
    
    bert_class = BERT(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        bert_class_emb,
        vocab[PAD]
    )
    
    if bert_share:
        bert_discr = bert_class
    else:
        bert_discr = BERT(
            Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
            bert_discr_emb,
            vocab[PAD]
        )
        
    translator = LSTM_Gumbel_Encoder_Decoder(
        hidden_dim=d_model, 
        emb_dim=d_model, 
        input_len=INPUT_MAX, 
        output_len=SUMM_MAX, 
        voc_size=vocab_sz, 
        critic_net=CriticNet(2*d_model),
        device=device,
        eps=1e-8
    )
    
    reconstructor = LSTM_Normal_Encoder_Decoder(
        hidden_dim=d_model, 
        emb_dim=d_model, 
        input_len=SUMM_MAX, 
        output_len=INPUT_MAX, 
        voc_size=vocab_sz, 
        pad_index=vocab[PAD],
        eps=1e-8
    )
#     translator = Translator(
#         Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
#         Decoder(DecoderLayer(d_model, c(attn), c(attn), 
#                              c(ff), dropout), N),
#         src_emb,
#         tgt_emb,
#         Generator(d_model, vocab_sz, device),
#         CriticNet(d_model)
#         )
    
#     classifier = Classifier(
#         bert_class,
#         out_class = 5
#         # criterion = BCE
#     )
#     reconstructor = Reconstructor(
#         Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
#         Decoder(DecoderLayer(d_model, c(attn), c(attn), 
#                              c(ff), dropout), N),
#         src_emb,
#         tgt_emb,
#         Generator(d_model, vocab_sz, device),
#         vocab[PAD]
#     )   
    discriminator = Discriminator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        d_model,
        len(vocab),
        vocab[PAD]
    )


    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for m in [translator, reconstructor, discriminator]:
        init_param(m)
        
#     if(str(device) == 'cpu'):
#         savedmodel = torch.load("pretrained/Translator4", map_location=lambda storage, location: storage)
#     else:
#         savedmodel = torch.load("pretrained/Translator4")
#     translator.load_state_dict(savedmodel['model'], strict=False)
#     if(str(device) == 'cpu'):
#         savedmodel = torch.load("pretrained/Translator4", map_location=lambda storage, location: storage)
#     else:
#         savedmodel = torch.load("pretrained/Translator4")
#     reconstructor.load_state_dict(savedmodel['model'])#, strict=False)
            
    # creation of big bird
    model = BigBird(
        translator, discriminator, reconstructor , 
        vocab, gamma=0.99, clip_value=0.1, #for WGAN, useless if WGAN-GP is used 
        lr_G = 5e-5,
        lr_D = 1e-4,
        lr_R = 2e-5,
        LAMBDA = 10, # Gradient penalty lambda hyperparameter
        RL_scale = 1,
        device = device
    )

    return model


In [8]:
model = make_big_bird(vocab, N=4, d_model=128, d_ff=256, h=4, dropout=0.1, emb_share=True, bert_share=True)
#model.load("Nest/NewbornBirdA2C_GumbelSoftmax")

In [9]:
vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [None]:
start = 1 if continue_from == None else (int(continue_from.split("Model")[-1])+1)
history = []


#from tensorboardX import SummaryWriter
#writer = SummaryWriter('mygraph')
writer = None
all_loss = []
all_reward = []

for epoch in range(start, num_epochs+1):
    print("Epoch", epoch)
    
    # training
    stats = Stats()
    model.train()
    
    rewards = []
    
    trange = tqdm(enumerate(data_gen_train()), total=total_train)
    for i, batch in trange:
        loss, score  = model.run_iter(batch.src, batch.src_mask, SUMM_MAX, batch.trg, writer, D_iters=5, verbose = 1)
        trange.set_postfix(
            **{'RL_loss': '{:.3f}'.format(loss[0])},
            **{'G_loss': '{:.3f}'.format(loss[1])},
            **{'D_loss': '{:.3f}'.format(loss[2])},
            **{'real_score': '{:.3f}'.format(score[0])},
            **{'fake_score': '{:.3f}'.format(score[1])},
            **{'acc': '{:.3f}'.format(score[2])},
            **{'reward':'{:.3f}'.format(score[3])}
        )
        stats.update(sum(loss), 1, log=0)
        rewards.append(score[3])
        
    t_h = stats.history
    history.append(t_h)
    writer.add_scalar('reward', np.mean(t_h), epoch)
    print("[info] epoch train loss:", np.mean(t_h))
    print("[info] epoch train reward:", sum(rewards)/len(rewards))
    all_loss.append(np.mean(t_h))
    all_reward.append(sum(rewards)/len(rewards))
writer.close()  
#     try:
#         torch.save({'model':model.state_dict(), 'training_history':t_h, 'validation_loss':np.mean(v_h)}, 
#                    "trained/Model"+str(epoch))
#     except:
#         continue

Epoch 1


HBox(children=(IntProgress(value=0, max=17765), HTML(value='')))

origin:
['[CLS]', 'but', 'my', 'daughter', 'is', 'and', 'she', 'said', "'", 'well', 'it', "'", 's', 'not', 'great', ',', 'but', 'it', "'", 's', 'coffee', 'in', 'a', 'can', 'so', 'what', 'can', 'you', 'expect', "'", '.', 'i', 'took', 'a', 'sip', 'and', 'it', "'", 's', 'really', 'a', 'me', '##h', '!', 'product', '.', '.', '.', 'not', 'as', 'bad', 'as', 'it', 'could', 'be', 'but', 'also', 'not', 'as', 'good', 'as', 'it', 'could', 'be', 'either', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
summary:
['[CLS]', 'dealings', '##ং', '##washed', 'debates', 'cougars', 'which', 'vast', 'guillermo', 'empathy', 'ـ', 'personality', 'executive', 'chen', '##erated', '##weather', '1900', 'cyril', 'ministry', 'fragrance']
real summary:
['[CLS]

origin:
['[CLS]', 'these', 'chips', 'are', 'a', 'great', 'addition', 'to', 'lots', 'of', 'recipes', '!', 'i', 'add', 'them', 'to', 'pancakes', ',', 'banana', 'bread', ',', 'pumpkin', 'mu', '##ffin', '##s', 'or', 'bread', ',', 'my', 'morning', 'o', '##at', '##me', '##al', ',', 'o', '##at', '##me', '##al', 'cookie', 'dough', 'etc', '.', '.', '.', 'i', 'also', 'roll', 'a', 'tbs', '##p', '.', 'of', 'them', 'up', 'in', 'crescent', 'rolls', 'and', 'make', 'a', 'delicious', 'breakfast', 'treat', '!', 'very', 'versatile', 'and', 'flavor', '##ful', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
summary:
['[CLS]', '区', 'curtiss', '##redo', 'chinese', 'inland', 'lisbon', 'returns', '##tan', 'swing', 'unaware', 'stephane', '##ulu', '##fles', 'attract', 'nostrils', '##ahl', 'winner',

origin:
['[CLS]', 'reality', 'strikes', '.', 'i', 'have', 'been', 'drinking', 'this', 'tea', 'every', 'night', 'for', '10', 'years', '!', 'it', 'is', 'the', 'best', 'thing', 'to', 'keep', 'me', 'on', 'the', 'straight', 'n', 'narrow', '!', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
summary:
['[CLS]', 'challenges', '##year', 'can', '##g', 'percy', 'regards', 'after', '##ener', 'confederate', 'dairy', '##ص', 'dep

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(range(len(all_reward)), all_reward)
plt.plot(range(len(all_loss)), all_loss)
plt.show()

In [None]:
#print(model.all_rewards)
#plt.plot(range(len(model.all_rewards)), model.all_rewards)

In [None]:
#plt.plot(range(len(model.all_rewards)-1), [sum(model.all_rewards[:i])/i for i in range(1,len(model.all_rewards))])