In [17]:
import torch
import math


def self_attention(q, k, v):
    # b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    prod = torch.einsum("btd, bsd -> bts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    softmaxed_prod = torch.nn.functional.softmax(scaled_prod, dim=-1)
    # print(softmaxed_prod.shape)
    # print(softmaxed_prod)
    return softmaxed_prod.bmm(v)


x = torch.rand([2, 3, 4])
self_attention(x, x, x)
self_attention(x, x, x).shape

torch.Size([2, 3, 4])

In [41]:
from torch import nn

class MHSA(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v):
        # b, t, d
        b, t, d = q.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(b, t, self.h, self.dh)
        wk = wk.view(b, t, self.h, self.dh)
        wv = wv.view(b, t, self.h, self.dh)
        # b, t, h, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # b*h, t, dh
        attn = self_attention(wq, wk, wv)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(b, self.h, t, self.dh).transpose(1, 2).contiguous().view(b, t, d)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa = MHSA()
x = torch.rand(2, 3, 512)
mhsa(x, x, x).shape

In [None]:
class Encoder(nn.Module): 

    def __init__(self, d, h):
        super().__init__()
        self.mhsa = MHSA(d, h)
        

In [2]:
a = torch.arange(12)
a = a.view(2,3,2)
a

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]]])

In [9]:
a.permute(2, 0, 1)

tensor([[[ 0,  2,  4],
         [ 6,  8, 10]],

        [[ 1,  3,  5],
         [ 7,  9, 11]]])