# Implementing a Model

Models are broadly defined as classes that represent a specific architecture, such as `ResNet18`. Unlike components, they are generally not as flexible in terms of input arguments, and it should be possible to pass them directly to applications. Models are designed to be easy to use and require minimal configuration to get started. They are also designed to be easily extensible, so that you can add new features without having to modify the existing code.

## What Should Be Implemented as a Model?

The first step is to ensure that what you want to implement is actually a model.
Most models are composed of a few named components (for example, `ConvolutionalNeuralNetwork`), and generally intended as a complete transformation from input to output for a given task.

Most models are standard neural networks with exact architectures (like `ResNet50`), but models can also be more general architectures (like a `RecurrentModel`). 

Unlike components, models generally have a rigid structure. It is not expected that
the number of blocks or the sizes of the layers can be defined in the input arguments. However, if possible, the input and output shapes should be flexible.

Examples of models are `ViT`, `CycleGANGenerator`, `ResNet`, `RecurrentModel`.

## What Should a Model Contain?

Generally, a model should define an `.__init__()` method that takes all the necessary arguments to define the model and a `.forward()` method that defines the forward pass of the model.

Optimally, a model should have a forward pass as simple as possible. A fully sequential forward pass is optimal.
This is because any hard-coded structure in the forward pass limits the flexibility of the model. For example, if the forward pass is defined as `self.conv1(x) + self.conv2(x)`, then it is not possible to replace `self.conv1` and `self.conv2` with a single `self.conv` without modifying the model.

Moreover, the model architecture should in almost all cases be defined purely out of components and operations. Try to limit direct calls to `torch.nn` modules and `blocks`. This is because the `torch.nn` modules are not as flexible as the Deeplay components and operations. If components do not exist for the desired architecture, then it is a good idea to create a new component and add it to the `components` folder.

## Implementing a Model

Here, you'll see the steps you should follow to implement a model in Deeplay. You'll do this implementing the `ResNet18` model.

### 1. Create a New File

The first step is to create a new file in the `deeplay/models` directory. It
can be in a deeper subdirectory if it makes sense.

**The base class.** 
Models generally don't have a fixed base class. Sometimes it makes sense to subclass an existing model, but it is not necessary. It is in some cases possible to subclass a component, if the model is simply that component with some additional layers or with an exact architecture. If neither are applicable, use `DeeplayModule` as the base class.

**Styled components and blocks.**
Special for the implementation of models is the expectation to used styled components
and blocks where possible. This is to ensure that the modules can be reused in other
models.

### 2a. Implement the ResNet18 Block

First, implement the ResNet block as a styled block. It should be implemented
in the same file as the model.

**NOTE:** The style should have a small docstring, just like in the case of a method. The first argument should not be documented (just as `self` is not documented in methods).

In [1]:
from deeplay.blocks import Conv2dBlock

@Conv2dBlock.register_style
def resnet(block: Conv2dBlock, stride: int = 1) -> None:
    """ResNet style block composed of two residual blocks.

    Parameters
    ----------
    stride : int
        Stride of the first block, by default 1
    """
    
    # 1. Create two blocks.
    block.multi(2)

    # 2. Make the two blocks.
    block.blocks[0].style("residual", order="lnaln|a")
    block.blocks[1].style("residual", order="lnaln|a")

    # 3. If stride > 1, stride first block and add normalization to shortcut.
    if stride > 1:
        block.blocks[0].strided(stride)
        block.blocks[0].shortcut_start.normalized()

    # 4. Remove the pooling layer if it exists.
    block[...].isinstance(Conv2dBlock).all.remove("pool", allow_missing=True)

You can now instatiate this block and verify its structure.

In [2]:
block = Conv2dBlock(16, 16).style("resnet")

print(block)

Conv2dBlock(
  (blocks): Sequential(
    (0-1): 2 x Conv2dBlock(
      (shortcut_start): Conv2dBlock(
        (layer): Layer[Identity](in_channels=16, out_channels=16, kernel_size=1, stride=1, padding=0)
      )
      (blocks): Sequential(
        (0): Conv2dBlock(
          (layer): Layer[Conv2d](in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1)
          (normalization): Layer[BatchNorm2d](num_features=16)
          (activation): Layer[ReLU]()
        )
        (1): Conv2dBlock(
          (layer): Layer[Conv2d](in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1)
          (normalization): Layer[BatchNorm2d](num_features=16)
        )
      )
      (shortcut_end): Add()
      (activation): Layer[ReLU]()
    )
  )
)


### 2b. Implement the ResNet18 Input Block

The input block is slightly different from the normal block. You can implement also this block as a styled block.

In [3]:
from deeplay.external.layer import Layer
import torch.nn as nn

@Conv2dBlock.register_style
def resnet18_input(block: Conv2dBlock) -> None:
    """ResNet18 input block.

    The block used on the input of the ResNet18 architecture.
    """
    
    block.configure(kernel_size=7, stride=2, padding=3, bias=False)
    block.normalized(mode="insert", after="layer")
    block.activated(
        Layer(nn.ReLU, inplace=True), mode="insert", after="normalization",
    )
    pool = Layer(
        nn.MaxPool2d, kernel_size=3, stride=2, padding=1, ceil_mode=False, 
        dilation=1,
    )
    block.pooled(pool, mode="append")

Also in this case, you can instantiate this block and verify its architecture.

In [4]:

block = Conv2dBlock(3, 64).style("resnet18_input")

print(block)

Conv2dBlock(
  (layer): Layer[Conv2d](in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
  (normalization): Layer[BatchNorm2d](num_features=64)
  (activation): Layer[ReLU](inplace=True)
  (pool): Layer[MaxPool2d](kernel_size=3, stride=2, padding=1, ceil_mode=False, dilation=1)
)


### 2c. Implement the ResNet18 Backbone

The backbone is a styled component, and should be implemented in the same file as the
model. As it is a convolutional encoder, you can style a `ConvolutionalEncoder2d` component.

In [5]:
from deeplay.components import ConvolutionalEncoder2d
from deeplay.initializers import Kaiming, Constant

@ConvolutionalEncoder2d.register_style
def resnet18(
    encoder: ConvolutionalEncoder2d, 
    pool_output: bool = True,
    set_hidden_channels: bool = False,
) -> None: 
    """ResNet18 backbone.

    Styles a ConvolutionalEncoder2d to have the ResNet18 architecture.

    Parameters
    ----------
    pool_output : bool
        Whether to append a pooling layer at the end of the encoder, by default 
        True.
    set_hidden_channels : bool
        Whether to set the hidden channels to the default ResNet18 values, by 
        default False.
    """

    if set_hidden_channels:
        encoder.configure(hidden_channels=[64, 64, 128, 256])

    # 1. Style the first block.
    encoder.blocks[0].style("resnet18_input")

    # 2. The second block does not have a stride.
    encoder.blocks[1].style("resnet", stride=1)

    # 3. The rest of the blocks have a stride of 2.
    encoder["blocks", 2:].hasattr("style").all.style("resnet", stride=2)

    # 4. Initialize the weights.
    encoder.initialize(Kaiming(targets=(nn.Conv2d,)))
    encoder.initialize(Constant(targets=(nn.BatchNorm2d,)))

    # 5. Set postprocess to pool the output if needed.
    if pool_output:
        encoder.postprocess.configure(nn.AdaptiveAvgPool2d, output_size=(1, 1))

You can now instantiate the backbone and print out its architecture.

In [6]:
backbone = ConvolutionalEncoder2d(3, [16], 32).style("resnet18", pool_output=True)

print(backbone)

ConvolutionalEncoder2d(
  (blocks): LayerList(
    (0): Conv2dBlock(
      (layer): Layer[Conv2d](in_channels=3, out_channels=16, kernel_size=7, stride=2, padding=3)
      (normalization): Layer[BatchNorm2d](num_features=16)
      (activation): Layer[ReLU](inplace=True)
      (pool): Layer[MaxPool2d](kernel_size=3, stride=2, padding=1, ceil_mode=False, dilation=1)
    )
    (1): Conv2dBlock(
      (blocks): Sequential(
        (0): Conv2dBlock(
          (shortcut_start): Conv2dBlock(
            (layer): Layer[Conv2d](in_channels=16, out_channels=32, kernel_size=1, stride=1, padding=0)
          )
          (blocks): Sequential(
            (0): Conv2dBlock(
              (layer): Layer[Conv2d](in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
              (normalization): Layer[BatchNorm2d](num_features=32)
              (activation): Layer[ReLU]()
            )
            (1): Conv2dBlock(
              (layer): Layer[Conv2d](in_channels=32, out_channels=32, ker

### 2d. Implement the ResNet18 Model

You can now finally implement the `ResNet18` model by subclassing `DeeplayModule`.

In [7]:
from deeplay.module import DeeplayModule
from deeplay.components import MultiLayerPerceptron

class ResNet18(DeeplayModule):

    def __init__(self, in_channels=3, latent_channels=512, num_classes=1000):
        self.backbone = ConvolutionalEncoder2d(
            in_channels, 
            [64, 64, 128, 256, 512], 
            latent_channels
        )
        self.backbone.style("resnet18", pool_output=True)

        self.head = MultiLayerPerceptron(latent_channels, [], num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x

### 3. Add Annotations

It's important to add annotations to the class and methods to ensure that the
user knows what to expect. This is also useful for the IDE to provide 
autocomplete.

In [9]:
import torch

class ResNet18(DeeplayModule):

    def __init__(
        self, 
        in_channels: int = 3, 
        latent_channels: int = 512, 
        num_classes: int = 1000,
    ) -> None: 
        self.backbone = ConvolutionalEncoder2d(
            in_channels, 
            [64, 64, 128, 256, 512], 
            latent_channels
        )
        self.backbone.style("resnet18", pool_output=True)

        self.head = MultiLayerPerceptron(latent_channels, [], num_classes)

    def forward(
        self, 
        x: torch.Tensor,  
    ) -> torch.Tensor: 
        x = self.backbone(x)
        x = self.head(x)
        return x

### 4. Document the Model

The next step is to document the model. This should include a description of 
the model, the input and output shapes, and the arguments that can be passed to
the model.

In [10]:
class ResNet18(DeeplayModule):
    """A ResNet18 model.

    A ResNet18 model composed of a ConvolutionalEncoder2d backbone and a 
    MultiLayerPerceptron head.

    Parameters
    ----------
    in_channels : int
        The number of input channels, by default 3.
    latent_channels : int
        The number of latent channels (at the end of the backbone), by default 
        512.
    num_classes : int
        The number of classes, by default 1000.
    
    Attributes
    ----------
    backbone : ConvolutionalEncoder2d
        The backbone of the model.
    head : MultiLayerPerceptron
        The head of the model. By default a simple linear layer.
    
    Input
    -----
    x : torch.Tensor
        The input tensor of shape (N, in_channels, H, W).
        Where N is the batch size, in_channels is the number of input channels,
        H is the height, and W is the width.
        H and W should be at least 33, but ideally 224.

    
    Output
    ------
    y : torch.Tensor
        The output tensor of shape (N, num_classes).
        Where N is the batch size and num_classes is the number of classes.

    Evaluation
    ----------
    ```python
    x = backbone(x)
    x = head(x)
    ```

    Examples
    --------
    >>> model = ResNet18(3, 512, 1000).build()
    >>> x = torch.randn(4, 3, 224, 224)
    >>> y = model(x)
    >>> y.shape
    torch.Size([4, 1000])

    """

    def __init__( 
        self, 
        in_channels: int = 3, 
        latent_channels: int = 512, 
        num_classes: int = 1000,
    ) -> None: 
        self.backbone = ConvolutionalEncoder2d(
            in_channels, 
            [64, 64, 128, 256, 512], 
            latent_channels
        )
        self.backbone.style("resnet18", pool_output=True)

        self.head = MultiLayerPerceptron(latent_channels, [], num_classes)

    def forward( 
        self, 
        x: torch.Tensor,  
    ) -> torch.Tensor:   
        """Forward pass of the model.
        
        Evaluates `backbone` and `head` sequentially.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor of shape (N, in_channels, H, W).
            Where N is the batch size, in_channels is the number of input channels,
            H is the height, and W is the width.
            H and W should be at least 33, but ideally 224.
        
        Returns
        -------
        torch.Tensor
            The output tensor of shape (N, num_classes).
            Where N is the batch size and num_classes is the number of classes.
        """

        x = self.backbone(x)
        x = self.head(x)
        return x