<a href="https://colab.research.google.com/github/SAIROHITHARETI/ATTENTION_IS_ALL_YOU_NEED/blob/main/Attention_is_all_you_need.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math, random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# -------------------
# Hyperparameters
# -------------------
DEVICE      = "cpu"
SEQ_LEN     = 5          # length of input sequence
MAX_NUM     = 20         # numbers are in [1..MAX_NUM]
D_MODEL     = 64
N_HEAD      = 4
D_FF        = 128
N_LAYERS    = 2
BATCH_SIZE  = 64
TRAIN_STEPS = 1500
LR          = 3e-4

BOS_ID      = 0          # only special token
VOCAB_SIZE  = MAX_NUM + 1  # 0 = BOS, 1..MAX_NUM = numbers


# -------------------
# Data: sort numbers
# -------------------
class SortDataset(Dataset):
    def __init__(self, size=10000):
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        nums = [random.randint(1, MAX_NUM) for _ in range(SEQ_LEN)]
        sorted_nums = sorted(nums)

        src = torch.tensor(nums, dtype=torch.long)          # encoder input
        tgt = torch.tensor(sorted_nums, dtype=torch.long)   # decoder target
        tgt_in = torch.cat([torch.tensor([BOS_ID]), tgt[:-1]])  # shifted, starts with BOS
        return src, tgt_in, tgt


# -------------------
# Positional encoding
# -------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))   # (1, max_len, d_model)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


# -------------------
# Multi-Head Attention
# -------------------
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        assert d_model % n_head == 0
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_model // n_head

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, attn_mask=None):
        B, T_q, _ = q.shape
        B, T_k, _ = k.shape

        q = self.W_q(q).view(B, T_q, self.n_head, self.d_head).transpose(1, 2)
        k = self.W_k(k).view(B, T_k, self.n_head, self.d_head).transpose(1, 2)
        v = self.W_v(v).view(B, T_k, self.n_head, self.d_head).transpose(1, 2)

        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)  # (B, h, T_q, T_k)
        if attn_mask is not None:
            scores = scores + attn_mask   # mask has -inf on forbidden positions
        attn = scores.softmax(dim=-1)
        context = attn @ v                # (B, h, T_q, d_head)
        context = context.transpose(1, 2).contiguous().view(B, T_q, self.d_model)
        return self.W_o(context)


# -------------------
# Encoder / Decoder
# -------------------
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_head, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.norm1(x + self.self_attn(x, x, x, attn_mask=None))
        x = self.norm2(x + self.ff(x))
        return x


class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head)
        self.cross_attn = MultiHeadAttention(d_model, n_head)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, memory, tgt_mask):
        x = self.norm1(x + self.self_attn(x, x, x, attn_mask=tgt_mask))
        x = self.norm2(x + self.cross_attn(x, memory, memory, attn_mask=None))
        x = self.norm3(x + self.ff(x))
        return x


def subsequent_mask(size):
    # (1, 1, size, size) with -inf above diagonal
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    mask = mask.masked_fill(mask == 1, float("-inf"))
    return mask.unsqueeze(0).unsqueeze(0)  # (1,1,T,T)


# -------------------
# Full Transformer
# -------------------
class TransformerSorter(nn.Module):
    def __init__(self):
        super().__init__()
        self.src_emb = nn.Embedding(VOCAB_SIZE, D_MODEL)
        self.tgt_emb = nn.Embedding(VOCAB_SIZE, D_MODEL)
        self.pos_enc = PositionalEncoding(D_MODEL, max_len=SEQ_LEN)
        self.enc_layers = nn.ModuleList(
            [EncoderLayer(D_MODEL, N_HEAD, D_FF) for _ in range(N_LAYERS)]
        )
        self.dec_layers = nn.ModuleList(
            [DecoderLayer(D_MODEL, N_HEAD, D_FF) for _ in range(N_LAYERS)]
        )
        self.out = nn.Linear(D_MODEL, VOCAB_SIZE)

    def encode(self, src):
        x = self.pos_enc(self.src_emb(src))
        for layer in self.enc_layers:
            x = layer(x)
        return x

    def decode(self, tgt, memory):
        x = self.pos_enc(self.tgt_emb(tgt))
        mask = subsequent_mask(tgt.size(1)).to(tgt.device)
        for layer in self.dec_layers:
            x = layer(x, memory, tgt_mask=mask)
        return self.out(x)

    def forward(self, src, tgt_in):
        memory = self.encode(src)
        logits = self.decode(tgt_in, memory)
        return logits


# -------------------
# Training + Demo
# -------------------
def train_model():
    ds = SortDataset()
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)

    model = TransformerSorter().to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss()

    step = 0
    running = 0.0
    model.train()

    while step < TRAIN_STEPS:
        for src, tgt_in, tgt in loader:
            src, tgt_in, tgt = src.to(DEVICE), tgt_in.to(DEVICE), tgt.to(DEVICE)
            logits = model(src, tgt_in)  # (B, T, V)
            loss = loss_fn(
                logits.view(-1, VOCAB_SIZE),
                tgt.view(-1),
            )

            opt.zero_grad()
            loss.backward()
            opt.step()

            running += loss.item()
            step += 1
            if step % 200 == 0:
                print(f"step {step}/{TRAIN_STEPS}, loss={running/200:.4f}")
                running = 0.0
            if step >= TRAIN_STEPS:
                break

    return model


@torch.no_grad()
def greedy_sort(model, nums):
    model.eval()
    src = torch.tensor(nums, dtype=torch.long, device=DEVICE).unsqueeze(0)  # (1, T)
    memory = model.encode(src)

    # start decoder with BOS
    tgt = torch.tensor([[BOS_ID]], dtype=torch.long, device=DEVICE)
    for _ in range(SEQ_LEN):
        logits = model.decode(tgt, memory)      # (1, t, V)
        next_tok = logits[:, -1, :].argmax(-1)  # (1,)
        tgt = torch.cat([tgt, next_tok.unsqueeze(1)], dim=1)

    # drop BOS, take first SEQ_LEN tokens
    pred = tgt[0, 1:SEQ_LEN+1].tolist()
    return pred


In [None]:
if __name__ == "__main__":
    model = train_model()
    for _ in range(5):
        nums = [random.randint(1, MAX_NUM) for _ in range(SEQ_LEN)]
        true_sorted = sorted(nums)
        pred_sorted = greedy_sort(model, nums)
        print("\nInput:       ", nums)
        print("True sorted: ", true_sorted)
        print("Model sorted:", pred_sorted)

step 200/1500, loss=1.1932
step 400/1500, loss=0.0749
step 600/1500, loss=0.0244
step 800/1500, loss=0.0153
step 1000/1500, loss=0.0100
step 1200/1500, loss=0.0145
step 1400/1500, loss=0.0037

Input:        [14, 9, 10, 9, 4]
True sorted:  [4, 9, 9, 10, 14]
Model sorted: [4, 9, 9, 10, 14]

Input:        [11, 3, 11, 6, 2]
True sorted:  [2, 3, 6, 11, 11]
Model sorted: [2, 3, 6, 11, 11]

Input:        [11, 11, 4, 15, 6]
True sorted:  [4, 6, 11, 11, 15]
Model sorted: [4, 6, 11, 11, 15]

Input:        [7, 20, 2, 13, 4]
True sorted:  [2, 4, 7, 13, 20]
Model sorted: [2, 4, 7, 13, 20]

Input:        [9, 14, 17, 7, 13]
True sorted:  [7, 9, 13, 14, 17]
Model sorted: [7, 9, 13, 14, 17]
