In [None]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from data_prepare import TextPairs, pad_batch
import numpy as np
import pickle as pkl
from torch.utils.data import DataLoader
from RNN import RNN_encoder, RNN_decoder
from Transformer import Transformer_encoder, Transformer_decoder
from datetime import datetime


%load_ext autoreload
%autoreload 2

NUM_EPOCHS = 200
BATCH_SIZE = 100
VOCA_SIZE = 1500 # smaller than len(text_pairs.voca['en']), len(text_pairs.voca['de'])
NUM_LAYERS = 2
HIDDEN_DIM = 64
EMBEDDING_DIM = 200

text_pairs = TextPairs(VOCA_SIZE)
textLoader = DataLoader(text_pairs, batch_size=BATCH_SIZE, num_workers=2, collate_fn = pad_batch)
valLoader = DataLoader(text_pairs, batch_size=1, collate_fn = pad_batch)

SAMPLE = np.random.choice(len(text_pairs), 3, replace=False)

print(datetime.now())
print(f'\nVOCA_SIZE: {VOCA_SIZE}')
print(f'SAMPLE: {SAMPLE}')

In [None]:
with open('Data/Glove/glove.6B.200d.pkl', 'rb') as f:
    glove = pkl.load(f)

embedding_matrix = torch.zeros(VOCA_SIZE, EMBEDDING_DIM)

for w in text_pairs.voca['en']:
    if glove.get(w) is None:
        embedding_matrix[ text_pairs.word2id['en'][w] ] = torch.zeros(EMBEDDING_DIM)
    else:
        embedding_matrix[ text_pairs.word2id['en'][w] ] = torch.from_numpy(glove.get(w))

In [None]:
MAX_LEN = text_pairs.max_len
encoder = RNN_encoder(NUM_LAYERS, HIDDEN_DIM, VOCA_SIZE, EMBEDDING_DIM, embedding_matrix)
decoder = RNN_decoder(NUM_LAYERS, HIDDEN_DIM, VOCA_SIZE, EMBEDDING_DIM, MAX_LEN)

encoder_params = sum(p.numel() for p in encoder.parameters())
decoder_params = sum(p.numel() for p in decoder.parameters())
print(f'Total number of parameters: {encoder_params+decoder_params}')

In [None]:
device = torch.device('cpu')
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss() 

print(datetime.now())
print()
for epoch in range(NUM_EPOCHS):
    for i, data in enumerate(textLoader):
        en_text, de_text = data
        en_text, de_text = en_text.to(device), de_text.to(device)

        encoder_inputs, decoder_inputs = en_text, de_text[:,:-1]
        labels = de_text[:,1:]
        
        context = encoder(encoder_inputs)
        
        
        preds = decoder(decoder_inputs, context, train=True) # BATCH_SIZE, MAX_LEN, hidden_dim
       

        loss = criterion( preds.view(-1, VOCA_SIZE), labels.contiguous().view(-1))

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()
        
    if (epoch+1) % 1==0:
        print(f'{loss.item():.3f}\t{str(datetime.now())}')
        with torch.no_grad():
            for j, data in enumerate(valLoader):
                if j in SAMPLE:
                    en_text, de_text = data

                    context = encoder(en_text)
                    sos = torch.tensor([[2]])
                    
                    
                    preds = decoder( sos, context, train=False) # BATCH_SIZE, MAX_LEN, hidden_dim
                    tokens = torch.argmax(preds[0], dim=-1)
                    text = [ text_pairs.voca['de'][t] for t in tokens ]

                    print('Pred:\t', text)
                    print('Target:\t', [ text_pairs.voca['de'][t] for t in de_text[0] ][1:] )
                    print()