# `ProteinWorkshop` Tutorial, Part 4 - Adding a New Model
![Models](../docs/source/_static/box_models.png)

In [None]:
%load_ext autoreload
%autoreload 2
#%load_ext blackcellmagic

The `ProteinWorkshop` encomposses several models as well as pre-trained weights for them so that you can readily use them. However, you may want to add your own model to the `ProteinWorkshop` to fulfill a specific use case. This tutorial will show you how to do that.

To add your custom model to the `ProteinWorkshop`, you just have to follow the following 4-step procedure (created files in brackets):

1. Create a new subclass of the `nn.Module` class (`my_new_model.py`).
2. Create a new model config file to accompany the custom `MyNewModel` (`my_new_model.yaml`).
3. Compose and instantiate your config for pre-training or finetuning using your model
4. Use your custom model in a pre-training or finetuning task

### Create a new subclass of the `nn.Module` class

Reference the `EGNNModel` below (i.e., `proteinworkshop/models/graph_encoders/egnn.py`) to fill out a custom `proteinworkshop/models/graph_encoders/my_new_model.py` in a similar style.

In [None]:
"""
class EGNNModel(nn.Module):
    def __init__(
        self,
        num_layers: int = 5,
        emb_dim: int = 128,
        activation: str = "relu",
        norm: str = "layer",
        aggr: str = "sum",
        pool: str = "sum",
        residual: bool = True
    ):
        '''E(n) Equivariant GNN model

        Args:
            num_layers: (int) - number of message passing layers
            emb_dim: (int) - hidden dimension
            in_dim: (int) - initial node feature dimension
            out_dim: (int) - output number of classes
            activation: (str) - non-linearity within MLPs (swish/relu)
            norm: (str) - normalisation layer (layer/batch)
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
            pool: (str) - global pooling function (sum/mean)
            residual: (bool) - whether to use residual connections
        '''
        super().__init__()

        # Embedding lookup for initial node features
        self.emb_in = torch.nn.LazyLinear(emb_dim)

        # Stack of GNN layers
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr))

        # Global pooling/readout function
        self.pool = get_aggregation(pool)

        self.residual = residual

    @property
    def required_batch_attributes(self) -> Set[str]:
        return {"x", "pos", "edge_index", "batch"}

    def forward(self, batch) -> EncoderOutput:
        h = self.emb_in(batch.x)  # (n,) -> (n, d)
        pos = batch.pos  # (n, 3)

        for conv in self.convs:
            # Message passing layer
            h_update, pos_update = conv(h, pos, batch.edge_index)

            # Update node features (n, d) -> (n, d)
            h = h + h_update if self.residual else h_update

            # Update node coordinates (no residual) (n, 3) -> (n, 3)
            pos = pos_update

        return EncoderOutput({
            "node_embedding": h,
            "graph_embedding": self.pool(h, batch.batch),  # (n, d) -> (batch_size, d)
            "pos": pos # Position
        })
"""

### 2. Create a new model config file to accompany the custom `MyNewModel`

Reference the `EGNN` config below (i.e., `configs/encoder/egnn.yaml`) to fill out a custom `configs/encoder/my_new_model.yaml`. This config file sets the actual values of the parameters of your model. The parameters present here will depend on the model you implemented; in the case of the `EGNN` model shown as demonstration in this tutorial, these parameters include the number of layers, the embedding dimension and the activation function used.

In [None]:
"""
_target_: "proteinworkshop.models.graph_encoders.egnn.EGNNModel"
num_layers: 6
emb_dim: 512
activation: relu
norm: layer
aggr: "sum"
pool: "sum"
residual: True
"""

### 3. Compose and instantiate your config for pre-training or finetuning using your model

Now we need to use the created config file in our code. To do this, we use `Hydra`, a library that helps with managing configuration options via `.yaml` files.

In the following code block, we initialize Hydra and then compose the `cfg` object which we will later use to perform downstream or pre-training tasks. We can pass `hydra.compose` various overrides in order to customize our setup. We can specify for example:
- the encoder to use (here our new custom model)
- the task to perform later on
- the dataset to use
- the features that are used
- which auxiliary test should be performed (if any)

In [None]:
# Misc. tools
import os

# Hydra tools
import hydra

from hydra.compose import GlobalHydra
from hydra.core.hydra_config import HydraConfig

from proteinworkshop.constants import HYDRA_CONFIG_PATH
from proteinworkshop.utils.notebook import init_hydra_singleton

version_base = "1.2"  # Note: Need to update whenever Hydra is upgraded
init_hydra_singleton(reload=True, version_base=version_base)

path = HYDRA_CONFIG_PATH
rel_path = os.path.relpath(path, start=".")

GlobalHydra.instance().clear()
hydra.initialize(rel_path, version_base=version_base)

cfg = hydra.compose(config_name="train", overrides=["encoder=my_new_model", "task=inverse_folding", "dataset=afdb_swissprot_v4", "features=ca_angles", "+aux_task=none"], return_hydra_config=True)

# Note: Customize as needed e.g., when running a sweep
cfg.hydra.job.num = 0
cfg.hydra.job.id = 0
cfg.hydra.hydra_help.hydra_help = False
cfg.hydra.runtime.output_dir = "outputs"

HydraConfig.instance().set_config(cfg)

### 4. Use your custom model in a pre-training or finetuning task

Now with the config object created, you can make use of the infrastructure that the Protein Workshop provides in order to directly use the config object for training or finetuning a model, depending on what your goal is.

In [None]:
from proteinworkshop.configs import config
from proteinworkshop.finetune import finetune
from proteinworkshop.train import train_model

cfg = config.validate_config(cfg)

# train_model(cfg)  # Pre-train a model using the selected data
# finetune(cfg)  # Fine-tune a model using the selected data

When we instantiated the config, we specified `ca_angles` as feature context. However, we can easily reconfigure the custom model to use side-chain atom context as you can see in the following code block.

In [None]:
version_base = "1.2"  # Note: Need to update whenever Hydra is upgraded
init_hydra_singleton(reload=True, version_base=version_base)

path = HYDRA_CONFIG_PATH
rel_path = os.path.relpath(path, start=".")

GlobalHydra.instance().clear()
hydra.initialize(rel_path, version_base=version_base)

cfg = hydra.compose(config_name="train", overrides=["encoder=my_new_model", "task=inverse_folding", "dataset=afdb_swissprot_v4", "features=ca_sc", "+aux_task=none"], return_hydra_config=True)

# Note: Customize as needed e.g., when running a sweep
cfg.hydra.job.num = 0
cfg.hydra.job.id = 0
cfg.hydra.hydra_help.hydra_help = False
cfg.hydra.runtime.output_dir = "outputs"

HydraConfig.instance().set_config(cfg)

cfg = config.validate_config(cfg)

# train_model(cfg)  # Pre-train a model using the selected data
# finetune(cfg)  # Fine-tune a model using the selected data

### 5. Wrapping up

Have any additional questions about adding your custom model to the `ProteinWorkshop`? [Create a new issue](https://github.com/a-r-j/ProteinWorkshop/issues/new/choose) on our [GitHub repository](https://github.com/a-r-j/ProteinWorkshop). We would be happy to work with you to add your new model to the repository!