## **The CNN – Convolutional Neural Network**

**Imports**

* [Docs](https://pytorch.org/docs/stable/nn.html) for `torch.nn`
  * Base class for all neural network modules


* [Docs](https://pytorch.org/docs/stable/nn.functional.html) for `torch.nn.functional`

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

**Building the Neural Network Class `Net`**

*Documentation for along with a brief explanation of the methods/modules used below*

* [Docs](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html?highlight=conv3d#torch.nn.Conv3d) for `nn.Conv3d`
  * Helps create a 3D Convolutional Layer that takes in an input tensor with specified number of channels and yields an output tensor with specified number of channels

* [Docs](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html?highlight=maxpool3d#torch.nn.MaxPool3d) for `nn.MaxPool3d`
  * In our case, our `MaxPool` "kernel" is of 2 x 2 shape, (i.e. it looks at 4 elements at once) and returns the maximum out of them in a 1 x 1 tensor. So, if our "kernel" looks at a `2a x 2a` tensor, the output is an `a x a` tensor.


* [Docs](https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html?highlight=sigmoid#torch.nn.Sigmoid) for `nn.Sigmoid`
  * Applies the sigmoid function to every element, i.e. this is an element-wise operation.

* [Docs](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html?highlight=sequential#torch.nn.Sequential) for `nn.Sequential`

* [Docs](https://pytorch.org/docs/stable/generated/torch.flatten.html?highlight=flatten#torch.flatten) for `torch.flatten`
  * Helps flatten our n-dimesnional tensor into a tensor of required dimensions. In our case, we use this to flatten our tensor into a 1D tensor.

In [19]:
class akasha4_Net(nn.Module):
    def __init__(self):
        # Calling the parent constructor
        super().__init__()

        # Creating the MaxPool layers

        # Notice we're creating two MaxPool layers.
        # The first one is a regular MaxPool Layer 
        # that works on its input without modifying it before performing the maxpool operation on it.
        self.pool = nn.MaxPool3d(2)
        # The second MaxPool layer, pool_last, adds some padding to the input tensor 
        # before performing the maxpool operation on it.
        # This is done to avoid a Runtime error saying you can't return an output tensor with lesser than 1 channel.
        self.pool_last = nn.MaxPool3d(2, padding = (1, 0, 0))

        kernel_size = 3
        in_padding = 1

        # Creating the convolutional layers

        self.conv_module1 = conv_module(io_channels = [3, 8], kernel_size = kernel_size, in_padding = in_padding)
        self.conv_module2 = conv_module(io_channels = [8, 16], kernel_size = kernel_size, in_padding = in_padding)
        self.conv_module3 = conv_module(io_channels = [16, 32], kernel_size = kernel_size, in_padding = in_padding)
        self.conv_module4 = conv_module(io_channels = [32, 64], kernel_size = kernel_size, in_padding = in_padding)
        self.conv_module5 = conv_module(io_channels = [64, 128], kernel_size = kernel_size, in_padding = in_padding)
        
    def forward(self, x):
        # Begin with: 20 x 224 x 224
        x = self.pool(torch.sigmoid(self.conv_module1(x))) # Yields: 10 x 112 x 112
        x = self.pool(torch.sigmoid(self.conv_module2(x))) # Yields: 5 x 56 x 56
        x = self.pool(torch.sigmoid(self.conv_module3(x))) # Yields: 2 x 28 x 28
        x = self.pool(torch.sigmoid(self.conv_module4(x))) # Yields: 1 x 14 x 14
        x = self.pool_last(torch.sigmoid(self.conv_module5(x))) # Yields: 1 x 7 x 7
        

        x = torch.flatten(x, 1)
        
        x = torch.sigmoid(x)
        x = torch.mean(x)
        return x
    
def conv_module(io_channels: list, kernel_size: int, in_padding: int):
    
    assert len(io_channels) == 2
    
    return nn.Sequential(
        nn.Conv3d(in_channels = io_channels[0], out_channels = io_channels[1], kernel_size = kernel_size, padding = in_padding, bias = True),
        nn.BatchNorm3d(io_channels[1]),
        nn.Conv3d(in_channels = io_channels[1], out_channels = io_channels[1], kernel_size = kernel_size, padding = in_padding, bias = True),
        nn.BatchNorm3d(io_channels[1]),
        nn.Conv3d(in_channels = io_channels[1], out_channels = io_channels[1], kernel_size = kernel_size, padding = in_padding, bias = True),
        nn.BatchNorm3d(io_channels[1]),
        nn.Conv3d(in_channels = io_channels[1], out_channels = io_channels[1], kernel_size = kernel_size, padding = in_padding, bias = True),
        nn.BatchNorm3d(io_channels[1]),
        nn.Conv3d(in_channels = io_channels[1], out_channels = io_channels[1], kernel_size = kernel_size, padding = in_padding, bias = True)
    )


x = torch.rand(1, 3, 20, 224, 224)
model = akasha4_Net()
model(x)

tensor(0.6369, grad_fn=<MeanBackward0>)