Efficient Convolution Block-3D

In [3]:
# Simple patchembdedding to run and test ECB. Can be modified according to architecture.
class PatchEmbed3D(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.proj = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
    def forward(self, x):
        return self.proj(x)

In [2]:
#MHCA 3D

import torch
import torch.nn as nn
from functools import partial
from timm.models.layers import DropPath # importing DropPath

NORM_EPS = 1e-5  # Use the same epsilon as original architecture

class MHCA3D(nn.Module):
    """
    Multi-Head Convolutional Attention (3D version)
    """
    def __init__(self, out_channels, head_dim):
        super(MHCA3D, self).__init__()
        norm_layer = partial(nn.BatchNorm3d, eps=NORM_EPS)
        self.group_conv3x3 = nn.Conv3d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=1,
            groups=out_channels // head_dim, bias=False
        )
        self.norm = norm_layer(out_channels)
        self.act = nn.ReLU(inplace=True)
        self.projection = nn.Conv3d(
            out_channels, out_channels, kernel_size=1, bias=False
        )

    def forward(self, x):
        out = self.group_conv3x3(x)
        out = self.norm(out)
        out = self.act(out)
        out = self.projection(out)
        return out


In [4]:
#Local Feed forward Network alredy defined in LTB module

class LocalFeedForward3D(nn.Module):
    def __init__(self, in_channels, expand_ratio=4):
        super(LocalFeedForward3D, self).__init__()
        hidden_dim = in_channels * expand_ratio

        self.conv1 = nn.Conv3d(in_channels, hidden_dim, kernel_size=1, bias=False)    # 1x1x1 conv
        self.bn1 = nn.BatchNorm3d(hidden_dim)
        self.relu1 = nn.ReLU(inplace=True)

        self.dwconv = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1,
                                groups=hidden_dim, bias=False)                        # 3x3x3 depthwise conv
        self.bn2 = nn.BatchNorm3d(hidden_dim)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv3d(hidden_dim, in_channels, kernel_size=1, bias=False)    # 1x1x1 conv
        self.bn3 = nn.BatchNorm3d(in_channels)

    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.dwconv(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv2(x)
        x = self.bn3(x)

        return identity + x   # Residual connection

In [5]:
# Calling ECB class


class ECB3D(nn.Module):
    """
    Efficient Convolution Block (3D)
    """
    def __init__(self, in_channels, out_channels, stride=1, path_dropout=0,
                 drop=0, head_dim=32, mlp_ratio=3):
        super(ECB3D, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        norm_layer = partial(nn.BatchNorm3d, eps=NORM_EPS)
        assert out_channels % head_dim == 0

        self.patch_embed = PatchEmbed3D(in_channels, out_channels, stride)  # <--  Patch embedding should be 3D before it is fed to ECB-3D. This uses a simple patch embedding (not from original architecture to make the function work. You can improvise using base architecutre)
        self.mhca = MHCA3D(out_channels, head_dim)                         # <-- our 3D MHCA
        self.attention_path_dropout = DropPath(path_dropout)

        self.conv = LocalFeedForward3D(out_channels, mlp_ratio)         # <-- LFFN3D class defined earlier in LTB Block

        self.norm = norm_layer(out_channels)

    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.attention_path_dropout(self.mhca(x))
        out = self.norm(x)  # Batchnormalization as done in original architecture
        x = x + self.conv(out)
        return x

