# Implementing a model

Here we will show the steps you should follow to implement a model in deeplay.

## 1. Should I implement a model?

The first step is to ensure that what you want to implement is actually a model.
Most models are composed of a few named components (e.g. ConvolutionalNeuralNetwork),
and generally intended as a complete transformation from input to output for a task.

Most models are standard nerual networks with exact architectures (like ResNet50), but
models can also be more general, like a RecurrentModel. 

Unlike components, models generally have a more rigid structure. It is not expected that
the number if blocks or the sizes of the layers can be defined in the input arguments. 
However, if possible, the input and output shapes should be flexible.

Examples of models are:
- ViT
- CycleGANGenerator
- ResNet50
- RecurrentModel

## 2. Implementing the model

The first step is to create a new file in the `deeplay/models` directory. It
can be in a deeper subdirectory if it makes sense.

# 2.1 The base class

models generally does not have a base class. Sometimes it makes sense to subclass an
existing model, but it is not necessary. It is in some cases possible to subclass a
component, if the model is simply that component with some additional layers or with
an exact architecture.

If neither are applicable, use `DeeplayModule` as the base class.

# 2.2 Example, ResNet18

Special for the implementation of models is the expectation to used styled components
and blocks where possible. This is to ensure that the modules can be reused in other
models.

## 2.2.1 The ResNet18 block

First we implement the resnet block. This is a styled block, and should be implemented
in the same file as the model.

Note, the style should have a small docstring, as though it was a method. The first
argument should not be documented (just as `self` is not documented in methods).

In [1]:
from deeplay.blocks import Conv2dBlock

@Conv2dBlock.register_style
def resnet(block: Conv2dBlock, stride: int = 1):
    """ResNet style block composed of two residual blocks.

    Parameters
    ----------
    stride : int
        Stride of the first block, by default 1
    """
    # 1. create two blocks
    block.multi(2)

    # 2. make the two blocks
    block.blocks[0].style("residual", order="lnaln|a")
    block.blocks[1].style("residual", order="lnaln|a")

    # 3. if stride > 1, stride the first block and add normalization to the shortcut
    if stride > 1:
        block.blocks[0].strided(stride)
        block.blocks[0].shortcut_start.normalized()

    # 4. remove the pooling layer if it exists.
    block[...].isinstance(Conv2dBlock).all.remove("pool", allow_missing=True)

block = Conv2dBlock(16, 16).style("resnet")
block

Conv2dBlock(
  (blocks): Sequential(
    (0-1): 2 x Conv2dBlock(
      (shortcut_start): Conv2dBlock(
        (layer): Layer[Identity](in_channels=16, out_channels=16, kernel_size=1, stride=1, padding=0)
      )
      (blocks): Sequential(
        (0): Conv2dBlock(
          (layer): Layer[Conv2d](in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1)
          (normalization): Layer[BatchNorm2d](num_features=16)
          (activation): Layer[ReLU]()
        )
        (1): Conv2dBlock(
          (layer): Layer[Conv2d](in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1)
          (normalization): Layer[BatchNorm2d](num_features=16)
        )
      )
      (shortcut_end): Add()
      (activation): Layer[ReLU]()
    )
  )
)

## 2.2.2 The ResNet18 input block

The input block is slightly different from the normal block. We also implement this
as a styled block.

In [2]:
from deeplay.external.layer import Layer
import torch.nn as nn

@Conv2dBlock.register_style
def resnet18_input(block: Conv2dBlock):
    """ResNet18 input block.

    The block used on the input of the ResNet18 architecture.
    """

    
    block.configure(kernel_size=7, stride=2, padding=3, bias=False)
    block.normalized(mode="insert", after="layer")
    block.activated(Layer(nn.ReLU, inplace=True), mode="insert", after="normalization")
    pool = Layer(nn.MaxPool2d, kernel_size=3, stride=2, padding=1, ceil_mode=False, 
                 dilation=1)
    block.pooled(pool, mode="append")

block = Conv2dBlock(3, 64).style("resnet18_input")
block

Conv2dBlock(
  (layer): Layer[Conv2d](in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
  (normalization): Layer[BatchNorm2d](num_features=64)
  (activation): Layer[ReLU](inplace=True)
  (pool): Layer[MaxPool2d](kernel_size=3, stride=2, padding=1, ceil_mode=False, dilation=1)
)

## 2.2.3 The ResNet18 backbone

The backbone is a styled component, and should be implemented in the same file as the
model. As it is a convolutional encoder, we style a `ConvolutionalEncoder2d` component.

In [3]:
from deeplay.components import ConvolutionalEncoder2d
from deeplay.initializers import Kaiming, Constant

@ConvolutionalEncoder2d.register_style
def resnet18(encoder: ConvolutionalEncoder2d, 
             pool_output: bool = True,
             set_hidden_channels: bool = False):
    """ResNet18 backbone.

    Styles a ConvolutionalEncoder2d to have the ResNet18 architecture.

    Parameters
    ----------
    pool_output : bool
        Whether to append a pooling layer at the end of the encoder, by default True
    set_hidden_channels : bool
        Whether to set the hidden channels to the default ResNet18 values, by default 
        False
    """

    if set_hidden_channels:
        encoder.configure(hidden_channels=[64, 64, 128, 256])

    # 1. style the first block
    encoder.blocks[0].style("resnet18_input")
    # 2. the second block does not have a stride
    encoder.blocks[1].style("resnet", stride=1)

    # 3. the rest of the blocks have a stride of 2
    encoder["blocks", 2:].hasattr("style").all.style("resnet", stride=2)

    # 4. initialize the weights
    encoder.initialize(Kaiming(targets=(nn.Conv2d,)))
    encoder.initialize(Constant(targets=(nn.BatchNorm2d,)))

    # 5. set postprocess to pool the output if needed
    if pool_output:
        encoder.postprocess.configure(nn.AdaptiveAvgPool2d, output_size=(1, 1))
        
encoder = ConvolutionalEncoder2d(3, [16], 32).style("resnet18", pool_output=True)
encoder

ConvolutionalEncoder2d(
  (blocks): LayerList(
    (0): Conv2dBlock(
      (layer): Layer[Conv2d](in_channels=3, out_channels=16, kernel_size=7, stride=2, padding=3)
      (normalization): Layer[BatchNorm2d](num_features=16)
      (activation): Layer[ReLU](inplace=True)
      (pool): Layer[MaxPool2d](kernel_size=3, stride=2, padding=1, ceil_mode=False, dilation=1)
    )
    (1): Conv2dBlock(
      (blocks): Sequential(
        (0): Conv2dBlock(
          (shortcut_start): Conv2dBlock(
            (layer): Layer[Conv2d](in_channels=16, out_channels=32, kernel_size=1, stride=1, padding=0)
          )
          (blocks): Sequential(
            (0): Conv2dBlock(
              (layer): Layer[Conv2d](in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
              (normalization): Layer[BatchNorm2d](num_features=32)
              (activation): Layer[ReLU]()
            )
            (1): Conv2dBlock(
              (layer): Layer[Conv2d](in_channels=32, out_channels=32, ker

### 2.2.4 The ResNet18 model

Finally, we implement the model. We subclass `DeeplayModule`.

In [4]:
from deeplay.module import DeeplayModule
from deeplay.components import MultiLayerPerceptron

class ResNet18(DeeplayModule):

    def __init__(self, in_channels: int = 3, latent_channels: int = 512, num_classes: int = 1000):
        self.backbone = ConvolutionalEncoder2d(
            in_channels, 
            [64, 64, 128, 256, 512], 
            latent_channels
        )
        self.backbone.style("resnet18", pool_output=True)

        self.head = MultiLayerPerceptron(latent_channels, [], num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x

# 2.3 Annotations

It is important to add annotations to the class and methods to ensure that the
user knows what to expect. This is also useful for the IDE to provide 
autocomplete.

In [5]:
from deeplay.module import DeeplayModule
from deeplay.components import MultiLayerPerceptron

class ResNet18(DeeplayModule):

    backbone: ConvolutionalEncoder2d
    head: MultiLayerPerceptron

    def __init__(self, in_channels: int = 3, latent_channels: int = 512, num_classes: int = 1000):
        self.backbone = ConvolutionalEncoder2d(
            in_channels, 
            [64, 64, 128, 256, 512], 
            latent_channels
        )
        self.backbone.style("resnet18", pool_output=True)

        self.head = MultiLayerPerceptron(latent_channels, [], num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x

## 2.3 Documenting the model

The next step is to document the model. This should include a description of 
the model, the input and output shapes, and the arguments that can be passed to
the model.

In [6]:
class ResNet18(DeeplayModule):
    """A ResNet18 model.

    A ResNet18 model composed of a ConvolutionalEncoder2d backbone and a MultiLayerPerceptron head.

    Parameters
    ----------
    in_channels : int
        The number of input channels, by default 3
    latent_channels : int
        The number of latent channels (at the end of the backbone), by default 512
    num_classes : int
        The number of classes, by default 1000
    
    Attributes
    ----------
    backbone : ConvolutionalEncoder2d
        The backbone of the model
    head : MultiLayerPerceptron
        The head of the model. By default a simple linear layer.
    
    Input
    -----
    x : torch.Tensor
        The input tensor of shape (N, in_channels, H, W).
        Where N is the batch size, in_channels is the number of input channels,
        H is the height, and W is the width.
        H and W should be at least 33, but ideally 224.

    
    Output
    ------
    y : torch.Tensor
        The output tensor of shape (N, num_classes).
        Where N is the batch size and num_classes is the number of classes.

    Evaluation
    ----------
    ```python
    x = backbone(x)
    x = head(x)
    ```

    Examples
    --------
    >>> model = ResNet18(3, 512, 1000).build()
    >>> x = torch.randn(4, 3, 224, 224)
    >>> y = model(x)
    >>> y.shape
    torch.Size([4, 1000])

    """
    # [rest of the code]