# Fine-Tune the Caduceus Model for Genomic Benchmark Tasks

In this scenario, we will be fine-tuning a pre-trained [Caduceus](https://caduceus-dna.github.io/) model (Schiff, Y. et al., 2024) for a range of DNA sequence classification tasks published on Genomic Benchmarks (Grešová, K., Martinek, V., Čechák, D. et al., 2023).

Like HyenaDNA, Caduceus is a State Space Model (SSM), allowing for highly efficient training of long context windows. SSM's have been shown to exceed the performance of attention-based models like DNABERT with orders of magnitude fewer parameters.

Caduceus is based on (1) a bi-directional implementation of the Mamba block (BiMamba). Mamba's *selective state space* implementation has shown promising scaling performance compared to Transformers. (2) The Caduceus model also introduces a novel module to enforce **Reverse Complement (RC) Equivariance** (MambaDNA). 

![Caduceus Comparison](https://caduceus-dna.github.io/static/images/caducues_comparison.png)

At this time, Caduceus represented the state of the art performance on all of the Genomic Benchmark tasks - even surpassing HyenaDNA.

![Results](https://caduceus-dna.github.io/static/images/experiments/nt_benchmark.png)

In this notebook, we will attempt to replicate some of the results from the paper by fine-tuning a smaller, pre-trained variant of Caduceus (with an added classification head) on a range of Genomic Benchmark tasks. There are (3) key ideas to look out for in this SageMaker Training Job.

1. **AWS HealthOmics Integration (Optional)** - If you followed the [`load_genomic_benchmarks_to_omics.ipynb`](load_genomic_benchmarks_to_omics.ipynb) notebook, you can access those read sets as fine-tuning training data. This demonstrates a workflow for HCLS organizations who may have proprietary DNA sequence datasets that they wish to use to train their own models. Otherwise, the benchmark datasets will be loaded directly from HuggingFace.
2. **Distributed Training with Distributed Data Parallel (DDP)** - the training script is setup to run DDP in multi-GPU instances. In this scenario, the full model weights are loaded onto each GPU, but the dataset is partitioned. At each step, the nodes share their gradients, all-reduce fashion.
3. **Parameter Efficient Fine-Tuning (PEFT)** - you can also choose to run the fine-tuning via PEFT, which significantly limits how many of the parameters will actually be learnable during fine-tuning (in our case, only about 9-10% of the parameters will be trainable). This greatly reduces memory requirements and speeds up performance.

In [None]:
import sagemaker
import boto3
import json

iam_client = boto3.client('iam')
role = sagemaker.get_execution_role()
sess = sagemaker.Session()

REGION_NAME = sess.boto_region_name
S3_BUCKET = sess.default_bucket()
ACCOUNT_ID = sess.account_id()

print(role, S3_BUCKET)

## 1. (Optional) Integration with AWS HealthOmics

**Note: This step is not required in order to run the fine-tuning job for Genomic Benchmarks datasets. If you do not want to use HealthOmics to store the sequences, we will alternatively load them directly from [HuggingFace](https://huggingface.co/katarinagresova).**

If you followed the optional [`load_genomic_benchmarks_to_omics.ipynb`](load_genomic_benchmarks_to_omics.ipynb) notebook, you can access those read sets as fine-tuning training data.

At runtime, the FASTQ read sets will be loaded from the sequence store and parsed into a [`datasets.Dataset`](https://huggingface.co/docs/datasets/en/index). 

If this is the case, there is one additional step required to ensure that the SageMaker Execution role has the appropriate access to the read sets. 

In [None]:
# optional - get from create_omics_dataset.ipynb or replace with None
# SEQUENCE_STORE_ID = 9757315158
SEQUENCE_STORE_ID = None  # to load datasets directly from HF

If you have a sequence store ID and have loaded in the readsets for the benchmarks, the output of the following cell is the policy that you should attach to the role.

In [None]:
# if using Omics, make sure that the execution role has access to GetReadSet and ListReadSets

if SEQUENCE_STORE_ID:
    omics_policy = json.dumps({
        "Version": "2012-10-17",
        "Statement": [
            {
                "Effect": "Allow",
                "Action": [
                    "omics:GetReadSet", 
                    "omics:ListReadSets",
                ],
                "Resource": f"arn:aws:omics:{REGION_NAME}:767398100082:sequenceStore/{SEQUENCE_STORE_ID}/readSet/*"
            },
        ]
    }, indent=2)
    print(omics_policy)
    print(sagemaker.get_execution_role())

## 2. Training

The [training script](scripts/train_caduceus_dist.py) can be configured for any of the Benchmark tasks with a range of different hyperparameters. 

Feel free to experiment with different learning rates/schedulers, pre-trained model variants, and PEFT configurations. 

Below is a table of the accuracies I was able to produce on some of the tasks, compared to HyenaDNA and DNABERT.

|           Task           | Caduceus-1.93M | HyenaDNA-6.6M [1]	| DNABERT-110M [2] |
| ------------------------ | -------------- | ------------- | ------------ |
| Human vs. Worm	       | 0.964	        | 0.966	        | 0.965        |
| Human Enhancers Cohn	   | 0.747	        | 0.742	        | 0.740        |
| Human Enhancers Ensembl  | 0.874	        | 0.892	        | 0.857        | 
| Human Nontata Promoters  | 0.905	        | 0.966	        | 0.856        | 
| Mouse Enhancers          | 0.723	        | 0.851	        | 0.669        | 

Even with the PEFT configuration and when using the smaller 1.93M variant of Caduceus, we are generally able to match or exceed the HyenaDNA results. **Perhaps even more exciting is that we beat the DNABERT results - a model with >50x as many parameters!**

[[1] Nguyen, et. al, 2023](https://arxiv.org/pdf/2306.15794)

[[2] Zhou, et. al, 2023](https://arxiv.org/pdf/2306.15006)

In [None]:
from sagemaker.pytorch import PyTorch
from sagemaker.debugger import TensorBoardOutputConfig

from datetime import datetime
import os

task_config = {
    "demo_human_or_worm": {"epochs": 4},
    "dummy_mouse_enhancers_ensembl": {"epochs": 100},
    "human_enhancers_cohn": {"epochs": 20},
    "human_enhancers_ensembl": {"epochs": 25},
    "human_nontata_promoters": {"epochs": 25},
    "human_ocr_ensembl": {"epochs": 5},
}

task = "dummy_mouse_enhancers_ensembl"

print(f"TASK: {task}")

# hyperparameters which are passed to the training job
hyperparameters = {
    "epochs": task_config[task]["epochs"],
    "per_device_train_batch_size": 128,
    # "model_name": "kuleshov-group/caduceus-ps_seqlen-1k_d_model-118_n_layer-4_lr-8e-3",  # tiny ~0.47M 
    # "model_name": "kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16",       # large ~7.73M
    "model_name": "kuleshov-group/caduceus-ps_seqlen-1k_d_model-256_n_layer-4_lr-8e-3",  # small ~1.93M  
    "benchmark_name": task,
    "peft": True,
    "learning_rate": 1e-3,
    "weight_decay": 0.0,
    "sequence_store_id": SEQUENCE_STORE_ID,
    "region": REGION_NAME,
}

now_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

LOG_DIR = "/opt/ml/output/tensorboard"
TRAINING_JOB_NAME = f"caduceus-{hyperparameters['benchmark_name'].replace('_', '-')}-{now_str}"

output_path = os.path.join(
    "s3://", S3_BUCKET, "sagemaker-output", "training", TRAINING_JOB_NAME
)

tensorboard_output_config = TensorBoardOutputConfig(
    s3_output_path=os.path.join(output_path, 'tensorboard'),
    container_local_output_path=LOG_DIR
)

image_uri = pytorch_image_uri = f"763104351884.dkr.ecr.{REGION_NAME}.amazonaws.com/pytorch-training:2.2.0-gpu-py310-cu121-ubuntu20.04-sagemaker"

# create the Estimator
estimator = PyTorch(
    entry_point="train_caduceus_dist.py",
    source_dir='./scripts',
    instance_type="ml.g5.12xlarge",  # multi-GPU to take advantage of data parallel
    instance_count=1,
    role=role,
    image_uri=image_uri,
    hyperparameters=hyperparameters,
    tensorboard_output_config=tensorboard_output_config,
    keep_alive_period_in_seconds=1800,
    distribution={"torch_distributed": {"enabled": True}},
)


By default, the `estimator.fit()` method below will stream the job logs to the notebook synchronously, but you can submit the job asynchronously by setting `wait=False`, e.g. 

`estimator.fit(..., wait=False)`

In [None]:
# run the training job synchronously
estimator.fit(job_name=TRAINING_JOB_NAME)

For monitoring results, the job has been configured to log performance metrics to [TensorBoard](https://docs.aws.amazon.com/sagemaker/latest/dg/tensorboard-on-sagemaker.html). 

![TensorBoard](images/tensorboard.png)