In [None]:
import torch
import torch.nn as nn
import math

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# Self-Attention
class SelfAttention(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)

    def forward(self, x):
        return self.attn(x, x, x)[0]

# Masked Self-Attention (for Decoder)
class MaskedSelfAttention(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
    
    def forward(self, x):
        seq_len = x.size(1)
        mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1).to(x.device)
        return self.attn(x, x, x, attn_mask=mask)[0]

# Cross-Attention
class CrossAttention(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
    
    def forward(self, x, memory):
        return self.attn(x, memory, memory)[0]

# Encoder Block
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048):
        super().__init__()
        self.self_attn = SelfAttention(d_model, nhead)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, src):
        src = self.norm1(src + self.self_attn(src))
        src = self.norm2(src + self.dropout(self.linear2(torch.relu(self.linear1(src)))))
        return src

# Decoder Block
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048):
        super().__init__()
        self.masked_attn = MaskedSelfAttention(d_model, nhead)
        self.cross_attn = CrossAttention(d_model, nhead)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, tgt, memory):
        tgt = self.norm1(tgt + self.masked_attn(tgt))
        tgt = self.norm2(tgt + self.cross_attn(tgt, memory))
        tgt = self.norm3(tgt + self.dropout(self.linear2(torch.relu(self.linear1(tgt)))))
        return tgt

# Full Transformer Model
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder_layers = nn.ModuleList([TransformerEncoderLayer(d_model, nhead) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([TransformerDecoderLayer(d_model, nhead) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        src = self.embedding(src)
        tgt = self.embedding(tgt)
        src = self.positional_encoding(src)
        tgt = self.positional_encoding(tgt)
        
        for layer in self.encoder_layers:
            src = layer(src)
        
        for layer in self.decoder_layers:
            tgt = layer(tgt, src)
        
        return self.fc_out(tgt)

# Model parameters
vocab_size = 30000
d_model = 512
nhead = 8
num_layers = 6

# Instantiate model
model = Transformer(vocab_size, d_model, nhead, num_layers)

# Dummy input (batch_size=32, seq_len=10)
src = torch.randint(0, vocab_size, (32, 10))  # Encoder input
tgt = torch.randint(0, vocab_size, (32, 10))  # Decoder input

output = model(src, tgt)
print("Output shape:", output.shape)  # Expected: (32, 10, 30000)


### Inference 

In [None]:
import torch

def generate_text(model, src, max_len=50, start_token=2, end_token=3):
    """
    Autoregressive text generation using the Transformer model.
    
    Args:
        model: The trained Transformer model.
        src: Input sequence (encoder input) of shape (1, seq_len).
        max_len: Maximum length of generated text.
        start_token: Token to start decoding.
        end_token: Token to stop decoding.
    
    Returns:
        Generated token sequence (tensor).
    """
    model.eval()  # Set model to evaluation mode
    device = next(model.parameters()).device  # Get model device
    
    # Ensure input is on the correct device
    src = src.to(device)
    
    # Prepare encoder input - embed and add positional encoding
    src_emb = model.embedding(src) * math.sqrt(model.embedding.embedding_dim)
    src_emb = model.positional_encoding(src_emb)
    
    # Run through encoder to get memory
    memory = src_emb
    for layer in model.encoder_layers:
        memory = layer(memory)
    
    # Initialize decoder input with start token
    tgt = torch.tensor([[start_token]], dtype=torch.long, device=device)
    
    # Generate sequence token by token
    for _ in range(max_len):
        # Embed current target sequence
        tgt_emb = model.embedding(tgt) * math.sqrt(model.embedding.embedding_dim)
        tgt_emb = model.positional_encoding(tgt_emb)
        
        # Decode step-by-step
        decoded = tgt_emb
        for layer in model.decoder_layers:
            decoded = layer(decoded, memory)
        
        # Get prediction for the next token (last position only)
        logits = model.fc_out(decoded[:, -1:])  # Shape: (1, 1, vocab_size)
        next_token_logits = logits.squeeze(1)  # Shape: (1, vocab_size)
        
        # Get most likely token
        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(1)  # Shape: (1, 1)
        
        # Append token to sequence
        tgt = torch.cat([tgt, next_token], dim=1)
        
        # Stop if end token is generated
        if next_token.item() == end_token:
            break
            
    return tgt.squeeze(0)  # Remove batch dimension