In [None]:
import torch
import torch.nn as nn
from math import prod

## Building submodels for Convnet
#### Linear Block

In [None]:
class _LinearBlock(nn.Module):
    # Basic Linear Block

    def __init__(self, inFeatures, outFeatures, act="relu", flatten=False, **kwargs):
        super().__init__()

        # Define activation function to be used in this block
        match act:
            case "identity":
                activationFunction = nn.Identity()
            case "relu":
                activationFunction = nn.ReLU()
            case "leaky":
                activationFunction = nn.LeakyReLU()
            case "gelu":
                activationFunction = nn.GELU()
            case "sigmoid":
                activationFunction = nn.Sigmoid()
            case _:
                raise Exception(f"{act} is not a recognised activation function for this class")

        self.linear = nn.Sequential(
            nn.Flatten(start_dim=1) if flatten else nn.Identity(), 
            nn.Linear(
                inFeatures,
                outFeatures,
                **kwargs,
            ),
            activationFunction,
        )

    def forward(self, x):
        # Output result of conv block when object is called

        return self.linear(x)

#### Convolutional Block

In [None]:
class _ConvBlock(nn.Module):
    # Basic convolutional block

    def __init__(self, inChannels, outChannels, down=True, act="relu", batchnorm=True, **kwargs):
        super().__init__()

        # Define activation function to be used in this block
        match act:
            case "identity":
                activationFunction = nn.Identity()
            case "relu":
                activationFunction = nn.ReLU()
            case "leaky":
                activationFunction = nn.LeakyReLU()
            case "gelu":
                activationFunction = nn.GELU()
            case _:
                raise Exception(f"{act} is not a recognised activation function for this class")

        # Define generic convolutional block / transpose convolutional block

        self.conv = nn.Sequential(
            nn.Conv3d(
                inChannels, 
                outChannels,
                padding_mode="reflect",
                **kwargs,
            ) 
            if down else nn.ConvTranspose3d(
                inChannels,
                outChannels,
                **kwargs,
            ),
            nn.BatchNorm3d(outChannels) if batchnorm else nn.Identity(),
            activationFunction,
        )

    def forward(self, x):
        # Output result of conv block when object is called

        return self.conv(x)

#### Non Downsampling Residual block

In [None]:
class _ResidualBlock(nn.Module):
    # Basic Residual block

    def __init__(self,channels, act="leaky", batchnorm=True):
        super().__init__()

        # Define convolutional blocks in residual blocks
        self.resBlock = nn.Sequential(
            _ConvBlock(
                channels, 
                channels,
                act=act,
                batchnorm=batchnorm,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            _ConvBlock(
                channels,
                channels,
                act="identity",
                batchnorm=batchnorm,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
        )
    
    def forward(self, x):
        # Define operations to be made to input when object is called
        
        return x + self.resBlock(x) 

###  DownSampling Residual Block

In [None]:
class _DownsampleResidualBlock(nn.Module):
    # Basic Residual block

    def __init__(self, inChannels, outChannels, act="leaky", batchnorm=True):
        super().__init__()

        # Define convolutional blocks in residual blocks
        self.resBlockDown = nn.Sequential(
            _ConvBlock(
                inChannels, 
                outChannels,
                act=act,
                batchnorm=batchnorm,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            _ConvBlock(
                outChannels, 
                outChannels,
                act="identity",
                batchnorm=batchnorm,
                kernel_size=3,
                stride=1,
                padding=1
            ),
        )

        # Define normal downsample convolution for resudial calculation in input is downsamples
        self.resBlockDownSkip = nn.Sequential(
            _ConvBlock(
                inChannels,
                outChannels,
                act= "identity",
                batchnorm=True,
                kernel_size=1,
                stride=2,
                padding=0
            ),      
        )
        
    def forward(self, x):
        # Define operations to be made to input when object is called
        
        return self.resBlockDownSkip(x) + self.resBlockDown(x)

#### :warning: No activation function in output of resblock as described in original paper. :warning:
Maybe change in future using relu as suggested in original paper

## Test layer outputs

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
N = 128
C = 10
layer1 = _ResidualBlock(C, act="leaky").to(DEVICE)
layer2 = _DownsampleResidualBlock(C, C*2, act="leaky").to(DEVICE)
layer3 = nn.MaxPool3d(kernel_size=3, stride=2, padding=1,).to(DEVICE)
layer4 = _ConvBlock(
            C,
            C*2,
            act="leaky",
            batchnorm=True,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False,
        ).to(DEVICE)

input = torch.randn((1, C, N, N, N), dtype=torch.float32).to(DEVICE)

# Outputs should be the same
output = layer1(input)
print(output.shape)
print(1, C, N, N, N)

print("\n")
# Outputs should be the same
output = layer2(input)
print(output.shape)
print(1, C*2, N/2, N/2, N/2)

print("\n")
# Outputs should be the same
output = layer3(input)
print(output.shape)
print(1, C, N/2, N/2, N/2)

print("\n")
# Outputs should be the same
output = layer4(input)
print(output.shape)
print(1, C*2, N/2, N/2, N/2)

## Defining Multiclass Classifier

In [None]:
class Classifier3D(nn.Module):
    def __init__(self, imageChannels, numClasses, imageDimentions=(64, 64, 64),
                numFeatures=64, listResiduals=[3, 4, 6, 3]):

        super().__init__()

        # Define initial block of the analiser
        self.initialLayer = nn.Sequential(
            _ConvBlock(
                imageChannels,
                numFeatures,
                act="leaky",
                batchnorm=True,
                kernel_size=7,
                stride=2,
                padding=3,               
                bias=False,
            ),
            nn.BatchNorm3d(numFeatures),
            nn.LeakyReLU(),
            nn.MaxPool3d(kernel_size=3, stride=2, padding=1,),
        )

        # Define first residual block with 64 features
        # Note: we use actual number in list as there is no downsample res block
        resBlock64 = nn.Sequential(
            *[_ResidualBlock(numFeatures, act="leaky", batchnorm=True) for _ in range(listResiduals[0])],
        )

        # Define second residual block with 128 features
        resBlock128 = nn.Sequential(
            _DownsampleResidualBlock(numFeatures, numFeatures*2, act="leaky"),
            *[_ResidualBlock(numFeatures*2, act="leaky", batchnorm=True) for _ in range(listResiduals[1]-1)],
        )

        # Define third residual block with 256 features
        resBlock256 = nn.Sequential(
            _DownsampleResidualBlock(numFeatures*2, numFeatures*4, act="leaky"),
            *[_ResidualBlock(numFeatures*4, act="leaky", batchnorm=True) for _ in range(listResiduals[2]-1)],
        )

        # Define third residual block with 512 features
        resBlock512 = nn.Sequential(
            _DownsampleResidualBlock(numFeatures*4, numFeatures*8, act="leaky"),
            _ResidualBlock(numFeatures*8, act="leaky", batchnorm=True),
            _ResidualBlock(numFeatures*8, act="leaky", batchnorm=False)
        )

        self.resBlocksAll = nn.ModuleList([
            resBlock64,
            resBlock128,
            resBlock256,
            resBlock512
        ])

        # We know that the hight and width of the latent tensor after all resnet is (B, 512, H/32, W/32, L/32)
        # Define number of nodes in linear layers
        productLatentDimentions = int((imageDimentions[0]/32) * (imageDimentions[1]/32) * (imageDimentions[2]/32))
        flattenedInFeatures = 512*productLatentDimentions
        
        self.denseBlocks = nn.ModuleList([
            _LinearBlock(
                flattenedInFeatures,
                numClasses,
                act="identity",
                flatten=True,
                bias=True
            ),
        ])

    def forward(self, x):

        # Apply initial layer
        x = self.initialLayer(x)        # Size: (B, 64, H/2, W/2, L/2)

        # Apply all resnet layers
        for layer in self.resBlocksAll:
            x = layer(x)
        # Size: (B, 512, H/32, W/32, L/32)
        
        # Apply linear layers 
        for layer in self.denseBlocks:
            x = layer(x)

        return x

### Version of Resnet for MedMNIST

In [None]:
class ClassifierMed(nn.Module):
    def __init__(self, imageChannels, numClasses, imageDimentions=(64, 64, 64),
                numFeatures=64, listResiduals=[3, 8,]):

        super().__init__()

        # Define initial block of the analiser
        self.initialLayer = nn.Sequential(
            _ConvBlock(
                imageChannels,
                numFeatures,
                act="leaky",
                batchnorm=True,
                kernel_size=7,
                stride=2,
                padding=3,               
                bias=False,
            ),
            nn.BatchNorm3d(numFeatures),
            nn.LeakyReLU(),
            nn.MaxPool3d(kernel_size=3, stride=2, padding=1,),
        )

        # Define first residual block with 64 features
        # Note: we use actual number in list as there is no downsample res block
        resBlock64 = nn.Sequential(
            *[_ResidualBlock(numFeatures, act="leaky", batchnorm=True) for _ in range(listResiduals[0])],
        )

        # Define second residual block with 128 features
        resBlock128 = nn.Sequential(
            _DownsampleResidualBlock(numFeatures, numFeatures*2, act="leaky"),
            *[_ResidualBlock(numFeatures*2, act="leaky", batchnorm=True) for _ in range(listResiduals[1]-1)],
        )

        self.resBlocksAll = nn.ModuleList([
            resBlock64,
            resBlock128,
        ])

        # We know that the hight and width of the latent tensor after all resnet is (B, 128, H/7, W/7, L/7)
        # Define number of nodes in linear layers
        productLatentDimentions = int((imageDimentions[0]/7) * (imageDimentions[1]/7) * (imageDimentions[2]/7))
        flattenedInFeatures = 128*productLatentDimentions
        
        self.denseBlocks = nn.ModuleList([
            _LinearBlock(
                flattenedInFeatures,
                numClasses,
                act="identity",
                flatten=True,
                bias=True
            ),
        ])

    def forward(self, x):

        # Apply initial layer
        x = self.initialLayer(x)        # Size: (B, 64, H/2, W/2, L/2)

        # Apply all resnet layers
        for layer in self.resBlocksAll:
            x = layer(x)
        # Size: (B, 128, H/7, W/7, L/7)
        
        # Apply linear layers 
        for layer in self.denseBlocks:
            x = layer(x)

        return x

## Test classifier outputs

In [None]:
N = 64
C = 1
numClasses = 2
image3D = torch.randn((1, C, N, N, N)).to(DEVICE)
classifier1 = Classifier3D(
    C,
    numFeatures=64,
    numClasses=2, 
    imageDimentions=(N, N, N), 
    listResiduals=[3, 4, 6, 3]
    ).to(DEVICE)

M = 28
imageMed = torch.randn((1, C, M, M, M)).to(DEVICE)
classifier2 = ClassifierMed(
    C,
    numFeatures=64,
    numClasses=2, 
    imageDimentions=(M, M, M), 
    listResiduals=[3, 8,]
    ).to(DEVICE)

# Outputs should be the same
output = classifier1(image3D)
print(output.shape)
print(1, 2)

# Outputs should be the same
output = classifier2(imageMed)
print(output.shape)
print(1, 2)