In [None]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from data_prepare import TextPairs
import numpy as np
import pickle as pkl
from torch.utils.data import DataLoader
from RNN import RNN_encoder, RNN_decoder
from Transformer import Transformer_encoder, Transformer_decoder
from datetime import datetime
from math import sin, cos

%load_ext autoreload
%autoreload 2

NUM_EPOCHS = 1000
BATCH_SIZE = 128
VOCA_SIZE = 4000 # smaller than len(text_pairs.voca['en']), len(text_pairs.voca['de'])
NUM_LAYERS = 2
NUM_HEADS = 20
EMBEDDING_DIM = 200

start = datetime.now()
train_pairs = TextPairs(VOCA_SIZE, train=True, toy=True)
val_pairs = TextPairs(VOCA_SIZE, train=False, toy=True)

print( len(train_pairs.voca['en']) )
print( len(train_pairs.voca['de']) )

trainLoader = DataLoader(train_pairs, batch_size=BATCH_SIZE, shuffle=True)
valLoader = DataLoader(val_pairs)

SAMPLE = [23, 44, 67]

print(f'\nElapsed time: {datetime.now() - start}')
print(f'\nData_length: {len(train_pairs)}')
print(f'VOCA_SIZE: {VOCA_SIZE}')
print(f'SAMPLE: {SAMPLE}')

In [None]:
with open('Data/Glove/glove.6B.200d.pkl', 'rb') as f:
    glove = pkl.load(f)

embedding_matrix = torch.zeros(VOCA_SIZE, EMBEDDING_DIM)

for w in train_pairs.voca['en']:
    if glove.get(w) is None:
        embedding_matrix[ train_pairs.word2id['en'][w] ] = torch.zeros(EMBEDDING_DIM)
    else:
        embedding_matrix[ train_pairs.word2id['en'][w] ] = torch.from_numpy(glove.get(w))

In [None]:
MAX_LEN = train_pairs.max_len
PE = torch.zeros(1, MAX_LEN, EMBEDDING_DIM)
for pos in range(MAX_LEN):
    for i in range(EMBEDDING_DIM//2):
        PE[0, pos, 2*i] = sin(pos / 10000**(2*i/EMBEDDING_DIM))
        PE[0, pos, 2*i+1] = cos(pos / 10000**(2*i/EMBEDDING_DIM))

In [None]:
encoder = Transformer_encoder(NUM_LAYERS, NUM_HEADS, VOCA_SIZE, EMBEDDING_DIM, embedding_matrix)
decoder = Transformer_decoder(NUM_LAYERS, NUM_HEADS, VOCA_SIZE, EMBEDDING_DIM, MAX_LEN)

encoder_params = sum(p.numel() for p in encoder.parameters())
decoder_params = sum(p.numel() for p in decoder.parameters())
print(f'Total number of parameters: {encoder_params+decoder_params}')

In [None]:
device = torch.device('cuda:2')
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss() 
PE = PE.to(device)

print(datetime.now())
print()

encoder = encoder.to(device)
decoder = decoder.to(device)

for epoch in range(NUM_EPOCHS):
    for i, data in enumerate(trainLoader):
        en_text, de_text = data['en'], data['de']
        
        encoder_inputs, decoder_inputs, targets = en_text, de_text[:,:-1], de_text[:,1:]
        
        encoder_inputs = encoder_inputs.to(device)
        context = encoder(encoder_inputs, PE)
        
        decoder_inputs = decoder_inputs.to(device)
        preds = decoder(decoder_inputs, context, PE, device, train=True) # BATCH_SIZE, MAX_LEN, hidden_dim

        targets = targets.to(device)
        loss = criterion( preds.view(-1, VOCA_SIZE), targets.contiguous().view(-1))

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()
        
    if (epoch+1) % 10==0:
        print(f'{loss.item():.3f}\t{str(datetime.now())}')
        with torch.no_grad():
            for j, data in enumerate(valLoader):
                if j in SAMPLE:
                    en_text, de_text = data['en'], data['de']
                    en_text = en_text.to(device)

                    context = encoder(en_text, PE)
                    sos = torch.cat( [torch.tensor([[2]]), torch.zeros(1, MAX_LEN-1, dtype=torch.long)], dim=-1 ).to(device)
                    
                    preds = decoder( sos, context, PE, device, train=False) # BATCH_SIZE, MAX_LEN, hidden_dim
                    tokens = torch.argmax(preds[0], dim=-1)
                    text = [ val_pairs.voca['de'][t] for t in tokens ]

                    print('Pred:\t', text)
                    print('Target:\t', [ val_pairs.voca['de'][t] for t in de_text[0] if t!=0 ][1:]  )
                    print()