In [69]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from math import sqrt

In [174]:
class conv_attention2d(nn.Module):
  
  def __init__(self, in_dim, out_dim, kernel_size=3, stride=1, padding=0):
    super(conv_attention2d, self).__init__()

    self.in_dim = in_dim
    self.out_dim = out_dim
    self.w_q = nn.Parameter(torch.ones((in_dim, out_dim)))
    self.w_k = nn.Parameter(torch.ones((in_dim, out_dim)))
    self.w_v = nn.Parameter(torch.ones((in_dim, out_dim)))
    self.padding = nn.ZeroPad2d(padding)
    self.kernel_size = kernel_size
    self.stride = stride

  def forward(self, x):
    
    x = self.padding(x)
    # x.shape = (h, w, c)
    x = x.permute(1,2,0)
    print(x.shape)
    # x_q.shape = x_k.shape = x_v.shape = (h, w, c)
    
    x_q = x @ self.w_q
    x_k = x @ self.w_k
    x_v = x @ self.w_v

    output_size = (x.shape[0]-self.kernel_size)//self.stride + 1
    ans = torch.zeros(output_size, output_size, self.out_dim)

    for h in range(0, x.shape[0] - self.kernel_size +1, self.stride):
      for w in range(0, x.shape[1] - self.kernel_size +1, self.stride):
        # q.shape = (1, 1, c)
        q = x_q[h+self.kernel_size//2, w+self.kernel_size//2, :]
        # q.shape = (c)
        q = q.squeeze()

        # kT.shape = (c, ks*ks) Tranpose
        kT = x_k[h:h+self.kernel_size, w:w+self.kernel_size, :].flatten(0,1).permute(1, 0)

        # qkT.shape = (ks*ks)
        qkT = q @ kT
        qkT /= sqrt(self.out_dim)
        qkT = F.softmax(qkT)

        # v.shape = (ks*ks, c)
        v = x_k[h:h+self.kernel_size, w:w+self.kernel_size, :].flatten(0,1)
        qkTv = qkT @ v

        ans[h//self.stride, w//self.stride, :] = qkTv
        # ans[h, w, :] = x[h, w, :]

    return ans.permute(2,0,1)


In [175]:
img = torch.randn((3, 7, 7))
print(img.shape)

ca = conv_attention2d(img.shape[0], 768, kernel_size = 3, padding=1, stride=1)


print(ca(img).shape)



torch.Size([3, 7, 7])
torch.Size([9, 9, 3])
torch.Size([768, 7, 7])


  qkT = F.softmax(qkT)
