In [1]:
import torch
import torch.nn as nn

import math

# from torch.nn.functional import scaled_dot_product_attention
from einops import rearrange

In [78]:
class MultiHeadAttention(nn.Module):
    def __init__(
            self,
            n_heads,
            input_dim,
            embed_dim,
            val_dim=None,
            key_dim=None
    ):
        super(MultiHeadAttention, self).__init__()

        if val_dim is None:
            val_dim = embed_dim // n_heads
        if key_dim is None:
            key_dim = val_dim

        self.n_heads = n_heads
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.val_dim = val_dim
        self.key_dim = key_dim

        self.norm_factor = 1 / math.sqrt(key_dim)  # See Attention is all you need

        self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
        self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
        self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim))

        self.W_out = nn.Parameter(torch.Tensor(n_heads, val_dim, embed_dim))

        self.init_parameters()

    def init_parameters(self):

        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, q, h=None, mask=None):
        """

        :param q: queries (batch_size, n_query, input_dim)
        :param h: data (batch_size, graph_size, input_dim)
        :param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1)
        Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency)
        :return:
        """
        if h is None:
            h = q  # compute self-attention

        # h should be (batch_size, graph_size, input_dim)
        batch_size, graph_size, input_dim = h.size()
        n_query = q.size(1)
        assert q.size(0) == batch_size
        assert q.size(2) == input_dim
        assert input_dim == self.input_dim, "Wrong embedding dimension of input"

        hflat = h.contiguous().view(-1, input_dim)
        qflat = q.contiguous().view(-1, input_dim)

        # last dimension can be different for keys and values
        shp = (self.n_heads, batch_size, graph_size, -1)
        shp_q = (self.n_heads, batch_size, n_query, -1)

        # Calculate queries, (n_heads, n_query, graph_size, key/val_size)
        Q = torch.matmul(qflat, self.W_query).view(shp_q)
        # Calculate keys and values (n_heads, batch_size, graph_size, key/val_size)
        K = torch.matmul(hflat, self.W_key).view(shp)
        V = torch.matmul(hflat, self.W_val).view(shp)

        # Calculate compatibility (n_heads, batch_size, n_query, graph_size)
        # compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3))
        compatibility = torch.matmul(Q, K.transpose(2, 3))

        # Optionally apply mask to prevent attention
        if mask is not None:
            mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
            compatibility[mask] = -np.inf

        attn = torch.softmax(compatibility, dim=-1)

        # If there are nodes with no neighbours then softmax returns nan so we fix them to 0
        if mask is not None:
            attnc = attn.clone()
            attnc[mask] = 0
            attn = attnc

        heads = torch.matmul(attn, V)

        out = torch.mm(
            heads.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * self.val_dim),
            self.W_out.view(-1, self.embed_dim)
        ).view(batch_size, n_query, self.embed_dim)

        # Alternative:
        # headst = heads.transpose(0, 1)  # swap the dimensions for batch and heads to align it for the matmul
        # # proj_h = torch.einsum('bhni,hij->bhnj', headst, self.W_out)
        # projected_heads = torch.matmul(headst, self.W_out)
        # out = torch.sum(projected_heads, dim=1)  # sum across heads

        # Or:
        # out = torch.einsum('hbni,hij->bnj', heads, self.W_out)

        return out, heads

In [79]:
N_HEADS = 4
INPUT_DIM = EMBED_DIM = 32
BATCH = 3
N_NODES = 5

x = torch.randn(BATCH, N_NODES, EMBED_DIM)
original_attn = MultiHeadAttention(n_heads=N_HEADS, input_dim=INPUT_DIM, embed_dim=EMBED_DIM)
print(original_attn.norm_factor)

0.35355339059327373


In [80]:
x = torch.randn(BATCH, N_NODES, EMBED_DIM)
x = torch.ones(BATCH, N_NODES, EMBED_DIM)
x[:, 0, :] = 0.0

In [81]:
Wqkv = nn.Linear(EMBED_DIM, 3 * EMBED_DIM, bias=False)
Wqkv.weight.data = torch.ones_like(Wqkv.weight.data)

out_proj = nn.Linear(EMBED_DIM, EMBED_DIM, bias=False)
out_proj.weight.data = torch.ones_like(out_proj.weight.data)

In [82]:
import torch.nn.functional as F

def scaled_dot_product_attention(
        Q, K, V, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
    ):
        print("Here")
        """Simple Scaled Dot-Product Attention in PyTorch without Flash Attention"""
        if scale is None:
            scale = Q.size(-1) ** -0.5  # scale factor

        print(f"Scale: {scale}")
        # compute the attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1))
        attn_probs = F.softmax(attn_scores, dim=-1)
        return torch.matmul(attn_probs, V)

In [83]:
q, k, v = rearrange(Wqkv(x), "b s (three h d) -> three b s h d", three=3, h=N_HEADS).unbind(dim=0)
o = scaled_dot_product_attention(q, k, v)
h = out_proj(rearrange(o, "b s h d -> b s (h d)"))

Here
Scale: 0.3535533905932738


In [84]:
original_attn.norm_factor

0.35355339059327373

In [85]:
original_attn = MultiHeadAttention(n_heads=N_HEADS, input_dim=INPUT_DIM, embed_dim=EMBED_DIM)
original_attn.W_query.data = torch.ones_like(original_attn.W_query.data)
original_attn.W_key.data = torch.ones_like(original_attn.W_key.data)
original_attn.W_val.data = torch.ones_like(original_attn.W_val.data)
original_attn.W_out.data = torch.ones_like(original_attn.W_out.data)

h2, o2 = original_attn(x)

torch.Size([4, 3, 5, 5]) torch.Size([4, 3, 5, 8])


In [86]:
o.shape

torch.Size([3, 5, 4, 8])

In [87]:
o2.shape

torch.Size([4, 3, 5, 8])

In [88]:
(o.permute(2,0,1,3) - o2).abs().max()

tensor(25.6000, grad_fn=<MaxBackward1>)

In [92]:
b, n = 0,0

h[b, n] - h2[b, n]

tensor([-819.1998, -819.1998, -819.1998, -819.1998, -819.1998, -819.1998,
        -819.1998, -819.1998, -819.1998, -819.1998, -819.1998, -819.1998,
        -819.1998, -819.1998, -819.1998, -819.1998, -819.1998, -819.1998,
        -819.1998, -819.1998, -819.1998, -819.1998, -819.1998, -819.1998,
        -819.1998, -819.1998, -819.1998, -819.1998, -819.1998, -819.1998,
        -819.1998, -819.1998], grad_fn=<SubBackward0>)

In [96]:
pytorch_attn = nn.MultiheadAttention(embed_dim=EMBED_DIM, num_heads=N_HEADS, 
                                     bias=False, batch_first=True)
pytorch_attn.in_proj_weight.data = torch.ones_like(pytorch_attn.in_proj_weight.data)

In [100]:
q.shape

torch.Size([3, 5, 4, 8])

In [103]:
h3, o3 = pytorch_attn(q.view(BATCH, N_NODES, -1), k.view(BATCH, N_NODES, -1), v.view(BATCH, N_NODES, -1))

In [115]:
o3[0].shape

torch.Size([5, 5])

In [121]:
h3[0]

tensor([[1155.6492, -444.0484, -230.0319,  -96.4411,  376.5367,  498.6359,
         -151.0114, -218.2406, -385.9967,  161.0141, -766.5836, -272.5924,
         -295.6412,  847.5152, -744.7882, -551.1451,  257.4138,  183.6791,
         -187.7901,  262.6219, -382.0503, 1638.3092,  -46.4304, -340.5738,
          143.1985, -323.4807, -705.9189, 1044.2151,   28.0299,  252.7086,
          769.5617, -720.9847],
        [1444.5614, -555.0605, -287.5400, -120.5514,  470.6707,  623.2948,
         -188.7643, -272.8008, -482.4958,  201.2676, -958.2294, -340.7405,
         -369.5516, 1059.3940, -930.9854, -688.9313,  321.7672,  229.5989,
         -234.7376,  328.2773, -477.5630, 2047.8865,  -58.0379, -425.7173,
          178.9982, -404.3507, -882.3987, 1305.2687,   35.0373,  315.8858,
          961.9521, -901.2307],
        [1444.5614, -555.0605, -287.5400, -120.5514,  470.6707,  623.2948,
         -188.7643, -272.8008, -482.4958,  201.2676, -958.2294, -340.7405,
         -369.5516, 1059.3940, -930.

In [122]:
h2[0]

tensor([[ 819.1998,  819.1998,  819.1998,  819.1998,  819.1998,  819.1998,
          819.1998,  819.1998,  819.1998,  819.1998,  819.1998,  819.1998,
          819.1998,  819.1998,  819.1998,  819.1998,  819.1998,  819.1998,
          819.1998,  819.1998,  819.1998,  819.1998,  819.1998,  819.1998,
          819.1998,  819.1998,  819.1998,  819.1998,  819.1998,  819.1998,
          819.1998,  819.1998],
        [1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000,
         1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000,
         1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000,
         1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000,
         1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000,
         1024.0000, 1024.0000],
        [1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000,
         1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000, 1024.0000,
         1024.0000, 1024.0000, 1024.