In [29]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, random_split
import torch.nn.functional as F
import sys
import os
import pickle
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
parent_dir = os.path.abspath("../papers/Attention Is All You Need")
sys.path.append(parent_dir)
from pathlib import Path
from tqdm import tqdm

In [12]:
import yaml

YAML_PATH = Path("../")  / "papers" / "Attention Is All You Need" / "config.yaml"
with open(YAML_PATH, "r") as file:
    config = yaml.safe_load(file)

## Get Training Data and Encode Training Data

In [13]:
from BPE.bpe import BPEEncoder, BPEDecoder


In [14]:
PATH = Path("../")  / "data" / "EnglishGerman" 
dataset = pd.read_csv(PATH / "wmt14_translate_de-en_validation.csv")

In [15]:
BPE_PATH = Path("../")  / "papers" / "Attention Is All You Need" / "BPE" 
with open(BPE_PATH / "vocab_merges_2500.pkl", "rb") as f:
    vocab = pickle.load(f)

In [16]:
bpe_encoder = BPEEncoder(vocab=vocab)
english_encoded = bpe_encoder.encode(dataset.iloc[:, 1])
german_encoded = bpe_encoder.encode(dataset.iloc[:, 0])

## Preparing the Data

In [17]:
max_vocab_size = max(vocab.values())
special_tokens = {max_vocab_size + 1 : "<SOS>",  max_vocab_size + 2 : "<EOS>", max_vocab_size + 3 : "<pad>"}

In [18]:
class LanguageTranslationDataset(Dataset):
    def __init__(self, seq_length, src_encodings, tgt_encodings, sos_token, eos_token, pad_token):
        super().__init__()
        self.paired_encodings = LanguageTranslationDataset.augment_encodings(src_encodings, tgt_encodings, sos_token, eos_token)
        self.seq_len = seq_length
        self.pad_token = pad_token

    @staticmethod
    def augment_encodings(src_encodings, tgt_encodings, sos_token, eos_token):
        src_encodings = [[sos_token] + sublist + [eos_token] for sublist in src_encodings]
        tgt_encodings = [[sos_token] + sublist for sublist in tgt_encodings]
        output_encodings = [sublist + [eos_token] for sublist in tgt_encodings] 
        full_encoding = list(zip(src_encodings, tgt_encodings, output_encodings))
        full_encoding.sort(key=lambda x: len(x[0])) # sort sequence lengths
        return full_encoding


    def __getitem__(self, idx):
        src_seq, tgt_seq, output_seq = self.paired_encodings[idx]
        src_tensor = torch.tensor(src_seq, dtype=torch.long)
        tgt_tensor = torch.tensor(tgt_seq, dtype=torch.long)
        output_tensor = torch.tensor(output_seq, dtype=torch.long)

        src_tensor = F.pad(src_tensor, (0, self.seq_len - src_tensor.size(0)), value=self.pad_token)
        tgt_tensor = F.pad(tgt_tensor, (0, self.seq_len - tgt_tensor.size(0)), value=self.pad_token)
        output_tensor = F.pad(output_tensor, (0, self.seq_len - output_tensor.size(0)), value=self.pad_token)
        encoder_mask = (src_tensor != self.pad_token).int()

        subsequent_mask = torch.tril(torch.ones((self.seq_len, self.seq_len), dtype=torch.int))
        padding_mask = (tgt_tensor != self.pad_token).int()
        decoder_mask = subsequent_mask & padding_mask.unsqueeze(0)

        

        return {
            "src": src_tensor, # Seq_len
            "tgt": tgt_tensor, # seq_len
            "output": output_tensor, # seq_len
            "encoder_mask" : encoder_mask.unsqueeze(0).unsqueeze(0), # 1 x 1 x seq_len
            "decoder_mask" : decoder_mask.unsqueeze(0), # 1 x seq_len x seq_len
        }

    def __len__(self): 
        return len(self.paired_encodings)

In [19]:
full_data = LanguageTranslationDataset(seq_length=config['SEQ_LEN'], src_encodings=english_encoded, tgt_encodings=german_encoded, sos_token=max_vocab_size + 1, eos_token=max_vocab_size + 2,
                                        pad_token=max_vocab_size + 3)

train_data, test_data = random_split(full_data, [config['TRAIN_RATIO'], 1-config['TRAIN_RATIO']])

In [20]:
train_dataloader = DataLoader(train_data, batch_size=config['BATCH_SIZE'], shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=config['BATCH_SIZE'], shuffle=True)

## Model Creation

In [21]:
from TransformerComponents.Encoder import Encoder
from TransformerComponents.Decoder import Decoder
from TransformerComponents.PE import PositionalEmbedding
from TransformerComponents.Transformer import Transformer
from TransformerComponents.UtilsLayers import Projection

In [22]:
vocab_size = len(vocab) + 256 + 3 # 3 special tokens

In [23]:
encoder_transformer = Encoder(config['N_ENCODERS'], config['N_HEADS'], config['D_MODEL'], config['D_MODEL'] // config['N_HEADS'], config['D_MODEL'] // config['N_HEADS'], config['FF_HIDDEN'])
decoder_transformer = Decoder(config['N_DECODERS'], config['N_HEADS'], config['D_MODEL'], config['D_MODEL'] // config['N_HEADS'], config['D_MODEL'] // config['N_HEADS'], config['FF_HIDDEN'])
src_embeddings = PositionalEmbedding(vocab_size, config['D_MODEL'], config['SEQ_LEN'])
tgt_embeddings = PositionalEmbedding(vocab_size, config['D_MODEL'], config['SEQ_LEN'])
projection = Projection(config['D_MODEL'], vocab_size)

In [24]:
model = Transformer(encoder_transformer, decoder_transformer, src_embeddings, tgt_embeddings, projection)

In [25]:
model.initialise()

## Model Training

In [26]:
writer = SummaryWriter('../papers/Attention Is All You Need/Tensorboard/experiment_1')

In [32]:
optimiser = torch.optim.Adam(model.parameters(), lr=config['LR'], eps=1e-9)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Device: {device}")
model = model.to(device)
loss_fn = torch.nn.NLLLoss(ignore_index=max_vocab_size + 3).to(device)

global_step = 0
losses = []
test_losses = []
for epoch in range(config['NUM_EPOCHS']):
    model.train()
    batch_train = tqdm(train_dataloader, desc=f"epoch: {epoch:02d}")

    for data in batch_train:
        target_indices = data['output'].to(device) # B x seq_len

        encoder_input = data['src'].to(device) # B x seq_len
        tgt_input = data['tgt'].to(device) # B x seq_len
        encoder_mask = data['encoder_mask'].to(device) # B x 1 x 1 x seq_len
        decoder_mask = data['decoder_mask'].to(device) # B x 1 x seq_len x seq_len
        logits = model(encoder_input,  tgt_input, encoder_mask=encoder_mask, decoder_mask=decoder_mask)
        loss = loss_fn(logits.view(-1, vocab_size), target_indices.view(-1))
        batch_train.set_postfix({"loss": f"{loss.item(): 6.3f}"})
        writer.add_scalar("train_loss", loss.item(), global_step)
        writer.flush()

        
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        global_step += 1
    
    # losses.append(loss_sum / count)
    # gru.eval()
    # loss_sum = 0.0 
    # count = 0

    # with torch.no_grad():
    #     for seq, yo in test_dataloader:
    #         target_indices = yo.argmax(dim=-1)  # now shape: (B, T)
    #         logits = gru(seq)
    #         loss = nn.CrossEntropyLoss()(logits.view(-1, full_data.vocab_size), target_indices.view(-1))
    #         loss_sum += loss.item()
    #         count += 1
    # test_losses.append(loss_sum / count)

Using Device: cpu


epoch: 00:  11%|█         | 4/38 [03:32<30:09, 53.22s/it, loss=21.191]


KeyboardInterrupt: 