In [2]:
import math, copy, sys

import torch

import nltk
nltk.download('wordnet') 

from scripts.MoveData import *
from scripts.Transformer import *
from scripts.TalkTrain import *

[nltk_data] Downloading package wordnet to /home/hari/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
opt = Options(batchsize=16, device=torch.device("cuda"), epochs=40, 
              lr=0.01, max_len = 40, save_path = 'data/transformer_custom_weights')

data_iter, infield, outfield, opt = json2datatools(path = 'data/data1.json', opt=opt)
print('input vocab size', len(infield.vocab), 'output vocab size', len(outfield.vocab))

In [None]:
emb_dim, n_layers, heads, dropout = 32, 2, 8, 0.1 
dole = Transformer(len(infield.vocab), len(outfield.vocab), emb_dim, n_layers, heads, dropout)

In [None]:
optimizer = torch.optim.Adam(dole.parameters(), lr=opt.lr, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=3)

In [None]:
def trainer(model, data_iterator, options, optimizer, scheduler):

    if torch.cuda.is_available() and options.device == torch.device("cuda:0"):
        print("a GPU was detected, model will be trained on GPU")
        model = model.cuda()
    else:
        print("training on cpu")

    model.train()
    start = time.time()
    best_loss = 100
    for epoch in range(options.epochs):
        total_loss = 0
        for i, batch in enumerate(data_iterator): 
            src = batch.listen.transpose(0,1)
            trg = batch.reply.transpose(0,1)
            trg_input = trg[:, :-1]
            src_mask, trg_mask = create_masks(src, trg_input, options)
            preds = model(src, src_mask, trg_input, trg_mask)
            ys = trg[:, 1:].contiguous().view(-1)
            optimizer.zero_grad()
            preds = preds.view(-1, preds.size(-1))
            batch_loss = F.cross_entropy(preds, ys, 
                                         ignore_index = options.trg_pad)
            batch_loss.backward()
            optimizer.step()
            total_loss += batch_loss.item()

        epoch_loss = total_loss/(num_batches(data_iterator)+1)
        scheduler.step(epoch_loss)

        if epoch_loss < best_loss:
            best_loss = epoch_loss
            print(f'saving model at', options.save_path)
            torch.save(model.state_dict(), options.save_path)
            
        print("%dm: epoch %d loss = %.3f" %((time.time() - start)//60, epoch, epoch_loss))
        total_loss = 0

    return model

In [None]:
dole = trainer(dole, data_iter, opt, optimizer, scheduler)