inspiration: https://www.youtube.com/watch?v=ISNdQcPhsts&ab_channel=UmarJamil

https://pytorch.org/tutorials/beginner/transformer_tutorial.html

https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html

In [None]:
import os
import torch
import torchmetrics
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm
from pathlib import Path
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.tensorboard import SummaryWriter


# Download the English to Spanish sentence pairs from Hugging Face datasets.

In [None]:
class BilingualDataset(Dataset):
    def __init__(self, dataset, tokenizer_src, tokenizer_tgt, lang_src, lang_tgt, seq_len):
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.lang_src = lang_src
        self.lang_tgt = lang_tgt
        self.seq_len = seq_len

        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64) #Special token 'SOS': tensor([2])
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64) #Special tokens 'EOS': tensor([3])
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64) #Special tokens 'PAD': tensor([1])

        # Filter out data points that do not meet the required sequence length
        self.processed_data = self._filter_dataset(dataset)

    def _filter_dataset(self, dataset):
        processed_data = []
        for idx, src_target_pair in enumerate(dataset):
            src_text = src_target_pair["translation"][self.lang_src]
            tgt_text = src_target_pair["translation"][self.lang_tgt]

            encode_input_tokens = self.tokenizer_src.encode(src_text).ids
            encode_target_tokens = self.tokenizer_tgt.encode(tgt_text).ids

            encode_padding_tokens = self.seq_len - len(encode_input_tokens) - 2
            decode_padding_tokens = self.seq_len - len(encode_target_tokens) - 1

            if encode_padding_tokens >= 0 and decode_padding_tokens >= 0:
                encoder_input = torch.cat([
                    self.sos_token,
                    torch.tensor(encode_input_tokens, dtype=torch.int64),
                    self.eos_token,
                    torch.tensor([self.pad_token] * encode_padding_tokens, dtype=torch.int64)
                ])

                decoder_input = torch.cat([
                    self.sos_token,
                    torch.tensor(encode_target_tokens),
                    torch.tensor([self.pad_token] * decode_padding_tokens, dtype=torch.int64)
                ])

                label = torch.cat([
                    torch.tensor(encode_target_tokens),
                    self.eos_token,
                    torch.tensor([self.pad_token] * decode_padding_tokens, dtype=torch.int64)
                ])

                assert encoder_input.size(0) == self.seq_len, f"Encoder input length does not match sequence length at index {idx}"
                assert decoder_input.size(0) == self.seq_len, f"Decoder input length does not match sequence length at index {idx}"
                assert label.size(0) == self.seq_len, f"Label length does not match sequence length at index {idx}"

                processed_data.append(
                    {
                        "encoder_input": encoder_input,
                        "decoder_input": decoder_input,
                        "label": label
                    }
                )
            else:
                print(f"Skipping data point at index {idx} due to insufficient sequence length.")

        return processed_data

    def __len__(self):
        return len(self.processed_data)

    def __getitem__(self, idx):
        return self.processed_data[idx]


In [None]:
def get_or_build_tokenizer(config, dataset, language):
    tokenizer_path = Path(config['tokenizer_file'].format(language))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(dataset, language), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def get_all_sentences(dataset, language):
    for sentence_pair in dataset:
        yield sentence_pair["translation"][language]

def get_dataset(config):
    # download dataset
    dataset = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split='train')

    # build tokenizer
    tokenizer_src = get_or_build_tokenizer(config, dataset, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, dataset, config['lang_tgt'])

    # split dataset
    train_size = int(len(dataset) * 0.8)
    val_size = int(len(dataset) - train_size)

    train_data_raw, val_data_raw = random_split(dataset, [train_size, val_size])

    train_dataset = BilingualDataset(train_data_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    valid_dataset = BilingualDataset(val_data_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for item in dataset:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    

    train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

In [None]:
def get_config():
    return {
        "batch_size": 16,
        "num_epochs": 2,
        "lr": 1e-4,
        "seq_len": 350,
        "d_model": 512,
        "lang_src": "en",
        "lang_tgt": "es",
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": None,
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/transformer_seq2seq",
    }

def get_weights_file_path(config, epoch: str):
    model_folder = config["model_folder"]
    model_basename = config["model_basename"]
    model_filename = f"{model_basename}{epoch}.pt"
    return str(Path(".") / model_folder / model_filename)

# Lets Build a Transformer model to translate English to Spanish

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int = 100, dropout: float = 0.1):
        """Initialize the PositionalEncoding module."""
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        positional_encoding = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        division_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        positional_encoding[:, 0::2] = torch.sin(position * division_term)
        positional_encoding[:, 1::2] = torch.cos(position * division_term)
        positional_encoding = positional_encoding.unsqueeze(0)
        self.register_buffer("positional_encoding", positional_encoding)

    def forward(self, x):
        """Perform the forward pass of the PositionalEncoding module."""
        x = x + self.positional_encoding[:, : x.size(1)].requires_grad_(False)
        x = self.dropout(x)
        return x

In [None]:
class TransformerModel(nn.Module):
    def __init__(
        self,
        src_vocab_size:int,
        tgt_vocab_size:int,
        src_seq_len:int,
        tgt_seq_len:int,
        d_model:int=512, 
        nhead:int=8,
        num_encoder_layers:int=6, 
        num_decoder_layers:int=6, 
        dim_feedforward:int=2048, 
        dropout:float=0.1,
        ) -> None:
        super().__init__()

        self.model_type="Transformer"
        self.d_model = d_model
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        self.src_positional_encoding = PositionalEncoding(d_model, src_seq_len, dropout)
        self.tgt_positional_encoding = PositionalEncoding(d_model, tgt_seq_len, dropout)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True, # (batch, seq, d_model)
            )
        self.src_mask = self.transformer.generate_square_subsequent_mask(src_seq_len)
        self.tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_seq_len)
        print(f"Source mask size: {self.src_mask.shape}")
        print(f"Target mask size: {self.tgt_mask.shape}")
        self.linear = nn.Linear(d_model, tgt_vocab_size) 

    def translate(
        self,
        src: torch.Tensor,
        max_length: int = 50,
        start_symbol: int = 2,
        stop_symbol: int = 3,
    ) -> torch.Tensor:
        self.eval()

        with torch.no_grad():
            src = self.src_embedding(src.long()) * torch.sqrt(torch.tensor(self.d_model))
            src = self.src_positional_encoding(src)

            src_mask = self.transformer.generate_square_subsequent_mask(src.size(1)).to(src.device)
            memory = self.transformer.encoder(src, src_mask)

            output_sequence  = torch.ones(1, 1, dtype=torch.long).fill_(start_symbol).to(src.device)
            for _ in range(max_length - 1):
                tgt_mask = self.transformer.generate_square_subsequent_mask(output_sequence.size(1)).to(src.device)
                tgt = self.tgt_embedding(output_sequence) * torch.sqrt(torch.tensor(self.d_model))
                tgt = self.tgt_positional_encoding(tgt)
                out = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask)
                out = self.linear(out)
                prob = F.log_softmax(out, dim=-1)
                _, next_word = torch.max(prob, dim=-1)
                next_word = next_word[0, -1].unsqueeze(0)
                output_sequence  = torch.cat([output_sequence, next_word.unsqueeze(0)], dim=1)  # Ensure both tensors have the same number of dimensions
                if next_word == stop_symbol:  # Assuming the end token has the index 3
                    break
        return output_sequence 

    def get_tgt_mask(self, seq_len:int) -> torch.Tensor:
        """
        sample output: tensor([[0., -inf, -inf, -inf],
                               [0., 0., -inf, -inf],
                               [0., 0., 0., -inf],
                               [0., 0., 0., 0.]])
        """
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).type(torch.int)
        mask = mask.masked_fill(mask==1, float('-inf'))

        return mask

    def get_pad_mask(self, matrix:torch.tensor, pad_token:int) -> torch.Tensor:
        """
        sample input = [1,2,3,0,0,0] where pad_token=0, the result mask is
        # [False, False, False, True, True, True]
        """
        mask = (matrix == pad_token)
        return mask.any(dim=-1)

    def forward(
        self, 
        src:torch.Tensor,
        tgt:torch.Tensor,
        ) -> torch.Tensor:
        src = self.src_embedding(src) * torch.sqrt(torch.tensor(self.d_model))
        src = self.src_positional_encoding(src) # (batch, seq, d_model)
        tgt = self.tgt_embedding(tgt) * torch.sqrt(torch.tensor(self.d_model))
        tgt = self.tgt_positional_encoding(tgt) # (batch, seq, d_model)
        
        src_key_padding_mask = self.get_pad_mask(src, 1).to(src.device)
        #tgt_key_padding_mask = self.get_pad_mask(tgt, 1).to(src.device)

        x = self.transformer(
            src=src, 
            tgt=tgt,
            src_key_padding_mask=src_key_padding_mask,
            tgt_mask=self.tgt_mask,
            ) # (batch, seq, d_model) though originally (seq, batch, d_model)
        x = self.linear(x)
        x = F.log_softmax(x, dim=-1)
        return x

# Download the English to Spanish sentence pairs from Hugging Face datasets.

In [None]:
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            model_out = model.translate(encoder_input, max_length=max_len)

            #print(f"Encoder input size: {encoder_input}")
            #print(f"Model output: {model_out}")

            src_input_text = tokenizer_src.decode(batch["decoder_input"][0].tolist())
            tgt_output_text = tokenizer_tgt.decode(batch["label"][0].tolist())
            model_out_text = tokenizer_tgt.decode(model_out[0].tolist())

            source_texts.append(src_input_text)
            expected.append(tgt_output_text)
            predicted.append(model_out_text)

            print(f"Source: {src_input_text}")
            print(f"Target: {tgt_output_text}")
            print(f"Predicted: {model_out_text}")

In [None]:
def train_model(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using {device} device.")

    Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)

    print("Loading dataset...")
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_dataset(config)
    # subset train_dataloader and val_dataloader to 100 examples
    # subset_train_dataloader = []
    # subset_val_dataloader = []

    # for i, data in enumerate(train_dataloader):
    #     subset_train_dataloader.append(data)
    #     if i >= 2:
    #         break

    # for i, data in enumerate(val_dataloader):
    #     subset_val_dataloader.append(data)
    #     if i >= 2:
    #         break

    # train_dataloader = subset_train_dataloader
    # val_dataloader = subset_val_dataloader
        
    print("Dataset loaded.")

    print("Building model...")
    model = TransformerModel(
        src_vocab_size= tokenizer_src.get_vocab_size(),
        tgt_vocab_size = tokenizer_tgt.get_vocab_size(),
        src_seq_len = config["seq_len"],
        tgt_seq_len = config["seq_len"],
        ).to(device)
    print("Model built.")

    # Setup Tensorboard
    print("Setting up hyperparameters...")
    writer = SummaryWriter(config["experiment_name"])
    optimizer = optim.Adam(model.parameters(), lr=config["lr"], eps=1e-09)

    initial_epoch = 0
    global_step = 0
    if config["preload"] is not None:
        model_filename = get_weights_file_path(config, config["preload"])
        print(f"Preloading model {model_filename}")
        checkpoint = torch.load(config["preload"])
        initial_epoch = checkpoint["epoch"] + 1
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        global_step = checkpoint["global_step"]
        model.load_state_dict(checkpoint["model_state_dict"])

    criteria = nn.CrossEntropyLoss(ignore_index=tokenizer_tgt.token_to_id("[PAD]"), label_smoothing=0.1)
    print("Hyperparameters set.")
    
    print("Starting training...")
    for epoch in range(initial_epoch, config["num_epochs"]):
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        for batch in batch_iterator:
            encoder_input = batch["encoder_input"].to(device) # [batch, seq_len]
            decoder_input = batch["decoder_input"].to(device) # [batch, seq_len]
            label = batch["label"].to(device) # [batch, seq_len]
            optimizer.zero_grad()

            output = model(encoder_input, decoder_input)
            loss = criteria(output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            writer.add_scalar("Loss/train", loss.item(), global_step)
            writer.flush()
            loss.backward()
            optimizer.step()
            global_step += 1

        # Run validation
        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config["seq_len"], device, num_examples=2)


        # Save model
        model_filename = get_weights_file_path(config, epoch)
        print(f"Saving model {model_filename}")
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "global_step": global_step,
            "loss": loss.item(),
        }, model_filename)

In [None]:
config = get_config()
train_model(config)

In [None]:
# batch_size=2
# d_model=512

# src_vocab_size=30000
# tgt_vocab_size=30000

# src_seq_length=100
# tgt_seq_length=100

# model = TransformerModel(
#     src_vocab_size= src_vocab_size,
#     tgt_vocab_size = tgt_vocab_size,
#     src_seq_len = src_seq_length,
#     tgt_seq_len = tgt_seq_length,
#     )

# src = torch.rand(batch_size, src_seq_length).long()
# tgt = torch.rand(batch_size, tgt_seq_length).long()

# output = model(src, tgt)
# print(f"Output shape: {output.shape}") # torch.Size([batch, tgt_seq_len, tgt_vocab_size]) Seq2Seq

In [None]:
# #source_sequence = "[SOS] Hello world [EOS] [PAD]"
# source_sequence = "Hello world"

# src_tokenizer = get_or_build_tokenizer(get_config(), "/home/saul/workspace/projects/pytorch/tokenizer_en.json", "en")
# trg_tokenizer = get_or_build_tokenizer(get_config(), "/home/saul/workspace/projects/pytorch/tokenizer_es.json", "es")
# encode_input_tokens = src_tokenizer.encode(source_sequence).ids
# encode_padding_tokens = src_seq_length - len(encode_input_tokens) - 2 # -2 for sos and eos token

# encoder_input = torch.cat([
#     torch.tensor([2], dtype=torch.int64),
#     torch.tensor(src_tokenizer.encode(source_sequence).ids, dtype=torch.int64),
#     torch.tensor([3], dtype=torch.int64),
#     torch.tensor([1] * encode_padding_tokens, dtype=torch.int64)
# ])

# print(f"Source sequence: {encoder_input}")
# print(encoder_input.shape)
# out = model.translate(encoder_input)
# output = trg_tokenizer.decode(out[0].tolist())
# print(f"Output: {output}")