In [None]:
import torch
import math
from torch import nn
import torch.nn.functional as F

# =============================================================================
# TRANSFORMER DECODER - COMPLETE IMPLEMENTATION
# =============================================================================
# The Decoder is the second half of the Transformer architecture (Encoder-Decoder).
# Its job is to take context-aware embeddings from the Encoder (source language, e.g. English)
# and sequentially predict output tokens in the target language (e.g. Kannada).
#
# Architecture flow per Decoder Layer:
#   1. Masked Multi-Head Self-Attention (prevents peeking at future tokens)
#   2. Add & Layer Normalization (residual/skip connection)
#   3. Multi-Head Cross-Attention (attends to Encoder output)
#   4. Add & Layer Normalization
#   5. Position-wise Feed-Forward Network
#   6. Add & Layer Normalization
#
# Key dimensions used throughout (with default hyperparameters):
#   batch_size = 30, sequence_length = 200, d_model = 512, num_heads = 8
#   head_dim = d_model // num_heads = 64, ffn_hidden = 2048
# =============================================================================


def scaled_dot_product(q, k, v, mask=None):
    """
    Computes Scaled Dot-Product Attention.
    
    This is the core attention mechanism: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
    
    Why scale by sqrt(d_k)?
      - Without scaling, the dot products grow large in magnitude for high-dimensional vectors,
        pushing softmax into regions with extremely small gradients. Scaling keeps values
        in a range where softmax gradients are healthy for training.
    
    Args:
        q: Query tensor  - shape: (batch, heads, seq_len, head_dim) e.g. 30 x 8 x 200 x 64
        k: Key tensor    - shape: (batch, heads, seq_len, head_dim) e.g. 30 x 8 x 200 x 64
        v: Value tensor  - shape: (batch, heads, seq_len, head_dim) e.g. 30 x 8 x 200 x 64
        mask: Optional look-ahead mask - shape: (seq_len, seq_len) e.g. 200 x 200
              Upper triangular matrix of -inf values to prevent attending to future tokens.
    
    Returns:
        values: Weighted sum of value vectors    - shape: (batch, heads, seq_len, head_dim)
        attention: Attention weight matrix        - shape: (batch, heads, seq_len, seq_len)
    """
    # Get the dimension of keys (d_k = 64) for scaling
    d_k = q.size()[-1] 
    
    # Step 1: Compute raw attention scores via dot product of Q and K^T
    # Q @ K^T gives a (seq_len x seq_len) matrix showing how much each token attends to every other
    # Divide by sqrt(d_k) to prevent vanishing gradients in softmax
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k) # 30 x 8 x 200 x 200
    print(f"scaled.size() : {scaled.size()}")
    
    # Step 2: Apply the look-ahead mask (if provided)
    # The mask is an upper triangular matrix filled with -inf.
    # Adding -inf to future positions makes their softmax output ~0,
    # effectively preventing the decoder from "cheating" by looking at future tokens.
    if mask is not None:
        print(f"-- ADDING MASK of shape {mask.size()} --") 
        scaled += mask # Broadcasting: 200x200 mask applied across batch & heads → 30 x 8 x 200 x 200
    
    # Step 3: Apply softmax to get attention weights (probabilities that sum to 1 per row)
    # dim=-1 means softmax is applied along the last dimension (key positions)
    attention = F.softmax(scaled, dim=-1) # 30 x 8 x 200 x 200
    
    # Step 4: Multiply attention weights by Value vectors to get the weighted output
    # Each token's output is a weighted combination of all Value vectors it can attend to
    values = torch.matmul(attention, v) # 30 x 8 x 200 x 64
    return values, attention


class PositionwiseFeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network (FFN).
    
    Applied independently to each position (token) in the sequence.
    Consists of two linear transformations with a ReLU activation in between:
        FFN(x) = Linear2(Dropout(ReLU(Linear1(x))))
    
    Purpose: Adds non-linearity and allows the model to learn complex transformations
    beyond what attention alone can capture. The hidden layer expands the dimension
    (512 → 2048) to give the network more capacity, then projects back (2048 → 512).
    
    This is sometimes called a "two-layer MLP" applied at each position.
    """
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)   # Expansion: 512 → 2048
        self.linear2 = nn.Linear(hidden, d_model)    # Projection: 2048 → 512
        self.relu = nn.ReLU()                         # Non-linear activation
        self.dropout = nn.Dropout(p=drop_prob)        # Regularization to prevent overfitting

    def forward(self, x):
        # Input x: (batch, seq_len, d_model) = 30 x 200 x 512
        x = self.linear1(x)  # Expand to higher dimension: 30 x 200 x 2048
        print(f"x after first linear layer: {x.size()}")
        x = self.relu(x)     # Apply ReLU: introduces non-linearity, zeroes out negatives
        print(f"x after relu layer: {x.size()}")
        x = self.dropout(x)  # Randomly zero elements during training for regularization
        print(f"x after dropout layer: {x.size()}")
        x = self.linear2(x)  # Project back to model dimension: 30 x 200 x 512
        print(f"x after 2nd linear layer: {x.size()}")
        return x  # Output: 30 x 200 x 512 (same shape as input — enables residual connection)


class LayerNormalization(nn.Module):
    """
    Layer Normalization.
    
    Normalizes activations across the feature (embedding) dimension for each token independently.
    For each token vector of 512 dimensions, it computes the mean and std, then normalizes
    to zero mean and unit variance.
    
    Formula: LayerNorm(x) = gamma * ((x - mean) / std) + beta
    
    Why Layer Norm?
      - Reduces "internal covariate shift": stabilizes the distribution of activations
        across layers, helping the model converge faster and more reliably.
      - gamma (scale) and beta (shift) are LEARNABLE parameters, allowing the model
        to undo the normalization if that's beneficial for a particular feature.
    
    Unlike Batch Normalization, Layer Norm normalizes across features (not across the batch),
    making it suitable for sequence models where batch statistics are unreliable.
    """
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape = parameters_shape  # [512] — the embedding dimension to normalize over
        self.eps = eps                             # Small constant to avoid division by zero
        self.gamma = nn.Parameter(torch.ones(parameters_shape))   # Learnable scale, initialized to 1
        self.beta = nn.Parameter(torch.zeros(parameters_shape))   # Learnable shift, initialized to 0

    def forward(self, inputs):
        # inputs: (batch, seq_len, d_model) = 30 x 200 x 512
        
        # Determine which dimensions to normalize over (last N dims matching parameters_shape)
        # For parameters_shape=[512], dims=[-1], meaning normalize across the embedding dimension
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]  # [-1]
        print(f"dims: {dims}")
        
        # Compute mean across the embedding dimension for each token
        mean = inputs.mean(dim=dims, keepdim=True)  # 30 x 200 x 1
        print(f"Mean ({mean.size()})")
        
        # Compute variance, then standard deviation (with epsilon for numerical stability)
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)  # 30 x 200 x 1
        std = (var + self.eps).sqrt()  # 30 x 200 x 1
        print(f"Standard Deviation  ({std.size()})")
        
        # Normalize: zero mean, unit variance per token embedding
        y = (inputs - mean) / std  # 30 x 200 x 512
        print(f"y: {y.size()}")
        
        # Apply learnable affine transformation: scale by gamma, shift by beta
        # This lets the model learn the optimal scale/shift for each feature
        out = self.gamma * y + self.beta  # 30 x 200 x 512
        print(f"out: {out.size()}")
        return out


class MultiHeadAttention(nn.Module):
    """
    Multi-Head Self-Attention.
    
    Instead of computing a single attention function, the model projects Q, K, V
    into 'num_heads' different subspaces (each of dimension head_dim = d_model // num_heads),
    performs attention independently in each subspace ("head"), then concatenates and
    projects the results.
    
    Why multiple heads?
      - Each head can learn to attend to different aspects of the input
        (e.g., one head might focus on syntax, another on semantics).
      - With 8 heads × 64 dims each = 512 total dims, we get diverse attention patterns
        at the same computational cost as a single 512-dim attention.
    
    In SELF-attention: Q, K, V all come from the SAME input (the decoder's own embeddings).
    A single linear layer produces all three (Q, K, V) concatenated for efficiency.
    """

    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model       # 512 — total embedding dimension
        self.num_heads = num_heads   # 8 — number of parallel attention heads
        self.head_dim = d_model // num_heads  # 64 — dimension per head
        
        # Single linear layer that produces Q, K, V concatenated (3 × 512 = 1536)
        # More efficient than 3 separate linear layers
        self.qkv_layer = nn.Linear(d_model, 3 * d_model)  # 512 → 1536
        
        # Final linear projection after concatenating all heads
        self.linear_layer = nn.Linear(d_model, d_model)    # 512 → 512
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor (batch, seq_len, d_model) — decoder embeddings
            mask: Look-ahead mask (seq_len, seq_len) — prevents attending to future tokens
        """
        batch_size, sequence_length, d_model = x.size()  # 30 x 200 x 512 
        print(f"x.size(): {x.size()}")
        
        # Step 1: Project input to Q, K, V in one shot
        qkv = self.qkv_layer(x)  # 30 x 200 x 1536
        print(f"qkv.size(): {qkv.size()}")
        
        # Step 2: Reshape to separate the heads
        # Split the last dim (1536) into num_heads (8) × 3*head_dim (192)
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)  # 30 x 200 x 8 x 192
        print(f"qkv after reshape .size(): {qkv.size()}")
        
        # Step 3: Permute so heads become the second dimension
        # This allows parallel attention computation across all heads
        qkv = qkv.permute(0, 2, 1, 3)  # 30 x 8 x 200 x 192
        print(f"qkv after permutation: {qkv.size()}")
        
        # Step 4: Split the last dimension into Q, K, V (each 64 dims)
        q, k, v = qkv.chunk(3, dim=-1)  # Each: 30 x 8 x 200 x 64
        print(f"q: {q.size()}, k:{k.size()}, v:{v.size()}")
        
        # Step 5: Compute scaled dot-product attention for all heads in parallel
        values, attention = scaled_dot_product(q, k, v, mask)  # values: 30 x 8 x 200 x 64
        print(f"values: {values.size()}, attention:{attention.size()}")
        
        # Step 6: Concatenate all heads back together
        # Reshape from (30, 8, 200, 64) → (30, 200, 512)
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim)  # 30 x 200 x 512
        print(f"values after reshaping: {values.size()}")
        
        # Step 7: Final linear projection to mix information across heads
        out = self.linear_layer(values)  # 30 x 200 x 512
        print(f"out after passing through linear layer: {out.size()}")
        return out  # 30 x 200 x 512


class MultiHeadCrossAttention(nn.Module):
    """
    Multi-Head Cross-Attention.
    
    KEY DIFFERENCE from Self-Attention:
      - In SELF-attention: Q, K, V all come from the same source (decoder input).
      - In CROSS-attention: Q comes from the DECODER, but K and V come from the ENCODER output.
    
    This is how the decoder "looks at" the source sentence (encoder output) to inform
    its predictions. The decoder asks "what parts of the English sentence should I
    attend to in order to predict the next Kannada token?"
    
    - x: Encoder output (source language context vectors)
    - y: Decoder's current state (target language being generated)
    """

    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads  # 64
        
        # K and V come from encoder output → single projection for both (2 × 512 = 1024)
        self.kv_layer = nn.Linear(d_model, 2 * d_model)  # 512 → 1024
        
        # Q comes from decoder state → separate projection
        self.q_layer = nn.Linear(d_model, d_model)        # 512 → 512
        
        # Final projection after concatenating heads
        self.linear_layer = nn.Linear(d_model, d_model)    # 512 → 512
    
    def forward(self, x, y, mask=None):
        """
        Args:
            x: Encoder output (batch, seq_len, d_model) — source of Keys and Values
            y: Decoder state  (batch, seq_len, d_model) — source of Queries
            mask: Optional mask (not typically used in cross-attention during standard training)
        """
        batch_size, sequence_length, d_model = x.size()  # 30 x 200 x 512
        print(f"x.size(): {x.size()}")
        
        # Step 1: Generate K, V from encoder output (source sentence)
        kv = self.kv_layer(x)  # 30 x 200 x 1024
        print(f"kv.size(): {kv.size()}")
        
        # Step 2: Generate Q from decoder state (target sentence being built)
        q = self.q_layer(y)  # 30 x 200 x 512
        print(f"q.size(): {q.size()}")
        
        # Step 3: Reshape for multi-head processing
        kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dim)  # 30 x 200 x 8 x 128
        q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)         # 30 x 200 x 8 x 64
        
        # Step 4: Move heads to second dimension for parallel computation
        kv = kv.permute(0, 2, 1, 3)  # 30 x 8 x 200 x 128
        q = q.permute(0, 2, 1, 3)    # 30 x 8 x 200 x 64
        
        # Step 5: Split kv into separate K and V tensors
        k, v = kv.chunk(2, dim=-1)  # K: 30 x 8 x 200 x 64, V: 30 x 8 x 200 x 64
        
        # Step 6: Compute attention — decoder queries attend to encoder keys/values
        # No mask needed here: decoder should see ALL encoder positions
        values, attention = scaled_dot_product(q, k, v, mask)  # 30 x 8 x 200 x 64
        print(f"values: {values.size()}, attention:{attention.size()}")
        
        # Step 7: Concatenate heads and project
        values = values.reshape(batch_size, sequence_length, d_model)  # 30 x 200 x 512
        out = self.linear_layer(values)  # 30 x 200 x 512
        print(f"out after passing through linear layer: {out.size()}")
        return out  # 30 x 200 x 512


class DecoderLayer(nn.Module):
    """
    Single Decoder Layer (one "block" in the decoder stack).
    
    Each decoder layer has 3 sub-layers, each followed by a residual connection + layer norm:
    
    Sub-layer 1: MASKED Self-Attention
        - The decoder attends to its own previous outputs
        - Look-ahead mask prevents attending to future positions
        - This ensures autoregressive generation (predict one token at a time)
    
    Sub-layer 2: Cross-Attention (Encoder-Decoder Attention)
        - Queries from decoder, Keys/Values from encoder output
        - This is how the decoder "reads" the source sentence
        - No mask needed — decoder can attend to all encoder positions
    
    Sub-layer 3: Position-wise Feed-Forward Network
        - Two linear layers with ReLU, applied independently per position
        - Adds non-linear transformation capacity
    
    RESIDUAL CONNECTIONS (Skip Connections):
        - After each sub-layer: output = LayerNorm(sub_layer(x) + x)
        - The "+ x" part is the residual/skip connection
        - Purpose: prevents vanishing gradients in deep networks by providing
          a direct path for gradients to flow backward through the network
        - Without these, training a 5+ layer decoder would be extremely difficult
    """

    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(DecoderLayer, self).__init__()
        # Sub-layer 1: Masked Multi-Head Self-Attention
        self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)
        
        # Sub-layer 2: Multi-Head Cross-Attention (encoder-decoder attention)
        self.encoder_decoder_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)
        
        # Sub-layer 3: Position-wise Feed-Forward Network
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm3 = LayerNormalization(parameters_shape=[d_model])
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, y, decoder_mask):
        """
        Args:
            x: Encoder output        (batch, seq_len, d_model) — 30 x 200 x 512
            y: Decoder input          (batch, seq_len, d_model) — 30 x 200 x 512
            decoder_mask: Look-ahead mask (seq_len, seq_len)    — 200 x 200
        Returns:
            y: Processed decoder output (batch, seq_len, d_model) — 30 x 200 x 512
        """
        # ---- Sub-layer 1: Masked Self-Attention + Residual + LayerNorm ----
        _y = y  # Save input for residual connection (skip connection)
        print("MASKED SELF ATTENTION")
        y = self.self_attention(y, mask=decoder_mask)  # Self-attend with causal mask
        print("DROP OUT 1")
        y = self.dropout1(y)  # Regularization
        print("ADD + LAYER NORMALIZATION 1")
        y = self.norm1(y + _y)  # Residual connection: add original input, then normalize
        # The residual connection (y + _y) allows gradients to bypass the attention layer

        # ---- Sub-layer 2: Cross-Attention + Residual + LayerNorm ----
        _y = y  # Save for next residual connection
        print("CROSS ATTENTION")
        y = self.encoder_decoder_attention(x, y, mask=None)  # Q from decoder, K/V from encoder
        # mask=None because decoder should freely attend to ALL encoder positions
        print("DROP OUT 2")
        y = self.dropout2(y)
        print("ADD + LAYER NORMALIZATION 2")
        y = self.norm2(y + _y)  # Residual + normalize

        # ---- Sub-layer 3: Feed-Forward + Residual + LayerNorm ----
        _y = y  # Save for residual connection
        print("FEED FORWARD 1")
        y = self.ffn(y)  # Position-wise FFN: expand to 2048, ReLU, project back to 512
        print("DROP OUT 3")
        y = self.dropout3(y)
        print("ADD + LAYER NORMALIZATION 3")
        y = self.norm3(y + _y)  # Residual + normalize
        
        return y  # 30 x 200 x 512 — same shape throughout, enabling layer stacking


class SequentialDecoder(nn.Sequential):
    """
    Custom Sequential container for decoder layers.
    
    PyTorch's nn.Sequential normally passes only a single tensor between layers.
    The decoder needs to pass THREE inputs (encoder output, decoder state, mask)
    through each layer. This subclass overrides forward() to handle that.
    
    Note: Only 'y' (decoder state) is updated between layers.
    'x' (encoder output) and 'mask' remain unchanged — every decoder layer
    receives the same encoder output and the same mask.
    """
    def forward(self, *inputs):
        x, y, mask = inputs
        for module in self._modules.values():
            y = module(x, y, mask)  # Each layer refines y; x and mask stay the same
        return y


class Decoder(nn.Module):
    """
    Complete Transformer Decoder.
    
    Stacks multiple DecoderLayers (default: num_layers=1, typically 6 in the original paper).
    Each layer progressively refines the decoder's representation of the target sequence
    by repeatedly:
      1. Attending to its own previous outputs (self-attention with causal mask)
      2. Attending to the encoder's output (cross-attention)
      3. Applying a feed-forward transformation
    
    The output maintains shape (batch, seq_len, d_model) throughout all layers,
    which is eventually mapped to vocabulary logits for next-token prediction.
    
    In the original "Attention Is All You Need" paper, num_layers = 6.
    Here it's configurable (set to 5 in the demo).
    """
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers=1):
        super().__init__()
        # Create a stack of num_layers identical (but independently parameterized) decoder layers
        self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob) 
                                          for _ in range(num_layers)])

    def forward(self, x, y, mask):
        # x:    Encoder output  — 30 x 200 x 512  (source language embeddings)
        # y:    Decoder input   — 30 x 200 x 512  (target language embeddings)
        # mask: Look-ahead mask — 200 x 200        (causal mask to prevent future peeking)
        y = self.layers(x, y, mask)
        return y  # 30 x 200 x 512 — ready for final linear + softmax to predict tokens

In [None]:
# =============================================================================
# HYPERPARAMETERS & DECODER DEMO
# =============================================================================

d_model = 512               # Embedding dimension: each token is represented as a 512-dim vector
                             # Higher = richer representation but more compute. 512 is from the original paper.

num_heads = 8                # Number of parallel attention heads.
                             # 512 / 8 = 64 dims per head. Each head learns different attention patterns.

drop_prob = 0.1              # Dropout probability: 10% of neurons randomly zeroed during training
                             # to prevent overfitting and improve generalization.

batch_size = 30              # Number of sentence pairs processed simultaneously.
                             # Larger batches = more stable gradients but more memory.

max_sequence_length = 200    # Maximum number of tokens per sentence.
                             # Sentences shorter than 200 are padded; longer ones are truncated.

ffn_hidden = 2048            # Hidden dimension of the feed-forward network (4x d_model).
                             # The expansion ratio (512→2048→512) gives the FFN more capacity.

num_layers = 5               # Number of stacked decoder layers.
                             # More layers = deeper network that can capture more complex patterns.
                             # Original paper uses 6; here we use 5 for demonstration.

# --- Simulated Inputs ---
# In a real model, these would come from:
#   x: Encoder output (positional-encoded English sentence embeddings passed through encoder)
#   y: Target language embeddings (positional-encoded Kannada sentence embeddings)
x = torch.randn( (batch_size, max_sequence_length, d_model) )  # Simulated encoder output (English)
y = torch.randn( (batch_size, max_sequence_length, d_model) )  # Simulated decoder input (Kannada)

# --- Look-Ahead (Causal) Mask ---
# Creates an upper triangular matrix of -inf values:
#   [[  0, -inf, -inf, ..., -inf],
#    [  0,    0, -inf, ..., -inf],
#    [  0,    0,    0, ..., -inf],
#    ...
#    [  0,    0,    0, ...,    0]]
#
# Purpose: During training, the decoder sees the ENTIRE target sentence at once.
# But at inference time, it generates one token at a time. The mask simulates this
# autoregressive behavior during training by preventing position i from attending
# to any position j > i (future tokens). Without this mask, the model would "cheat"
# by looking at the answer while trying to predict it.
mask = torch.full([max_sequence_length, max_sequence_length], float('-inf'))
mask = torch.triu(mask, diagonal=1)  # Keep upper triangle as -inf, lower triangle + diagonal as 0

# --- Build and Run the Decoder ---
decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers)
out = decoder(x, y, mask)
# out shape: 30 x 200 x 512
# This output would then be passed through a final Linear layer (512 → vocab_size)
# followed by softmax to get probability distributions over the target vocabulary.

MASKED SELF ATTENTION
x.size(): torch.Size([30, 200, 512])
qkv.size(): torch.Size([30, 200, 1536])
qkv after reshape .size(): torch.Size([30, 200, 8, 192])
qkv after permutation: torch.Size([30, 8, 200, 192])
q: torch.Size([30, 8, 200, 64]), k:torch.Size([30, 8, 200, 64]), v:torch.Size([30, 8, 200, 64])
scaled.size() : torch.Size([30, 8, 200, 200])
-- ADDING MASK of shape torch.Size([200, 200]) --
values: torch.Size([30, 8, 200, 64]), attention:torch.Size([30, 8, 200, 200])
values after reshaping: torch.Size([30, 200, 512])
out after passing through linear layer: torch.Size([30, 200, 512])
DROP OUT 1
ADD + LAYER NORMALIZATION 1
dims: [-1]
Mean (torch.Size([30, 200, 1]))
Standard Deviation  (torch.Size([30, 200, 1]))
y: torch.Size([30, 200, 512])
out: torch.Size([30, 200, 512])
CROSS ATTENTION
x.size(): torch.Size([30, 200, 512])
kv.size(): torch.Size([30, 200, 1024])
q.size(): torch.Size([30, 200, 512])
scaled.size() : torch.Size([30, 8, 200, 200])
values: torch.Size([30, 8, 200, 64]),

In [None]:
# Display the look-ahead mask matrix.
# You'll see 0s on and below the diagonal (positions the token CAN attend to)
# and -inf above the diagonal (future positions that are BLOCKED).
# Row i represents what token i can see: it can attend to tokens 0..i but not i+1..N.
mask

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])