References 
* https://www.datacamp.com/tutorial/building-a-transformer-with-py-torch

# Motivations
* Transformers is one of the building block for the modern LLMs.
* It has been established far back from 2017
* Understanding the foundation of LLM, might uncover how the module works and how it stores data and generelized it's weights so that it becames the current LLM.
* Then possibly unlocking new discoveries on new neural networks techniques.

# Components of transformers
## Multi attention layer
* computes attention between each pair of positions in a sequence.
* It has multiple attention heads, capturing different aspect of input sequence


In [92]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
from tqdm import tqdm_notebook


In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert dim % num_heads == 0

        self.dim = dim
        self.num_heads = num_heads

        self.dim_k = self.dim // self.num_heads  # dimensions for each head's key, query, and value
        self.lin_query = nn.Linear(self.dim, self.dim)
        self.lin_key = nn.Linear(self.dim, self.dim)
        self.lin_val = nn.Linear(self.dim, self.dim)
        self.lin_out = nn.Linear(self.dim, self.dim)
    
    def split_heads(self, x: torch.Tensor):
        """
            Split input, x, into multiple heads input
                x           Tensor (batch_size, seq_length, dim)
                return      Tensor (batch_size, num_heads, seq_length, dim_per_head)
                            where 
                                dim_per_head = dim // num_heads
            "head" might be the brain to pay attention and judge / specialized to certain number of "segments" of dimension in a sequence
        """
        batch_size, seq_length, dim = x.size()
        # say the x is in size (batch_size=1, seq_length=4, dim=4)
        # torch.tensor([
        #     [
        #         [1,1,1,1], 
        #         [2,2,2,2],
        #         [3,3,3,3],
        #         [4,4,4,4]
        #     ],
        # ])
        # then after view, it will split the dimension into (1, 4, 2, 2)
        # tensor([
        # [[[1, 1],
        #   [1, 1]],

        #  [[2, 2],
        #   [2, 2]],

        #  [[3, 3],
        #   [3, 3]],

        #  [[4, 4],
        #   [4, 4]]]])
        # then transpose will do (1, 2, 4, 2)
        # tensor([
        # [[[1, 1],
        #   [2, 2],
        #   [3, 3],
        #   [4, 4]],

        #  [[1, 1],
        #   [2, 2],
        #   [3, 3],
        #   [4, 4]]]])
        
        return x.view(
            batch_size, seq_length, self.num_heads, self.dim_k
        ).transpose(1, 2)

    def scaled_dot_product_attention(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask=None):
        """
            Q, K, V are Tensors with size (batch_size, n_heads, seq_length, dim_per_head)
        """
        # calculate attention scores
        # https://stackoverflow.com/questions/73924697/whats-the-difference-between-torch-mm-torch-matmul-and-torch-mul
        # https://pytorch.org/docs/main/generated/torch.matmul.html
        # broadcasting https://numpy.org/doc/stable/user/basics.broadcasting.html
        # this will do matrix multiplication between (seq_length, dim_per_head) x (dim_per_head, seq_length) for each n_heads in all batch
        # attn_scores = (batch_size, n_heads, dim_per_head, seq_length)
        # then attn_scores will be // by sqrt(dim) (not sure why sqrt is chosen for scaling, why not other scaling factors?)
        attn_scores = torch.matmul(
            Q,
            K.transpose(-2, -1)
        ) / math.sqrt(self.dim_k)

        # determine which attn_scores should be hidden or ignored with mask, for example:
        # torch.tensor([
        #     [0,1,1,1,0,1,1]
        # ]).masked_fill(
        #     torch.tensor([
        #         [True,True,False,True,True,True,True]
        #     ]), 
        #     33
        # )
        # tensor([[33, 33,  1, 33, 33, 33, 33]])
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        attn_probs = torch.softmax(attn_scores, dim=-1)
        # attn_scores = (batch_size, n_heads, dim_per_head, seq_length)
        # V = (batch_size, n_heads, seq_length, dim_per_head)
        # out = (batch_size, n_heads, seq_length, dim_per_head) -> following the second argument
        out = torch.matmul(attn_probs, V)
        return out

    def combine_heads(self, multi_head_attns: torch.Tensor):
        batch_size, _, seq_length, dim_k = multi_head_attns.size()
        
        # little notes about pytorch operation
        # transpose = switch the dimensions of the tensor
        # contiguous = make a new copy of the tensor and save it in new memory
        return multi_head_attns.transpose(1, 2).contiguous().view(batch_size, seq_length, self.dim)
    
    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask=None):
        """
         Q, K, V are Tensor with (batch_size, seq_length, dim)
        """
        # split will turn each tensor into multiple heads tensors
        # (batch_size, seq_length, dim) -> (batch_size, n_heads, seq_length, dim_per_head)
        # each heads will get the same sequence length, but will only train for a specific set of dimensions
        query_heads = self.split_heads(self.lin_query(Q))
        key_heads = self.split_heads(self.lin_key(K))
        val_heads = self.split_heads(self.lin_val(V))

        attn_output = self.scaled_dot_product_attention(query_heads, key_heads, val_heads, mask)

        out = self.lin_out(self.combine_heads(attn_output))
        return out




In [10]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, dim: int, dim_ff: int):
        """
            dim: hidden dimension
            dim_ff: dimension of feed forward
        """
        super(PositionWiseFeedForward, self).__init__()
        self.lin_1 = nn.Linear(dim, dim_ff)
        self.lin_2 = nn.Linear(dim_ff, dim)
        self.relu = nn.ReLU()
    
    def forward(self, x: torch.Tensor):
        y = self.lin_1(x)
        y = self.relu(y)
        y = self.lin_2(y)
        return y


In [11]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim: int, max_seq_length: int):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, dim)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, dim, 2).float() * -(math.log(10000.0) / dim)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor):
        return x + self.pe[:, :x.size(1)]


In [12]:
class EncoderLayer(nn.Module):
    def __init__(self, dim, n_heads, dim_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(dim, n_heads)
        self.feed_forward = PositionWiseFeedForward(dim, dim_ff)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor, return_attn=False):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm1(x + self.dropout(ff_output))

        if return_attn:
            return x, attn_ouput
        
        return x

In [13]:
class DecoderLayer(nn.Module):
    def __init__(self, dim: int, n_heads: int, dim_ff: int, dropout: float):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(dim, n_heads)
        self.cross_attn = MultiHeadAttention(dim, n_heads)
        self.feed_forward = PositionWiseFeedForward(dim, dim_ff)
        self.lin = nn.Linear(dim_ff, dim_ff)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, enc_output: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
        attn_output = self.self_attn(x, x, x, mask=tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, mask=src_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x


In [14]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, dim, n_heads, n_layers, dim_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, dim)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, dim)
        self.positional_encoding = PositionalEncoding(dim, max_seq_length)
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(dim, n_heads, dim_ff, dropout) for _ in range(n_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(dim, n_heads, dim_ff, dropout) for _ in range(n_layers)
        ])
        self.fc = nn.Linear(dim, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(
            self.positional_encoding(
                self.encoder_embedding(src)
            )
        )
        tgt_embedded = self.dropout(
            self.positional_encoding(
                self.decoder_embedding(tgt)
            )
        )
        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)
        
        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [16]:
from datasets import load_dataset

ds = load_dataset("spdenisov/word_aligned_translation", split="train")

In [81]:
class Tokenizer:
    def __init__(self, start_idx=100, unk_token="[UNK]", pad_token="[PAD]", pad_token_id=10, max_seq_length=100):
        self.start_idx = start_idx
        self.unk_token = unk_token
        self.pad_token = pad_token
        self.pad_token_id = pad_token_id
        self.max_seq_length = max_seq_length
        self.word_to_id = {self.unk_token: start_idx}
        self.id_to_word = {k: self.unk_token for k in range(start_idx)}
        self.word_to_id[self.pad_token] = pad_token_id
        self.id_to_word[self.pad_token_id] = self.pad_token

    def train_tokenize(self, sentence):
        splitted = sentence.split(" ")
        res = []
        for s in splitted:
            if s not in self.word_to_id:
                max_id = max(self.word_to_id.values()) 
                cur_id = max_id + 1
                self.word_to_id[s] = cur_id
                self.id_to_word[cur_id] = s
            
            res.append(self.word_to_id[s]) 
        return res

    def tokenize(self, sentence):
        splitted = sentence.split(" ")
        res = []
        for s in splitted:
            if s not in self.word_to_id:
                res.append(self.word_to_id[self.unk_token])    
                continue
            res.append(self.word_to_id[s]) 
        if len(res) < self.max_seq_length:
            diff = self.max_seq_length - len(res)
            for _ in range(diff):
                res.append(self.word_to_id[self.pad_token])
        if len(res) > self.max_seq_length:
            res = res[:self.max_seq_length]
        return res
        

In [82]:
en_tokenizer = Tokenizer()
fr_tokenizer = Tokenizer()

In [83]:
en_to_fr = ds.filter(
    lambda x: x["target_language"] == "French" and x["source_language"] == "English")
en_to_fr = en_to_fr.remove_columns(["target_language", "source_language"])
en_to_fr = en_to_fr.map(
    lambda x: {
        "source_sentence": " ".join(["_".join(w.split(" ")[1:]) for w in x["source_words"]]),
        "target_sentence": " ".join(["_".join(w.split(" ")[1:]) for w in x["target_lines"]])
    }).remove_columns(["source_words", "target_lines"])

In [84]:
for x in en_to_fr:
    en_tokenizer.train_tokenize(x["source_sentence"])
    fr_tokenizer.train_tokenize(x["target_sentence"])

In [85]:
en_to_fr = en_to_fr.map(
    lambda x: {
        "tokenized_source": en_tokenizer.tokenize(x["source_sentence"]),
        "tokenized_target": fr_tokenizer.tokenize(x["target_sentence"])
    }
)

Map:   0%|          | 0/321 [00:00<?, ? examples/s]

In [86]:
splitted = en_to_fr.train_test_split(test_size=0.3)

In [87]:
splitted["train"].set_format(type="torch")

In [88]:
splitted["train"]

Dataset({
    features: ['source_sentence', 'target_sentence', 'tokenized_source', 'tokenized_target'],
    num_rows: 224
})

In [89]:
train_data = splitted["train"].remove_columns(["source_sentence", "target_sentence"])

In [90]:
train_dataloader = torch.utils.data.DataLoader(
    train_data,
    batch_size=8
)

In [108]:
params = dict(
    src_vocab_size = max(en_tokenizer.word_to_id.values()) + 100,
    tgt_vocab_size = max(fr_tokenizer.word_to_id.values()) + 100,
    dim = 256,
    n_heads = 32,
    n_layers = 1,
    dim_ff = 256,
    max_seq_length = 100,
    dropout = 0.3
)

t = Transformer(**params)

In [110]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(t.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

t.train()

epoch_losses = []
pbar = tqdm_notebook(range(100))
for epoch in pbar:
    for batch in train_dataloader:
        optimizer.zero_grad()
        src_data = batch["tokenized_source"]
        tgt_data = batch["tokenized_target"]
        output = t(src_data, tgt_data[:, :-1])
        loss = criterion(output.contiguous().view(-1, params["tgt_vocab_size"]), tgt_data[:, 1:].contiguous().view(-1))
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
        pbar.set_postfix({"loss": epoch_losses[-1]})
        # print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  pbar = tqdm_notebook(range(100))


  0%|          | 0/100 [00:00<?, ?it/s]

In [149]:
pbar = tqdm_notebook(range(100))
for epoch in pbar:
    for batch in train_dataloader:
        optimizer.zero_grad()
        src_data = batch["tokenized_source"]
        tgt_data = batch["tokenized_target"]
        output = t(src_data, tgt_data[:, :-1])
        loss = criterion(output.contiguous().view(-1, params["tgt_vocab_size"]), tgt_data[:, 1:].contiguous().view(-1))
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
        pbar.set_postfix({"loss": epoch_losses[-1]})


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  pbar = tqdm_notebook(range(100))


  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [144]:
def decode(tokenizer, ten):
    res = torch.argmax(ten, dim=-1).squeeze(0).detach().cpu().numpy()
    result = []
    for a in res:
        if a in tokenizer.id_to_word:
            result.append(tokenizer.id_to_word[a])
            continue
        result.append(tokenizer.id_to_word[tokenizer.unk_token])
    return " ".join(result)

In [158]:
t.eval()

# Generate random sample validation data
# val_src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
# val_tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

with torch.no_grad():
    sentence = "You can filter fields in the row column , and detail areas ."
    tokenized = en_tokenizer.tokenize(sentence)
    tokenized_tensor = torch.tensor([tokenized])
    target_tensor = torch.ones(tokenized_tensor.size(), dtype=torch.long) * 12
    
    out = t(tokenized_tensor, target_tensor)
    print(decode(fr_tokenizer, out))
    # val_output = t(val_src_data, val_tgt_data[:, :-1])
    # val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data[:, 1:].contiguous().view(-1))
    # print(f"Validation Loss: {val_loss.item()}")

volonté volonté volonté volonté volonté et et et , , , , et . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] . . . . . .


In [147]:
en_to_fr[0]

{'source_sentence': 'You can filter fields in the row , column , and detail areas .',
 'target_sentence': 'Vous peut filtre les_champs dans le/la/les rangée , colonnes , et détail zones .',
 'tokenized_source': [101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  108,
  110,
  111,
  112,
  113,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10],
 'tokenized_target': [101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  108,
  110,
  111,
  112,
  113,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  10,
  1