<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://chatgpt.com/c/67ecc4e3-2488-800e-af5d-1738e22c9257

# https://gemini.google.com/app/d25e17f514b6a9d4
# how is nn.MultiheadAttention implemented

# my practice/experiments
# https://colab.research.google.com/drive/1AmG5aDiyk0OrxnVcXey6Nd9ZxQzJuTBw#scrollTo=f_yJLgxxTXf5

In [None]:
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig

class TransformerDecoderConfig(PretrainedConfig):
    def __init__(self, vocab_size=50257, max_position_embeddings=1024,
                 d_model=512, num_heads=8, num_layers=12, d_ff=2048,
                 dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.d_ff = d_ff
        self.dropout = dropout

class TransformerDecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        # note layer norm is applied after self attention/FF
        # word embedding is normalized already
        # you can also set batch_first=True to avoid transpose
        self.self_attn = nn.MultiheadAttention(config.d_model, config.num_heads, dropout=config.dropout)
        self.norm1 = nn.LayerNorm(config.d_model)
        self.ff = nn.Sequential(
            nn.Linear(config.d_model, config.d_ff),
            nn.ReLU(),
            nn.Linear(config.d_ff, config.d_model)
        )
        self.norm2 = nn.LayerNorm(config.d_model)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, attn_mask):
        attn_out, _ = self.self_attn(x, x, x, attn_mask=attn_mask)
        x = self.norm1(x + self.dropout(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x

class TransformerDecoderModel(PreTrainedModel):
    config_class = TransformerDecoderConfig

    def __init__(self, config):
        super().__init__(config)
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.positional_encoding = nn.Parameter(torch.zeros(1, config.max_position_embeddings, config.d_model))
        self.layers = nn.ModuleList([TransformerDecoderBlock(config) for _ in range(config.num_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)

    def forward(self, input_ids):
        seq_len = input_ids.size(1)
        x = self.embedding(input_ids) + self.positional_encoding[:, :seq_len, :]
        x = x.transpose(0, 1)  # (seq_len, batch, dim)

        mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))  # float mask

        for layer in self.layers:
            x = layer(x, mask)

        x = self.ln_f(x).transpose(0, 1)  # (batch, seq_len, dim)
        return self.head(x)

# Example usage:
config = TransformerDecoderConfig()
model = TransformerDecoderModel(config)
input_ids = torch.randint(0, config.vocab_size, (1, 50))  # Single batch, 50 tokens
output = model(input_ids)
print(output.shape)  # Expected: (1, 50, vocab_size)

torch.Size([1, 50, 50257])
