<h1>Attention Mechanisms</h2>

In [1]:
import math
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    query: (..., Tq, d_k)
    key:   (..., Tk, d_k)
    value: (..., Tk, d_v)
    mask:  broadcastable to (..., Tq, Tk), with True=keep, False=mask-out
    """
    d_k = query.size(-1)

    # scores: (..., Tq, Tk)
    scores = torch.matmul(query, key.transpose(-2, -1)) * (1.0 / math.sqrt(d_k))

    if mask is not None:
        # mask should be bool with True = allowed positions
        scores = scores.masked_fill(~mask, torch.finfo(scores.dtype).min)

    attn = F.softmax(scores, dim=-1)
    output = torch.matmul(attn, value)
    return output, attn


In [2]:
# -------------------------
# 1) Self-attention example
# -------------------------
B, T, d = 2, 4, 8
q = torch.randn(B, T, d)
k = torch.randn(B, T, d)
v = torch.randn(B, T, d)

out, attn = scaled_dot_product_attention(q, k, v)
print("Self-attn out:", out.shape)     # (B, T, d)
print("Self-attn attn:", attn.shape)   # (B, T, T)



Self-attn out: torch.Size([2, 4, 8])
Self-attn attn: torch.Size([2, 4, 4])


In [3]:
# -----------------------------------
# 2) Causal self-attn (decoder/GPT)
# -----------------------------------
T = q.size(1)
causal_mask = torch.tril(torch.ones(T, T, dtype=torch.bool))  # (T, T) broadcastable
out_causal, attn_causal = scaled_dot_product_attention(q, k, v, mask=causal_mask)
print("Causal out:", out_causal.shape)        # (B, T, d)
print("Causal attn:", attn_causal.shape)      # (B, T, T)



Causal out: torch.Size([2, 4, 8])
Causal attn: torch.Size([2, 4, 4])


In [4]:
# -----------------------------------
# 3) Cross-attention example (enc-dec)
#    decoder queries attend over encoder keys/values
# -----------------------------------
B, T_dec, T_enc, d = 2, 3, 5, 8
q_dec = torch.randn(B, T_dec, d)   # decoder states as queries
k_enc = torch.randn(B, T_enc, d)   # encoder outputs as keys
v_enc = torch.randn(B, T_enc, d)   # encoder outputs as values

out_xattn, attn_xattn = scaled_dot_product_attention(q_dec, k_enc, v_enc)
print("Cross-attn out:", out_xattn.shape)     # (B, T_dec, d)
print("Cross-attn attn:", attn_xattn.shape)   # (B, T_dec, T_enc)



Cross-attn out: torch.Size([2, 3, 8])
Cross-attn attn: torch.Size([2, 3, 5])


In [5]:
# -----------------------------------
# 4) Padding mask example (variable-length sequences)
#    Suppose encoder has padding on the right.
# -----------------------------------
lengths = torch.tensor([5, 3])  # batch: first has 5 valid, second has 3 valid
B, T_enc, d = 2, 5, 8
q_dec = torch.randn(B, 2, d)
k_enc = torch.randn(B, T_enc, d)
v_enc = torch.randn(B, T_enc, d)

# mask shape (B, 1, T_enc) -> broadcast to (B, T_dec, T_enc)
pad_mask = torch.arange(T_enc).unsqueeze(0) < lengths.unsqueeze(1)  # (B, T_enc) True=valid
pad_mask = pad_mask.unsqueeze(1)  # (B, 1, T_enc)

out_pad, attn_pad = scaled_dot_product_attention(q_dec, k_enc, v_enc, mask=pad_mask)
print("Pad-masked cross-attn out:", out_pad.shape)   # (B, T_dec, d)
print("Pad-masked attn:", attn_pad.shape)            # (B, T_dec, T_enc)

Pad-masked cross-attn out: torch.Size([2, 2, 8])
Pad-masked attn: torch.Size([2, 2, 5])
