In [None]:
import torch
from torch.nn.modules.activation import MultiheadAttention

In [None]:
torch.no_grad().__enter__()

In [None]:
EMBED_DIM = 4
T = 3

In [None]:
mha = MultiheadAttention(EMBED_DIM, num_heads=1, batch_first=True)

In [None]:
def f(perturb, attn):
    q = torch.randn(1, T, EMBED_DIM)
    k = torch.randn(1, T, EMBED_DIM)
    v = torch.randn(1, T, EMBED_DIM)
    s0: torch.Tensor = attn(q, k, v)
    perturb(q, k, v)
    s1: torch.Tensor = attn(q, k, v)
    return (s1 - s0)[0, :, :].square().mean(dim=1)

In [None]:
def causalMask():
    return torch.ones((T, T)).tril().log()

def attnNoMask(q, k, v):
    return mha.forward(q, k, v, need_weights=False)[0]

def attnCausalMask(q, k, v):
    return mha.forward(q, k, v, attn_mask=causalMask(), need_weights=False)[0]

def attnKeyPaddingMask(q, k, v):
    return mha.forward(q, k, v, key_padding_mask=torch.tensor([
        0, 0, float('-inf'), 
    ]).unsqueeze(0), need_weights=False)[0]


In [None]:
def nop(q, k, v):
    pass

print(f(nop, attnNoMask))
print(f(nop, attnCausalMask))
print(f(nop, attnKeyPaddingMask))

In [None]:
def p0(q, k, v):
    v[0, 0, :] = 0

def p1(q, k, v):
    v[0, 1, :] = 0

def p2(q, k, v):
    v[0, 2, :] = 0

print(f(p0, attnNoMask))
print(f(p0, attnKeyPaddingMask))
print(f(p1, attnKeyPaddingMask))
print(f(p2, attnKeyPaddingMask))

In [None]:
print(f(p0, attnCausalMask))
print(f(p1, attnCausalMask))
print(f(p2, attnCausalMask))