In [1]:
####

In [2]:
import torch
import os
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from torchsummary import summary
import torch.nn.functional as F

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_dropout = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_dropout = use_dropout
        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding)
        if self.use_norm:
            self.norm = nn.BatchNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.ReLU(inplace=True)
        if self.use_dropout:
            self.dropout = nn.Dropout()

    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_dropout:
            x = self.dropout(x)
        return x

In [5]:
class CAM(nn.Module):
    ## Channel Attention Module
    def __init__(self , 
                 in_channels , 
                 r=1):
        super(CAM , self).__init__()

        self.adp_max_pool = nn.AdaptiveMaxPool2d(output_size=(1 , 1))
        self.adp_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1 , 1))
        self.linear = nn.Linear(in_channels , in_channels//r)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self , x):
        x_max = self.adp_max_pool(x)
        x_avg = self.adp_avg_pool(x)
        #print(x_max.shape , x_avg.shape)
        x_max = self.linear(x_max.squeeze(-1).squeeze(-1))
        x_avg = self.linear(x_avg.squeeze(-1).squeeze(-1))

        x = x_max + x_avg
        x = self.sigmoid(x)
        return x.view(x.shape[0] , x.shape[1] , 1 , 1) 


In [None]:
x = torch.randn(2 , 64 , 128 , 128).to(device)
cam = CAM(64).to(device)
z = cam(x)
z.shape

In [7]:
class SAM(nn.Module):
    ## Spatial Attention Module
    def __init__(self , 
                 in_channels):
        super(SAM , self).__init__()

        self.conv1 = Conv(in_channels , 1 , kernel_size=(1 , 1) , stride=(1 , 1) , padding=0)

        self.conv2 = Conv(1 , in_channels//2)
        self.conv3 = Conv(in_channels//2 , 1)
    def forward(self , x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x        

In [None]:
x = torch.randn(2 , 512 , 32 , 32).to(device)
sam = SAM(512).to(device)
z = sam(x)
z.shape

In [9]:
class CBAM(nn.Module):
    def __init__(self , 
                 in_channels):
        super(CBAM , self).__init__()

        self.sam = SAM(in_channels)
        self.cam = CAM(in_channels)

    def forward(self , x):
        x_ = x.clone()
        x = self.cam(x)
        x = x_ * x
        x = self.sam(x)
        x = x_ * x
        return x


In [None]:
x = torch.randn(2 , 32 , 128 , 128).to(device)
cbam = CBAM(32).to(device)
z = cbam(x)
z.shape

In [11]:
class Resnet_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 downsample = False):
        super(Resnet_Block , self).__init__()

        self.downsample = downsample
    
        if self.downsample:
            self.conv1 = Conv(in_channels , 
                        in_channels , 
                        kernel_size=(2 , 2) , 
                        stride=(2 , 2) ,
                        padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                            out_channels ,
                            kernel_size = (2 ,2) , 
                            stride = (2 , 2) , 
                            padding = 0)
        else:    
            self.conv1 = Conv(in_channels , 
                            in_channels , 
                            kernel_size=(1 , 1) , 
                            stride=(1 , 1) ,
                            padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                              out_channels ,
                              kernel_size = (1 , 1) , 
                              stride = (1 ,1) , 
                              padding = 0)
            
        self.conv2 = Conv(in_channels , 
                          in_channels)
        
        self.conv3 = Conv(in_channels , 
                          out_channels , 
                          kernel_size = (1 , 1) , 
                          stride = (1 , 1) , 
                          padding = 0)
        
        self.cbam = CBAM(out_channels)
        
    def forward(self , x): 
        x_ = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.cbam(x)
        x_ = self.conv_skip(x_)
        x += x_
        return x

In [12]:
class Linear(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels):
        super(Linear , self).__init__()
        self.linear1 = nn.Linear(in_channels , out_channels)
        self.softmax = nn.Softmax(dim = 1)

    def forward(self , x):
        x = self.linear1(x)
        x = self.softmax(x)
        return x

In [13]:
class Resnet(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels):
        super(Resnet , self).__init__()

        self.conv1 = Conv(in_channels , 64 , kernel_size=(7 , 7) , stride=(2 , 2) , padding=3)

        self.conv2 = self._make_repeated_blocks(64 , 256 , 3 , downsample = False)
        self.conv3 = self._make_repeated_blocks(256 , 512 , 8)
        self.conv4 = self._make_repeated_blocks(512 , 1024 , 36)
        self.conv5 = self._make_repeated_blocks(1024 , 2048 , 3)
        self.linear = Linear(2048 , out_channels)

    def _make_repeated_blocks(self , in_channels , out_channels , repeats , downsample = True):
        layers = []
        for i in range(repeats):
            if i == 0 and downsample == True:
                layers.append(Resnet_Block(in_channels , out_channels , downsample=downsample))
            elif i == 0:
                layers.append(Resnet_Block(in_channels , out_channels))
            else:
                layers.append(Resnet_Block(out_channels , out_channels))
        return nn.Sequential(*layers)

    def forward(self , x):
        x = self.conv1(x)
        x = torch.max_pool2d(x , kernel_size = (2 , 2) , stride = (2 , 2))
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.shape[0] , x.shape[1])
        x = self.linear(x)
        return x

In [14]:
def test():
    resnet = Resnet(3 , 1000).to(device)
    x = torch.randn(2 , 3 , 224 , 224).to(device)
    z = resnet(x)
    print(z.shape)

In [None]:
test()