# Transformer-XL from Scratch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/transformer_xl.ipynb)

Transformer-XL addresses the **fixed context length limit** of standard Transformers.

Key Innovations:
1. **Segment-Level Recurrence**: Reuse hidden states from the previous segment as *extended memory* for the current segment (like RNNs but for tokens).
2. **Relative Positional Encoding**: Since absolute positions don't work across segments (pos 0 happens every segment), it uses relative distances ($i - j$).

This allows modeling **very long-term dependencies** beyond the training segment length.

In [None]:
!pip install torch matplotlib

In [None]:
import torch
import torch.nn as nn
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Relational Attention (Relative Position)

Standard attention: $A_{ij} = (W_q E_i + W_q U_i)^T (W_k E_j + W_k U_j)$
Transformer-XL attention disentangles content and position:

$$A_{rel} = \underbrace{E_i^T W_q^T W_k E_j}_{\text{content-content}} + \underbrace{E_i^T W_q^T W_k R_{i-j}}_{\text{content-position}} + \underbrace{u^T W_k E_j}_{\text{global-content}} + \underbrace{v^T W_k R_{i-j}}_{\text{global-position}}$$

Where $R$ is relative pos embedding, $u, v$ are learnable global biases.

In [None]:
class RelationalAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        self.qkv_net = nn.Linear(d_model, 3 * d_model, bias=False)
        self.o_net = nn.Linear(d_model, d_model, bias=False)
        
        # Global content bias (u) and global position bias (v)
        self.u = nn.Parameter(torch.Tensor(self.n_heads, self.d_head))
        self.v = nn.Parameter(torch.Tensor(self.n_heads, self.d_head))
        
        # Relative position embedding table (sinusoidal usually, learned here simpler)
        self.r_emb = nn.Embedding(512, self.d_head)  # Max relative distance

    def forward(self, x, mem=None):
        # x: [len, batch, d_model]
        # mem: [m_len, batch, d_model]
        
        qlen, bsz, _ = x.size()
        mlen = mem.size(0) if mem is not None else 0
        klen = qlen + mlen
        
        # Concatenate x with memory for Keys and Values
        cat = torch.cat([mem, x], 0) if mem is not None else x
        
        # Compute Q, K, V
        qkv = self.qkv_net(cat)
        q, k, v = qkv.chunk(3, dim=-1)
        
        # Reshape for multi-head: [len, batch, n_heads, d_head]
        # Only Query needs to cover 'x' (current segment)
        q = q[-qlen:].view(qlen, bsz, self.n_heads, self.d_head)
        k = k.view(klen, bsz, self.n_heads, self.d_head) # Keys cover history+current
        v = v.view(klen, bsz, self.n_heads, self.d_head)

        # Content-Content score: (Q + u) @ K^T
        # We add bias u to Q for content matching
        AC = torch.einsum('ibnd,jbnd->ijbn', q + self.u, k)
        
        # Content-Position score: (Q + v) @ R^T
        # We use relative positions R
        # Generate relative positions: 0, 1, ..., klen
        pos_seq = torch.arange(klen - 1, -1, -1.0, device=x.device, dtype=torch.long)
        pos_seq = pos_seq.clamp(max=511) # Clamp to embedding size
        R = self.r_emb(pos_seq)
        
        # BD = R @ (Q + v)^T, simplified as Einstein sum
        BD = torch.einsum('ibnd,jd->ijbn', q + self.v, R)
        
        scores = (AC + BD) / math.sqrt(self.d_head)
        
        # Causal Masking (standard lower triangular)
        # [i, j] valid if i >= j - mlen (i.e., query i can attend to k_j if j is in past)
        mask = torch.triu(
            torch.ones((qlen, klen), device=x.device, dtype=torch.bool), 
            diagonal=1 + mlen
        )
        scores = scores.masked_fill(mask[:, :, None, None], float('-inf'))
        
        attn = torch.softmax(scores, dim=1) # Softmax over Key dimension
        out = torch.einsum('ijbn,jbnd->ibnd', attn, v)
        out = out.contiguous().view(qlen, bsz, -1)
        
        return self.o_net(out)

## 2. Transformer-XL Model with Recurrence

Forward pass takes a `mems` list (hidden states from previous step).

In [None]:
class TransformerXLBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn = RelationalAttention(d_model, n_heads)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
    
    def forward(self, x, mem=None):
        # Pre-LN
        attn_out = self.attn(self.ln1(x), mem)
        x = x + attn_out
        x = x + self.ff(self.ln2(x))
        return x

class TransformerXL(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([TransformerXLBlock(d_model, n_heads, 4*d_model) for _ in range(n_layers)])
        self.head = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, mems=None):
        # x: [seq_len, batch]
        # mems: list of [m_len, batch, d_model] for each layer
        
        if mems is None:
            mems = [None] * len(self.layers)
            
        x = self.emb(x)
        new_mems = []
        
        for i, layer in enumerate(self.layers):
            # Store current input x to be next memory
            # Important: Detach from graph to stop gradients flowing back endlessly!
            new_mems.append(x.detach())
            
            x = layer(x, mem=mems[i])

        logits = self.head(x)
        return logits, new_mems

# Init Model
model = TransformerXL(vocab_size=10000, d_model=512, n_heads=8, n_layers=6).to(device)
print("Transformer-XL Initialized")

## 3. Visualize Recurrence

We process a long sequence in chunks, passing `mems` forward.

In [None]:
# Simulate processing 2 segments
seq_len = 20
batch = 1

# Segment 1
input1 = torch.randint(0, 10000, (seq_len, batch), device=device)
out1, mems1 = model(input1, mems=None)
print(f"Segment 1 output: {out1.shape}")
print(f"Memory size: {mems1[0].shape}")

# Segment 2 (feeding mems1)
input2 = torch.randint(0, 10000, (seq_len, batch), device=device)
out2, mems2 = model(input2, mems=mems1)
print(f"Segment 2 output: {out2.shape} (Used memory from Seg 1)")
print("Success! This demonstrates state carry-over across segments.")