In [None]:
pip install -r requirements.txt

Collecting timm==1.0.8 (from -r requirements.txt (line 2))
  Downloading timm-1.0.8-py3-none-any.whl.metadata (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting medmnist==3.0.1 (from -r requirements.txt (line 3))
  Downloading medmnist-3.0.1-py3-none-any.whl.metadata (13 kB)
Collecting fvcore (from -r requirements.txt (line 8))
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting fire (from -r requirements.txt (line 11))
  Downloading fire-0.7.0.tar.gz (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torchattacks (from -r requirements.txt (line 12))
  Downloa

In [2]:
# Installing pytorch
import torch
import torch.nn as nn

Local Transformer Block

In [41]:
# CBAM Module.

#1. Batch normalizatioin for 3D image.
class BatchNorm3DBlock(nn.Module):
    def __init__(self, channels):
        super(BatchNorm3DBlock, self).__init__()
        self.bn=nn.BatchNorm3d(channels)

    def forward(self,x):
        return self.bn(x)


#2. Channel Attention

class ChannelAttention(nn.Module):
  def __init__(self, channels, reduction=16):
    super(ChannelAttention, self).__init__()
    self.avg_pool=nn.AdaptiveAvgPool3d(1)
    self.max_pool=nn.AdaptiveMaxPool3d(1)
    self.mlp=nn.Sequential(
        nn.Linear(channels, channels//  reduction, bias=False),
        nn.ReLU(),
        nn.Linear(channels//  reduction, channels, bias=False)
    )
    self.sigmoid=nn.Sigmoid()

  def forward(self,x):
    b,c,d,h,w=x.size() # b=batch, c=channel, d=dept,h=height, w=width.
    avg=self.avg_pool(x).view(b,c)
    max=self.max_pool(x).view(b,c)
    avg_out=self.mlp(avg)
    max_out=self.mlp(max)
    out= avg_out + max_out
    out=self.sigmoid(out).view(b,c,1,1,1)
    return x*out.expand_as(x)

#3 Spatial Attention

class SpatialAttention(nn.Module):
   def __init__(self,kernel_size=7):

       super(SpatialAttention, self).__init__()
       padding = kernel_size // 2
       self.conv= nn.Conv3d(2,1, kernel_size= kernel_size, padding=padding, bias=False)
       self.sigmoid=nn.Sigmoid()

   def forward(self,x):
        avg_out=torch.mean(x,dim=1, keepdim=True)
        max_out,_=torch.max(x,dim=1, keepdim=True)
        x_cat=torch.cat([avg_out,max_out], dim=1)
        x_out=self.conv(x_cat)
        return x*self.sigmoid(x_out).expand_as(x)

# Consolidating CBAM

class CBAM(nn.Module):

     def __init__(self, channels, reduction=16, kernel_size=7):

         super(CBAM,self).__init__()
         self.bn = BatchNorm3DBlock(channels)
         self.channel_attention=ChannelAttention(channels, reduction)
         self.spatial_attention=SpatialAttention(kernel_size=kernel_size)

     def forward(self, x):
        x=self.bn(x)
        x=self.channel_attention(x)
        x=self.spatial_attention(x)
        return x

In [45]:
#Local Feed forward Network

class LocalFeedForward3D(nn.Module):
    def __init__(self, in_channels, expand_ratio=4):
        super(LocalFeedForward3D, self).__init__()
        hidden_dim = in_channels * expand_ratio

        self.conv1 = nn.Conv3d(in_channels, hidden_dim, kernel_size=1, bias=False)    # 1x1x1 conv
        self.bn1 = nn.BatchNorm3d(hidden_dim)
        self.relu1 = nn.ReLU(inplace=True)

        self.dwconv = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1,
                                groups=hidden_dim, bias=False)                        # 3x3x3 depthwise conv
        self.bn2 = nn.BatchNorm3d(hidden_dim)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv3d(hidden_dim, in_channels, kernel_size=1, bias=False)    # 1x1x1 conv
        self.bn3 = nn.BatchNorm3d(in_channels)

    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.dwconv(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv2(x)
        x = self.bn3(x)

        return identity + x   # Residual connection


torch.Size([2, 32, 8, 32, 32])


In [48]:
# LTB (integrated)
class LTB_CBAM(nn.Module):
    def __init__(self, in_channels, out_channels=None, reduction=16, kernel_size=7, expand_ratio=4):
        super(LTB_CBAM, self).__init__()
        if out_channels is None:
            out_channels = in_channels
        # First 3D Conv layer
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        # CBAM block
        self.cbam = CBAM(out_channels, reduction=reduction, kernel_size=kernel_size)
        # Second 3D Conv layer
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        # Local FeedForward
        self.lffn = LocalFeedForward3D(out_channels, expand_ratio=expand_ratio)
    def forward(self, x):
        # 1st Conv path
        x1 = self.relu1(self.bn1(self.conv1(x)))
        # Path A: through CBAM
        cbam_out = self.cbam(x1)
        # Path B: bypass CBAM (direct from x1)
        # Combine both paths (sum)
        combined = cbam_out + x1
        # 2nd Conv
        x2 = self.relu2(self.bn2(self.conv2(combined)))
        # Path C: through LFFN
        lffn_out = self.lffn(x2)
        # Path D: bypass LFFN (direct from x2)
        # Final output: LFFN output + bypass
        out = lffn_out + x2
        return out