## Implementation of the self-attention module in the SAGAN paper.

![self-attention module](images/sa.png)

In [1]:
from fastai2.vision.all import *

Dummy input

In [2]:
x = torch.randn(32, 128, 64, 64) # ie. bs*c*h*w

In [41]:
class SelfAttention(Module):
    def __init__(self, n_c, k):
        self.fx, self.gx, self.hx = [ConvLayer(n_c, n_c//k, ks=1, act_cls=None,
                                               ndim=1, norm_type=NormType.Spectral)
                                     for _ in range(3)]
        self.vx = ConvLayer(n_c//k, n_c, ks=1, act_cls=None,
                            ndim=1, norm_type=NormType.Spectral)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        size = x.size()
        x = x.view(*size[:-2], -1)
        
        f, g, h = self.fx(x), self.gx(x), self.hx(x)
        att_map = F.softmax(torch.bmm(f.transpose(1,2), g), dim=1)
        
        v = self.vx(torch.bmm(h, att_map))
        
        return (v * self.gamma + x).view(*size).contiguous()

Try

In [42]:
att = SelfAttention(x.shape[1], 8)

In [43]:
out = att(x)

In [44]:
out.shape

torch.Size([32, 128, 64, 64])