In [7]:
import numpy as np
import torch
from torch import nn
from einops import rearrange

In [111]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, d_k=None, d_v=None, dropout=0.1):
        super().__init__()
        self.d_model, self.n_heads = d_model, n_heads
        if d_k is None:
            d_k = self.d_model//self.n_heads
            d_v = self.d_model//self.n_heads

        self.w_qs = nn.Linear(d_model, n_heads * d_k)
        self.w_ks = nn.Linear(d_model, n_heads * d_k)
        self.w_vs = nn.Linear(d_model, n_heads * d_v)

        # nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        # nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        # nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        self.fc = nn.Linear(n_heads * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key=None, value=None, mask='causal', alibi=False):
        if key is None:
            key = query
        if value is None:
            value = query
        q = rearrange(self.w_qs(query), 'b l (head q) -> b head l q', head=self.n_heads)
        k = rearrange(self.w_ks(key), 'b t (head k) -> b head t k', head=self.n_heads)
        v = rearrange(self.w_vs(value), 'b t (head v) -> b head t v', head=self.n_heads)
        attn = torch.einsum('bhlk,bhtk->bhlt', [q, k]) / np.sqrt(q.shape[-1])
        attn = torch.softmax(attn, dim=3)
        output = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        output = rearrange(output, 'b head l v -> b l (head v)')
        output = self.dropout(self.fc(output))
        return output, attn
    
torch.manual_seed(0)
mha = MultiHeadAttention(384, 12).eval()
p = torch.randperm(100) # permutation
a = torch.randn(10, 100, 384) # random inputs
o1, _ = mha(a); o1 = o1[:, p] # p(f(a))
o2, _ = mha(a[:, p]) # f(p(a))
torch.allclose(o1, o2, atol=1e-4) # prints True

True