In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Synthetic Input Generation
input_tensor = torch.randint(0, 10000, (2, 8))  # (batch_size, seq_len)

# 2. Model 
class MultiHeadAttention(nn.Module):
    """Scaled dot-product attention with causal masking"""
    def __init__(self, d_model=512, n_heads=8):
        super().__init__()
        self.d_head = d_model // n_heads
        self.n_heads = n_heads
        self.qkv_proj = nn.Linear(d_model, 3*d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask):
        B, T, C = x.size()
        q, k, v = self.qkv_proj(x).split(C, dim=2)
        q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * (1.0 / (self.d_head ** 0.5))
        attn = attn.masked_fill(mask, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        x = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(x)

class FeedForward(nn.Module):
    """Position-wise FFN with GELU activation"""
    def __init__(self, d_model=512, d_ff=2048, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    """Single decoder layer with pre-norm architecture"""
    def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.attn_norm = nn.LayerNorm(d_model)
        self.ffn_norm = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        x = x + self.dropout(self.attn(self.attn_norm(x), mask))
        x = x + self.dropout(self.ffn(self.ffn_norm(x)))
        return x

class Transformer(nn.Module):
    """Full decoder-only Transformer"""
    def __init__(self, vocab_size=10000, d_model=512, n_heads=8,
                 d_ff=2048, num_layers=6, max_seq_len=512, dropout=0.1):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.layers = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff, dropout)
                                   for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def forward(self, x):
        B, T = x.size()
        pos = torch.arange(T, device=x.device).unsqueeze(0)
        x = self.tok_emb(x) + self.pos_emb(pos)
        mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=x.device), diagonal=1)
        mask = mask.view(1, 1, T, T)  # (1, 1, T, T)
        for layer in self.layers:
            x = layer(x, mask)
        return self.out(self.norm(x))

# 3. RUN
model = Transformer(vocab_size=10000)
output = model(input_tensor)
print(f"Output shape: {output.shape}")  

Output shape: torch.Size([2, 8, 10000])
