# Creating Models in PyTorch

This notebook is a tutorial accompying the manuscript "Perspectives: Comparison of Deep Learning Based Segmentation Models on Typical Biophysics and Biomedical Data" by JS Bryan IV, M Tavakoli, and S Presse. In this tutorial, we will learn the basics of using the `nn.Module` class in PyTorch with prebuilt and custom layers.

**Before reading this tutorial, make sure you have properly installed PyTorch and downloaded the data as explained in this repository's README.**

## Introduction

Welcome to the tutorial on creating models in PyTorch! In this tutorial, we will learn how to create models in PyTorch using the `nn.Module` class. We will also learn how to use prebuilt layers and create custom layers. The specific aim of this tutorial is the explain the models used in the accompanying manuscript, which can be found in the `models/` directory of this repository. This tutorial will be focussed on the basics of `nn.Module`, specificaly with regard to convolutional neural networks. For specific details on the transformer arcitecture, please see the transformer tutorial.

Models in PyTorch are a convenient way to package together a neural network architecture. The `nn.Module` class is the base class for all models in PyTorch. It provides a convenient way to define the forward pass of a neural network, which is the process of passing input data through the network to get an output. The `nn.Module` class also provides a way to define the parameters of the network, which are the weights and biases that are learned during training.

### Importing libraries

Before we start, let's import the necessary libraries. We will be using the `math`, `torch`, `torch.nn` modules from PyTorch.

In [None]:
# Import libraries
import math
import torch
import torch.nn as nn

## Basics of nn.Module

The Module class is the core of model creation in PyTorch. It provides a simple way to set up the parameters of the network and define the operations that are performed on the input data. The Module class has two main methods that need to be implemented: `__init__` and `forward`. The `__init__` method is used to define the layers and parameters of the network, while the `forward` method specifies how the input data is processed through these layers to produce the output.

Lets start by creating a very simple model that simply scales and shifts the input data. To do this we will create a custom module called `ScaleShift` that inherits from `nn.Module`. In the `__init__` method, we will initialize the parent class of `ScaleShift` then define two parameters `scale` and `shift` that will be learned during training. Notice that to define parameters in a module, we use the `nn.Parameter` class. The `nn.Parameter` class is a wrapper around a tensor that tells PyTorch that this tensor should be treated as a parameter of the network during training. In the `forward` method, we will apply the scaling and shifting operations to the input data. Lastly we apply our custom module to some input data to see how it works.

In [None]:
## Create simple model that simply scales and shifts data
class ScaleShift(nn.Module):
    def __init__(self):
        super(ScaleShift, self).__init__()         # Call parent class constructor, this is required for model
        self.scale = nn.Parameter(torch.rand(1))   # Create scale parameter, initialized to random value
        self.shift = nn.Parameter(torch.randn(1))  # Create shift parameter, initialized to random value
    def forward(self, x):
        return x * self.scale + self.shift         # Return scaled and shifted data
    
# Instantiate model and data
model = ScaleShift()
data = torch.linspace(1, 10, 10)

# Run model on data
output = model(data)

# Print results
print("Data: ", data)
print("Output: ", output)
print("Scale: ", model.scale)
print("Shift: ", model.shift)

The convenient part of using the `nn.Module` class is that it automatically handles the backpropagation of gradients through the network. This means that we don't have to manually calculate the gradients of the loss function with respect to the parameters of the network. PyTorch will automatically calculate these gradients for us using the `autograd` module.

Additionally, the `nn.Module` we can call the forward method of the module directly on input data to get the output of the network. This is done by simply calling the module as if it were a function. For example we can call `output = model(input)` to get the output of the model on the input data.

## Using Prebuilt Layers

PyTorch provides a rich set of prebuilt layers in the torch.nn module, which makes it easy to build complex neural network architectures. These layers include convolutional layers, pooling layers, activation functions, and more. Using these prebuilt layers can save you time and ensure that your model components are optimized for performance. In this section, we will build a simple convolutional neural network (CNN) for image segmentation.

Before we start, it's important to note that PyTorch models expect input tensors to have a batch dimension. This means that even if we are processing a single monocolor image, we need to structure our input as a 4D tensor with dimensions (batch, channel, height, width). This is important to keep in mind especially when loading data from files, where the channel dimension may be in a different order. Each layer has a different expected input shape, and it is important to read the [documentation](https://pytorch.org/docs/stable/index.html) for each layer to understand how to properly structure the input data.

Now that we understand the basics of Modules and inputs, let us move onto creating our CNN. A simple model will consist of layers and blocks. Layers are individual components of the network, such as convolutional layers, pooling layers, and activation functions. Blocks are groups of layers that are repeated multiple times in the network. Our network will have three main blocks: an input block, a convolutional block, and an output block. Let us go over these in detail.

### Attributes

Our model will have several attributes that define the architecture of the network. These attributes include the number of input channels, the number of feature channels, and the number of output channels. The number of input channels is the number of channels in the input data, which is typically 1 for grayscale images and 3 for color images. The number of feature channels is the number of channels in the intermediate feature maps of the network, which is a hyperparameter that can be tuned to control the capacity of the network. The number of output channels is the number of channels in the output data, which is typically the number of classes in a segmentation task.

```python

        # Set up attributes
        self.in_channels = 3
        self.out_channels = 2
        self.n_features = 8
        self.n_layers = 1

```

We specify 3 input channels for RGB images, 2 output channels (one is background and one is target), 8 feature channels, and one convolutional layer. These values can be adjusted depending on the specific task and dataset.

As a quick aside, one point of confusion frequenctly encountered in segmentation is the number of output channels. The number of output channels is the number of classes in the segmentation task, where a class is loosely defined as "the different types of objects that we want to segment". For example, in a binary segmentation task, we have two classes: the background and the target object. In a multi-class segmentation task, for example if we wanted to segment nucleus, cytoplasm, and background, we would have three classes. The output of the network will be a tensor with dimensions (batch, classes, height, width), where the class dimension corresponds to the logits (log probability) for each class. If a pixel outputs a higher logit for a class, it is more likely to belong to that class. We can classify each pixel by finding the class with the highest logit value. Lastly, while our model convention outputs logits, other networks in the literature may output probabilites (the softmax function is applied to the logits to get probabilities). As such it is always important to understand the output of the network and how it is used in the loss function.


### Input Block

The input block is the first part of the network that preprocesses the input data. For our purposes, we need to normalize the input data to have zero mean and unit variance, and then add in additional feature channels so that the input data has the correct number of channels for the network. To do this we will use `nn.GroupNorm` to normalize the input data, and `nn.Conv2d` to add additional feature channels.

```python

        # Set up input block
        self.input_block = nn.Sequential(
            nn.GroupNorm(1, in_channels, affine=False),  # Normalize input
            nn.Conv2d(in_channels, n_features, kernel_size=3, padding=1),
        )

```

The `nn.GroupNorm` layer normalizes the input data to have zero mean and unit variance. The `nn.Conv2d` layer adds additional feature channels to the input data by applying a set of filters to the input data. The `kernel_size` parameter specifies the size of the filters, and the `padding` parameter specifies the amount of zero padding to add to the input data.

Notice that we group layers together using the `nn.Sequential` class. This is a convenient way to define a sequence of layers in PyTorch. We can then call the `forward` method of the `nn.Sequential` object to apply the layers in sequence to the input data.

### Convolutional Block

The convolutional block is the main part of the network that processes the input data. A standard convolutional block consists of a convolutional layer, followed by a normalization layer, and finally an activation layer. The convolutional layer applies a set of filters to the input data to extract features. The normalization layer normalizes the output of the convolutional layer to stabilize the training process. The activation layer applies a non-linear function to the output of the normalization layer to introduce non-linearity into the network. For our convolutional block, we will create a `ModuleList` of convolutional layers, where each layer will be packaged into a `Sequential` block with normalization and activation layers. Notice that during the forward loop we will need to loop through elements of the `ModuleList`, but the layers within the `Sequential` block will be applied in sequence.

```python

        # Set up convolutional block
        self.conv_block = nn.ModuleList()
        for _ in range(n_layers):
            self.conv_block.append(nn.Sequential(
                nn.Conv2d(n_features, n_features, kernel_size=3, padding=1),
                nn.GroupNorm(1, n_features, affine=False),
                nn.ReLU(),
            ))

```

When we specify the parameters of `nn.Conv2d` the argument convention is `(in_channels, out_channels, kernel_size, padding)`. The `in_channels` parameter specifies the number of input channels to the convolutional layer, the `out_channels` parameter specifies the number of output channels, the `kernel_size` parameter specifies the size of the filters, and the `padding` parameter specifies the amount of zero padding to add to the input data.

Notice that there are many different choices for normalization. Typically in the literature, batch normalization is used. However, in this model we use instance normalization. Instance normalization normalizes the output of the convolutional layer for each individual sample in the batch, which can be useful for style transfer and other tasks where the statistics of the output need to be preserved (i.e. where we do not want the output of the network to be affected by the statistics of the entire batch).

Lastly, we use a Rectified Linear Unit (ReLU) activation function to introduce non-linearity into the network. The ReLU function applies the function `f(x) = max(0, x)` to the output of the normalization layer, which ensures that the output of the network is always positive. There are many other activation functions available in PyTorch, such as the sigmoid and tanh functions, but ReLU is the most commonly used activation function in deep learning.

### Output Block

The output block is the final part of the network that produces the output data. For our purposes, we need to convert the intermediate feature maps of the network into the final output data. To do this we will use `nn.Conv2d` to ensure that the output data has the correct number of channels.

```python

        # Set up output block
        self.output_block = nn.Sequential(
            nn.Conv2d(n_features, out_channels, kernel_size=3, padding=1),
        )

```

**One big consideration** There are many conventions for outputs of segmentation networks. We can choose to output the probability of each class for each pixel, or the logits of each class, or the class with the highest probability for each pixel. In this model, we output the logits of each class for each pixel. The logits are the raw output of the network before applying the softmax function, which converts the logits into probabilities. This is a common choice for segmentation networks, as it allows us to use the cross-entropy loss function to train the network. However, we must keep in mind that when we want to use our model for evaluation, we must apply the softmax function to the output to get the probabilities of each class.

Lastly, one thing we should point out about output blocks is that in the code in the `models/` directory of this repository, we create a function, `set_output_block` to set the output block. This is because the number of output channels can change depending on the task. For example, in the manuscript, we use a model with 2 output channels for binary segmentation, and a model with 3 output channels for multiclass segmentation. By creating a separate function to set the output block, we can easily change the number of output channels without having to modify the rest of the model. This is beyond the scope of this tutorial and we will not cover it here.

### Forward Method

The `forward` method of the model specifies how the input data is processed through the network to produce the output. In our model, we apply the input block to the input data, then pass the output through the convolutional block multiple times, and finally apply the output block to produce the final output.

```python

    def forward(self, x):

        # Apply input block
        x = self.input_block(x)

        # Apply convolutional block
        for layer in self.conv_block:  # Loops over layers in the ModuleList
            x = layer(x)               # Applies all layers in the Sequential block

        # Apply output block
        x = self.output_block(x)

        # Return x
        return x

```

### Putting it all together

Now that we have defined the input block, convolutional block, and output block, we can put them all together to create the full model. We can do this by defining a new class called `SimpleCNN` that inherits from `nn.Module` and combines the input block, convolutional block, and output block into a single model.

Let us define our model, instantiate it, and apply it to some input data to see how it works. After that we will print the input and output shape, as well as the number of parameters in the model.

In [None]:
# Define the ConvolutionalNet class
class ConvolutionalNet(nn.Module):
    def __init__(self, in_channels, out_channels, n_features=8, n_layers=1):
        super(ConvolutionalNet, self).__init__()

        # Set up attributes
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_features = n_features
        self.n_layers = n_layers


        ### SET UP BLOCKS ###

        # Set up input block
        self.input_block = nn.Sequential(
            nn.GroupNorm(1, in_channels, affine=False),  # Normalize input
            nn.Conv2d(in_channels, n_features, kernel_size=3, padding=1),
        )

        # Set up layers
        self.conv_block = nn.ModuleList()
        for _ in range(n_layers):
            self.conv_block.append(nn.Sequential(
                nn.Conv2d(n_features, n_features, kernel_size=3, padding=1),
                nn.GroupNorm(1, n_features, affine=False),
                nn.ReLU(),
            ))

        # Set up output block
        self.output_block = nn.Sequential(
            nn.Conv2d(self.n_features, out_channels, kernel_size=1),
        )

    def forward(self, x):

        # Apply input block
        x = self.input_block(x)

        # Apply convolutional block
        for layer in self.conv_block:  # Loops over layers in the ModuleList
            x = layer(x)               # Applies all layers in the Sequential block

        # Apply output block
        x = self.output_block(x)

        # Return x
        return x
    

# Set up model and data
model = ConvolutionalNet(3, 2, n_features=8)  # 3 input channels (RGB), 2 output channels, 8 features
data = torch.randn(1, 3, 32, 32)              # 1 batch, 3 channels, 32x32 image

# Run model on data
output = model(data)

# Print results
print("Data shape: ", data.shape)      # Shape will be (batch, channels, height, width)
print("Output shape: ", output.shape)  # Shape will be (batch, channels, height, width)

# Print number of parameters
#  - `model.parameters()` returns an iterator over all parameters in the model.
#  - `p.numel()` returns the number of elements in a tensor.
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: ", n_params)

## Sub Modules

In the last section we showed how we could use the `nn.Modules` class to create a model in PyTorch. The model we created was a simple CNN for image segmentation. However, in practice, we often want to create more complex models that consist of multiple submodules. This will allow us to reuse these submodules in other models and create more complex architectures.

Sub modules are very simple in practice, we simply create a module block in the exact same way we would create a model. Notice that we will need to create an `__init__` and `forward` method for the submodule, just like we did for the model. When we intantiate the submodule in the model, the parameters of the submodule will be added to the parameters of the model. This means that we do not have to do anything extra to train the parameters of the submodule, they will be trained automatically when we train the model.

As a simple example of submodules in PyTorch, let us recreate the simple CNN model we created in the last section, but this time we will create a submodule for the convolutional block. This will allow us to reuse these blocks in other models and create more complex architectures.

In [None]:
# Create convolutional block
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
        self.norm = nn.InstanceNorm2d(out_channels)
        self.relu = nn.ReLU()
    def forward(self, x):
        return self.relu(self.norm(self.conv(x)))
    
# Create model using convolutional block
class ConvModel(nn.Module):
    def __init__(self, in_channels, out_channels, n_features=8, n_layers=3):
        super(ConvModel, self).__init__()

        # Input block
        self.input_block = nn.Sequential(
            nn.GroupNorm(1, in_channels, affine=False),
            nn.Conv2d(in_channels, n_features, kernel_size=3, padding=1),
        )
        
        # Convolutional block
        self.conv_block = nn.ModuleList()
        for _ in range(n_layers):
            self.conv_block.append(ConvBlock(n_features, n_features))  # Here we use ConvBlock

        # Output block
        self.output_block = nn.Sequential(
            nn.Conv2d(n_features, out_channels, kernel_size=1),
        )

    def forward(self, x):

        # Apply input block
        x = self.input_block(x)

        # Apply convolutional block
        for layer in self.conv_block:  # Loops over layers in the ModuleList
            x = layer(x)               # Applies all layers in the Sequential block

        # Apply output block
        x = self.output_block(x)

        # Return x
        return x
    
# Set up model and data
model = ConvModel(3, 2, n_features=8)  # 3 input channels (RGB), 2 output channels, 8 features
data = torch.randn(1, 3, 32, 32)        # 1 batch, 3 channels, 32x32 image

# Run model on data
output = model(data)

# Print results
print("Data shape: ", data.shape)      # Shape will be (batch, channels, height, width)
print("Output shape: ", output.shape)  # Shape will be (batch, channels, height, width)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters: ", n_params)

## Transformers

Here will end our tutorial on creating models in PyTorch by introducing a new architecture: the transformer. The transformer is a powerful architecture that has been used in natural language processing, time series analysis, and computer vision. The transformer is based on the self-attention mechanism, which allows the model to attend to different parts of the input data with different weights. This makes the transformer very effective at capturing long-range dependencies in the data.

A transformer block consists of two main components: the multi-head self-attention mechanism and the feedforward neural network. The multi-head self-attention mechanism allows the model to attend to different parts of the input data with different weights, while the feedforward neural network processes the output of the attention mechanism to produce the final output. We can build a transformer block by combining these two components into a single module, where we specify the number of features, the number of heads, and the expansion factor of the feedforward network. 

In [None]:
# Define transformer block
class TransformerBlock(nn.Module):
    def __init__(self, n_features, n_heads=8, expansion=1):
        super(TransformerBlock, self).__init__()

        # Set up attributes
        self.n_features = n_features
        self.n_heads = n_heads
        self.expansion = expansion
        
        # Calculate constants
        n_features_inner = int(n_features * expansion)
        self.n_features_inner = n_features_inner

        # Set up multi-head self-attention
        self.self_attn = nn.MultiheadAttention(n_features, n_heads, batch_first=True)

        # Set up feedforward layer
        self.mlp = nn.Sequential(
            nn.Linear(n_features, n_features_inner),
            nn.ReLU(),
            nn.Linear(n_features_inner, n_features),
        )

        # Set up normalization layers
        self.norm1 = nn.LayerNorm(n_features)
        self.norm2 = nn.LayerNorm(n_features)

    def forward(self, x):

        # Apply self-attention
        attn_output, _ = self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + attn_output

        # Feedforward layer
        x = x + self.mlp(self.norm2(x))

        return x

We can use the `TransformerBlock` from above to create a transformer model. However, there is an important difference between the transformer and the CNN model we created earlier. The input shape of a `TransformerBlock` is (batch, patches, features), where patches is the number of patches in the input data and features is the number of features in each patch. This is different from the input shape of a CNN model, which is (batch, channels, height, width). This means that we will need to preprocess the input data differently for the transformer model. In the transformer model, we will need to convert the input data into patches and then flatten the patches into a 2D tensor. We can calculate the number of patches in our model, by specifying a patch size and an image size for the network. Notice that this means that unlike the CNN model, the transformer model will not be able to process images of arbitrary size.

```python

        # Calculate constants
        n_patches = (img_size // patch_size) ** 2
        shape_after_patch = (img_size // patch_size, img_size // patch_size)

```

To preprocess the input data to convert it into patches we can use the `Conv2d` class with a stride and kernel equal to patch size, so that each patch is extracted from the input data. After the transformer is applied to the patches, we can use a `ConvTranspose2d` layer to convert the patches back into the original image shape.

```python

        # Patch embedding
        self.patch_embedding = nn.Sequential(
            nn.Conv2d(in_channels, n_features, kernel_size=patch_size, stride=patch_size),
        )

        # Set up output block
        self.patch_expansion = nn.Sequential(
            nn.ConvTranspose2d(
                self.n_features, 
                out_channels, 
                kernel_size=self.patch_size, 
                stride=self.patch_size,
            ),
        )

```

In the `forward` method we then flatten the patches into a 2D tensor using the `flatten` and `transpose` methods. We can then apply the transformer block to the flattened patches to produce the final output. We then use the `view` method to reshape the output back into the original image shape.

```python

        # Convert image to patch embeddings
        x = self.patch_embedding(x)  # Expected shape: (B, C, H, W)
        x = x.flatten(2)             # Flatten patch embeddings to (B, C, N)
        x = x.transpose(1, 2)        # Transpose for sequence-first format: (B, N, C)

        # ... Perform analysis ...
        
        # Convert patch embeddings to image
        x = x.transpose(1, 2)    # Transpose back to sequence-last format
        x = x.view(x.shape[0], self.n_features, *self.shape_after_patch)
        x = self.patch_expansion(x)

```

Notice that by processing the data in this way we lose information about the spatial structure of the input data. Transformer arcitectures solve this problem by adding a patch embedding to the input data, which allows the model to learn the spatial structure of the input data. Our patch embedding is defined as

```python

        # Set up positional encoding
        pos_embed = torch.zeros(n_patches, n_features)
        position = torch.arange(0, n_patches).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, n_features, 2) * -(math.log(10000.0) / n_features))
        pos_embed[:, 0::2] = torch.sin(position * div_term)
        pos_embed[:, 1::2] = torch.cos(position * div_term)
        pos_embed = pos_embed.unsqueeze(0)

```

Altogether the full transformer model will look like the following

In [None]:
# Define vision transformer class
class VisionTransformer(nn.Module):
    def __init__(self, 
            img_size, in_channels, out_channels,
            n_layers=8, n_features=64,
            patch_size=8, **kwargs
        ):
        super(VisionTransformer, self).__init__()

        # Set attributes
        self.img_size = img_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.patch_size = patch_size
        self.n_layers = n_layers
        self.n_features = n_features

        # Calculate constants
        n_patches = (img_size // patch_size) ** 2
        shape_after_patch = (img_size // patch_size, img_size // patch_size)
        self.n_patches = n_patches
        self.shape_after_patch = shape_after_patch

        # Set up positional encoding
        pos_embed = torch.zeros(n_patches, n_features)
        position = torch.arange(0, n_patches).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, n_features, 2) * -(math.log(10000.0) / n_features))
        pos_embed[:, 0::2] = torch.sin(position * div_term)
        pos_embed[:, 1::2] = torch.cos(position * div_term)
        pos_embed = pos_embed.unsqueeze(0)
        self.register_buffer('pos_embed', pos_embed)

        # Patch embedding
        self.patch_embedding = nn.Sequential(
            nn.Conv2d(in_channels, n_features, kernel_size=patch_size, stride=patch_size),
        )

        # Set up transformer blocks
        self.transformer_blocks = nn.ModuleList()
        for i in range(n_layers):
            self.transformer_blocks.append(TransformerBlock(n_features, **kwargs))

        # Set up output block
        self.patch_expansion = nn.Sequential(
            nn.ConvTranspose2d(
                self.n_features, 
                out_channels, 
                kernel_size=self.patch_size, 
                stride=self.patch_size,
            ),
        )
    
    def forward(self, x):
        """Forward pass."""

        # Convert image to patch embeddings
        x = self.patch_embedding(x)  # Expected shape: (B, C, H, W)
        x = x.flatten(2)             # Flatten patch embeddings to (B, C, N)
        x = x.transpose(1, 2)        # Transpose for sequence-first format: (B, N, C)

        # Add positional embedding
        x = x + self.pos_embed

        # Apply transformer blocks
        for block in self.transformer_blocks:
            x = block(x)
        
        # Convert patch embeddings to image
        x = x.transpose(1, 2)    # Transpose back to sequence-last format
        x = x.view(x.shape[0], self.n_features, *self.shape_after_patch)
        x = self.patch_expansion(x)

        # Return
        return x

## Conclusion

This covers all the basics of creating models in PyTorch. We have learned how to use the `nn.Module` class to define the architecture of a neural network, how to use prebuilt layers to build complex models, and how to create custom layers for specific tasks. We have also learned how to structure the input data for a neural network and how to define the forward pass of the network to produce the output. By following these steps, you can create your own models in PyTorch for a wide range of tasks.