In [9]:
import torch
import torch.nn as nn

## Building submodels for Convnet
#### Linear Block

In [10]:
class _LinearBlock(nn.Module):
    # Basic Linear Block

    def __init__(self, inFeatures, outFeatures, act="relu", flatten=False, **kwargs):
        super().__init__()

        # Define activation function to be used in this block
        match act:
            case "identity":
                activationFunction = nn.Identity()
            case "relu":
                activationFunction = nn.ReLU()
            case "leaky":
                activationFunction = nn.LeakyReLU()
            case "gelu":
                activationFunction = nn.GELU()
            case "sigmoid":
                activationFunction = nn.Sigmoid()
            case _:
                raise Exception(f"{act} is not a recognised activation function for this class")

        self.linear = nn.Sequential(
            nn.Flatten(start_dim=1) if flatten else nn.Identity(), 
            nn.Linear(
                inFeatures,
                outFeatures,
                **kwargs,
            ),
            nn.BatchNorm1d(outFeatures),
            activationFunction,
        )

    def forward(self, x):
        # Output result of conv block when object is called

        return self.linear(x)

#### Convolutional Block

In [11]:
class _ConvBlock(nn.Module):
    # Basic convolutional block

    def __init__(self, inChannels, outChannels, down=True, act="relu", **kwargs):
        super().__init__()

        # Define activation function to be used in this block
        match act:
            case "identity":
                activationFunction = nn.Identity()
            case "relu":
                activationFunction = nn.ReLU()
            case "leaky":
                activationFunction = nn.LeakyReLU()
            case "gelu":
                activationFunction = nn.GELU()
            case _:
                raise Exception(f"{act} is not a recognised activation function for this class")

        # Define generic convolutional block / transpose convolutional block

        self.conv = nn.Sequential(
            nn.Conv3d(
                inChannels, 
                outChannels,
                padding_mode="reflect",
                **kwargs,
            ) 
            if down else nn.ConvTranspose3d(
                inChannels,
                outChannels,
                **kwargs,
            ),
            nn.BatchNorm3d(outChannels),
            activationFunction,
        )

    def forward(self, x):
        # Output result of conv block when object is called

        return self.conv(x)

#### Non Downsampling Residual block

In [12]:
class _ResidualBlock(nn.Module):
    # Basic Residual block

    def __init__(self,channels, act="leaky"):
        super().__init__()

        # Define convolutional blocks in residual blocks
        self.resBlock = nn.Sequential(
            _ConvBlock(
                channels, 
                channels,
                act=act,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            _ConvBlock(
                channels,
                channels,
                act="identity",
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
        )
    
    def forward(self, x):
        # Define operations to be made to input when object is called
        
        return x + self.resBlock(x) 

###  DownSampling Residual Block

In [13]:
class _DownsampleResidualBlock(nn.Module):
    # Basic Residual block

    def __init__(self, inChannels, outChannels, act="leaky"):
        super().__init__()

        # Define convolutional blocks in residual blocks
        self.resBlockDown = nn.Sequential(
            _ConvBlock(
                inChannels, 
                outChannels,
                act=act,
                kernel_size=3,
                stride=2,
                padding=1
            ),
            _ConvBlock(
                outChannels, 
                outChannels,
                act="identity",
                kernel_size=3,
                stride=1,
                padding=1
            ),
        )

        # Define normal downsample convolution for resudial calculation in input is downsamples
        self.resBlockDownSkip = nn.Sequential(
            _ConvBlock(
                inChannels,
                outChannels,
                act= "identity",
                kernel_size=1,
                stride=2,
                padding=0
            ),      
        )
        
    def forward(self, x):
        # Define operations to be made to input when object is called
        
        return self.resBlockDownSkip(x) + self.resBlockDown(x)

#### :warning: No activation function in output of resblock as described in original paper. :warning:
Maybe change in future using relu as suggested in original paper

## Test layer outputs

In [15]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
N = 128
C = 10
layer1 = _ResidualBlock(C, act="leaky").to(DEVICE)
layer2 = _DownsampleResidualBlock(C, C*2, act="leaky").to(DEVICE)
layer3 = nn.MaxPool3d(kernel_size=3, stride=2, padding=1,).to(DEVICE)
layer4 = _ConvBlock(
            C,
            C*2,
            act="leaky",
            kernel_size=7,
            padding=3,
            bias=False,
        ).to(DEVICE)

input = torch.randn((1, C, N, N, N), dtype=torch.float32).to(DEVICE)

# Outputs should be the same
output = layer1(input)
print(output.shape)
print(1, C, N, N, N)

print("\n")
# Outputs should be the same
output = layer2(input)
print(output.shape)
print(1, C*2, N/2, N/2, N/2)

print("\n")
# Outputs should be the same
output = layer3(input)
print(output.shape)
print(1, C, N/2, N/2, N/2)

print("\n")
# Outputs should be the same
output = layer4(input)
print(output.shape)
print(1, C*2, N, N, N)

torch.Size([1, 10, 128, 128, 128])
1 10 128 128 128


torch.Size([1, 20, 64, 64, 64])
1 20 64.0 64.0 64.0


torch.Size([1, 10, 64, 64, 64])
1 10 64.0 64.0 64.0


torch.Size([1, 20, 128, 128, 128])
1 20 128 128 128


## Defining Multiclass Classifier

In [30]:
class Classifier3D(nn.Module):
    def __init__(self, imageChannels, numFeatures=64, listResiduals=[3, 4, 6, 3]):
        super().__init__()

        # Define initial block of the analiser
        self.initialLayer = nn.Sequential(
            _ConvBlock(
                imageChannels,
                numFeatures,
                act="leaky",
                kernel_size=7,
                bias=False,
            ),
            nn.BatchNorm3d(numFeatures),
            nn.LeakyReLU(),
            nn.MaxPool3d(kernel_size=3, stride=2, padding=1,),
        )

        # Define first residual block with 64 features
        # Note: we use actual number in list as there is no downsample res block
        resBlock64 = nn.Sequential(
            *[_ResidualBlock(numFeatures, act="leaky") for _ in range(listResiduals[0])],
        )

        # Define second residual block with 128 features
        resBlock128 = nn.Sequential(
            _DownsampleResidualBlock(numFeatures, numFeatures*2, act="leaky"),
            *[_ResidualBlock(numFeatures*2, act="leaky") for _ in range(listResiduals[1]-1)],
        )

        # Define third residual block with 256 features
        resBlock256 = nn.Sequential(
            _DownsampleResidualBlock(numFeatures*2, numFeatures*4, act="leaky"),
            *[_ResidualBlock(numFeatures*4, act="leaky") for _ in range(listResiduals[2]-1)],
        )

        # Define third residual block with 512 features
        resBlock512 = nn.Sequential(
            _DownsampleResidualBlock(numFeatures*4, numFeatures*8, act="leaky"),
            *[_ResidualBlock(numFeatures*8, act="leaky") for _ in range(listResiduals[3]-1)],
        )

        self.resBlocksAll = nn.ModuleList([
            resBlock64,
            resBlock128,
            resBlock256,
            resBlock512
        ])

    def forward(self, x):

        x = self.initialLayer(x)
        print("layer 1 done")

        for idx, layer in enumerate(self.resBlocksAll):
            x = layer(x)
            print(idx, x.shape)
        
        return x

## Test classifier outputs

In [32]:
N = 128
C = 1
image3D = torch.randn((1, C, N, N, N)).to(DEVICE)
classifier = Classifier3D(C, numFeatures=64, listResiduals=[3, 4, 6, 3]).to(DEVICE)

# Outputs should be the same
output = classifier(image3D)
print(output.shape)
print(1, 512, N/16, N/16, N/16)

layer 1 done
0 torch.Size([1, 64, 61, 61, 61])
1 torch.Size([1, 128, 31, 31, 31])
2 torch.Size([1, 256, 16, 16, 16])
3 torch.Size([1, 512, 8, 8, 8])
torch.Size([1, 512, 8, 8, 8])
1 512 8.0 8.0 8.0
