In [2]:
import torch

![image](https://www.ibm.com/content/dam/connectedassets-adobe-cms/worldwide-content/creative-assets/s-migr/ul/g/3f/b4/mha-mqa-and-gqa.component.xl-retina.ts=1744898975663.png/content/adobe-cms/us/en/think/topics/grouped-query-attention/jcr:content/root/table_of_contents/body-article-8/image_700444495)

In [None]:
B = 25 # batch size 
T = 50 # sequence length for Q 
S = T # sequence length for K,V 
D = 8*50 # d_model embedding dim 
N = 8 # no of query heads 
K = 4 # no of key, value heads 
H = D // N # attention head dim 
G = N // K # q heads per kv head 

x = torch.rand((B, T, D))

Q_w = torch.rand((D, N, H))
K_w = torch.rand((D, K, H))
V_w = torch.rand((D, K, H))
output_w = torch.rand((N, H, D))

query = torch.einsum("BTD,DNH->BTNH", x, Q_w)   # (B,T,N,H)
key   = torch.einsum("BTD,DKH->BTKH", x, K_w)   # (B,T,K,H)
value = torch.einsum("BTD,DKH->BTKH", x, V_w)   # (B,T,K,H)

query = query.reshape(B, T, K, G, H)            # (B,T,K,G,H)

attn_logits = torch.einsum("BTKGH,BSKH->BTSKG", query, key)  # (B,T,S,K,G)
scale = torch.sqrt(torch.tensor(H, dtype=attn_logits.dtype, device=attn_logits.device))
attn_logits = attn_logits / scale

causal = torch.ones(T, S, dtype=torch.bool, device=attn_logits.device).tril().view(1, T, S, 1, 1)
attn_logits = attn_logits.masked_fill(~causal, float("-inf"))

attn_probs = torch.softmax(attn_logits, dim=2)               # (B,T,S,K,G)

context = torch.einsum("BTSKG,BSKH->BTKGH", attn_probs, value)  # (B,T,K,G,H)
context = context.reshape(B, T, N, H)                           # (B,T,N,H)

outputs = torch.einsum("BTNH,NHD->BTD", context, output_w)      # (B,T,D)

In [None]:
B = 25 # batch size 
T = 50 # sequence length for Q 
S = T # sequence length for K,V 
D = 8*50 # d_model embedding dim 
N = 8 # no of query heads 
K = 4 # no of key, value heads 
H = D // N # attention head dim 
G = N // K # q heads per kv head 

x = torch.rand((B, T, D))

Q_w = torch.rand((D, N, H))
K_w = torch.rand((D, K, H))
V_w = torch.rand((D, K, H))
output_w = torch.rand((N, H, D))

query = torch.einsum("BTD,DNH->BTNH", x, Q_w)     # [B,T,N,H]
key   = torch.einsum("BTD,DKH->BTKH", x, K_w)     # [B,T,K,H]
value = torch.einsum("BTD,DKH->BTKH", x, V_w)     # [B,T,K,H]

# GQA: expand K/V heads to N heads
key   = key.repeat_interleave(G, dim=2)           # [B,T,N,H]
value = value.repeat_interleave(G, dim=2)         # [B,T,N,H]

attn_logits = torch.einsum("BTNH,BSNH->BTSN", query, key)  # [B,T,S,N]
scale = torch.sqrt(torch.tensor(H, dtype=attn_logits.dtype, device=attn_logits.device))
attn_logits = attn_logits / scale

# causal mask (lower triangle), broadcast to [B,T,S,N]
mask = torch.ones(T, S, dtype=torch.bool).tril().view(1, T, S, 1)
mask = mask.to(attn_logits.device)

# IMPORTANT: assign back, and use -inf
attn_logits = attn_logits.masked_fill(~mask, float("-inf"))

attn_probs = torch.softmax(attn_logits, dim=2)             # [B,T,S,N]
context    = torch.einsum("BTSN,BSNH->BTNH", attn_probs, value)
output     = torch.einsum("BTNH,NHD->BTD", context, output_w)

print(output.shape)  # torch.Size([25, 50, 400])

torch.Size([25, 50, 400])


In [47]:
# import math, torch, torch.distributed as dist

# # ----- config -----
# B, T = 25, 50
# S = T
# D = 8*50
# N = 8   # query heads
# K = 4   # kv heads
# H = D // N
# assert D % N == 0
# P = dist.get_world_size()
# rank = dist.get_rank()
# assert N % P == 0 and K % P == 0, "Heads must divide across TP ranks"

# N_p = N // P
# K_p = K // P
# G = N // K
# G_local = N_p // K_p    # should equal G

# # ----- data -----
# device = torch.device("cuda", index=rank) if torch.cuda.is_available() else torch.device("cpu")
# x = torch.rand((B, T, D), device=device)

# # Shard weights by head axis
# # global shapes: Q:(D,N,H), K/V:(D,K,H), O:(N,H,D)
# Q_w_local = torch.rand((D, N_p, H), device=device)
# K_w_local = torch.rand((D, K_p, H), device=device)
# V_w_local = torch.rand((D, K_p, H), device=device)
# O_w_local = torch.rand((N_p, H, D), device=device)  # row-parallel shard

# # ----- forward on each rank -----
# # Project to local heads
# q = torch.einsum("BTD,DNH->BTNH", x, Q_w_local)     # [B,T,N_p,H]
# k = torch.einsum("BTD,DKH->BTKH", x, K_w_local)     # [B,T,K_p,H]
# v = torch.einsum("BTD,DKH->BTKH", x, V_w_local)     # [B,T,K_p,H]

# # GQA: expand local K/V to match local Q heads
# k = k.repeat_interleave(G_local, dim=2)             # [B,T,N_p,H]
# v = v.repeat_interleave(G_local, dim=2)             # [B,T,N_p,H]

# # Attention per-head (local, no comms)
# attn_logits = torch.einsum("BTNH,BSNH->BTSN", q, k) # [B,T,S,N_p]
# attn_logits = attn_logits / math.sqrt(H)

# mask = torch.ones(T, S, dtype=torch.bool, device=device).tril().view(1,T,S,1)
# attn_logits = attn_logits.masked_fill(~mask, float("-inf"))

# attn_probs = torch.softmax(attn_logits, dim=2)      # [B,T,S,N_p]
# ctx_local  = torch.einsum("BTSN,BSNH->BTNH", attn_probs, v)  # [B,T,N_p,H]

# # Row-parallel output: each rank produces a partial [B,T,D], then sum-reduce
# y_partial = torch.einsum("BTNH,NHD->BTD", ctx_local, O_w_local)  # [B,T,D]
# dist.all_reduce(y_partial, op=dist.ReduceOp.SUM)                 # -> [B,T,D]
# y = y_partial  # final output on all ranks

# # y shape is [B,T,D], matches your single-GPU result
# print(rank, y.shape)

In [None]:
import torch.nn as nn 

class MHA(nn.Module):
    def __init__(self, d_model, num_q_heads, num_kv_heads):
        super().__init__()
        assert num_q_heads % num_kv_heads == 0

        self.d_model = d_model
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads    
        self.head_dim = d_model // num_q_heads 
        self.group_size = num_q_heads // num_kv_heads 

        self.q_proj = nn.Parameter(torch.randn(d_model, num_q_heads, self.head_dim))
        self.k_proj = nn.Parameter(torch.randn(d_model, num_kv_heads, self.head_dim))
        self.v_proj = nn.Parameter(torch.randn(d_model, num_kv_heads, self.head_dim))
        self.out_proj = nn.Parameter(torch.rand(num_q_heads, self.head_dim, d_model))
    
    def forward(self, x, causal_mask=True):
        B, T, _ = x.shape 
        N, K, H = self.num_q_heads, self.num_kv_heads, self.head_dim 
        G = self.group_size 

        # compute Q, K, V
        Q = torch.einsum("BTD,DNH->BTNH", x, self.q_proj)  # [B,T,N,H]
        K = torch.einsum("BTD,DKH->BTKH", x, self.k_proj)  # [B,T,K,H]
        V = torch.einsum("BTD,DKH->BTKH", x, self.v_proj)  # [B,T,K,H]

        # expand K/V for grouped query attention
        K = K.repeat_interleave(G, dim=2)  # [B,T,N,H]
        V = V.repeat_interleave(G, dim=2)  # [B,T,N,H]

        # attention logits
        attn_logits = torch.einsum("BTNH,BSNH->BTSN", Q, K) / math.sqrt(H)  # [B,T,S,N]

        # causal mask if needed
        if causal_mask:
            mask = torch.ones(T, T, dtype=torch.bool, device=x.device).tril().view(1, T, T, 1)
            attn_logits = attn_logits.masked_fill(~mask, float("-inf"))

        attn_probs = torch.softmax(attn_logits, dim=2)  # [B,T,S,N]
        context = torch.einsum("BTSN,BSNH->BTNH", attn_probs, V)  # [B,T,N,H]

        # output projection
        out = torch.einsum("BTNH,NHD->BTD", context, self.out_proj)
        return out

In [51]:
import math
import torch
import torch.nn as nn

class GQAWithSlicedMask(nn.Module):
    def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int, max_seq_len: int, causal: bool = True):
        super().__init__()
        assert num_q_heads % num_kv_heads == 0, "num_q_heads must be divisible by num_kv_heads"
        assert d_model % num_q_heads == 0, "d_model must be divisible by num_q_heads"

        self.d_model = d_model
        self.N = num_q_heads
        self.K = num_kv_heads
        self.H = d_model // num_q_heads
        self.G = self.N // self.K
        self.max_seq_len = max_seq_len
        self.causal = causal

        # Projections
        self.q_proj = nn.Parameter(torch.randn(d_model, self.N, self.H) / math.sqrt(d_model))
        self.k_proj = nn.Parameter(torch.randn(d_model, self.K, self.H) / math.sqrt(d_model))
        self.v_proj = nn.Parameter(torch.randn(d_model, self.K, self.H) / math.sqrt(d_model))
        self.out_proj = nn.Parameter(torch.randn(self.N, self.H, d_model) / math.sqrt(self.N * self.H))

        # Precompute full additive mask [1, Lmax, Lmax, 1]: 0 on/below diag, -inf above
        if causal:
            full = torch.full((max_seq_len, max_seq_len), float("-inf"), dtype=torch.float32)
            full = torch.tril(full, diagonal=0)  # keep lower triangle as -inf; fix below
            # Set allowed (on/below diag) to 0
            full[torch.tril(torch.ones_like(full, dtype=torch.bool))] = 0.0
        else:
            full = torch.zeros(max_seq_len, max_seq_len, dtype=torch.float32)
        self.register_buffer("additive_mask_full", full.view(1, max_seq_len, max_seq_len, 1), persistent=False)

        # Scale buffer
        self.register_buffer("scale", torch.tensor(1.0 / math.sqrt(self.H), dtype=torch.float32), persistent=False)

    def forward(self, x: torch.Tensor, kv: torch.Tensor | None = None, padding_mask: torch.Tensor | None = None):
        """
        x:  [B, T, D]           - queries
        kv: [B, S, D] or None   - keys/values source (defaults to x for self-attn)
        padding_mask: optional boolean mask, shape:
            - self-attn: [B, T]  (True = keep, False = pad)
            - cross-attn: [B, S]
        Returns: [B, T, D]
        """
        B, T, D = x.shape
        assert D == self.d_model, f"d_model mismatch: got {D}, expected {self.d_model}"
        if kv is None:
            kv = x
        S = kv.shape[1]
        if T > self.max_seq_len or S > self.max_seq_len:
            raise ValueError(f"Sequence length exceeds max_seq_len ({self.max_seq_len}): T={T}, S={S}")

        # Projections
        Q = torch.einsum("BTD,DNH->BTNH", x, self.q_proj)   # [B,T,N,H]
        K = torch.einsum("BSD,DKH->BSKH", kv, self.k_proj)  # [B,S,K,H]
        V = torch.einsum("BSD,DKH->BSKH", kv, self.v_proj)  # [B,S,K,H]

        # Expand K/V heads to N (GQA)
        K = K.repeat_interleave(self.G, dim=2)  # [B,S,N,H]
        V = V.repeat_interleave(self.G, dim=2)  # [B,S,N,H]

        # Attention logits
        attn_logits = torch.einsum("BTNH,BSNH->BTSN", Q, K)  # [B,T,S,N]
        attn_logits = attn_logits * self.scale.to(attn_logits.dtype)

        # Add sliced causal mask (broadcasts to [B,T,S,N])
        if self.causal:
            mask_slice = self.additive_mask_full[:, :T, :S, :].to(attn_logits.dtype)
            attn_logits = attn_logits + mask_slice

        # Optional padding mask (True=keep, False=pad). Convert to additive.
        if padding_mask is not None:
            if padding_mask.shape[1] == S:   # key padding (common)
                keep = padding_mask.view(B, 1, S, 1)  # -> [B,1,S,1], broadcast to [B,T,S,N]
                add = torch.where(keep, 0.0, float("-inf"))
                attn_logits = attn_logits + add.to(attn_logits.dtype)
            elif padding_mask.shape[1] == T: # query padding (rare, but supported)
                keep = padding_mask.view(B, T, 1, 1)  # -> [B,T,1,1]
                add = torch.where(keep, 0.0, float("-inf"))
                attn_logits = attn_logits + add.to(attn_logits.dtype)
            else:
                raise ValueError("padding_mask second dim must match T or S")

        attn_probs = torch.softmax(attn_logits, dim=2)        # [B,T,S,N]
        context    = torch.einsum("BTSN,BSNH->BTNH", attn_probs, V)  # [B,T,N,H]
        out        = torch.einsum("BTNH,NHD->BTD", context, self.out_proj)   # [B,T,D]
        return out