# Implementation of Stand Alone Self Attention in Vision Models

![Stand Alone SA module](images/sasa.png)

In [1]:
from fastai2.vision.all import *

## Relative Self Attention Module

In [2]:
class RelativeSelfAttention(Module):
    def __init__(self, d_in, d_out, ks, groups, stride=1):
        self.n_c, self.ks, self.groups, self.stride = d_out, ks, groups, stride
        # linear transformation for queries, values and keys
        self.qx, self.kx, self.vx = [ConvLayer(d_in, d_out, ks=1, norm_type=None,
                                               act_cls=None) for _ in range(3)]
        # positional embeddings
        self.row_embeddings = nn.Parameter(torch.randn(d_out//2, ks))
        self.col_embeddings = nn.Parameter(torch.randn(d_out//2, ks))
        
    def calc_out_shape(self, inp_shape, pad):
        out_shape = [(sz + 2*pad - self.ks) // self.stride + 1 for sz in inp_shape]
        return out_shape
    
    def forward(self, x):
        query, keys, values = self.qx(x), self.kx(x), self.vx(x)
        
        pad = (self.ks -1) // 2
        
        # use unfold to extract the memory blocks and their associated queries
        query = F.unfold(query, kernel_size=1, stride=self.stride)
        keys = F.unfold(keys, kernel_size=self.ks, padding=pad, stride=self.stride)
        values = F.unfold(values, kernel_size=self.ks, padding=pad, stride=self.stride)
        
        
        # reshape and permute the dimensions into the appropriate format for matrix multiplication
        query = query.view(query.shape[0], self.groups, self.n_c//self.groups, -1, query.shape[-1]) # bs*G*C//G*1*N
        query = query.permute(0, 4, 1, 2, 3) # bs * N * G * C//G * 1
        keys = keys.view(keys.shape[0], self.groups, self.n_c//self.groups, -1, keys.shape[-1]) # bs*G*C//G*ks^2*N
        keys = keys.permute(0, 4, 1, 2, 3) # bs * N * G * C//G * ks^2
        values = values.view(values.shape[0], self.groups, self.n_c//self.groups, -1, values.shape[-1]) # bs*G*C//G*ks^2*N
        values = values.permute(0, 4, 1, 2, 3) # bs * N * G * C//G * ks^2
        
        # get positional embeddings
        row_embeddings = self.row_embeddings.unsqueeze(-1).expand(-1, -1, self.ks)
        col_embeddings = self.col_embeddings.unsqueeze(-2).expand(-1, self.ks, -1)
        
        embeddings = torch.cat((row_embeddings, col_embeddings)).view(self.groups,
                                self.n_c//self.groups, -1) # G * C//G * ks^2
        # add empty dimensions to match the shape of keys
        embeddings = embeddings[None, None, -1] # 1 * 1 * G * C//G * ks^2
        
        # compute attention map
        att_map = F.softmax(torch.matmul(query.transpose(-2,-1), keys+embeddings), dim=-1)
        # compute final output
        out = torch.matmul(att_map, values.transpose(-2,-1)).permute(0, 2, 3, 4, 1)
        
        return out.view(out.shape[0], self.n_c, *self.calc_out_shape(x.shape[-2:], pad))

### test with dummy input

In [3]:
inp = torch.randn(32, 64, 56, 56)
inp = inp.cuda()

In [4]:
inp.device

device(type='cuda', index=0)

In [5]:
sa = RelativeSelfAttention(64, 128, 7, 8)

In [6]:
sa = sa.cuda()

In [7]:
out = sa(inp)

In [8]:
out.shape

torch.Size([32, 128, 56, 56])

In [11]:
sa.summary(inp)

RelativeSelfAttention (Input shape: ['32 x 64 x 56 x 56'])
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               32 x 128 x 56 x 56   8,320      True      
________________________________________________________________
Conv2d               32 x 128 x 56 x 56   8,320      True      
________________________________________________________________
Conv2d               32 x 128 x 56 x 56   8,320      True      
________________________________________________________________

Total params: 24,960
Total trainable params: 24,960
Total non-trainable params: 0
