In [None]:
import torch
import torchvision
import torch.nn as nn
import numpy as np

In [None]:
class SelfAttention:
  def __init__(self,X,W_qkv,mask,W_out):
    self.X = X
    self.W_qkv = W_qkv
    self.mask = mask
    self.W_out = W_out

  def provideAttention(self):
    # X -> b * T * d
    # W_qkv ->  (3*d) * d
    # mask -> T * T
    # W_out -> T * d
    d = X.shape[-1]
    print(self.W_qkv.shape)
    W_k,W_q,W_v =torch.split(self.W_qkv,[d,d,d],dim=0) # d * d
    print(W_k.shape)

    K = torch.matmul(self.X,W_k.T) # b * T * d

    Q = torch.transpose(torch.matmul(self.X,W_q.T),-2,-1) # b * d * T

    V = torch.matmul(self.X,W_v.T) # b * T * d

    out = torch.matmul(K,Q) / np.sqrt(d) # T * T
    out = out + self.mask # T * T
    print(out.shape)
    softmax = nn.Softmax(dim=-1)
    out = softmax(out) # T * T
    atten = out
    out = out@V@self.W_out  # T * T
    # out = out@self.W_out
    return out, atten

In [None]:
b, T, d = 50, 100 , 64
torch_attn = nn.MultiheadAttention(d,4,bias=False,batch_first=True)
Mask = torch.triu(-float("inf")*torch.ones(T,T),1)
X = torch.rand(b,T,d)
Y_ , A_ = torch_attn(X,X,X,attn_mask=Mask)

In [None]:
attn = SelfAttention(X,
                     torch_attn.in_proj_weight,
                     Mask,
                     torch_attn.out_proj.weight.T)
Y, A = attn.provideAttention()

torch.Size([192, 64])
torch.Size([64, 64])
torch.Size([50, 100, 100])


In [None]:
torch.norm(Y-Y_)

tensor(0., grad_fn=<LinalgVectorNormBackward0>)

In [None]:
class MultiheadAttention:
  def __init__(self,X,h,W_qkv,mask,W_out):
    self.X = X
    self.h = h
    self.W_qkv = W_qkv.T
    self.mask = mask
    self.W_out = W_out

  def provideAttention(self):
    # X -> b * T * d
    # W_qkv ->  d * (3*d)
    # mask -> T * T
    # W_out -> d * d
    b,T,d = X.shape
    h = self.h
    temp = self.X@self.W_qkv
    temp = torch.split(temp,[d,d,d],dim=-1) # d * d
    # print(type(torch.reshape(temp[0],(b,T,h,int(d/h)))))
    K = torch.swapaxes(torch.reshape(temp[0],(b,T,h,int(d/h))),1,2)
    Q = torch.swapaxes(torch.reshape(temp[1],(b,T,h,int(d/h))),1,2)
    V = torch.swapaxes(torch.reshape(temp[2],(b,T,h,int(d/h))),1,2)
    softmax = nn.Softmax(dim=-1)
    atten = softmax( torch.matmul(K, torch.swapaxes(Q,-1,-2)) / np.sqrt(d/h) + self.mask)
    # out = torch.matmul(K,Q) / np.sqrt(d) # T * T
    # out = out + self.mask # T * T
    # print(out.shape)
    # out = softmax(out) # T * T
    # atten = out
    # out = out@V@self.W_out  # T * T
    # # out = out@self.W_out
    return torch.reshape(torch.swapaxes(atten@V,1,2),(b,T,d))@self.W_out, atten

In [None]:
attn = MultiheadAttention(X,4,
                     torch_attn.in_proj_weight,
                     Mask,
                     torch_attn.out_proj.weight.T)
Y, A = attn.provideAttention()

<class 'torch.Tensor'>


In [None]:
torch.norm(Y-Y_)

tensor(0., grad_fn=<LinalgVectorNormBackward0>)