In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn

In [None]:
low_f = torch.randn(1, 17, 128, 128).float()

global_f = torch.randn(1, 17, 128, 128).float()

In [None]:
class LinearAttentionBlock(nn.Module):
    def __init__(self, in_f, normalize_attn=True):
        super(LinearAttentionBlock, self).__init__()
        self.normalize_attn = normalize_attn
        self.conv_pointwise = nn.Conv2d(
            in_channels=in_f, out_channels=1, 
            kernel_size=1, padding=0, bias=False
        )

    def forward(self, low_f, global_f):
        N, C, W, H = low_f.size()
        # (Batch size, 1, H, W)
        x_pointwised = self.conv_pointwise(low_f + global_f)
        if self.normalize_attn:
            # (Batch size, 1, H, W) -> (Batch size, 1, H * W)
            x_flatten = x_pointwised.view(N, 1, -1)
            # (Batch size, 1, H, W)
            x_attention = F.softmax(x_flatten, dim=2).view(N, 1, H, W)
        else:
            x_attention = torch.sigmoid(x_pointwised)
        # Apply attention to our input local features
        f_attented = torch.mul(x_pointwised.expand_as(low_f), low_f)
        if self.normalize_attn:
            # (Batch size, C, H, W) -> (Batch size, C, H * W) -> (Batch size, C)
            f_attented = f_attented.view(N, C, -1).sum(dim=2)
        else:
            f_attented = F.adaptive_avg_pool2d(f_attented, (1, 1)).view(N, C)
        x_pointwised = x_pointwised.view(N, 1, H, W)
        return x_pointwised, f_attented

In [None]:
att_l = LinearAttentionBlock(17, False)

In [None]:
x_pointwised, f_attented = att_l(low_f, global_f)
x_pointwised.shape, f_attented.shape

In [None]:
import matplotlib.pyplot as plt

In [None]:
with torch.no_grad():
    plt.imshow(x_pointwised[0, 0])

In [None]:
f_attented

In [None]:
class PAM_Module(nn.Module):
    """ 
    Position attention module

    """
    #Ref from SAGAN
    def __init__(self, in_dim, dim_reduse: int = 8):
        super(PAM_Module, self).__init__()
        self.chanel_in = in_dim

        self.query_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim // dim_reduse, 
            kernel_size=1
        )
        self.key_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim // dim_reduse,
             kernel_size=1
        )
        self.value_conv = nn.Conv2d(
            in_channels=in_dim, out_channels=in_dim, 
            kernel_size=1
        )
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, return_attention=False):
        """
        Parameters:
        ----------
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X (HxW) X (HxW)
        """
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)

        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma * out + x

        if return_attention:
            return out, attention

        return out


class CAM_Module(nn.Module):
    """ 
    Channel attention module

    """
    def __init__(self, in_dim):
        super(CAM_Module, self).__init__()
        self.chanel_in = in_dim
        
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)

    def forward(self, x, return_attention=False):
        """
        Parameters:
        ----------
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X C X C
        """
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
       
        energy = torch.bmm(proj_query, proj_key)
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, C, -1)

        out = torch.bmm(attention, proj_value)
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma * out + x

        if return_attention:
            return out, attention

        return out


In [None]:
pam_l = PAM_Module(17)
cam_l = CAM_Module(17)

In [None]:
global_f.shape

In [None]:
out_pam = pam_l(global_f)

In [None]:
out_pam.shape

In [None]:
out_cam = cam_l(global_f)

In [None]:
out_cam.shape

In [None]:
!nvidia-smi