In [1]:
from RelGAN import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm

In [2]:
device = torch.device('cuda')

In [3]:
folder = 'data/IMDB/'
data_name = folder+'data.json'
# validation_name = folder+'valid_seq.json'
# testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 150
SUMM_MAX = 50
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 1

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, SUMM_MAX, vocab[PAD], batch_size, cutoff=None, shuffle=True, num_workers=4)

# validation_set, validation_generator = make_data_generator(\
# validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src, label, tgt in training_generator:
        src = src.to(device)
        #label = (label).long().to(device)
        tgt = tgt.to(device)
        b = Batch(src, tgt, vocab[PAD])
        #b.label = label
        yield b

loading json
load json done.


HBox(children=(IntProgress(value=0, max=25000), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
# total_valid = int(math.ceil(validation_set.size / batch_size))
# print(total_train, total_valid)

In [7]:
def init_param(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

def make_big_bird(vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=False, bert_share=False):
    "Helper: Construct a model from hyperparameters."
    
    vocab_sz = len(vocab)
        
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    src_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    if emb_share:        
        tgt_emb = src_emb
        bert_class_emb = src_emb
        bert_discr_emb = src_emb
    else:
        tgt_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_class_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_discr_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        
    translator = RelationalMemory(
        mem_slots = 1,
        head_size = 192,
        input_size = d_model,
        num_tokens = vocab_sz,
        device = device,
        num_heads = 4,
        attention_mlp_layers=3,
        key_size = 64,
        use_adaptive_softmax=True,
        cutoffs = [1000, 5000, 20000]
    )

    reconstructor = RelationalMemory(
        mem_slots = 1,
        head_size = 192,
        input_size = d_model,
        num_tokens = vocab_sz,
        device = device,
        num_heads = 4,
        attention_mlp_layers=3,
        key_size = 64,
        use_adaptive_softmax=True,
        cutoffs = [1000, 5000, 20000]
    )
#     reconstructor = LSTM_Normal_Encoder_Decoder(
#         hidden_dim=d_model, 
#         emb_dim=d_model, 
#         input_len=SUMM_MAX, 
#         output_len=INPUT_MAX, 
#         voc_size=vocab_sz, 
#         pad_index=vocab[PAD],
#         device = device,
#         eps=1e-8,
#         num_layers = 2
#     )
 
    discriminator = Discriminator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        d_model,
        len(vocab),
        vocab[PAD]
    )


    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for m in [translator, reconstructor, discriminator]:
        init_param(m)
            
    # creation of big bird
    model = BigBird(
        translator, discriminator, reconstructor , 
        vocab, gamma=0.99, clip_value=0.5, #for WGAN, useless if WGAN-GP is used 
        lr_G = 5e-5,
        lr_D = 1e-4,
        lr_R = 2e-4,
        LAMBDA = 10, # Gradient penalty lambda hyperparameter
        TEMP_END = 0.8,
        device = device
    )

    return model


In [8]:
model = make_big_bird(vocab, N=4, d_model=256, d_ff=256, h=4, dropout=0.1, emb_share=True, bert_share=True)
model.load("Nest/DoubleRelationMEM_GAN")

load Bird from Nest/DoubleRelationMEM_GAN


In [None]:
vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [None]:
trange = tqdm(enumerate(data_gen_train()), total=total_train)
all_acc = [];
all_CE_loss = [];
ct = 0
for i, batch in trange:
    acc, loss = model.eval_iter(batch.src, batch.src_mask, SUMM_MAX, batch.trg, ct)
    all_acc.append(acc)
    all_CE_loss.append(loss)
    ct += 1

HBox(children=(IntProgress(value=0, max=25000), HTML(value='')))

origin:
[CLS] an absolute classic ! ! the direction is flawless , the acting is just superb . words fall short for this great work . the most definitive movie on mumbai police . this movie has stood the test of times . om puri gives a stellar performance , smita patil no less . all the actors have done their best and the movie races on thrilling you at every moment . this movie shakes your whole being badly and forces you to rethink about many issues that confront our society . this is the story of a cop ( om puri ) who starts out in his career as a honest man but ultimately degenerates into a killer . the first attempt in bollywood to get behind the scenes and expose the depressing truth about mumbai cops . kudos to nihalani ! ! after this movie a slew of
summary:
[CLS] brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally brutally br

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(range(len(all_CE_loss)), all_CE_loss)
print("[eval] epoch reward avg:",sum(all_CE_loss)/len(all_CE_loss))

In [None]:
plt.plot(range(len(all_acc)), all_acc)
print("[eval] epoch acc avg:",sum(all_acc)/len(all_acc))