# `ProteinWorkshop` Tutorial, Part 5 - Adding a New Task
![Tasks](../docs/source/_static/box_downstream_tasks.png)

In [None]:
%load_ext autoreload
%autoreload 2
# %load_ext blackcellmagic

Welcome to the last entry in the `ProteinWorkshop` tutorial series!

The task is the centerpiece of our framework. It is the thing you will use to specify other things like the dataset or the model you want to use. It is also the thing that will be used to specify the loss function and the metrics you want to use. In this tutorial, we will show you how to add a new task to the `ProteinWorkshop`.

To add your custom task to the `ProteinWorkshop`, you just have to follow the following 4-step procedure (created files in brackets):

1. Create a new subclass of the `BaseTransform` class from `torch_geometric.transforms`(`my_new_task.py`)
2. Create a new task config file to accompany the custom `BaseTransform` (`my_new_task.yaml`)
3. Compose and instantiate your config for pre-training or finetuning using your task
4. Use your custom task

### 1. Create a new subclass of the `BaseTransform` class from `torch_geometric.transforms`(`my_new_task.py`)

Reference the `SequenceNoiseTransform` below (i.e., `proteinworkshop/tasks/sequence_denoising.py`) to fill out a custom `proteinworkshop/tasks/my_new_task.py`

In [None]:
"""
class SequenceNoiseTransform(BaseTransform):
    def __init__(
        self, corruption_rate: float, corruption_strategy: Literal["mutate", "mask"]
    ):
        self.corruption_rate = corruption_rate
        self.corruption_strategy = corruption_strategy

    @property
    def required_attributes(self) -> Set[str]:
        return {"residue_type"}

    @jaxtyped(typechecker=typechecker)
    def __call__(self, x: Union[Data, Protein]) -> Union[Data, Protein]:
        x.residue_type_uncorrupted = copy.deepcopy(x.residue_type)
        # Get indices of residues to corrupt
        indices = torch.randint(
            0,
            x.residue_type.shape[0],
            (int(x.residue_type.shape[0] * self.corruption_rate),),
            device=x.residue_type.device,
        ).long()

        # Apply corruption
        if self.corruption_strategy == "mutate":
            # Set indices to random residue type
            x.residue_type[indices] = torch.randint(
                0,
                23,  # TODO: probably best to not hardcode this
                (indices.shape[0],),
                device=x.residue_type.device,
            )
        elif self.corruption_strategy == "mask":
            # Set indices to 23 -> "UNK"
            x.residue_type[indices] = 23  # TODO: probably best to not hardcode this
        else:
            raise NotImplementedError(
                f"Corruption strategy: {self.corruption_strategy} not supported."
            )
        # Get indices of applied corruptions
        index = torch.zeros(x.residue_type.shape[0])
        index[indices] = 1
        x.sequence_corruption_mask = index.bool()

        return x

    def __repr__(self) -> str:
        return f"{self.__class__}(corruption_strategy: {self.corruption_strategy} corruption_rate: {self.corruption_rate})"
"""

### 2. Create a new task config file to accompany the custom `BaseTransform` (`my_new_task.yaml`)

Reference the `sequence_denoising` config below (i.e., `proteinworkshop/config/task/sequence_denoising.yaml`) to fill out a custom `proteinworkshop/config/task/my_new_task.yaml`.

This config file sets the actual values of the parameters of your task. This includes default options like the metrics, the decoder or the transforms to use for your dataset, as well as options for your specific dataset, the callbacks used and other things.

In [None]:
"""
# @package _global_

defaults:
  - override /metrics:
      - accuracy
      - f1_score
      - perplexity
  - override /decoder:
      - residue_type
  - override /transforms:
      - remove_missing_ca
      - sequence_denoising

dataset:
  num_classes: 23

callbacks:
  early_stopping:
    monitor: val/residue_type/accuracy
    mode: "max"
  model_checkpoint:
    monitor: val/residue_type/accuracy
    mode: "max"

task:
  task: "sequence_denoising"
  classification_type: "multiclass"
  metric_average: "micro"

  losses:
    residue_type: cross_entropy
  label_smoothing: 0.0

  output:
    - residue_type
  supervise_on:
    - residue_type
"""

### 3. Compose and instantiate your config for pre-training or finetuning using your task

Now we need to use the created config file in our code. To do this, we use `Hydra`, a library that helps with managing configuration options via `.yaml` files.

In the following code block, we initialize Hydra and then compose the `cfg` object which we will later use to perform downstream or pre-training tasks. We can pass `hydra.compose` various overrides in order to customize our setup. We can specify for example:
- the encoder to use
- the task to perform later on (here our task `my_new_task`)
- the dataset to use
- the features that are used
- which auxiliary test should be performed (if any)

In [None]:
# Misc. tools
import os

# Hydra tools
import hydra

from hydra.compose import GlobalHydra
from hydra.core.hydra_config import HydraConfig

from proteinworkshop.constants import HYDRA_CONFIG_PATH
from proteinworkshop.utils.notebook import init_hydra_singleton

version_base = "1.2"  # Note: Need to update whenever Hydra is upgraded
init_hydra_singleton(reload=True, version_base=version_base)

path = HYDRA_CONFIG_PATH
rel_path = os.path.relpath(path, start=".")

GlobalHydra.instance().clear()
hydra.initialize(rel_path, version_base=version_base)

cfg = hydra.compose(
    config_name="train",
    overrides=[
        "encoder=schnet",
        "task=my_new_task",
        "dataset=afdb_swissprot_v4",
        "features=ca_angles",
        "+aux_task=none",
    ],
    return_hydra_config=True,
)

# Note: Customize as needed e.g., when running a sweep
cfg.hydra.job.num = 0
cfg.hydra.job.id = 0
cfg.hydra.hydra_help.hydra_help = False
cfg.hydra.runtime.output_dir = "outputs"

HydraConfig.instance().set_config(cfg)

### 4. Use your custom task

Now with the config object created, you can make use of the infrastructure that `ProteinWorkshop` provides in order to directly use the config object for training or finetuning a model, depending on what your goal is.

In [None]:
from proteinworkshop.configs import config
from proteinworkshop.finetune import finetune
from proteinworkshop.train import train_model

cfg = config.validate_config(cfg)

# train_model(cfg)  # Pre-train a model using the selected data
# finetune(cfg)  # Fine-tune a model using the selected data

When we instantiated the config, we specified `ca_angles` as feature context. However, we can easily reconfigure the custom dataset to use side-chain atom context as you can see in the following code block.

In [None]:
version_base = "1.2"  # Note: Need to update whenever Hydra is upgraded
init_hydra_singleton(reload=True, version_base=version_base)

path = HYDRA_CONFIG_PATH
rel_path = os.path.relpath(path, start=".")

GlobalHydra.instance().clear()
hydra.initialize(rel_path, version_base=version_base)

cfg = hydra.compose(
    config_name="train",
    overrides=[
        "encoder=schnet",
        "task=my_new_task",
        "dataset=afdb_swissprot_v4",
        "features=ca_sc",
        "+aux_task=none",
    ],
    return_hydra_config=True,
)

# Note: Customize as needed e.g., when running a sweep
cfg.hydra.job.num = 0
cfg.hydra.job.id = 0
cfg.hydra.hydra_help.hydra_help = False
cfg.hydra.runtime.output_dir = "outputs"

HydraConfig.instance().set_config(cfg)

cfg = config.validate_config(cfg)

# train_model(cfg)  # Pre-train a model using the selected data
# finetune(cfg)  # Fine-tune a model using the selected data

### 5. Wrapping up

Have any additional questions about adding your custom task to the `ProteinWorkshop`? [Create a new issue](https://github.com/a-r-j/ProteinWorkshop/issues/new/choose) on our [GitHub repository](https://github.com/a-r-j/ProteinWorkshop). We would be happy to work with you to add your new task to the repository!