In [1]:
import torch
from src.train.config import *
from src.preprocessing.config import *
from src.preprocessing import vocab_transform
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torch.optim
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
MODELS_PATH = '../data/interim/transf_cp.tar'
import warnings
import time
from torch import optim
warnings.filterwarnings('ignore')



In [2]:
from torchtext.datasets import Multi30k, IWSLT2016
from torch.utils.data import DataLoader
from src.preprocessing import collate_fn
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE)) #+ \
    #IWSLT2016(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS)

val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=NUM_WORKERS)


In [3]:
from src.models.seq2seq import Attention, Encoder, Decoder, Seq2Seq
from src.train.seq2seq_train import train, evaluate
from src.train.utils import epoch_time
import math

In [5]:
attn = Attention(S2S_ENC_HID_DIM, S2S_DEC_HID_DIM)
enc = Encoder(SRC_VOCAB_SIZE, S2S_ENC_EMB_DIM, S2S_ENC_HID_DIM, S2S_DEC_HID_DIM, S2S_ENC_DROPOUT)
dec = Decoder(TGT_VOCAB_SIZE, S2S_DEC_EMB_DIM, S2S_ENC_HID_DIM, S2S_DEC_HID_DIM, S2S_DEC_DROPOUT, attn)
model = Seq2Seq(enc, dec, DEVICE, PAD_IDX).to(DEVICE)

In [6]:
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)

In [7]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            
model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(8014, 256)
    (rnn): GRU(256, 512, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): Attention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(6191, 256)
    (rnn): GRU(1280, 512)
    (fc_out): Linear(in_features=1792, out_features=6191, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [8]:
N_EPOCHS = 20
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_dataloader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, val_dataloader, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut3-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Epoch: 01 | Time: 8m 11s
	Train Loss: 16.632 | Train PPL: 16716021.092
	 Val. Loss: 14.780 |  Val. PPL: 2624528.029
Epoch: 02 | Time: 8m 5s
	Train Loss: 11.484 | Train PPL: 97195.760
	 Val. Loss: 12.983 |  Val. PPL: 435067.567
Epoch: 03 | Time: 7m 52s
	Train Loss: 9.043 | Train PPL: 8458.517
	 Val. Loss: 12.930 |  Val. PPL: 412491.602


KeyboardInterrupt: 

In [6]:
test = {'1': 2}
if test.get('22'):
    print('yes')