In [1]:
import torch.nn as nn
import torch
import math
import torch.nn.functional as F

class MultiheadAttention(nn.Module):
    
    def __init__(self, model_d, num_heads) -> None:
        super().__init__()
        self.num_heads = num_heads
        assert model_d % num_heads == 0
        self.head_d = model_d // num_heads

        self.wq = nn.Linear(model_d, model_d, bias=False)
        self.wk = nn.Linear(model_d, model_d, bias=False)
        self.wv = nn.Linear(model_d, model_d, bias=False)
        
        self.wo = nn.Linear(model_d, model_d, bias=False)
        
        self.scale = math.sqrt(self.head_d)
        
    def forward(self, x, attn_mask):
        """
        x: (B, S, model_d)
        attn_mask: (B, S)
        """
        B, S, _ = x.shape
        h = self.num_heads
        q = self.wq(x).reshape(B, S, h, self.head_d).transpose(1, 2)
        k = self.wk(x).reshape(B, S, h, self.head_d).transpose(1, 2)
        v = self.wv(x).reshape(B, S, h, self.head_d).transpose(1, 2)
        # now they have (B, h, S, head_d)
        logits = torch.matmul(q, k.transpose(-1, -2)) / self.scale
        masked_logits = torch.masked_fill(logits, attn_mask.reshape(B, 1, 1, S), -float("inf"))
        attention = F.softmax(masked_logits, dim=-1)  # (B, h, S, S)
        attended = torch.matmul(attention, v)  # (B, h, S, head_d)
        outputs = attended.transpose(1, 2).reshape(B, S, -1)
        return self.wo(outputs)
        

In [4]:
# Version-robust comparison against torch.nn.functional.multi_head_attention_forward
import math
import inspect
import torch
import torch.nn.functional as F

def manual_attn_probs_with_module_weights(x_bsd, key_padding_mask, mha):
    B, S, D = x_bsd.shape
    h, d = mha.num_heads, mha.head_d
    scale = math.sqrt(d)

    Wq, Wk, Wv = mha.wq.weight, mha.wk.weight, mha.wv.weight  # (D, D)

    q = (x_bsd @ Wq.t()).reshape(B, S, h, d).transpose(1, 2)  # (B,h,S,d)
    k = (x_bsd @ Wk.t()).reshape(B, S, h, d).transpose(1, 2)
    v = (x_bsd @ Wv.t()).reshape(B, S, h, d).transpose(1, 2)

    logits = (q @ k.transpose(-1, -2)) / scale  # (B,h,S,S)
    if key_padding_mask is not None:
        logits = logits.masked_fill(key_padding_mask[:, None, None, :], float("-inf"))
    return logits.softmax(dim=-1)  # (B,h,S,S)

def functional_call_compat(x_bsd, key_padding_mask, mha):
    """
    Always call the functional with (S, B, D). Provide in_proj_weight for all versions.
    If separate weights are supported, set use_separate_proj_weight=True and in_proj_weight=None.
    Otherwise, concatenate Wq;Wk;Wv into in_proj_weight.
    """
    B, S, D = x_bsd.shape
    H = mha.num_heads
    x_sbd = x_bsd.transpose(0, 1)  # (S,B,D)

    sig = inspect.signature(F.multi_head_attention_forward)
    params = set(sig.parameters.keys())

    # Start with the minimal, widely supported args (in the right order via kwargs).
    kwargs = {
        "query": x_sbd,
        "key": x_sbd,
        "value": x_sbd,
        "embed_dim_to_check": D,
        "num_heads": H,
        "in_proj_weight": None,      # ensure present for old signatures
        "in_proj_bias": None,
        "bias_k": None,
        "bias_v": None,
        "add_zero_attn": False,
        "dropout_p": 0.0,
        "out_proj_weight": mha.wo.weight,
        "out_proj_bias": None,
        "training": False,
        "need_weights": False,
    }
    if "key_padding_mask" in params:
        kwargs["key_padding_mask"] = key_padding_mask
    if "attn_mask" in params:
        kwargs["attn_mask"] = None
    if "is_causal" in params:
        kwargs["is_causal"] = False

    if "use_separate_proj_weight" in params:
        # Use separate q/k/v weights; keep in_proj_weight=None
        kwargs.update({
            "use_separate_proj_weight": True,
            "q_proj_weight": mha.wq.weight,
            "k_proj_weight": mha.wk.weight,
            "v_proj_weight": mha.wv.weight,
            "static_k": None,
            "static_v": None,
        })
        # Remove keys not in this version
        kwargs = {k: v for k, v in kwargs.items() if k in params}
        y_sbd, *_ = F.multi_head_attention_forward(**kwargs)
    else:
        # Older versions: concatenate into in_proj_weight (3D, D)
        in_proj_weight = torch.cat([mha.wq.weight, mha.wk.weight, mha.wv.weight], dim=0)
        kwargs["in_proj_weight"] = in_proj_weight
        # Remove unsupported keys
        for k in ["use_separate_proj_weight", "q_proj_weight", "k_proj_weight", "v_proj_weight",
                  "static_k", "static_v"]:
            kwargs.pop(k, None)
        kwargs = {k: v for k, v in kwargs.items() if k in params}
        y_sbd, *_ = F.multi_head_attention_forward(**kwargs)

    return y_sbd.transpose(0, 1)  # back to (B,S,D)

def compare_with_functional(B=2, S=5, D=32, H=4, use_mask=True, tol=(1e-6, 1e-5)):
    torch.manual_seed(123)

    x = torch.randn(B, S, D)
    key_padding_mask = torch.zeros(B, S, dtype=torch.bool)
    if use_mask and S >= 3:
        key_padding_mask[:, -1] = True
    if use_mask and S >= 5 and B >= 2:
        key_padding_mask[0, 1] = True

    mha = MultiheadAttention(D, H)

    # Your module output
    y_mod = mha(x, key_padding_mask)

    # Functional output
    y_fun = functional_call_compat(x, key_padding_mask, mha)

    # Compare outputs
    max_diff = (y_mod - y_fun).abs().max().item()
    print(f"Output max|diff| = {max_diff:.3e}")
    assert torch.allclose(y_mod, y_fun, atol=tol[0], rtol=tol[1]), "Outputs differ!"

    # Check masking via manual probabilities (version-agnostic)
    probs = manual_attn_probs_with_module_weights(x, key_padding_mask, mha)  # (B,H,S,S)
    if key_padding_mask.any():
        avg_probs = probs.mean(dim=1).mean(dim=1)  # (B,S)
        masked_vals = avg_probs[key_padding_mask]
        if masked_vals.numel():
            max_masked = masked_vals.max().item()
            print(f"Max prob on masked keys (manual) = {max_masked:.3e}")
            assert max_masked < 1e-6, "Masked keys still receive attention!"

    print("✓ Functional parity and masking look good.\n")

# Try a few configs
compare_with_functional(B=2, S=5, D=32, H=4, use_mask=True)
compare_with_functional(B=1, S=7, D=16, H=1, use_mask=False)
compare_with_functional(B=3, S=4, D=24, H=3, use_mask=True)


Output max|diff| = 8.941e-08
Max prob on masked keys (manual) = 0.000e+00
✓ Functional parity and masking look good.

Output max|diff| = 5.960e-08
✓ Functional parity and masking look good.

Output max|diff| = 6.706e-08
Max prob on masked keys (manual) = 0.000e+00
✓ Functional parity and masking look good.

