In [1]:
from LSTM_BigBirdA2C_GumbelSoftmax import *
from dataset import make_data_generator
import json
from tqdm import tqdm_notebook as tqdm
#from torchsummary import summary

In [2]:
folder = 'data/Food/'
data_name = folder+'data.json'
# validation_name = folder+'valid_seq.json'
# testdata_name = folder+'testdata_seq.json'
vocab_name = folder+'vocab.json'

In [3]:
num_epochs = 10
save_rate = 1 #how many epochs per modelsave
#continue_from = "trained/Model1" # if none, put None
continue_from = None
epsilon = 1e-8
validation_size = 10000
device = torch.device('cuda')
!mkdir -p trained

In [4]:
vocab = json.load(open(vocab_name, 'r'))
VOC_SIZE = len(vocab)
INPUT_MAX = 100
SUMM_MAX = 101
UNK = "[UNK]"
BOS = "[CLS]"
EOS = "[SEP]"
PAD = "[PAD]"

In [5]:
batch_size = 72

training_set, training_generator = make_data_generator(\
data_name, INPUT_MAX, SUMM_MAX, vocab[PAD], batch_size, pretrain=True, cutoff=None, shuffle=True, num_workers=4)

# validation_set, validation_generator = make_data_generator(\
# validation_name, INPUT_MAX, OUTPUT_MAX, vocab[PAD], batch_size, cutoff=validation_size, shuffle=False, num_workers=4)

def data_gen_train():
    for src, tgt in training_generator:
        src = src.to(device)
        tgt = tgt.to(device)
        b = Batch(src, tgt, vocab[PAD])
        yield b

loading json
load json done.


HBox(children=(IntProgress(value=0, max=568454), HTML(value='')))




In [6]:
import math
total_train = int(math.ceil(training_set.size / batch_size))
# total_valid = int(math.ceil(validation_set.size / batch_size))
# print(total_train, total_valid)

In [7]:
def init_param(model):
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

model = LSTM_Normal_Encoder_Decoder(
        hidden_dim=256, 
        emb_dim=256, 
        input_len=INPUT_MAX, 
        output_len=SUMM_MAX-1, 
        voc_size=len(vocab), 
        pad_index=vocab[PAD],
        eps=1e-8
    ).to(device)

# init_param(model)

from adabound import AdaBound
model_opt = AdaBound(model.parameters(), lr=1e-4, betas=(0.9, 0.998), final_lr=0.1, eps=1e-8)
# model_opt = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.998), eps=1e-8)
criterion = torch.nn.NLLLoss(reduction='sum', ignore_index = vocab[PAD]).to(device)

In [8]:
# !pip install adabound

In [None]:
vocab_inv = {a:b for b, a in vocab.items()}
def convert_ids_to_tokens(ids):
    return [vocab_inv[i] for i in ids]

In [None]:
start = 1 if continue_from == None else (int(continue_from.split("Model")[-1])+1)
history = []


for epoch in range(start, num_epochs+1):
    print("Epoch", epoch)
    
    # training
    stats = Stats()
    model.train()
    trange = tqdm(enumerate(data_gen_train()), total=total_train)
    
    for i, batch in trange:
        
        #r, _, next_words = model(
        #    x=batch.src, 
        #    src_mask=batch.src_mask, 
        #    max_len=batch.trg_y.shape[1], 
        #    start_symbol=vocab[BOS], 
        #    y=batch.trg_y.contiguous(), 
        #    mode = 'sample'
        #)
        
        logits, next_words = model.pretrian_forward(
            x=batch.src,
            y=batch.trg,
            mode = 'sample'
        )
        N = (batch.trg_y.shape[0]*batch.trg_y.shape[1])
        loss = criterion(logits.view(N, len(vocab)), batch.trg_y.contiguous().view(N)) / batch.ntokens
    
        model.zero_grad()
        loss.backward()
        model_opt.step()
        
        stats.update(loss, 1, log=0)
        
        if( i % 1000 == 999):
            print("\n")            
            print(convert_ids_to_tokens([i.item() for i in next_words[0]]))
            print(convert_ids_to_tokens([i.item() for i in batch.trg_y[0]]))
        
        trange.set_postfix(
            **{'loss': '{:.3f}'.format(loss)},
            **{'tgt_len': '{}'.format(batch.trg_y.shape[1])}
        )
        
    t_h = stats.history
    history.append(t_h)
    
    print("[info] epoch train loss:", np.mean(t_h))
    
    try:
        !mkdir -p pretrained
        torch.save({'model':translator.state_dict(), 'training_history':t_h}, 
                   "pretrained/LSTM"+str(epoch))
    except:
        continue

Epoch 1


HBox(children=(IntProgress(value=0, max=7896), HTML(value='')))

done


['i', 'are', 'a', 'a', 'the', '.', 'the', 'best', '.', '.', 'and', 'i', 'is', 'a', 'to', '.', '.', 'be', 'a', 'a', '.', '.', 'be', 'a', '[SEP]', "'", 'to', '.', 'be', '.', 'a', '.', '.', '.', '.', '.', '.', 'you', 'and', 'a', '.', '.', 'the', '.', 'be', 'best', '.', '[SEP]', 'is', 'is', 'the', 'little', '.', '[SEP]', "'", 'the', 'to', '.', 'little', '.', '.', '.', "'", 'the', '.', '.', '.', 'is', 'the', 'best', '.', '.', '.', '.', 'be', '.', '.', '.', '[SEP]']
['these', 'cost', 'more', 'on', 'amazon', 'than', 'the', 'grocery', 'store', ',', 'but', 'it', 'is', 'sometimes', 'more', 'convenient', 'to', 'just', 'have', 'them', 'shipped', 'to', 'you', '.', 'we', 'used', 'these', 'to', 'help', 'our', 'baby', 'feel', '"', 'more', 'comfortable', '"', 'if', 'she', 'was', 'having', 'issues', 'with', 'going', 'to', 'the', 'bathroom', '.', 'this', 'worked', 'like', 'a', 'charm', '.', 'we', 'ended', 'up', 'having', 'a', 'regular', 'rotation', 'where', 'we', 'made', 'sure', 'she', 'got', 'thi

In [None]:
history = np.asarray(history).flatten()

In [None]:
plt.plot(history)