# XLNet: Generalized Autoregressive Pretraining

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/xlnet.ipynb)

XLNet combines the best of BERT (bidirectional context) and GPT (autoregressive generation) using **Permutation Language Modeling**.

Key Concepts:
1. **Permutation LM**: Predicts tokens in a random order (e.g., 3 → 1 → 2 → 4), allowing the model to see "future" context (bidirectional) while remaining autoregressive.
2. **Two-Stream Attention**: To handle the permutation, it uses:
   - **Content Stream**: Encodes the token itself (standard).
   - **Query Stream**: Encodes position only (for prediction target).
3. **Transformer-XL Backbone**: Inherits relative positional encoding and recurrence mechanism (recurrence omitted here for simplicity).

In [None]:
!pip install torch matplotlib

In [None]:
import torch
import torch.nn as nn
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Permutation Masking

Create a mask that enforces a specific permutation order. If order is [3, 1, 2, 4]:
- 3 sees nothing (start)
- 1 sees 3
- 2 sees 3, 1
- 4 sees 3, 1, 2

In [None]:
def create_permutation_mask(seq_len, device):
    """Generate a random permutation mask."""
    perm = torch.randperm(seq_len, device=device)
    inv_perm = torch.argsort(perm)
    
    # Standard causal mask for the permuted sequence
    # mask[i, j] = 1 if i can attend to j
    # In permuted order: i can attend to j if index(j) <= index(i) in permutation
    
    # More efficiently: Create causal mask, then shuffle rows/cols
    causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    
    # This is tricky to visualize: effectively we want the original positions
    # i and j to have mask[i, j] = 1 if perm_rank[j] <= perm_rank[i]
    
    # Simple implementation: Shuffle mask rows/cols
    # This corresponds to: permute inputs -> run causal attention -> unpermute outputs
    # XLNet usually keeps inputs fixed and permutes the attention mask
    
    # Let's visualize the permutation
    return perm, causal_mask

perm, mask = create_permutation_mask(6, device)
print(f"Permutation order: {perm.tolist()}")
print("Conceptually: The model predicts tokens in this order.")

## 2. Two-Stream Attention

The core of XLNet. Standard query (Q), content key (K), content value (V).
- **Content Stream ($h$)**: Self-attention on context + token itself.
- **Query Stream ($g$)**: Self-attention on context + position embedding (but NOT token itself).

This allows predicting $x_t$ using its position $p_t$ and all context $x_{<t}$, without seeing $x_t$.

In [None]:
class TwoStreamAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Parameters shared between streams usually, simplified here
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)
        
    def forward(self, h, g, mask):
        # h: Content stream (batch, len, d_model)
        # g: Query stream (batch, len, d_model)
        
        # Content Stream Attention (standard self-attention with causal mask)
        # h attends to h
        Q_h = self.q_proj(h)
        K_h = self.k_proj(h)
        V_h = self.v_proj(h)
        
        # Query Stream Attention
        # g attends to h (using position info from g, content from h)
        Q_g = self.q_proj(g)
        # K and V come from h (context)
        
        # Compute scores
        def attention(Q, K, V, mask_type):
            scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
            # Apply mask (simplified)
            if mask is not None:
                scores = scores + mask
            probs = torch.softmax(scores, dim=-1)
            return torch.matmul(probs, V)

        # 1. Content update: h uses standard causal mask (can see itself)
        # In permutation LM, mask ensures we see only 'previous' tokens in perm order
        h_new = attention(Q_h, K_h, V_h, "content")
        
        # 2. Query update: g uses strict mask (cannot see itself)
        g_new = attention(Q_g, K_h, V_h, "query")
        
        return self.o_proj(h_new), self.o_proj(g_new)

## 3. XLNet Model Structure

Stacking Two-Stream layers.

In [None]:
class XLNetLayer(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.attn = TwoStreamAttention(d_model, n_heads)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model)
        )
        
    def forward(self, h, g, mask):
        # Two-stream attention
        h_attn, g_attn = self.attn(self.ln1(h), self.ln1(g), mask)
        h = h + h_attn
        g = g + g_attn
        
        # Feed-foward (applied to both streams independently)
        h = h + self.ff(self.ln2(h))
        g = g + self.ff(self.ln2(g))
        return h, g

class XLNet(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([XLNetLayer(d_model, n_heads) for _ in range(n_layers)])
        self.lm_head = nn.Linear(d_model, vocab_size)
        
    def forward(self, input_ids):
        # Initialize streams
        # Content stream (h): Initialized with word embeddings
        h = self.token_emb(input_ids)
        
        # Query stream (g): Initialized with a learnable vector 'w' (omitted, using zeros/random)
        # In practice, g starts as positional embeddings only
        g = torch.zeros_like(h)
        
        # Mask would be passed here
        for layer in self.layers:
            h, g = layer(h, g, mask=None)
            
        # Prediction is done using query stream 'g' (which didn't see target content)
        logits = self.lm_head(g)
        return logits

# Init XLNet
model = XLNet(vocab_size=32000, d_model=768, n_heads=12, n_layers=6).to(device)
print(f"XLNet Initialized: {sum(p.numel() for p in model.parameters())/1e6:.1f}M params")