In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

In [5]:
class LightAttention(nn.Module):
    def __init__(self, dim_input_q, dim_input_kv, dim_q, dim_k, device="cpu", with_mask=False):
        super().__init__()
        self.device = device
        self.with_mask = with_mask
        self.softmax_col = nn.Softmax(dim=-1)
        self.softmax_row = nn.Softmax(dim=-2)
        self.W_q = nn.Linear(in_features=dim_input_q, out_features=dim_q)
        self.W_k = nn.Linear(in_features=dim_input_kv, out_features=dim_k)
        self.W_v = nn.Linear(in_features=dim_input_kv, out_features=dim_k)
        self.d_q = torch.pow(torch.Tensor([dim_q]).to(device), 1/4)
        self.d_k = torch.pow(torch.Tensor([dim_k]).to(device), 1/4)

    def mask(self, dim: (int, int)) -> Tensor :
        a, b = dim
        mask = torch.ones(b, a)
        mask = torch.triu(mask, diagonal=0)
        mask = torch.log(mask.T)
        return mask.to(self.device)
        
    def forward(self, x_q, x_k, x_v):
        Q = self.W_q(x_q)
        K = self.W_k(x_k)
        V = self.W_v(x_v)
        if self.with_mask == True:
            Q += self.mask(Q.shape[-2:])
        A = self.softmax_row(Q / self.d_q)
        B = torch.matmul(self.softmax_col(K.transpose(-2, -1) / self.d_k), V)
        Z = torch.matmul(A, B)
        return Z


class MHLA(nn.Module):
    def __init__(self, 
                 num_heads, 
                 dim_input_q,
                 dim_input_kv,
                 dim_q = 64,
                 dim_k = 64,
                 device="cpu",
                 mask=False
                ):
        """
        Args:
        dim_input - if shape is (B, C, H, W), then dim_input is W
        """
        #TODO: Make parallel heads like channel dim
        super().__init__()
        heads = [LightAttention(dim_input_q, dim_input_kv, dim_q, dim_k, device, mask) for _ in range(num_heads)]
        self.heads = nn.ModuleList(heads)                
        self.W_o = nn.Linear(dim_k*num_heads, dim_input_kv)
        
    def forward(self, x_q, x_k, x_v):
        x = torch.cat([latt(x_q, x_k, x_v) for latt in self.heads], dim=-1)
        y = self.W_o(x)
        return y

In [9]:
#x_q.shape=torch.Size([2, 1, 64, 256])
X = torch.randn(2, 1, 64, 256)

mhla = MHLA(6, 256, 256)
y = mhla(X, X, X)
print(f"{y.shape=}")
print(f"{y.shape==X.shape}")

y.shape=torch.Size([2, 1, 64, 256])
True


In [68]:
class MHLA2(nn.Module):
    def __init__(self, 
                 num_heads, 
                 dim_last_input_q,
                 dim_last_input_kv,
                 dim_q = 16,
                 dim_k = 16,
                 device="cpu",
                 mask=False
                ):
        """
        Args:
        
        """
        #TODO: Make parallel heads like channel dim
        super().__init__()
        # self.W_Q = nn.Linear(dim_last_input_q, dim_q, bias=False)
        # self.W_K = nn.Linear(dim_last_input_kv, dim_k, bias=False)
        # self.W_V = nn.Linear(dim_last_input_kv, dim_k, bias=False)      
        self.with_mask = mask
        self.W_Q = torch.ones((num_heads, dim_last_input_q, dim_q), device=device, requires_grad=True)
        self.W_K = torch.ones((num_heads, dim_last_input_q, dim_q), device=device, requires_grad=True)
        self.W_V = torch.ones((num_heads, dim_last_input_q, dim_q), device=device, requires_grad=True)
        self.W_O = nn.Linear(dim_k*num_heads, dim_k*num_heads, bias=False)
        nn.init.xavier_uniform_(self.W_Q)
        nn.init.xavier_uniform_(self.W_K)
        nn.init.xavier_uniform_(self.W_V)
        self.d_q = torch.pow(torch.Tensor([dim_q]).to(device), 1/4)
        self.d_k = torch.pow(torch.Tensor([dim_k]).to(device), 1/4)
        self.softmax_col = nn.Softmax(dim=-1)
        self.softmax_row = nn.Softmax(dim=-2)
        
    def mask(self, dim: (int, int)) -> Tensor :
        a, b = dim
        mask = torch.ones(b, a)
        mask = torch.triu(mask, diagonal=0)
        mask = torch.log(mask.T)
        return mask.to(self.device)
        
    def forward(self, x_q, x_k, x_v):
        Q = torch.matmul(x_q.transpose(-1, -2).contiguous(), self.W_Q)
        K = torch.matmul(x_k.transpose(-1, -2).contiguous(), self.W_K)
        V = torch.matmul(x_v.transpose(-1, -2).contiguous(), self.W_V)
        if self.with_mask == True:
            Q += self.mask(Q.shape[-2:])
        print(f"{Q.shape=}")
        print(f"{K.shape=}")
        print(f"{V.shape=}")
        A = torch.matmul(self.softmax_col(K.transpose(-1, -2).contiguous() / self.d_k), V)
        B = torch.matmul(self.softmax_row(Q / self.d_q), A)
        #A = torch.matmul(Q, K.transpose(-1, -2).contiguous())
        print(f"{A.shape=}")
        #B = torch.matmul(A, V)
        print(f"{B.shape=}")
        b, h, w, d = B.shape
        B = self.W_O(B.view(b, w, h*d))
        B = B.unsqueeze(dim=1).permute(0,1,3,2)
        
        return B
    
    
X = torch.randn(2, 1, 64, 256)

mhla2 = MHLA2(4, 64, 64)
y = mhla2(X, X, X)
print(f"{y.shape=}")
print(f"{y.shape==X.shape}")

Q.shape=torch.Size([2, 4, 256, 16])
K.shape=torch.Size([2, 4, 256, 16])
V.shape=torch.Size([2, 4, 256, 16])
A.shape=torch.Size([2, 4, 16, 16])
B.shape=torch.Size([2, 4, 256, 16])
y.shape=torch.Size([2, 1, 64, 256])
True


In [53]:
A = autograd_tensor = torch.randn((2, 3, 4), requires_grad=True)
A

tensor([[[-0.8611,  0.2410,  0.3465,  1.3166],
         [-1.1952,  1.1787, -1.0054,  0.7407],
         [-0.7926,  0.2395, -0.8082, -1.3719]],

        [[ 0.0121, -0.0247,  1.6018,  1.1075],
         [ 0.0031,  0.2784, -2.0812, -2.1760],
         [-0.5154,  0.3037,  2.8692, -0.0420]]], requires_grad=True)

In [40]:
x = torch.tensor([[1., -1.], [1., 1.]], requires_grad=True)
out = x.pow(2).sum()
out.backward()
x.grad


tensor([[ 2., -2.],
        [ 2.,  2.]])