In [1]:
from src.train import train_epoch, evaluate
import torch
from src.train.config import *
from src.preprocessing.config import *
from src.preprocessing import vocab_transform
from src.models.transformer import Seq2SeqTransformer
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from timeit import default_timer as timer
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
INITIAL_TRAIN = False
MODELS_PATH = '../data/interim/transf_cp.tar'



In [2]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [3]:
SRC_VOCAB_SIZE, TGT_VOCAB_SIZE

(19214, 10837)

In [4]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
transformer = transformer.to(DEVICE)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
if INITIAL_TRAIN:
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

else:
    checkpoint = torch.load(MODELS_PATH, map_location=DEVICE)
    transformer.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    #loss = checkpoint['loss'] 


In [5]:
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=2, verbose=True)

In [6]:
def train_wrapper(epoch_num, model, optimizer, loss_fn, run_name):
    global_val_loss = 1e10
    writer = SummaryWriter(f'../data/interim/runs/{run_name}')
    for epoch in range(epoch_num + 1, epoch_num+11):
        start_time = timer()
        train_loss = train_epoch(model, optimizer, loss_fn)
        end_time = timer()
        val_loss = evaluate(model, loss_fn)
        scheduler.step(val_loss)
        writer.add_scalars('Training vs validation loss', {'Training': train_loss, 'Validation': val_loss}, epoch)
        if val_loss < global_val_loss:
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': model.state_dict()},
            f'../data/interim/runs/{run_name}/transf_cp.tar',
            )
            print('## Vall loss decreased, model succefully saved ##')
            global_val_loss = val_loss
        print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

    writer.flush()

In [7]:
import warnings
warnings.filterwarnings('ignore')

In [8]:
train_wrapper(epoch, transformer, optimizer, loss_fn, 'try_2')

## Vall loss decreased, model succefully saved ##
Epoch: 10, Train loss: 2.155, Val loss: 3.690, Epoch time = 102.823s
Epoch: 11, Train loss: 1.966, Val loss: 3.755, Epoch time = 100.174s
Epoch: 12, Train loss: 1.789, Val loss: 3.808, Epoch time = 101.949s
Epoch 00004: reducing learning rate of group 0 to 1.0000e-05.
Epoch: 13, Train loss: 1.632, Val loss: 3.826, Epoch time = 102.310s
Epoch: 14, Train loss: 1.430, Val loss: 3.725, Epoch time = 101.608s
Epoch: 15, Train loss: 1.371, Val loss: 3.725, Epoch time = 105.622s
Epoch 00007: reducing learning rate of group 0 to 1.0000e-06.
Epoch: 16, Train loss: 1.335, Val loss: 3.729, Epoch time = 105.825s
Epoch: 17, Train loss: 1.314, Val loss: 3.720, Epoch time = 102.386s
Epoch: 18, Train loss: 1.303, Val loss: 3.720, Epoch time = 101.134s
Epoch 00010: reducing learning rate of group 0 to 1.0000e-07.
Epoch: 19, Train loss: 1.299, Val loss: 3.718, Epoch time = 101.474s
