In [None]:
import torch.nn as nn
import resnet
import torch
from torchvision import models
from torchvision.models import ResNet34_Weights, ResNet50_Weights

## DANet

In [None]:
class PAM_Module(nn.Module):
    def __init__(self, in_dim):
        super(PAM_Module, self).__init__()
        self.channel_in = in_dim
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1, stride=1, padding='valid')
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1, stride=1, padding='valid')
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1, stride=1, padding='valid')
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
        proj_value = proj_value.permute(0, 2, 1)
        out = torch.bmm(attention, proj_value).view(m_batchsize, C, height, width)
        out = self.gamma * out + x
        return out

class CAM_Module(nn.Module):
    def __init__(self):
        super(CAM_Module, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = energy.max(dim=-1, keepdim=True)[0]
        energy_new = energy_new - energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, C, -1)

        out = torch.bmm(attention, proj_value).view(m_batchsize, C, height, width)
        out = self.gamma * out + x
        return out

class DANet(nn.Module):
    def __init__(self, in_channels=512, out_channels=128, out_dim=8):
        super(DANet, self).__init__()
        inter_channels = in_channels // 4
        base_model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
        self.backbone = torch.nn.Sequential(*list(base_model.children())[:-2])

        self.conv5a = nn.Conv2d(in_channels, inter_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(inter_channels)
        self.relu = nn.ReLU()
        self.conv5c = nn.Conv2d(in_channels, inter_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(inter_channels)
        self.sa = PAM_Module(inter_channels)
        self.sc = CAM_Module()
        self.conv51 = nn.Conv2d(inter_channels, inter_channels, kernel_size=3, padding=1)
        self.conv52 = nn.Conv2d(inter_channels, inter_channels, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(inter_channels)
        self.bn4 = nn.BatchNorm2d(inter_channels)
        self.dropout = nn.Dropout2d(0.1)
        self.conv6 = nn.Conv2d(inter_channels, out_channels, kernel_size=1)
        self.conv7 = nn.Conv2d(inter_channels, out_channels, kernel_size=1)
        self.conv8 = nn.Conv2d(inter_channels, out_channels, kernel_size=1)

        self.fc = nn.Sequential(
            nn.Linear(6272, 2048),
            nn.BatchNorm1d(2048),
            nn.LeakyReLU(),
            nn.Dropout(0.3),
            nn.Linear(2048, out_dim)
        )

        self.feat = None
        self.gradients = None
    def activations_hook(self, grad):
        self.gradients = grad
    
    def forward(self, x):
        x = self.backbone(x)
        feat1 = self.relu(self.bn1(self.conv5a(x)))
        sa_feat = self.sa(feat1)
        sa_conv = self.relu(self.bn3(self.conv51(sa_feat)))
        sa_output = self.conv6(self.dropout(sa_conv))

        feat2 = self.relu(self.bn2(self.conv5c(x)))
        sc_feat = self.sc(feat2)
        sc_conv = self.relu(self.bn4(self.conv52(sc_feat)))
        sc_output = self.conv7(self.dropout(sc_conv))

        # feat_sum = sa_conv + sc_conv
        feat_sum = sa_output + sc_output
        sasc_output = self.conv8(self.dropout(feat_sum))
        self.feat = sasc_output
        h = sasc_output.register_hook(self.activations_hook)
        output = nn.Flatten()(sasc_output)
        output = self.fc(output)
        return output
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self):
        return self.feat

## CBAM

In [None]:
def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out    

class CBAM(nn.Module):

    def __init__(self, num_classes=8):
        self.inplanes = 512
        super(CBAM, self).__init__()
        base_model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
        self.backbone = torch.nn.Sequential(*list(base_model.children())[:-2])
        self.layer = self._make_layer(BasicBlock, 512, 2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, out_dim)
        self.feat = None
        self.gradients = None
        
    def activations_hook(self, grad):
        self.gradients = grad
        
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.backbone(x)
        x = self.layer(x)
        self.feat = x
        h = x.register_hook(self.activations_hook)
        output = self.avgpool(x)
        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return output
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self):
        return self.feat
    

## PAN

In [None]:
import resnet

class ResNet50(nn.Module):
    def __init__(self, pretrained=True):
        """Declare all needed layers."""
        super(ResNet50, self).__init__()
        self.model = resnet.resnet50(pretrained=pretrained)
        self.relu = self.model.relu  # Place a hook

        layers_cfg = [4, 5, 6, 7]
        self.blocks = []
        for i, num_this_layer in enumerate(layers_cfg):
            self.blocks.append(list(self.model.children())[num_this_layer])

    def forward(self, x):
        feature_map = []
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        for i, block in enumerate(self.blocks):
            x = block(x)
            feature_map.append(x)

        out = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], -1)

        return feature_map, out
    
class ResNet34(nn.Module):
    def __init__(self, pretrained=True):
        """Declare all needed layers."""
        super(ResNet34, self).__init__()
        self.model = resnet.resnet34(pretrained=pretrained)
        self.relu = self.model.relu  # Place a hook

        layers_cfg = [4, 5, 6, 7]
        self.blocks = []
        for i, num_this_layer in enumerate(layers_cfg):
            self.blocks.append(list(self.model.children())[num_this_layer])

    def forward(self, x):
        feature_map = []
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        for i, block in enumerate(self.blocks):
            x = block(x)
            feature_map.append(x)

        out = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], -1)

        return feature_map, out

class Classifier(nn.Module):
    def __init__(self, in_features=2048, num_class=20):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(in_features, num_class)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.fc1(x)
        return x

class FPA(nn.Module):
    def __init__(self, channels=2048):
        """
        Feature Pyramid Attention
        :type channels: int
        """
        super(FPA, self).__init__()
        channels_mid = int(channels/4)

        self.channels_cond = channels

        # Master branch
        self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
        self.bn_master = nn.BatchNorm2d(channels)

        # Global pooling branch
        self.conv_gpb = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
        self.bn_gpb = nn.BatchNorm2d(channels)

        # C333 because of the shape of last feature maps is (16, 16).
        self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False)
        self.bn1_1 = nn.BatchNorm2d(channels_mid)
        self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False)
        self.bn2_1 = nn.BatchNorm2d(channels_mid)
        self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False)
        self.bn3_1 = nn.BatchNorm2d(channels_mid)

        self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False)
        self.bn1_2 = nn.BatchNorm2d(channels_mid)
        self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False)
        self.bn2_2 = nn.BatchNorm2d(channels_mid)
        self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False)
        self.bn3_2 = nn.BatchNorm2d(channels_mid)

        # Convolution Upsample
        self.conv_upsample_3 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn_upsample_3 = nn.BatchNorm2d(channels_mid)

        self.conv_upsample_2 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=1, padding=0, bias=False)
        self.bn_upsample_2 = nn.BatchNorm2d(channels_mid)

        self.conv_upsample_1 = nn.ConvTranspose2d(channels_mid, channels, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn_upsample_1 = nn.BatchNorm2d(channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        """
        :param x: Shape: [b, 2048, h, w]
        :return: out: Feature maps. Shape: [b, 2048, h, w]
        """
        # Master branch
        x_master = self.conv_master(x)
        x_master = self.bn_master(x_master)

        # Global pooling branch
        x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1)
        x_gpb = self.conv_gpb(x_gpb)
        x_gpb = self.bn_gpb(x_gpb)

        # Branch 1
        x1_1 = self.conv7x7_1(x)
        x1_1 = self.bn1_1(x1_1)
        x1_1 = self.relu(x1_1)
        x1_2 = self.conv7x7_2(x1_1)
        x1_2 = self.bn1_2(x1_2)

        # Branch 2
        x2_1 = self.conv5x5_1(x1_1)
        x2_1 = self.bn2_1(x2_1)
        x2_1 = self.relu(x2_1)
        x2_2 = self.conv5x5_2(x2_1)
        x2_2 = self.bn2_2(x2_2)

        # Branch 3
        x3_1 = self.conv3x3_1(x2_1)
        x3_1 = self.bn3_1(x3_1)
        x3_1 = self.relu(x3_1)
        x3_2 = self.conv3x3_2(x3_1)
        x3_2 = self.bn3_2(x3_2)

        # Merge branch 1 and 2
        x3_upsample = self.relu(self.bn_upsample_3(self.conv_upsample_3(x3_2)))
        x2_merge = self.relu(x2_2 + x3_upsample)
        x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge)))
        x1_merge = self.relu(x1_2 + x2_upsample)

        x_master = x_master * self.relu(self.bn_upsample_1(self.conv_upsample_1(x1_merge)))

        #
        out = self.relu(x_master + x_gpb)

        return out

class GAU(nn.Module):
    def __init__(self, channels_high, channels_low, upsample=True):
        super(GAU, self).__init__()
        # Global Attention Upsample
        self.upsample = upsample
        self.conv3x3 = nn.Conv2d(channels_low, channels_low, kernel_size=3, padding=1, bias=False)
        self.bn_low = nn.BatchNorm2d(channels_low)

        self.conv1x1 = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False)
        self.bn_high = nn.BatchNorm2d(channels_low)

        if upsample:
            self.conv_upsample = nn.ConvTranspose2d(channels_high, channels_low, kernel_size=4, stride=2, padding=1, bias=False)
            self.bn_upsample = nn.BatchNorm2d(channels_low)
        else:
            self.conv_reduction = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False)
            self.bn_reduction = nn.BatchNorm2d(channels_low)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, fms_high, fms_low, fm_mask=None):
        """
        Use the high level features with abundant catagory information to weight the low level features with pixel
        localization information. In the meantime, we further use mask feature maps with catagory-specific information
        to localize the mask position.
        :param fms_high: Features of high level. Tensor.
        :param fms_low: Features of low level.  Tensor.
        :param fm_mask:
        :return: fms_att_upsample
        """
        b, c, h, w = fms_high.shape

        fms_high_gp = nn.AvgPool2d(fms_high.shape[2:])(fms_high).view(len(fms_high), c, 1, 1)
        fms_high_gp = self.conv1x1(fms_high_gp)
        fms_high_gp = self.bn_high(fms_high_gp)
        fms_high_gp = self.relu(fms_high_gp)

        # fms_low_mask = torch.cat([fms_low, fm_mask], dim=1)
        fms_low_mask = self.conv3x3(fms_low)
        fms_low_mask = self.bn_low(fms_low_mask)

        fms_att = fms_low_mask * fms_high_gp
        if self.upsample:
            out = self.relu(
                self.bn_upsample(self.conv_upsample(fms_high)) + fms_att)
        else:
            out = self.relu(
                self.bn_reduction(self.conv_reduction(fms_high)) + fms_att)

        return out
    
    
class PAN(nn.Module):
    def __init__(self, blocks=[]):
        """
        :param blocks: Blocks of the network with reverse sequential.
        """
        super(PAN, self).__init__()
        channels_blocks = []
        for i, block in enumerate(blocks):
            channels_blocks.append(list(list(block.children())[2].children())[4].weight.shape[0])

        self.fpa = FPA(channels=channels_blocks[0])
        # channels_high = channels_blocks[0]
        # for i, channels_low in enumerate(channels_blocks[1:]):
        #     self.gau.append(GAU(channels_high, channels_low))
        #     channels_high = channels_low
        self.gau_block1 = GAU(channels_blocks[0], channels_blocks[1], upsample=False)
        self.gau_block2 = GAU(channels_blocks[1], channels_blocks[2])
        self.gau_block3 = GAU(channels_blocks[2], channels_blocks[3])
        self.gau = [self.gau_block1, self.gau_block2, self.gau_block3]

        self.relu = nn.ReLU(inplace=True)

    def forward(self, fms=[]):
        """
        :param fms: Feature maps of forward propagation in the network with reverse sequential. shape:[b, c, h, w]
        :return: fm_high. [b, 256, h, w]
        """
        for i, fm_low in enumerate(fms):
            if i == 0:
                fm_high = self.fpa(fm_low)
            else:
                fm_high = self.gau[int(i-1)](fm_high, fm_low)

        return fm_high

class PAN_final(nn.Module):
    def __init__(self, num_class=8):
        super(PAN_final, self).__init__()
        
        self.convnet = ResNet34(pretrained=True)
        self.pan = PAN(self.convnet.blocks[::-1])
        self.cls = Classifier(in_features=64, num_class=num_class)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.feat = None
        self.gradients = None
        
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        fms_blob, _ = self.convnet(x)
        out_ss = self.pan(fms_blob[::-1])
        self.feat = out_ss
        h = out_ss.register_hook(self.activations_hook)
        out = self.avgpool(out_ss)
        out = out.view(out.size(0), -1)
        out = self.cls(out)
        return out
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self):
        return self.feat

## CBAM_DA

In [None]:
def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class PAM_Module(nn.Module):
    def __init__(self, in_dim):
        super(PAM_Module, self).__init__()
        self.channel_in = in_dim
        self.query_conv = nn.Conv2d(in_dim, in_dim // 4, kernel_size=1, stride=1, padding='valid')
        self.key_conv = nn.Conv2d(in_dim, in_dim // 4, kernel_size=1, stride=1, padding='valid')
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1, stride=1, padding='valid')
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
        

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
        proj_value = proj_value.permute(0, 2, 1)
        out = torch.bmm(attention, proj_value).view(m_batchsize, C, height, width)
        out = self.gamma * out + x
        return out

class CAM_Module(nn.Module):
    def __init__(self):
        super(CAM_Module, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = energy.max(dim=-1, keepdim=True)[0]
        energy_new = energy_new - energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, C, -1)

        out = torch.bmm(attention, proj_value).view(m_batchsize, C, height, width)
        out = self.gamma * out + x
        return out


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.ca = PAM_Module(planes)
        self.sa = CAM_Module()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.ca(out)
        out = self.sa(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class CBAM_DA(nn.Module):

    def __init__(self, num_classes=8):
        self.inplanes = 512
        super(CBAM_DA, self).__init__()
        base_model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
        self.backbone = torch.nn.Sequential(*list(base_model.children())[:-2])
        self.layer = self._make_layer(BasicBlock, 512, 2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        self.feat = None
        self.gradients = None
        
    def activations_hook(self, grad):
        self.gradients = grad
        
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.backbone(x)
        x = self.layer(x)
        self.feat = x
        h = x.register_hook(self.activations_hook)
        output = self.avgpool(x)
        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return output
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self):
        return self.feat

# PAN_DA

In [None]:
import resnet

class ResNet50(nn.Module):
    def __init__(self, pretrained=True):
        """Declare all needed layers."""
        super(ResNet50, self).__init__()
        self.model = resnet.resnet50(pretrained=pretrained)
        self.relu = self.model.relu  # Place a hook

        layers_cfg = [4, 5, 6, 7]
        self.blocks = []
        for i, num_this_layer in enumerate(layers_cfg):
            self.blocks.append(list(self.model.children())[num_this_layer])

    def forward(self, x):
        feature_map = []
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        for i, block in enumerate(self.blocks):
            x = block(x)
            feature_map.append(x)

        out = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], -1)

        return feature_map, out
    
class ResNet34(nn.Module):
    def __init__(self, pretrained=True):
        """Declare all needed layers."""
        super(ResNet34, self).__init__()
        self.model = resnet.resnet34(pretrained=pretrained)
        self.relu = self.model.relu  # Place a hook

        layers_cfg = [4, 5, 6, 7]
        self.blocks = []
        for i, num_this_layer in enumerate(layers_cfg):
            self.blocks.append(list(self.model.children())[num_this_layer])

    def forward(self, x):
        feature_map = []
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        for i, block in enumerate(self.blocks):
            x = block(x)
            feature_map.append(x)

        out = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], -1)

        return feature_map, out

class Classifier(nn.Module):
    def __init__(self, in_features=2048, num_class=20):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(in_features, num_class)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.fc1(x)
        return x

class FPA(nn.Module):
    def __init__(self, channels=2048):
        """
        Feature Pyramid Attention
        :type channels: int
        """
        super(FPA, self).__init__()
        channels_mid = int(channels/4)

        self.channels_cond = channels

        # Master branch
        self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
        self.bn_master = nn.BatchNorm2d(channels)

        # Global pooling branch
        self.conv_gpb = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
        self.bn_gpb = nn.BatchNorm2d(channels)

        # C333 because of the shape of last feature maps is (16, 16).
        self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False)
        self.bn1_1 = nn.BatchNorm2d(channels_mid)
        self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False)
        self.bn2_1 = nn.BatchNorm2d(channels_mid)
        self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False)
        self.bn3_1 = nn.BatchNorm2d(channels_mid)

        self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False)
        self.bn1_2 = nn.BatchNorm2d(channels_mid)
        self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False)
        self.bn2_2 = nn.BatchNorm2d(channels_mid)
        self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False)
        self.bn3_2 = nn.BatchNorm2d(channels_mid)

        # Convolution Upsample
        self.conv_upsample_3 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn_upsample_3 = nn.BatchNorm2d(channels_mid)

        self.conv_upsample_2 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=1, padding=0, bias=False)
        self.bn_upsample_2 = nn.BatchNorm2d(channels_mid)

        self.conv_upsample_1 = nn.ConvTranspose2d(channels_mid, channels, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn_upsample_1 = nn.BatchNorm2d(channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        """
        :param x: Shape: [b, 2048, h, w]
        :return: out: Feature maps. Shape: [b, 2048, h, w]
        """
        # Master branch
        x_master = self.conv_master(x)
        x_master = self.bn_master(x_master)

        # Global pooling branch
        x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1)
        x_gpb = self.conv_gpb(x_gpb)
        x_gpb = self.bn_gpb(x_gpb)

        # Branch 1
        x1_1 = self.conv7x7_1(x)
        x1_1 = self.bn1_1(x1_1)
        x1_1 = self.relu(x1_1)
        x1_2 = self.conv7x7_2(x1_1)
        x1_2 = self.bn1_2(x1_2)

        # Branch 2
        x2_1 = self.conv5x5_1(x1_1)
        x2_1 = self.bn2_1(x2_1)
        x2_1 = self.relu(x2_1)
        x2_2 = self.conv5x5_2(x2_1)
        x2_2 = self.bn2_2(x2_2)

        # Branch 3
        x3_1 = self.conv3x3_1(x2_1)
        x3_1 = self.bn3_1(x3_1)
        x3_1 = self.relu(x3_1)
        x3_2 = self.conv3x3_2(x3_1)
        x3_2 = self.bn3_2(x3_2)

        # Merge branch 1 and 2
        x3_upsample = self.relu(self.bn_upsample_3(self.conv_upsample_3(x3_2)))
        x2_merge = self.relu(x2_2 + x3_upsample)
        x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge)))
        x1_merge = self.relu(x1_2 + x2_upsample)

        x_master = x_master * self.relu(self.bn_upsample_1(self.conv_upsample_1(x1_merge)))

        #
        out = self.relu(x_master + x_gpb)

        return out

class GAU(nn.Module):
    def __init__(self, channels_high, channels_low, upsample=True):
        super(GAU, self).__init__()
        # Global Attention Upsample
        self.upsample = upsample
        self.conv3x3 = nn.Conv2d(channels_low, channels_low, kernel_size=3, padding=1, bias=False)
        self.bn_low = nn.BatchNorm2d(channels_low)

        self.conv1x1 = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False)
        self.bn_high = nn.BatchNorm2d(channels_low)

        if upsample:
            self.conv_upsample = nn.ConvTranspose2d(channels_high, channels_low, kernel_size=4, stride=2, padding=1, bias=False)
            self.bn_upsample = nn.BatchNorm2d(channels_low)
        else:
            self.conv_reduction = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False)
            self.bn_reduction = nn.BatchNorm2d(channels_low)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, fms_high, fms_low, fm_mask=None):
        """
        Use the high level features with abundant catagory information to weight the low level features with pixel
        localization information. In the meantime, we further use mask feature maps with catagory-specific information
        to localize the mask position.
        :param fms_high: Features of high level. Tensor.
        :param fms_low: Features of low level.  Tensor.
        :param fm_mask:
        :return: fms_att_upsample
        """
        b, c, h, w = fms_high.shape

        fms_high_gp = nn.AvgPool2d(fms_high.shape[2:])(fms_high).view(len(fms_high), c, 1, 1)
        fms_high_gp = self.conv1x1(fms_high_gp)
        fms_high_gp = self.bn_high(fms_high_gp)
        fms_high_gp = self.relu(fms_high_gp)

        # fms_low_mask = torch.cat([fms_low, fm_mask], dim=1)
        fms_low_mask = self.conv3x3(fms_low)
        fms_low_mask = self.bn_low(fms_low_mask)

        fms_att = fms_low_mask * fms_high_gp
        if self.upsample:
            out = self.relu(
                self.bn_upsample(self.conv_upsample(fms_high)) + fms_att)
        else:
            out = self.relu(
                self.bn_reduction(self.conv_reduction(fms_high)) + fms_att)

        return out


class PAM_Module(nn.Module):
    def __init__(self, in_dim):
        super(PAM_Module, self).__init__()
        self.channel_in = in_dim
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1, stride=1, padding='valid')
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1, stride=1, padding='valid')
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1, stride=1, padding='valid')
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
        proj_value = proj_value.permute(0, 2, 1)
        out = torch.bmm(attention, proj_value).view(m_batchsize, C, height, width)
        out = self.gamma * out + x
        return out

class CAM_Module(nn.Module):
    def __init__(self):
        super(CAM_Module, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = energy.max(dim=-1, keepdim=True)[0]
        energy_new = energy_new - energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, C, -1)

        out = torch.bmm(attention, proj_value).view(m_batchsize, C, height, width)
        out = self.gamma * out + x
        return out
    
class PAN(nn.Module):
    def __init__(self, blocks=[]):
        """
        :param blocks: Blocks of the network with reverse sequential.
        """
        super(PAN, self).__init__()
        channels_blocks = []
        for i, block in enumerate(blocks):
            channels_blocks.append(list(list(block.children())[2].children())[4].weight.shape[0])

        self.fpa = FPA(channels=channels_blocks[0])
        # channels_high = channels_blocks[0]
        # for i, channels_low in enumerate(channels_blocks[1:]):
        #     self.gau.append(GAU(channels_high, channels_low))
        #     channels_high = channels_low
        self.gau_block1 = GAU(channels_blocks[0], channels_blocks[1], upsample=False)
        self.gau_block2 = GAU(channels_blocks[1], channels_blocks[2])
        self.gau_block3 = GAU(channels_blocks[2], channels_blocks[3])
        self.gau = [self.gau_block1, self.gau_block2, self.gau_block3]

        self.relu = nn.ReLU(inplace=True)

    def forward(self, fms=[]):
        """
        :param fms: Feature maps of forward propagation in the network with reverse sequential. shape:[b, c, h, w]
        :return: fm_high. [b, 256, h, w]
        """
        for i, fm_low in enumerate(fms):
            if i == 0:
                fm_high = self.fpa(fm_low)
            else:
                fm_high = self.gau[int(i-1)](fm_high, fm_low)

        return fm_high

class PAN_DA(nn.Module):
    def __init__(self, in_channels=512, num_class=8):
        super(PAN_DA, self).__init__()
        
        self.convnet = ResNet34(pretrained=True)
        self.pan = PAN(self.convnet.blocks[::-1])
        inter_channels = in_channels // 4
        out_channels = in_channels // 4
        self.conv5a = nn.Conv2d(in_channels, inter_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(inter_channels)
        self.relu = nn.ReLU()
        self.conv5c = nn.Conv2d(in_channels, inter_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(inter_channels)
        self.sa = PAM_Module(inter_channels)
        self.sc = CAM_Module()
        self.conv51 = nn.Conv2d(inter_channels, inter_channels, kernel_size=3, padding=1)
        self.conv52 = nn.Conv2d(inter_channels, inter_channels, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(inter_channels)
        self.bn4 = nn.BatchNorm2d(inter_channels)
        self.dropout = nn.Dropout2d(0.1)
        self.conv6 = nn.Conv2d(inter_channels, out_channels, kernel_size=1)
        self.conv7 = nn.Conv2d(inter_channels, out_channels, kernel_size=1)
        self.conv8 = nn.Conv2d(inter_channels, out_channels, kernel_size=1)
        
        self.cls = Classifier(in_features=64, num_class=num_class)
        self.avgpool1 = nn.AdaptiveAvgPool2d((1, 1))
        self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))

        self.fc1 = Classifier(in_features=out_channels, num_class=num_class)
        self.fcf = Classifier(in_features=2*num_class, num_class=num_class)
        self.feat = None
        self.gradients = None
        
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        fms_blob, _ = self.convnet(x)
        out_ss = self.pan(fms_blob[::-1])
        self.feat = out_ss
        h = out_ss.register_hook(self.activations_hook)
        out1 = self.avgpool1(out_ss)
        out1 = out1.view(out1.size(0), -1)
        out1 = self.cls(out1)

        feat1 = self.relu(self.bn1(self.conv5a(fms_blob[-1])))
        sa_feat = self.sa(feat1)
        sa_conv = self.relu(self.bn3(self.conv51(sa_feat)))
        sa_output = self.conv6(self.dropout(sa_conv))

        feat2 = self.relu(self.bn2(self.conv5c(fms_blob[-1])))
        sc_feat = self.sc(feat2)
        sc_conv = self.relu(self.bn4(self.conv52(sc_feat)))
        sc_output = self.conv7(self.dropout(sc_conv))

        feat_sum = sa_output + sc_output
        sasc_output = self.conv8(self.dropout(feat_sum))
        out2 = self.avgpool2(sasc_output)
        out2 = out2.view(out2.shape[0], -1)
        out2 = self.fc1(out2)
        
        out3 = self.fcf(torch.cat([out1, out2], dim=1))
        
        return out3
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self):
        return self.feat

## PAN_Inverse

In [None]:
import resnet

class ResNet50(nn.Module):
    def __init__(self, pretrained=True):
        """Declare all needed layers."""
        super(ResNet50, self).__init__()
        self.model = resnet.resnet50(pretrained=pretrained)
        self.relu = self.model.relu  # Place a hook

        layers_cfg = [4, 5, 6, 7]
        self.blocks = []
        for i, num_this_layer in enumerate(layers_cfg):
            self.blocks.append(list(self.model.children())[num_this_layer])

    def forward(self, x):
        feature_map = []
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        for i, block in enumerate(self.blocks):
            x = block(x)
            feature_map.append(x)

        out = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], -1)

        return feature_map, out
    
class ResNet34(nn.Module):
    def __init__(self, pretrained=True):
        """Declare all needed layers."""
        super(ResNet34, self).__init__()
        self.model = resnet.resnet34(pretrained=pretrained)
        self.relu = self.model.relu  # Place a hook

        layers_cfg = [4, 5, 6, 7]
        self.blocks = []
        for i, num_this_layer in enumerate(layers_cfg):
            self.blocks.append(list(self.model.children())[num_this_layer])

    def forward(self, x):
        feature_map = []
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        for i, block in enumerate(self.blocks):
            x = block(x)
            feature_map.append(x)

        out = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], -1)

        return feature_map, out


class Classifier(nn.Module):
    def __init__(self, in_features=2048, num_class=20):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(in_features, num_class)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.fc1(x)
        return x

class FPA(nn.Module):
    def __init__(self, channels=2048):
        """
        Feature Pyramid Attention
        :type channels: int
        """
        super(FPA, self).__init__()
        channels_mid = int(channels/4)

        self.channels_cond = channels

        # Master branch
        self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
        self.bn_master = nn.BatchNorm2d(channels)

        # Global pooling branch
        self.conv_gpb = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
        self.bn_gpb = nn.BatchNorm2d(channels)

        # C333 because of the shape of last feature maps is (16, 16).
        self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False)
        self.bn1_1 = nn.BatchNorm2d(channels_mid)
        self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False)
        self.bn2_1 = nn.BatchNorm2d(channels_mid)
        self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False)
        self.bn3_1 = nn.BatchNorm2d(channels_mid)

        self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False)
        self.bn1_2 = nn.BatchNorm2d(channels_mid)
        self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False)
        self.bn2_2 = nn.BatchNorm2d(channels_mid)
        self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False)
        self.bn3_2 = nn.BatchNorm2d(channels_mid)

        # Convolution Upsample
        self.conv_upsample_3 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn_upsample_3 = nn.BatchNorm2d(channels_mid)

        self.conv_upsample_2 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=1, padding=0, bias=False)
        self.bn_upsample_2 = nn.BatchNorm2d(channels_mid)

        self.conv_upsample_1 = nn.ConvTranspose2d(channels_mid, channels, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn_upsample_1 = nn.BatchNorm2d(channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        """
        :param x: Shape: [b, 2048, h, w]
        :return: out: Feature maps. Shape: [b, 2048, h, w]
        """
        # Master branch
        x_master = self.conv_master(x)
        x_master = self.bn_master(x_master)

        # Global pooling branch
        x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1)
        x_gpb = self.conv_gpb(x_gpb)
        x_gpb = self.bn_gpb(x_gpb)

        # Branch 1
        x1_1 = self.conv7x7_1(x)
        x1_1 = self.bn1_1(x1_1)
        x1_1 = self.relu(x1_1)
        x1_2 = self.conv7x7_2(x1_1)
        x1_2 = self.bn1_2(x1_2)

        # Branch 2
        x2_1 = self.conv5x5_1(x1_1)
        x2_1 = self.bn2_1(x2_1)
        x2_1 = self.relu(x2_1)
        x2_2 = self.conv5x5_2(x2_1)
        x2_2 = self.bn2_2(x2_2)

        # Branch 3
        x3_1 = self.conv3x3_1(x2_1)
        x3_1 = self.bn3_1(x3_1)
        x3_1 = self.relu(x3_1)
        x3_2 = self.conv3x3_2(x3_1)
        x3_2 = self.bn3_2(x3_2)

        # Merge branch 1 and 2
        x3_upsample = self.relu(self.bn_upsample_3(self.conv_upsample_3(x3_2)))
        x2_merge = self.relu(x2_2 + x3_upsample)
        x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge)))
        x1_merge = self.relu(x1_2 + x2_upsample)

        x_master = x_master * self.relu(self.bn_upsample_1(self.conv_upsample_1(x1_merge)))

        #
        out = self.relu(x_master + x_gpb)

        return out


class GAD(nn.Module):
    def __init__(self, channels_low, channels_high, downsample=True):
        super(GAD, self).__init__()
        # Global Attention Upsample
        self.downsample = downsample
        self.conv3x3 = nn.Conv2d(channels_high, channels_high, kernel_size=3, padding=1, bias=False)
        self.bn_low = nn.BatchNorm2d(channels_high)

        self.conv1x1 = nn.Conv2d(channels_low, channels_high, kernel_size=1, padding=0, bias=False)
        self.bn_high = nn.BatchNorm2d(channels_high)

        if self.downsample:
            self.conv_down = nn.Conv2d(channels_low, channels_high, kernel_size=4, stride=2, padding=1, bias=False)
            self.bn_down = nn.BatchNorm2d(channels_high)
        else:
            self.conv_down = nn.Conv2d(channels_low, channels_high, kernel_size=3, padding=1, bias=False)
            self.bn_down = nn.BatchNorm2d(channels_high)

        self.bn_final = nn.BatchNorm2d(channels_high)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, fms_low, fms_high, fm_mask=None):

        b, c, h, w = fms_low.shape

        fms_low_gp = nn.AvgPool2d(fms_low.shape[2:])(fms_low).view(len(fms_low), c, 1, 1)
        fms_low_gp = self.conv1x1(fms_low_gp)
        fms_low_gp = self.bn_low(fms_low_gp)
        fms_low_gp = self.relu(fms_low_gp)

        # fms_low_mask = torch.cat([fms_low, fm_mask], dim=1)
        fms_high_mask = self.conv3x3(fms_high)
        fms_high_mask = self.bn_high(fms_high_mask)

        fms_att = fms_high_mask * fms_low_gp

        out = self.relu(self.bn_final(fms_att + self.bn_down(self.conv_down(fms_low))))



        return out
    
class PAN(nn.Module):
    def __init__(self, blocks=[]):

        super(PAN, self).__init__()
        channels_blocks = []
        for i, block in enumerate(blocks):
            channels_blocks.append(list(list(block.children())[2].children())[4].weight.shape[0])

        self.fpa = FPA(channels=channels_blocks[-1])
        # channels_high = channels_blocks[0]
        # for i, channels_low in enumerate(channels_blocks[1:]):
        #     self.gau.append(GAU(channels_high, channels_low))
        #     channels_high = channels_low
        self.gau_block1 = GAD(channels_blocks[0], channels_blocks[1])
        self.gau_block2 = GAD(channels_blocks[1], channels_blocks[2])
        self.gau_block3 = GAD(channels_blocks[2], channels_blocks[3], downsample=False)
        self.gau = [self.gau_block1, self.gau_block2, self.gau_block3]

        self.relu = nn.ReLU(inplace=True)

    def forward(self, fms=[]):

        for i, fm_low in enumerate(fms):
            if i == len(fms) - 1:
                fm_high = self.gau[int(i-1)](fm_high, fm_low)
                fm_high = self.fpa(fm_high)
            elif i == 0:
                fm_high = fm_low
            else:
                fm_high = self.gau[int(i-1)](fm_high, fm_low)

        return fm_high

class PAN_Inverse(nn.Module):
    def __init__(self, num_class=8):
        super(PAN_Inverse, self).__init__()
        
        self.convnet = ResNet34(pretrained=True)
        self.pan = PAN(self.convnet.blocks)
        self.cls = Classifier(in_features=512, num_class=num_class)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.feat = None
        self.gradients = None
        
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        fms_blob, _ = self.convnet(x)
        out_ss = self.pan(fms_blob)
        self.feat = out_ss
        h = out_ss.register_hook(self.activations_hook)
        out = self.avgpool(out_ss)
        out = out.view(out.size(0), -1)
        out = self.cls(out)
        return out
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self):
        return self.feat

## PAN_DualInverse

In [None]:
import resnet

class ResNet50(nn.Module):
    def __init__(self, pretrained=True):
        """Declare all needed layers."""
        super(ResNet50, self).__init__()
        self.model = resnet.resnet50(pretrained=pretrained)
        self.relu = self.model.relu  # Place a hook

        layers_cfg = [4, 5, 6, 7]
        self.blocks = []
        for i, num_this_layer in enumerate(layers_cfg):
            self.blocks.append(list(self.model.children())[num_this_layer])

    def forward(self, x):
        feature_map = []
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        for i, block in enumerate(self.blocks):
            x = block(x)
            feature_map.append(x)

        out = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], -1)

        return feature_map, out
    
class ResNet34(nn.Module):
    def __init__(self, pretrained=True):
        """Declare all needed layers."""
        super(ResNet34, self).__init__()
        self.model = resnet.resnet34(pretrained=pretrained)
        self.relu = self.model.relu  # Place a hook

        layers_cfg = [4, 5, 6, 7]
        self.blocks = []
        for i, num_this_layer in enumerate(layers_cfg):
            self.blocks.append(list(self.model.children())[num_this_layer])

    def forward(self, x):
        feature_map = []
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        for i, block in enumerate(self.blocks):
            x = block(x)
            feature_map.append(x)

        out = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], -1)

        return feature_map, out

class Classifier(nn.Module):
    def __init__(self, in_features=2048, num_class=20):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(in_features, num_class)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.fc1(x)
        return x

class FPA(nn.Module):
    def __init__(self, channels=2048):
        """
        Feature Pyramid Attention
        :type channels: int
        """
        super(FPA, self).__init__()
        channels_mid = int(channels/4)

        self.channels_cond = channels

        # Master branch
        self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
        self.bn_master = nn.BatchNorm2d(channels)

        # Global pooling branch
        self.conv_gpb = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False)
        self.bn_gpb = nn.BatchNorm2d(channels)

        # C333 because of the shape of last feature maps is (16, 16).
        self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False)
        self.bn1_1 = nn.BatchNorm2d(channels_mid)
        self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False)
        self.bn2_1 = nn.BatchNorm2d(channels_mid)
        self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False)
        self.bn3_1 = nn.BatchNorm2d(channels_mid)

        self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False)
        self.bn1_2 = nn.BatchNorm2d(channels_mid)
        self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False)
        self.bn2_2 = nn.BatchNorm2d(channels_mid)
        self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False)
        self.bn3_2 = nn.BatchNorm2d(channels_mid)

        # Convolution Upsample
        self.conv_upsample_3 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn_upsample_3 = nn.BatchNorm2d(channels_mid)

        self.conv_upsample_2 = nn.ConvTranspose2d(channels_mid, channels_mid, kernel_size=4, stride=1, padding=0, bias=False)
        self.bn_upsample_2 = nn.BatchNorm2d(channels_mid)

        self.conv_upsample_1 = nn.ConvTranspose2d(channels_mid, channels, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn_upsample_1 = nn.BatchNorm2d(channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        """
        :param x: Shape: [b, 2048, h, w]
        :return: out: Feature maps. Shape: [b, 2048, h, w]
        """
        # Master branch
        x_master = self.conv_master(x)
        x_master = self.bn_master(x_master)

        # Global pooling branch
        x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1)
        x_gpb = self.conv_gpb(x_gpb)
        x_gpb = self.bn_gpb(x_gpb)

        # Branch 1
        x1_1 = self.conv7x7_1(x)
        x1_1 = self.bn1_1(x1_1)
        x1_1 = self.relu(x1_1)
        x1_2 = self.conv7x7_2(x1_1)
        x1_2 = self.bn1_2(x1_2)

        # Branch 2
        x2_1 = self.conv5x5_1(x1_1)
        x2_1 = self.bn2_1(x2_1)
        x2_1 = self.relu(x2_1)
        x2_2 = self.conv5x5_2(x2_1)
        x2_2 = self.bn2_2(x2_2)

        # Branch 3
        x3_1 = self.conv3x3_1(x2_1)
        x3_1 = self.bn3_1(x3_1)
        x3_1 = self.relu(x3_1)
        x3_2 = self.conv3x3_2(x3_1)
        x3_2 = self.bn3_2(x3_2)

        # Merge branch 1 and 2
        x3_upsample = self.relu(self.bn_upsample_3(self.conv_upsample_3(x3_2)))
        x2_merge = self.relu(x2_2 + x3_upsample)
        x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge)))
        x1_merge = self.relu(x1_2 + x2_upsample)

        x_master = x_master * self.relu(self.bn_upsample_1(self.conv_upsample_1(x1_merge)))

        #
        out = self.relu(x_master + x_gpb)

        return out

class GAD(nn.Module):
    def __init__(self, channels_low, channels_high, downsample=True):
        super(GAD, self).__init__()
        # Global Attention Upsample
        self.downsample = downsample
        self.conv3x3 = nn.Conv2d(channels_high, channels_high, kernel_size=3, padding=1, bias=False)
        self.bn_low = nn.BatchNorm2d(channels_high)

        self.conv1x1 = nn.Conv2d(channels_low, channels_high, kernel_size=1, padding=0, bias=False)
        self.bn_high = nn.BatchNorm2d(channels_high)

        if self.downsample:
            self.conv_down = nn.Conv2d(channels_low, channels_high, kernel_size=4, stride=2, padding=1, bias=False)
            self.bn_down = nn.BatchNorm2d(channels_high)
        else:
            self.conv_down = nn.Conv2d(channels_low, channels_high, kernel_size=3, padding=1, bias=False)
            self.bn_down = nn.BatchNorm2d(channels_high)

        self.bn_final = nn.BatchNorm2d(channels_high)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, fms_low, fms_high, fm_mask=None):

        b, c, h, w = fms_low.shape

        fms_low_gp = nn.AvgPool2d(fms_low.shape[2:])(fms_low).view(len(fms_low), c, 1, 1)
        fms_low_gp = self.conv1x1(fms_low_gp)
        fms_low_gp = self.bn_low(fms_low_gp)
        fms_low_gp = self.relu(fms_low_gp)

        # fms_low_mask = torch.cat([fms_low, fm_mask], dim=1)
        fms_high_mask = self.conv3x3(fms_high)
        fms_high_mask = self.bn_high(fms_high_mask)

        fms_att = fms_high_mask * fms_low_gp

        out = self.relu(self.bn_final(fms_att + self.bn_down(self.conv_down(fms_low))))



        return out

class GADS(nn.Module):
    def __init__(self, in_channels, hw_low, hw_high):
        super(GADS, self).__init__()
        self.conv3x3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.bn_low = nn.BatchNorm2d(in_channels)
        
        
        
        if hw_low // hw_high == 2:
            self.conv1 = nn.Conv2d(2, 1, kernel_size=7, stride=2, padding=3, bias=False)
            self.conv_down = nn.Conv2d(in_channels//2, in_channels, kernel_size=3, stride=2, padding=2, dilation=2)
            self.bn_down = nn.BatchNorm2d(in_channels)
        elif hw_low == hw_high:
            self.conv1 = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
            self.conv_down = nn.Conv2d(in_channels//2, in_channels, kernel_size=3, padding=1)
            self.bn_down = nn.BatchNorm2d(in_channels)
            
        self.bn_final = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, fms_low, fms_high, fm_mask=None):
        
        avg_out = torch.mean(fms_low, dim=1, keepdim=True)
        max_out, _ = torch.max(fms_low, dim=1, keepdim=True)
        mask_x = torch.cat([avg_out, max_out], dim=1)
        mask_x = self.conv1(mask_x)
        
        x = self.conv3x3(fms_high)
        x = self.bn_low(x)

        fms_att = x * mask_x
        out = self.relu(fms_att + self.bn_down(self.conv_down(fms_low)))

        return out
    
class PAN(nn.Module):
    def __init__(self, blocks=[]):

        super(PAN, self).__init__()
        channels_blocks = []
        spatial_blocks = [56, 28, 14, 14]
        for i, block in enumerate(blocks):
            channels_blocks.append(list(list(block.children())[2].children())[4].weight.shape[0])
            
        self.fpa = FPA(channels=channels_blocks[-1])
        self.fpas = FPA(channels=channels_blocks[-1])

        self.gad_block1 = GAD(channels_blocks[0], channels_blocks[1])
        self.gad_block2 = GAD(channels_blocks[1], channels_blocks[2])
        self.gad_block3 = GAD(channels_blocks[2], channels_blocks[3], downsample=False)
        self.gad = [self.gad_block1, self.gad_block2, self.gad_block3]
        
        self.gads_block1 = GADS(channels_blocks[1], spatial_blocks[0], spatial_blocks[1])
        self.gads_block2 = GADS(channels_blocks[2], spatial_blocks[1], spatial_blocks[2])
        self.gads_block3 = GADS(channels_blocks[3], spatial_blocks[2], spatial_blocks[3])
        self.gads = [self.gads_block1, self.gads_block2, self.gads_block3]
        
        self.relu = nn.ReLU(inplace=True)

    def forward(self, fms=[]):

        for i, fm_low in enumerate(fms):
            if i == len(fms) - 1:
                fm1 = self.gad[int(i-1)](fm_high, fm_low)
                fm1 = self.fpa(fm1)
            elif i == 0:
                fm_high = fm_low
            else:
                fm_high = self.gad[int(i-1)](fm_high, fm_low)
                
        for i, fm_low in enumerate(fms):
            if i == len(fms) - 1:
                fm2 = self.gads[int(i-1)](fm_high, fm_low)
                fm2 = self.fpas(fm2)
            elif i == 0:
                fm_high = fm_low
            else:
                fm_high = self.gads[int(i-1)](fm_high, fm_low)
                       
        out = self.relu(fm1 + fm2)
        return out

class PAN_DualInverse(nn.Module):
    def __init__(self, num_class=8):
        super(PAN_DualInverse, self).__init__()
        
        self.convnet = ResNet34(pretrained=True)
        self.pan = PAN(self.convnet.blocks)
        self.cls = Classifier(in_features=512, num_class=num_class)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.feat = None
        self.gradients = None
        
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        fms_blob, _ = self.convnet(x)
        out_ss = self.pan(fms_blob)
        self.feat = out_ss
        h = out_ss.register_hook(self.activations_hook)
        out = self.avgpool(out_ss)
        out = out.view(out.size(0), -1)
        out = self.cls(out)
        return out
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self):
        return self.feat