In [1]:
from RelGAN import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm
import os
#from torchsummary import summary

In [2]:
folder = '/tmp2/Food/'
data_name = folder+'data.json'
# validation_name = folder+'valid_seq.json'
# testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [3]:
num_epochs = 10
save_rate = 1 #how many epochs per modelsave
#continue_from = "trained/Model1" # if none, put None
continue_from = None
epsilon = 1e-8
validation_size = 10000
device = torch.device('cuda')
!mkdir -p trained
os.environ['WANDB_NOTEBOOK_NAME'] = 'LSTM_GumbelSoftmax'

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 50
SUMM_MAX = 50
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 2

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, SUMM_MAX, vocab[PAD], batch_size, pretrain = True, cutoff=None, shuffle=True, num_workers=4)

# validation_set, validation_generator = make_data_generator(\
# validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src, tgt in training_generator:
        src = src.to(device)
        #label = (label).long().to(device)
        tgt = tgt.to(device)
        b = Batch(src, tgt, vocab[PAD])
        #b.label = label
        yield b

loading json
load json done.


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
# total_valid = int(math.ceil(validation_set.size / batch_size))
# print(total_train, total_valid)

In [7]:
def init_param(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

def make_big_bird(vocab, N=6, 
               d_model=512, d_ff=2048, h=8, dropout=0.1, emb_share=False, bert_share=False):
    "Helper: Construct a model from hyperparameters."
    
    vocab_sz = len(vocab)
        
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    
    src_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
    if emb_share:        
        tgt_emb = src_emb
        bert_class_emb = src_emb
        bert_discr_emb = src_emb
    else:
        tgt_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_class_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        bert_discr_emb = nn.Sequential(Embeddings(d_model, vocab_sz), c(position))
        
    translator = RelationalMemory(
        mem_slots = 1,
        head_size = 192,
        input_size = d_model,
        num_tokens = vocab_sz,
        device = device,
        num_heads = 4,
        attention_mlp_layers=3,
        key_size = 64,
        use_adaptive_softmax=True,
        cutoffs = [1000, 5000, 20000]
    )

    reconstructor = RelationalMemory(
        mem_slots = 1,
        head_size = 192,
        input_size = d_model,
        num_tokens = vocab_sz,
        device = device,
        num_heads = 4,
        attention_mlp_layers=3,
        key_size = 64,
        use_adaptive_softmax=True,
        cutoffs = [1000, 5000, 20000]
    )
#     reconstructor = LSTM_Normal_Encoder_Decoder(
#         hidden_dim=d_model, 
#         emb_dim=d_model, 
#         input_len=SUMM_MAX, 
#         output_len=INPUT_MAX, 
#         voc_size=vocab_sz, 
#         pad_index=vocab[PAD],
#         device = device,
#         eps=1e-8,
#         num_layers = 2
#     )
 
    discriminator = Discriminator(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        d_model,
        len(vocab),
        vocab[PAD]
    )


    # This was important from their code. 
    # Initialize parameters with Glorot / fan_avg.
    for m in [translator, reconstructor, discriminator]:
        init_param(m)
            
    # creation of big bird
    model = BigBird(
        translator, discriminator, reconstructor , 
        vocab, gamma=0.99, clip_value=0.5, #for WGAN, useless if WGAN-GP is used 
        lr_G = 5e-5,
        lr_D = 1e-4,
        lr_R = 5e-5,
        LAMBDA = 10, # Gradient penalty lambda hyperparameter
        TEMP_END = 0.5,
        device = device
    )

    return model


In [8]:
model = make_big_bird(vocab, N=4, d_model=256, d_ff=256, h=4, dropout=0.1, emb_share=True, bert_share=True)
#model.load("Nest/NewbornBird_LSTM_GumbelSoftmax")


In [9]:
vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [10]:
import wandb
import matplotlib.pyplot as plt
import matplotlib

wandb.init(project="seq2seq-discrete-encoder-decoder")
wandb.config.update({
    "batch_size": batch_size,
    "input len":INPUT_MAX,
    "summary len":SUMM_MAX,
    "lr_G":model.lr_G,
    "lr_D":model.lr_D,
    "lr_R":model.lr_R,
    "temperature min":model.TEMP_END,
    })
wandb.watch([model.generator, model.discriminator, model.reconstructor])
#ecc70f422dabf793a9101343c84e8ead3c0bf72e


[<wandb.wandb_torch.TorchGraph at 0x7f9bbc0515d0>,
 <wandb.wandb_torch.TorchGraph at 0x7f9bb014afd0>,
 <wandb.wandb_torch.TorchGraph at 0x7f9bb011c550>]

In [11]:
#start = 1 if continue_from == None else (int(continue_from.split("Model")[-1])+1)
history = []

start = model.epoch

#from tensorboardX import SummaryWriter
#writer = SummaryWriter('mygraph')

step_ct = 1
width = 0.35



for epoch in range(start, num_epochs+1):
    print("Epoch", epoch)
    
    # training
    stats = Stats()
    model.train()
    
    rewards = []
    
    trange = tqdm(enumerate(data_gen_train()), total=total_train)
    for i, batch in trange:
        #expect src has [CLS] and [SEP]
        #batch.trg.shape[1] is for dataset
        GAN_loss, Rec_loss, score, output, distrib, one_hot = model.pretrainGAN_run_iter(batch.src, batch.src_mask, batch.trg.shape[1], batch.trg, D_iters=5, D_toggle = 'Off', verbose = 0)
        trange.set_postfix(
            **{'G_loss': '{:.3f}'.format(GAN_loss[0])},
            **{'D_loss': '{:.3f}'.format(GAN_loss[1])},
            **{'CE_loss': '{:.3f}'.format(Rec_loss[0])},
            #**{'vq_loss': '{:.3f}'.format(Rec_loss[1])},
            #**{'commit_loss': '{:.3f}'.format(Rec_loss[2])},
            **{'real_score': '{:.3f}'.format(score[0])},
            **{'fake_score': '{:.3f}'.format(score[1])},
            **{'acc': '{:.3f}'.format(score[2])},
        )

        if step_ct % 50 == 0:
            
            x = np.arange(len(distrib))
            ratio = 1.0/max(distrib)
            plt.bar(x, ratio * distrib, label='distrib hist' , align = "edge", width = width)
            plt.bar(x, one_hot, label='Gumbel softmax hist' , align = "edge", width = -width)
            
            plt.legend()
            plt.title("distrib vs gumbel sample (max distrib is upscale to 1)")
            plt.xlabel("dictionary [:100]")
            plt.ylabel("prob")
            wandb.log({"hist":wandb.Image(plt)})
            plt.clf()


            
            
        wandb.log({"input":output[0],
                   "encode out":output[1],
                   "reconsturct out":output[2],        
                  })
        wandb.log({
                   "G_loss":GAN_loss[0],
                   "D_loss":GAN_loss[1],
                   "CE_loss":Rec_loss[0],
                   #"vq_loss":Rec_loss[1],
                   #"commit_loss":Rec_loss[2],
                   "real_score":score[0],
                   "fake_score":score[1],
                   "acc":score[2],
                   "gumbel temperature":model.gumbel_temperature
                  }, commit=False)
            
        step_ct += 1
    model.epoch += 1


Epoch 0


HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 47])
torch.Size([2, 46, 100000])
torch.Size([2, 46])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 24])
torch.Size([2, 23, 100000])
torch.Size([2, 23])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 46])
torch.Size([2, 45, 100000])
torch.Size([2, 45])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 47])
torch.Size([2, 46, 100000])
torch.Size([2, 46])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])


torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 30])
torch.Size([2, 29, 100000])
torch.Size([2, 29])
torch.Size([2, 43])
torch.Size([2, 42, 100000])
torch.Size([2, 42])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 32])
torch.Size([2, 31, 100000])
torch.Size([2, 31])
torch.Size([2, 44])
torch.Size([2, 43, 100000])
torch.Size([2, 43])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])


torch.Size([2, 41, 100000])
torch.Size([2, 41])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 26])
torch.Size([2, 25, 100000])
torch.Size([2, 25])
torch.Size([2, 32])
torch.Size([2, 31, 100000])
torch.Size([2, 31])
torch.Size([2, 45])
torch.Size([2, 44, 100000])
torch.Size([2, 44])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 27])
torch.Size([2, 26, 100000])
torch.Size([2, 26])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 44])
torch.Size([2, 43, 100000])
torch.Size([2, 43])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])


torch.Size([2, 48, 100000])
torch.Size([2, 48])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 31])
torch.Size([2, 30, 100000])
torch.Size([2, 30])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 33])
torch.Size([2, 32, 100000])
torch.Size([2, 32])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 30])
torch.Size([2, 29, 100000])
torch.Size([2, 29])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 49])
torch.Size([2, 48, 100000])
torch.Size([2, 48])
torch.Size([2, 33])
torch.Size([2, 32, 100000])
torch.Size([2, 32])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 37])
torch.Size([2, 36, 100000])
torch.Size([2, 36])


torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 29])
torch.Size([2, 28, 100000])
torch.Size([2, 28])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 28])
torch.Size([2, 27, 100000])
torch.Size([2, 27])
torch.Size([2, 42])
torch.Size([2, 41, 100000])
torch.Size([2, 41])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 39])
torch.Size([2, 38, 100000])
torch.Size([2, 38])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 34])
torch.Size([2, 33, 100000])
torch.Size([2, 33])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 33])
torch.Size([2, 32, 100000])
torch.Size([2, 32])


HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

torch.Size([2, 48])
torch.Size([2, 47, 100000])
torch.Size([2, 47])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 49])
torch.Size([2, 48, 100000])
torch.Size([2, 48])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 41])
torch.Size([2, 40, 100000])
torch.Size([2, 40])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 28])
torch.Size([2, 27, 100000])
torch.Size([2, 27])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 42])
torch.Size([2, 41, 100000])
torch.Size([2, 41])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 50])
torch.Size([2, 49, 100000])
torch.Size([2, 49])
torch.Size([2, 30])
torch.Size([2, 29, 100000])


KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(range(len(all_loss)), all_loss)
plt.show()

In [None]:
#print(model.all_rewards)
#plt.plot(range(len(model.all_rewards)), model.all_rewards)

In [None]:
#plt.plot(range(len(model.all_rewards)-1), [sum(model.all_rewards[:i])/i for i in range(1,len(model.all_rewards))])

In [None]:

# if appear [enforce fail at CPUAllocator.cpp:56], it means cutoffs of adaptive softmax is too big

