In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

In [None]:
class Attention(nn.Module):
    def forward(self, query, key, value, mask=None):
        """
        mask: 0 = attend, 1 = do not attend.
        """
        dk = query.shape[-1]
        scores = query @ key.mT / math.sqrt(dk)
        if mask is not None:
            scores = scores.masked_fill(mask == 1, -torch.inf)
        scores = F.softmax(scores, dim=-1)
        return scores @ value


class MultiheadAttention(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        assert dim % n_heads == 0

        self.attention = Attention()
        self.dim = dim
        self.n_heads = n_heads
        self.wQ, self.wK, self.wV, self.wO = [
            nn.Linear(dim, dim, bias=False) for _ in range(4)
        ]

    def forward(self, q, k, v, mask=None):
        q, k, v = self.wQ(q), self.wK(k), self.wV(v)
        q, k, v = [self._split(x) for x in (q, k, v)]

        if mask is not None:
            mask = mask.repeat(self.n_heads, 1, 1)

        y = self.attention(q, k, v, mask)
        y = self._combine(y)
        y = self.wO(y)
        return y

    def _split(self, x):
        return rearrange(x, 'b t (h dh) -> b h t dh', h=self.n_heads)

    def _combine(self, x):
        return rearrange(x, 'b h t d -> b t (h d)')

    @staticmethod
    def gen_causal_mask(x):
        """Returns a causal mask for the input x"""
        # FIXME: we should follow the pytorch convention of using 1
        # for masked and 0 for unmasked. So the code below is broken.
        batch, seq_len, _ = x.shape
        return torch.tril(torch.ones(1, seq_len, seq_len)).expand(batch, -1, -1)


In [None]:
mha = MultiheadAttention(dim=768, n_heads=12)
x = torch.randn(32, 10, 768)
y = mha(x, x, x)
print(y.shape)

torch.Size([32, 10, 768])
