<a href="https://colab.research.google.com/github/Erickrus/llm/blob/main/transformer_from_scratch(pytorch).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#TRANSFORMERS FROM SCRATCH


This notebook is based on a youtube video:

Pytorch Transformers from Scratch (Attention is all you need)

url: https://www.youtube.com/watch?v=U0s0f995w14&t=2738s

<img src="https://i.ytimg.com/vi/U0s0f995w14/hqdefault.jpg?sqp=-oaymwEcCNACELwBSFXyq4qpAw4IARUAAIhCGAFwAcABBg==&rs=AOn4CLAOaG0P92UCvlf9IUtQWXB6yVq9lA" width=250px/>

Primary source:

url: https://peterbloem.nl/blog/transformers

Attention Is All You Need paper

url: https://arxiv.org/pdf/1706.03762.pdf


https://github.com/Mooler0410/LLMsPracticalGuide?tab=readme-ov-file

https://arxiv.org/pdf/2304.13712.pdf

- BERT-style Language Models: Encoder-Decoder or Encoder-only

- GPT-style Language Models: Decoder-only

<img src="https://pbs.twimg.com/media/Fuw9fv9akAA_h0q?format=jpg&name=large" />

<img src="https://miro.medium.com/v2/resize:fit:856/format:webp/1*ZCFSvkKtppgew3cc7BIaug.png" width="400px" />

For the moment, we will just focus on the encoder part, which is the left part. Multi-Head Attention is the most important of the transformer. If we understand this, we can understand the essential of the whole thing.

<img src="https://images.squarespace-cdn.com/content/v1/58da330debbd1a5419611082/1555580207391-2ECHQBG0O8MSP4IU4XZ8/ashish+vaswani+headshot.jpg" width=300px />

Ashish Vaswani

Notice: BERT is encoder-only model. It is different from the orignal paper "Attention Is All You Need". The orignal paper is a sequence to sequence model, the model is about the solve end-to-end issue (for translation task).

Attention Is All You Need
- url: https://arxiv.org/pdf/1706.03762.pdf
- author: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, Illia Polosukhin
- source: A third party implementation of the translation task https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py


BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
- url: https://arxiv.org/pdf/1810.04805.pdf
- author: Jacob Devlin Ming-Wei Chang Kenton Lee Kristina Toutanova
- source: https://github.com/google-research/bert/blob/master/modeling.py

BERT source code only have "encoder" definition. And the rest part is only a transformer (with or without masked)

In [None]:
#@title Mermaid Ink Setup
#https://github.com/jihchi/mermaid.ink

import base64
import io
import requests
import IPython

def mermaid(graph, scale=1.0):
  mermaidInkUrl = 'https://mermaid.ink/svg/'
  display(
    IPython.display.HTML(
      io.BytesIO( # generate svg on the fly
        requests.get(
          mermaidInkUrl +
          base64.b64encode(
            graph.encode('ascii')
          ).decode("ascii")
        ).content
      ).read().decode('utf-8').replace("<svg ", f"<svg transform=\"scale({scale})\" ")
    )
  )


In [None]:
#@title Encoder/Decoder
mermaid('''
%%{init: {'theme': 'base', 'themeVariables': { 'primaryColor': '#f3f3f3', 'primaryBorderColor':'#000000'}}}%%

graph BT;

style Inputs fill:#ffffff,stroke:#ffffff,stroke-width:2px;
style Outputs fill:#ffffff,stroke:#ffffff,stroke-width:2px;
style A1 fill:#ffffff
style A2 fill:#ffffff
style Input_Embedding fill:#ffdfe0
style Output_Embedding fill:#ffdfe0
style note fill:#f3f3f3,stroke:#f3f3f3,stroke-width:2px;
style note2 fill:#f3f3f3,stroke:#f3f3f3,stroke-width:2px;
style Positional_Encoding fill:#ffffff,stroke:#ffffff,stroke-width:2px;
style Positional_Encoding2 fill:#ffffff,stroke:#ffffff,stroke-width:2px;

style Multi_Head_Attention fill:#ffe3b7
style Multi_Head_Attention2 fill:#ffe3b7
style Masked_Multi_Head_Attention fill:#ffe3b7
style Add_Norm1 fill:#eff9b3
style Add_Norm2 fill:#eff9b3
style Add_Norm3 fill:#eff9b3
style Add_Norm4 fill:#eff9b3
style Add_Norm5 fill:#eff9b3
style Feed_Forward fill:#bae9f9
style Feed_Forward2 fill:#bae9f9

style Encoder fill:#f3f3f3,stroke:#000000,stroke-width:2px,rounded;
style Decoder fill:#f3f3f3,stroke:#000000,stroke-width:2px,rounded;

style Linear fill:#dadfed
style Softmax fill:#c7ebcd
style Output_Probablities fill:#ffffff,stroke:#ffffff,stroke-width:2px;

Inputs --> Input_Embedding("Input\nEmbedding");
Positional_Encoding("Positional\nEncoding") --> A1("+");
Input_Embedding --> A1;
A1 --> |K| Multi_Head_Attention("Multi-Head\nAttention");
A1 --> |V| Multi_Head_Attention("Multi-Head\nAttention");
A1 --> |Q| Multi_Head_Attention("Multi-Head\nAttention");
A1 --> Add_Norm1("Add & Norm");

subgraph Encoder
  Multi_Head_Attention --> Add_Norm1;
  Add_Norm1 --> Feed_Forward("Feed\nForward");
  Add_Norm1 --> Add_Norm2("Add & Norm");
  Feed_Forward --> Add_Norm2;
  note("Nx");

end

Outputs("Outputs\n(shifted right)") --> Output_Embedding("Output\nEmbedding");
Output_Embedding --> A2("+");
Positional_Encoding2("Positional\nEncoding") --> A2("+");
A2 --> |K| Masked_Multi_Head_Attention("Masked\nMulti-Head\nAttention");
A2 --> |V| Masked_Multi_Head_Attention("Masked\nMulti-Head\nAttention");
A2 --> |Q| Masked_Multi_Head_Attention("Masked\nMulti-Head\nAttention");
A2 --> Add_Norm3("Add & Norm");
Add_Norm2 --> |K| Multi_Head_Attention2;
Add_Norm2 --> |V| Multi_Head_Attention2;

subgraph Decoder
  Masked_Multi_Head_Attention --> Add_Norm3;
  Add_Norm3 --> |Q| Multi_Head_Attention2("Multi-Head\nAttention");
  Multi_Head_Attention2 --> Add_Norm4("Add & Norm");
  Add_Norm3 --> Add_Norm4;
  Add_Norm4 --> Feed_Forward2("Feed\nForward");

  Feed_Forward2-->Add_Norm5("Add & Norm")
  Add_Norm4 --> Add_Norm5;
  note2("Nx");

end

Add_Norm5 --> Linear("Linear");
Linear --> Softmax("Softmax");
Softmax --> Output_Probablities("Output\nProbablities")
''', 0.75)

K, Q and V
- the building block: Multi-Head Attention and Masked Multi-Head Attention both have 3 inputs. In my mermaid graph above, it can only draw 1 arrow. In fact it should have 3 arrows, which stand for  K, Q and V.
- in decoder, the Multi-Head Attention takes K, V from encoder, and Q from the decoder's Masked Multi-Head Attention.

Skip connections
- line from "+" directly to "Add & Norm"
- line from "Add & Norm" directly to the above "Add & Norm"

Nx
  - the encoder and decoder will be repeated N times
  - the encoder will be sending output to upper encoder
  - the final encoder output will be sent to the decoder


<img src="https://machinelearningmastery.com/wp-content/uploads/2022/03/dotproduct_1.png" width=400px/>

In [None]:
#@title Encoder/Decoder
mermaid('''
%%{init: {'theme': 'base', 'themeVariables': { 'primaryColor': '#f3f3f3', 'primaryBorderColor':'#000000'}}}%%

graph BT;

style Q fill:#ffffff,stroke:#ffffff,stroke-width:2px;
style K fill:#ffffff,stroke:#ffffff,stroke-width:2px;
style V fill:#ffffff,stroke:#ffffff,stroke-width:2px;



V --> MatMul2("MatMul")
K --> MatMul("MatMul")
Q --> MatMul("MatMul")


MatMul --> Scale
Scale --> Mask("Mask (opt.)")
Mask --> SoftMax("SoftMax")
SoftMax --> MatMul2


''', 0.75)


<img src="https://i.stack.imgur.com/1JdN6.png" width=400px />

In [None]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, queries, mask):
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy = shape: (N, heads, query_len, key_len)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # dim=3 means normalize across key_len
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # key_len == value_len: l
        # attention shape: (N, shape, query_le, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
      attention = self.attention(value, key, query, mask)

      # skip connection is: attention + query
      x = self.dropout(self.norm1(attention + query))
      feed_forward = self.feed_forward(x)
      out = self.dropout(self.norm2(feed_forward + x))

      return out

class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList([
            TransformerBlock(embed_size, heads, dropout=dropout, forward_expansion=forward_expansion)
            for _ in range(num_layers)
        ])

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out

class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList([
            DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
            for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out

class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size=256,
        num_layers=6,
        forward_expansions=4,
        heads=8,
        dropout=0,
        device='cpu',
        max_length=100
    ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansions,
            dropout,
            max_length
        )
        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansions,
            dropout,
            device,
            max_length
        )


        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, src_len)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

x = torch.tensor([[1,5,6,4,3,9,5,2,0], [1,8,7,3,4,5,6,7,2]])
trg = torch.tensor([[1,7,4,3,5,9,2,0], [1,5,6,2,4,7,6,2]])
src_pad_idx = 0
trg_pad_idx = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device)

out = model(x, trg[:, :-1])
print(out.shape)


torch.Size([2, 7, 10])
