In [1]:
from torch.nn import CrossEntropyLoss
from transformer import TransformerMT
from torch import optim
from data_loader import *
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data = EnFrDataset(used_abridged_data=True, max_seq_length=100)

train_dataloader = DataLoader(data, batch_size=32, shuffle=False, num_workers=0)

transformer_mt = TransformerMT(
    source_vocabulary_size=data.get_src_lang_size(),
    target_vocabulary_size=data.get_tgt_lang_size(),
    embedding_size=512,
    max_num_embeddings=100,
    num_attention_heads=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    linear_layer_size=2048,
    dropout=0.1,
    activation='relu',
    layer_norm_eps=1e-5,
    batch_first=False,
    norm_first=False,
    bias=True
)
transformer_mt.to(device)

transformer_mt.train()

optimizer = optim.Adam(transformer_mt.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

loss_criterion = CrossEntropyLoss(ignore_index=0)

epoch_losses = []
for e in range(60):
    transformer_mt
    running_loss = 0
    for en_token_ids, fr_token_ids in train_dataloader:
        en_token_ids.to(device)
        fr_token_ids.to(device)

        optimizer.zero_grad()
    
        output = transformer_mt(src=en_token_ids, tgt=fr_token_ids[:-1, :])
        
        loss = loss_criterion(output.reshape(-1, output.shape[-1]), fr_token_ids[1:, :].reshape(-1))
    
        loss.backward()
        optimizer.step()
    
        running_loss += loss.item()

    epoch_losses.append(running_loss)
    print(f"epoch: {e+1}, epoch loss: {round(running_loss, 3)}")

Reading the dataframe and storing untokenized pairs...


100%|██████████| 300/300 [00:00<00:00, 79008.61it/s]


Adding sentences to Langs amd geting data pairs...


100%|██████████| 300/300 [00:00<00:00, 2227.36it/s]


Creating tokenized pairs of english and french sentences...


100%|██████████| 300/300 [00:00<00:00, 9908.43it/s]


epoch: 1, epoch loss: 27.167
epoch: 2, epoch loss: 17.472



KeyboardInterrupt



In [None]:
print(epoch_losses)
plt.plot(epoch_losses)
plt.ylabel('cumulative loss')
plt.xlabel('epoch')
plt.show()