In [1]:
from transformer_BigBirdA2C import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm

In [2]:
device = torch.device('cuda')

In [3]:
folder = 'data/food/'
data_name = folder+'data.json'
# validation_name = folder+'valid_seq.json'
# testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 100
SUMM_MAX = 20
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 16

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, SUMM_MAX, vocab[PAD], batch_size, cutoff=None, shuffle=True, num_workers=4)

# validation_set, validation_generator = make_data_generator(\
# validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src, label, tgt in training_generator:
        src = src.to(device)
        label = (label).long().to(device)
        tgt = tgt.to(device)
        b = Batch(src, tgt, vocab[PAD])
        b.label = label
        yield b

loading json
load json done.


HBox(children=(IntProgress(value=0, max=568454), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
# total_valid = int(math.ceil(validation_set.size / batch_size))
# print(total_train, total_valid)

In [7]:
def init_param(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

def make_big_bird(vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=False, bert_share=False):
    "Helper: Construct a model from hyperparameters."
    
    vocab_sz = len(vocab)
    
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    src_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    if emb_share:        
        tgt_emb = src_emb
        bert_class_emb = src_emb
        bert_discr_emb = src_emb
    else:
        tgt_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_class_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_discr_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    
    
    bert_class = BERT(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        bert_class_emb,
        vocab[PAD]
    )
    
    if bert_share:
        bert_discr = bert_class
    else:
        bert_discr = BERT(
            Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
            bert_discr_emb,
            vocab[PAD]
        )
    
    translator = Translator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        src_emb,
        tgt_emb,
        Generator(d_model, vocab_sz),
        CriticNet(d_model)
        )
    
#     classifier = Classifier(
#         bert_class,
#         out_class = 5
#         # criterion = BCE
#     )
        
    discriminator = Discriminator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        d_model,
        len(vocab),
        vocab[PAD]
    )

    reconstructor = Reconstructor(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), 
                             c(ff), dropout), N),
        src_emb,
        tgt_emb,
        Generator(d_model, vocab_sz),
        vocab[PAD]
    )
    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for m in [translator, reconstructor, discriminator]:
        init_param(m)
            
    # creation of big bird
    model = BigBird(
        translator, discriminator, reconstructor , 
        vocab, gamma=0.99, clip_value=0.1, #for WGAN, useless if WGAN-GP is used 
        lr_G = 2e-4,
        lr_D = 5e-5,
        lr_R = 2e-5,
        LAMBDA = 10, # Gradient penalty lambda hyperparameter
        RL_scale = 1,
        device = device
    )

    return model

In [8]:
model = make_big_bird(vocab, N=4, d_model=128, d_ff=256, h=4, dropout=0.1, emb_share=True, bert_share=True)
model.load("Nest/NewbornBirdA2C_GumbelSoftmax")

load Bird from Nest/NewbornBirdA2C_GumbelSoftmax


In [9]:
vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [10]:
trange = tqdm(enumerate(data_gen_train()), total=total_train)
all_acc = [];
all_reward = [];
ct = 0
for i, batch in trange:
    acc, reward = model.eval_iter(batch.src, batch.src_mask, SUMM_MAX, batch.trg, ct)
    all_acc.append(acc)
    all_reward.append(reward)
    ct += 1

HBox(children=(IntProgress(value=0, max=35529), HTML(value='')))

origin:
['[CLS]', 'i', 'really', 'liked', 'this', 'coffee', ';', 'my', 'boss', 'thought', 'it', 'a', 'little', 'strong', 'because', '<', 'br', '/', '>', 'he', 'likes', 'more', 'medium', 'smooth', '.', 'thus', 'the', 'reason', 'i', 'gave', '4', 'instead', 'of', '5', 'stars', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
summary:
['[CLS]', '!', '[CLS]', 'non', 'with', '[CLS]', '!', 'non', '!', '[CLS]', '!', 'had', 'non', 'with', 'soup', 'non', 'non', 'soup', 'sou

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(range(len(all_reward)), all_reward)
print("[eval] epoch reward avg:",sum(all_reward)/len(all_reward))

In [None]:
plt.plot(range(len(all_acc)), all_acc)
print("[eval] epoch acc avg:",sum(all_acc)/len(all_acc))