In [15]:
# implement the architecture from (Attention Is All You Need) https://arxiv.org/abs/1706.03762
import torch
import torch.nn as nn

## Basic architecture 

![image.png](https://cdn.fs.teachablecdn.com/ADNupMnWyR7kCWRvm76Laz/https://substackcdn.com/image/fetch/w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F12cdf506-6cd8-4afa-93a3-b77b82770309_2755x1570.png)

#### Position Embeddings

![image.gif](https://i.imgur.com/KgZCdzX.gif)

The typical way to implement the values of the embedding is by hard coding them by using a sine and cosine function of the vectors and elements’ positions

![image.png](https://cdn.fs.teachablecdn.com/ADNupMnWyR7kCWRvm76Laz/https://substackcdn.com/image/fetch/w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F58a77f49-ed6d-4614-9c64-505455bd0c83_2043x1300.png)

In [16]:
class PositionalEncoding(nn.Module):
    def __init__(self, context_size: int, d_attn: int):
        """represent positional encoding as harcoded matrix of size (context_size,d_attn)

        Args:
            context_size (int): max context size
            d_attn (int): model hidden size
        """
        super().__init__()
        self.encoding = torch.zeros(
            size=(context_size, d_attn)
        )  # placeholder matrix of the encoding , check above figures (orange matrix)
        pos = torch.arange(0, context_size).unsqueeze(
            dim=1
        )  # positions are ranged from 0 to context size (those are rows indexes in orange matrix in above figures)
        i = torch.arange(
            0, d_attn, 2
        )  # i range from 0 to d_attn in every pos (row in orange matrix)
        arg = pos / (10000 ** (2 * i / d_attn))
        self.encoding[:, 0::2] = torch.sin(arg)  # even columns (even i)
        self.encoding[:, 1::2] = torch.cos(arg)  # odd i

    def forward(self, tokens_sequence: torch.Tensor) -> torch.Tensor:
        """encode embedded tokens sequence

        Args:
            tokens_sequence (torch.Tensor):

        Returns:
            torch.Tensor: position encoded embedded tokens
        """
        return self.encoding[
            : tokens_sequence.shape[1], :
        ]  # just query the self.encoding matrix with tokens sequence

#### Encoder Block

The encoder block is composed of a multi-head attention layer, a position-wise feed-forward network, and two-layer normalization.

![img.png](https://cdn.fs.teachablecdn.com/ADNupMnWyR7kCWRvm76Laz/https://substackcdn.com/image/fetch/w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6627a678-0582-4950-a829-a8e9e4e97db9_3289x1326.png)

The attention layer allows to learn complex relationships between the hidden states, whereas the position-wise feed-forward network allows to learn complex relationships between the different elements within each vector.

![img.png](https://cdn.fs.teachablecdn.com/ADNupMnWyR7kCWRvm76Laz/https://substackcdn.com/image/fetch/w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fcf4bdbd2-8e45-4f33-9c86-35eede3571ab_3433x1050.png)

In [17]:
class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_attn, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_attn, d_ff)
        self.linear2 = nn.Linear(d_ff, d_attn)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

In [18]:
class EncoderBlock(nn.Module):
    def __init__(self, n_attention_heads: int, d_ff: int, d_attn: int):
        """init encoder

        Args:
            n_attention_heads (int): number of attention heads
            d_ff (int): dimention feed forward network
            d_attn (int): encoder hidden size
        """
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_attn, n_attention_heads)
        self.feed_forward = PositionwiseFeedForward(d_attn, d_ff)
        self.norm1 = nn.LayerNorm(d_attn)
        self.norm2 = nn.LayerNorm(d_attn)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """_summary_

        Args:
            hidden_states (torch.Tensor): hidden state tensor is elementwise addition between token embeddings and positional embeddings for the input sequence

        Returns:
            torch.Tensor: encoder projection tensor (encoder output)
        """
        out1 = (
            self.self_attn(query=hidden_states, key=hidden_states, value=hidden_states)[
                0
            ]
            + hidden_states  # apply resiudal connection
        )  # perform self attention on hidden states (note hidden state tensor is elementwise addition between token embeddings and positional embeddings)
        norm1 = self.norm1(out1)  # layer normalization
        out2 = self.feed_forward(norm1) + norm1
        out3 = self.norm2(out2)
        return out3

The encoder is just the token embedding and the position embedding followed by multiple encoder blocks.

![img.png](https://cdn.fs.teachablecdn.com/ADNupMnWyR7kCWRvm76Laz/https://substackcdn.com/image/fetch/w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc3808c9f-715e-4ab0-be11-34e16b3d8644_3540x1022.png)

In [19]:
class Encoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        context_size: int,
        n_blocks: int,
        n_heads: int,
        d_attn: int,
        d_ff: int,
    ) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_attn)
        self.pos_embedding = PositionalEncoding(context_size, d_attn)
        self.blocks = nn.ModuleList(
            [
                EncoderBlock(
                    d_attn=d_attn,
                    n_attention_heads=n_heads,
                    d_ff=d_ff,
                )
                for _ in range(n_blocks)
            ]
        )

    def forward(self, tokens_seq: torch.Tensor) -> torch.Tensor:
        embedded_tokens = self.embedding(
            tokens_seq
        )  # apply embeddings layer to tokens input sequence
        pos_embedded_tokens = self.pos_embedding(tokens_seq).cuda()
        hidden_states = embedded_tokens + pos_embedded_tokens
        for block in self.blocks:
            hidden_states = block(hidden_states)
        return hidden_states

#### Decoder Block

The decoder block is composed of a multi-head attention layer, a position-wise feed-forward network, a cross-attention layer, and three layer normalization.

![img.png](https://cdn.fs.teachablecdn.com/ADNupMnWyR7kCWRvm76Laz/https://substackcdn.com/image/fetch/w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fb0287aa3-7a69-41c4-a692-c1940e007f29_3301x1582.png)

In [20]:
class DecoderBlock(nn.Module):
    def __init__(self, d_attn: int, num_heads: int, d_ff: int) -> None:
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_attn, num_heads)
        self.norm1 = nn.LayerNorm(d_attn)
        self.cross_attn = nn.MultiheadAttention(d_attn, num_heads)
        self.norm2 = nn.LayerNorm(d_attn)
        self.feed_forward = PositionwiseFeedForward(d_attn, d_ff)
        self.norm3 = nn.LayerNorm(d_attn)

    def forward(
        self, hidden_states: torch.Tensor, encoder_output: torch.Tensor
    ) -> torch.Tensor:
        """_summary_

        Args:
            hidden_states (torch.Tensor): hidden state tensor is elementwise addition between token embeddings and positional embeddings for the output sequence

        Returns:
            torch.Tensor: _description_
        """
        out1 = (
            self.self_attn(query=hidden_states, key=hidden_states, value=hidden_states)[
                0
            ]
            + hidden_states
        )  # apply resiudal connection
        out1 = self.norm1(out1)
        # apply corss attention between out1 and encoder output
        out2 = (
            self.cross_attn(query=out1, key=encoder_output, value=encoder_output)[0]
            + out1
        )
        out2 = self.norm2(out2)
        out3 = self.feed_forward(out2) + out2
        out3 = self.norm3(out3)
        return out3

the cross-attention layer computes the attentions between the decoder's hidden states and the encoder output

![img.png](https://cdn.fs.teachablecdn.com/ADNupMnWyR7kCWRvm76Laz/https://substackcdn.com/image/fetch/w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F6fa4f653-3985-40ac-932d-3eb023be2eb0_2723x1332.png)

The decoder is just the token embedding and the position embedding followed by multiple decoder blocks and the predicting head.

![img.png](https://cdn.fs.teachablecdn.com/ADNupMnWyR7kCWRvm76Laz/https://substackcdn.com/image/fetch/w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0808f5de-f713-4a56-8750-ac1cda39b929_2753x1542.png)

The predicting head is just a linear layer that projects the last hidden states from the d_attn dimension to the size of the vocabulary. To predict, we perform an ArgMax function on the resulting probability vectors

![img.png](https://cdn.fs.teachablecdn.com/ADNupMnWyR7kCWRvm76Laz/https://substackcdn.com/image/fetch/w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fc452d478-581f-4baf-941f-0ab07a39bdb3_3386x1342.png)

In [21]:
class Decoder(nn.Module):
    def __init__(self, output_size, context_size, d_attn, d_ff, num_heads, n_blocks):
        super().__init__()
        self.embedding = nn.Embedding(output_size, d_attn)
        self.pos_embedding = PositionalEncoding(context_size, d_attn)
        self.blocks = nn.ModuleList(
            [
                DecoderBlock(
                    d_attn=d_attn,
                    num_heads=num_heads,
                    d_ff=d_ff,
                )
                for _ in range(n_blocks)
            ]
        )
        self.out = nn.Linear(d_attn, output_size)

    def forward(self, x, enc_output):
        x = self.embedding(x) + self.pos_embedding(x).cuda()
        for block in self.blocks:
            x = block(x, enc_output)
        output = self.out(x)
        return output

#### Transfomer

In [22]:
class Transformer(nn.Module):
    def __init__(
        self,
        encoder_vocab_size,
        decoder_vocab_size,
        context_size,
        d_attn,
        d_ff,
        num_heads,
        n_blocks,
    ):
        super().__init__()
        self.encoder = Encoder(
            encoder_vocab_size, context_size, n_blocks, num_heads, d_attn, d_ff
        )
        self.decoder = Decoder(
            decoder_vocab_size, context_size, d_attn, d_ff, num_heads, n_blocks
        )

    def forward(self, input_encoder, input_decoder):
        enc_output = self.encoder(input_encoder)
        output = self.decoder(input_decoder, enc_output)
        return output

In [23]:
# # test the architecture with dummy data
# SOS_token = 0
# EOS_token = 1
# PAD_token = 2
# index2words = {SOS_token: "SOS", EOS_token: "EOS", PAD_token: "PAD"}
# words = "How are you doing ? I am good and you ?"
# words_list = set(words.lower().split(" "))
# for word in words_list:
#     index2words[len(index2words)] = word

# words2index = {w: i for i, w in index2words.items()}


# def convert2tensors(sentence, max_len):
#     words_list = sentence.lower().split(" ")
#     padding = ["PAD"] * (max_len - len(words_list))
#     words_list.extend(padding)
#     indexes = [words2index[word] for word in words_list]
#     return torch.tensor(indexes, dtype=torch.long).view(1, -1)


# d_attn = 10
# VOCAB_SIZE = len(words2index)
# N_BLOCKS = 10
# D_FF = 20
# CONTEXT_SIZE = 100
# NUM_HEADS = 2
# transformer = Transformer(
#     encoder_vocab_size=VOCAB_SIZE,
#     decoder_vocab_size=VOCAB_SIZE,
#     context_size=CONTEXT_SIZE,
#     d_attn=d_attn,
#     d_ff=D_FF,
#     num_heads=NUM_HEADS,
#     n_blocks=N_BLOCKS,
# )
# input_sentence = "How are you doing ?"
# output_sentence = "I am good and"
# input_encoder = convert2tensors(input_sentence, CONTEXT_SIZE)
# input_decoder = convert2tensors(output_sentence, CONTEXT_SIZE)
# output = transformer(input_encoder, input_decoder)
# _, indexes = output.squeeze().topk(1)
# index2words[indexes[3].item()]

train the transformer architecture on machine translation task

In [24]:
from torch.utils.data import Dataset


class BilingualDataset(Dataset):

    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.seq_len = seq_len

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        self.sos_token = torch.tensor(
            [tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64
        )
        self.eos_token = torch.tensor(
            [tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64
        )
        self.pad_token = torch.tensor(
            [tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64
        )

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        src_target_pair = self.ds[idx]
        src_text = src_target_pair["translation"][self.src_lang]
        tgt_text = src_target_pair["translation"][self.tgt_lang]

        # Transform the text into tokens
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = (
            self.seq_len - len(enc_input_tokens) - 2
        )  # We will add <s> and </s>
        # We will only add <s>, and </s> only on the label
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor(
                    [self.pad_token] * enc_num_padding_tokens, dtype=torch.int64
                ),
            ],
            dim=0,
        )

        # Add only <s> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor(
                    [self.pad_token] * dec_num_padding_tokens, dtype=torch.int64
                ),
            ],
            dim=0,
        )

        # Add only </s> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor(
                    [self.pad_token] * dec_num_padding_tokens, dtype=torch.int64
                ),
            ],
            dim=0,
        )

        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "label": label,  # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text,
        }

In [25]:
from pathlib import Path
from tokenizers.pre_tokenizers import Whitespace
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from torch.utils.data import DataLoader, random_split


def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 1,
        "lr": 10**-4,
        "seq_len": 350,
        "d_model": 512,
        "n_heads": 16,
        "n_blocks": 8,
        "d_ff": 1024,
        "datasource": "opus_books",
        "lang_src": "en",
        "lang_tgt": "it",
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": "latest",
        "tokenizer_file": "tokenizers/tokenizer_{0}.json",
        "experiment_name": "runs/tmodel",
    }


def get_all_sentences(ds, lang):
    for item in ds:
        yield item["translation"][lang]


def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config["tokenizer_file"].format(lang))
    if not Path.exists(tokenizer_path):
        # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
        tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = BpeTrainer(
            special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2
        )
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer


def get_ds(config):
    # It only has the train split, so we divide it overselves
    ds_raw = load_dataset(
        f"{config['datasource']}",
        f"{config['lang_src']}-{config['lang_tgt']}",
        split="train",
    )

    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config["lang_src"])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config["lang_tgt"])

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(
        train_ds_raw,
        tokenizer_src,
        tokenizer_tgt,
        config["lang_src"],
        config["lang_tgt"],
        config["seq_len"],
    )
    val_ds = BilingualDataset(
        val_ds_raw,
        tokenizer_src,
        tokenizer_tgt,
        config["lang_src"],
        config["lang_tgt"],
        config["seq_len"],
    )

    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item["translation"][config["lang_src"]]).ids
        tgt_ids = tokenizer_tgt.encode(item["translation"][config["lang_tgt"]]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f"Max length of source sentence: {max_len_src}")
    print(f"Max length of target sentence: {max_len_tgt}")

    train_dataloader = DataLoader(
        train_ds, batch_size=config["batch_size"], shuffle=True
    )
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

In [26]:
# training params
config = get_config()
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
if device == "cuda":
    print(f"Device name: {torch.cuda.get_device_name(device.index)}")
    print(
        f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB"
    )
# Make sure the weights folder exists
Path(f"{config['datasource']}_{config['model_folder']}").mkdir(
    parents=True, exist_ok=True
)
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = Transformer(
    encoder_vocab_size=tokenizer_src.get_vocab_size(),
    decoder_vocab_size=tokenizer_tgt.get_vocab_size(),
    context_size=config["seq_len"],
    d_attn=config["d_model"],
    num_heads=config["n_heads"],
    n_blocks=config["n_blocks"],
    d_ff=config["d_ff"],
)
model.cuda().train()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], eps=1e-9)
loss_fn = nn.CrossEntropyLoss(
    ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1
).to(device)

Using device: cuda
Device name: NVIDIA GeForce GTX 1660 Ti with Max-Q Design
Device memory: 5.7974853515625 GB
Max length of source sentence: 316
Max length of target sentence: 287
Transformer(
  (encoder): Encoder(
    (embedding): Embedding(22439, 512)
    (pos_embedding): PositionalEncoding()
    (blocks): ModuleList(
      (0-7): 8 x EncoderBlock(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=1024, bias=True)
          (linear2): Linear(in_features=1024, out_features=512, bias=True)
          (relu): ReLU()
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (decoder): Decoder(
    (embedding): Embedding(30000, 512)
    (pos_embedding): PositionalEncoding(

In [27]:
# train loop
for epoch in range(config["num_epochs"]):
    for batch in train_dataloader:
        encoder_input = batch["encoder_input"].to(device)  # (b, seq_len)
        decoder_input = batch["decoder_input"].to(device)  # (B, seq_len)
        out = model.forward(input_encoder=encoder_input, input_decoder=decoder_input)
        # out is (batch_size, seq_len, vocab_size)
        # We want to compute the loss for each token in the sequence, so we reshape
        # to (batch_size * seq_len, vocab_size)
        out = out.reshape(-1, out.shape[-1])
        # batch["label"] is (batch_size, seq_len)
        # We reshape it to (batch_size * seq_len)
        labels = batch["label"].to(device).reshape(-1)
        loss = loss_fn(out, labels)  # Removed vocab_idxs, and reshaped out and labels
        print(loss.item())
        # Backpropagate the loss
        loss.backward()
        # Update the weights
        optimizer.step()
        # Zero the gradients
        optimizer.zero_grad()

10.48658561706543
9.852306365966797
9.777769088745117
9.666818618774414
9.666036605834961
9.29448127746582
9.456995010375977
9.34490966796875
9.43882942199707
9.111284255981445
9.216217041015625
8.905555725097656
8.746499061584473
9.016515731811523
9.037217140197754
8.893003463745117
8.916192054748535
8.953568458557129
8.764799118041992
9.039119720458984
8.509190559387207
8.440814971923828
8.660299301147461
8.428574562072754
8.796348571777344
8.523812294006348
8.573486328125
8.465386390686035
8.551826477050781
8.498612403869629
8.760406494140625
8.235589027404785
8.483545303344727
8.360697746276855
8.234393119812012
8.102263450622559
8.315869331359863
8.157088279724121
8.346614837646484
8.317895889282227
7.871866226196289
8.167115211486816
7.770043849945068
7.839936256408691
8.155109405517578
7.991952896118164
8.125548362731934
8.014846801757812
7.817283630371094
8.139261245727539
8.082860946655273
7.735185623168945
7.600800514221191
8.039831161499023
8.015372276306152
7.89137649536132

KeyboardInterrupt: 