## Understanding Self Attention mechanism on images

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [29]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels: int):
        super(SelfAttention, self).__init__()
        self.in_channels = in_channels
        
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1) # dimensionality reduction
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1) # dimensionality reduction
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        
        self.gamma = nn.Parameter(torch.zeros(1)) # scalar to represent the "relevance" of the attention
        
    def forward(self, x: torch.Tensor):
        batch_size, C, width, height = x.size() # get shape of tensor
        
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)  # B x N x C # where N = width * height #  B x N x C where N is analogou as the countext length and C is the embedding dimension
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)  # B x C x N # where N = width * height # the transposed of the query tensor
        energy = torch.bmm(proj_query, proj_key)  # B x N x N # batch matmul
        attention = F.softmax(energy, dim=-1)  # B x N x N # regular attention softmax implementation
        proj_value = self.value_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)  # B x N x C
        
        out = torch.bmm(attention, proj_value)  # B x N x C
        out = out.permute(0, 2, 1) # B x C x N
        out = out.view(batch_size, C, width, height) # B x C x H x W
        
        out = self.gamma * out + x # attention added as a residual
        return out # B x C x H x W


In [30]:
x = torch.randn([1, 32, 64, 60])

In [31]:
self_attention =  SelfAttention(32)

self_attention(x).shape


torch.Size([1, 32, 64, 60])