In [1]:
!pip install timm > /dev/null

In [33]:
import os
import shutil
import timm

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

import random
import gc

In [6]:
def get_activation(activ_name: str = 'relu'):
  """ this function will return the activation function """

  act_dict = {
      "relu": nn.ReLU(inplace=True),
        "tanh": nn.Tanh(),
        "sigmoid": nn.Sigmoid(),
        "identity": nn.Identity(),
  }

  if activ_name in act_dict:
    return act_dict[activ_name]
  else:
    raise NotImplementedError

In [19]:
class Conv2DBNActiv(nn.Module):
  """ Conv2d -> (BN ->) -> Activation """

  def __init__(self,
               in_channels : int,
               out_channels : int,
               kernel_size : int,
               stride : int = 1,
               padding : int = 0,
               bias : bool = False,
               use_bn : bool = True,
               activ : str = 'relu',
               ):
    super(Conv2DBNActiv, self).__init__()
    layers = []
    layers.append(nn.Conv2d(
            in_channels, out_channels,
            kernel_size, stride, padding, bias=bias))
    if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
    
    layers.append(get_activation(activ))
    self.layers = nn.Sequential(*layers)
  
  def forward(self, x):
    x = self.layers(x)
    return x

In [22]:
class SpatialAttentionBlock(nn.Module):
  """Spatial Attention for (C, H, W) feature maps"""

  def __init__(
        self, in_channels: int,
        out_channels_list,
    ):
    super(SpatialAttentionBlock, self).__init__()
    self.n_layers = len(out_channels_list)
    channels_list = [in_channels] + out_channels_list
    assert self.n_layers > 0
    assert channels_list[-1] == 1

    for i in range(self.n_layers - 1):
      in_chs, out_chs = channels_list[i : i + 2]
      layer = Conv2DBNActiv(in_chs, out_chs, 3, 1, 1, activ="relu")
      setattr(self, f"conv{i + 1}" , layer)
    
    in_chs, out_chs = channels_list[-2:]
    layer = Conv2DBNActiv(in_chs, out_chs, 3, 1, 1, activ="sigmoid")
    setattr(self, f"conv{self.n_layers}", layer)

  def forward(self, x):
    h = x
    for i in range(self.n_layers):
        h = getattr(self, f"conv{i + 1}")(h)
        
    h = h * x
    return h

In [27]:
SpatialAttentionBlock(3, [64, 32, 16, 1])

SpatialAttentionBlock(
  (conv1): Conv2DBNActiv(
    (layers): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (conv2): Conv2DBNActiv(
    (layers): Sequential(
      (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (conv3): Conv2DBNActiv(
    (layers): Sequential(
      (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (conv4): Conv2DBNActiv(
    (layers): Sequential(
      (0): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d

In [45]:
class CustomModel(nn.Module):
  def __init__(
        self, 
        base_name: str = 'resnet18', 
        out_dim: int = 2, 
        pretrained = False,
    ):
        self.base_name = base_name
        super(CustomModel, self).__init__()

        #loading the base model
        base_model = timm.create_model(base_name, pretrained=pretrained)
        in_features = base_model.num_features

        # remove global pooling and head classifier
        # base_model.reset_classifier(0)
        base_model.reset_classifier(0, '')

        self.backbone = base_model

        self.head_fc = nn.Sequential(
            SpatialAttentionBlock(in_features, [64, 32, 16, 1]),
            nn.AdaptiveAvgPool2d(output_size=1),
            nn.Flatten(start_dim=1),
            nn.Linear(in_features, in_features),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(in_features, out_dim))
  
  def forward(self, x):
    h = self.backbone(x)
    h = self.head_fc(h)
    return h

In [46]:
model = CustomModel()
model

CustomModel(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-

In [49]:
model.eval()

sample_image = torch.rand(1, 3, 512, 512)
with torch.no_grad():
  y = model(sample_image)

print("[forward test]")
print("input:\t{}\noutput:\t{}".format(sample_image.shape, y.shape))

del model ; del y ; del sample_image
gc.collect()

[forward test]
input:	torch.Size([1, 3, 512, 512])
output:	torch.Size([1, 2])


917