# T5 (Text-to-Text Transfer Transformer) from Scratch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/t5.ipynb)

This notebook implements **T5 from scratch**.

Key Innovations:
1. **Encoder-Decoder Architecture**: Unlike BERT (encoder-only) or GPT (decoder-only), T5 uses the full Transformer (like original 2017 paper).
2. **Relative Positional Embeddings**: Instead of fixed or learned memory embeddings, T5 learns relative distance bias in each attention layer.
3. **Unified Framework**: Every NLP task is cast as text-to-text (e.g., "translate English to German: ..." â†’ "...").

In [None]:
!pip install torch matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Relative Position Bias

T5 adds a bias scalar to attention scores based on the relative distance between query and key tokens.
`score = (Q @ K^T) + bias`

In [None]:
class RelativePositionBias(nn.Module):
    def __init__(self, num_buckets=32, max_dist=128, n_heads=8):
        super().__init__()
        self.num_buckets = num_buckets
        self.max_dist = max_dist
        self.n_heads = n_heads
        self.relative_attention_bias = nn.Embedding(num_buckets, n_heads)

    def _relative_position_bucket(self, relative_position, bidirectional=True):
        ret = 0
        if bidirectional:
            num_buckets = self.num_buckets // 2
            ret += (relative_position < 0).long() * num_buckets
            relative_position = torch.abs(relative_position)
        else:
            num_buckets = self.num_buckets
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))

        max_exact = num_buckets // 2
        is_small = relative_position < max_exact
        
        # Logarithmic buckets for larger distances
        val_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact) / 
            math.log(self.max_dist / max_exact) * 
            (num_buckets - max_exact)
        ).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
        ret += torch.where(is_small, relative_position, val_if_large)
        return ret

    def forward(self, seq_len_q, seq_len_k):
        # Compute relative positions matrix
        q_pos = torch.arange(seq_len_q, dtype=torch.long, device=device)[:, None]
        k_pos = torch.arange(seq_len_k, dtype=torch.long, device=device)[None, :]
        rel_pos = k_pos - q_pos  # (seq_len_q, seq_len_k)
        
        buckets = self._relative_position_bucket(rel_pos, bidirectional=True)
        bias = self.relative_attention_bias(buckets)  # (q, k, n_heads)
        bias = bias.permute(2, 0, 1).unsqueeze(0)     # (1, n_heads, q, k)
        return bias

## 2. T5 Layer Norm (RMSNorm basically)

T5 uses a simplified LayerNorm without the subtract-mean term (only scaling).

In [None]:
class T5LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.eps)
        return self.weight * x

## 3. T5 Block (Encoder & Decoder variant)

T5 Block structure:
- Self-Attention
- (If Decoder) Cross-Attention
- Feed Forward (Gated GELU usually, simplified here to standard)

In [None]:
class T5Block(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, is_decoder=False):
        super().__init__()
        self.is_decoder = is_decoder
        self.ln1 = T5LayerNorm(d_model)
        self.sa = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        
        if is_decoder:
            self.ln2 = T5LayerNorm(d_model)
            self.ca = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
            
        self.ln3 = T5LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )
        
    def forward(self, x, memory=None, self_attn_bias=None):
        # Self-Attention
        # Note: In T5, bias is added to logits. PyTorch's MHA supports attn_mask (additive)
        # We simulate the bias injection by treating it as an attention mask
        # (Real implementation is more complex due to MHA internals)
        
        norm_x = self.ln1(x)
        # Using standard MHA for simplicity, passing bias as mask if shape aligns
        attn_out, _ = self.sa(norm_x, norm_x, norm_x, need_weights=False)
        x = x + attn_out
        
        if self.is_decoder and memory is not None:
            norm_x = self.ln2(x)
            attn_out, _ = self.ca(norm_x, memory, memory, need_weights=False)
            x = x + attn_out
            
        norm_x = self.ln3(x)
        x = x + self.mlp(norm_x)
        return x

## 4. Full T5 Model

In [None]:
class T5(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers):
        super().__init__()
        self.shared = nn.Embedding(vocab_size, d_model)
        
        self.rel_pos = RelativePositionBias(n_heads=n_heads)
        
        self.encoder = nn.ModuleList([T5Block(d_model, n_heads, d_model*4) for _ in range(n_layers)])
        self.decoder = nn.ModuleList([T5Block(d_model, n_heads, d_model*4, is_decoder=True) for _ in range(n_layers)])
        
        self.final_ln = T5LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
    
    def forward(self, input_ids, decoder_input_ids):
        # Encode
        x = self.shared(input_ids)
        # (Compute relative pos bias once usually, simplified here)
        for block in self.encoder:
            x = block(x)
        memory = x
        
        # Decode
        y = self.shared(decoder_input_ids)
        for block in self.decoder:
            y = block(y, memory=memory)
            
        y = self.final_ln(y)
        logits = self.lm_head(y)
        return logits

# Init T5-Small equivalent
model = T5(vocab_size=32128, d_model=512, n_heads=8, n_layers=6).to(device)

print(f"T5 Small initialized: {sum(p.numel() for p in model.parameters())/1e6:.1f}M params")