In [None]:
# tiny_transformer.py -- minimal Transformer encoder (1 layer)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class TinyPositionalEmbedding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        # trainable positional embeddings (simple, small)
        self.pos = nn.Parameter(torch.randn(max_len, d_model) * 0.02)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        batch, seq_len, _ = x.shape
        pos = self.pos[:seq_len, :].unsqueeze(0)            # (1, seq_len, d_model)
        return x + pos

In [None]:
class TinyTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=64, nhead=4, dim_feedforward=128, dropout=0.1):
        super().__init__()
        # MultiheadAttention expects (seq_len, batch, d_model) by default, we will transpose
        self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu

    def forward(self, x, src_mask=None, src_key_padding_mask=None):
        # x: (batch, seq_len, d_model)
        # 1) Self-attention (PyTorch MultiheadAttention uses seq-first)
        x_t = x.transpose(0, 1)                              # -> (seq_len, batch, d_model)
        attn_out, attn_weights = self.self_attn(
            x_t, x_t, x_t, attn_mask=src_mask,
            key_padding_mask=src_key_padding_mask
        )                                                   # attn_out: (seq_len, batch, d_model)
        attn_out = attn_out.transpose(0, 1)                 # -> (batch, seq_len, d_model)

        # 2) Add & Norm (residual)
        x = self.norm1(x + self.dropout(attn_out))

        # 3) Feed-forward
        ff = self.linear2(self.dropout(self.activation(self.linear1(x))))  # (batch, seq_len, d_model)

        # 4) Add & Norm (residual)
        x = self.norm2(x + self.dropout(ff))

        return x, attn_weights

In [None]:
class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=4, max_len=128):
        super().__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)          # token embeddings
        self.pos_embed = TinyPositionalEmbedding(max_len, d_model)  # trainable pos embeddings
        self.encoder_layer = TinyTransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, token_ids, src_key_padding_mask=None):
        # token_ids: (batch, seq_len)  -- integer tokens
        x = self.tok_embed(token_ids)                       # (batch, seq_len, d_model)
        x = self.pos_embed(x)                               # (batch, seq_len, d_model)
        x, attn_weights = self.encoder_layer(x, src_key_padding_mask=src_key_padding_mask)
        x = self.norm(x)
        return x, attn_weights