# Fine Tune and Benchmark Geneformer (Single cell RNA-Seq foundation model) For Cell Type/Cell State Classification

# 0. Import Dependencies


In [None]:
%pip install --disable-pip-version-check -q -U 'boto3==1.35.16' 'sagemaker==2.231.0' 'mlflow==2.13.2' 'sagemaker-mlflow==0.1.0'

In [None]:
import os
from time import gmtime, strftime

import sagemaker
import boto3
import mlflow
from sagemaker.processing import (
    FrameworkProcessor, ScriptProcessor,
    ProcessingInput, ProcessingOutput
)
from sagemaker.sklearn.estimator import SKLearn
from sagemaker.sklearn.processing import SKLearnProcessor
from sagemaker.experiments.run import Run
from sagemaker.inputs import TrainingInput
from sagemaker.pytorch import PyTorch
from sagemaker.tuner import (
    CategoricalParameter,
    ContinuousParameter,
    HyperparameterTuner,
    IntegerParameter,
    HyperbandStrategyConfig,
    StrategyConfig
)
from sagemaker.sklearn.model import SKLearnModel
from sagemaker.deserializers import CSVDeserializer
from sagemaker.serializers import CSVSerializer
from IPython.core.display import display, HTML

# 1. Preparations

## 1.1 Create Some Necessary Clients

In [None]:
boto_session = boto3.session.Session()
region = boto_session.region_name
sagemaker_session = sagemaker.session.Session(boto_session)
sagemaker_execution_role = sagemaker.session.get_execution_role(sagemaker_session)
sagemaker_boto_client = boto_session.client("sagemaker")
s3_boto_client = boto_session.client("s3")
account_id = boto_session.client("sts").get_caller_identity().get("Account")
print(f"Assumed SageMaker role is {sagemaker_execution_role}")

## 1.2. Specify S3 Bucket and Prefix

In [None]:
S3_BUCKET = sagemaker_session.default_bucket()
S3_PREFIX = "scrnaseq-fm-finetune"
S3_PATH = sagemaker.s3.s3_path_join(S3_BUCKET, S3_PREFIX)
print(f"S3 path is {S3_PATH}")

# 2. Data Preparation with Amazon SageMaker Processing


Here we download an example 10x scRNA-Seq dataset, and add cell type annotation using marker gene expression as the ground truth for our classification task.

## 2.1. Define parameters of the SageMaker Processing Job

In [None]:
processing_job_name = "sc-preprocess-hao2021"
print("Preparing and splitting scRNASeq dataset and saving adata", processing_job_name)

sklearn_processor = SKLearnProcessor(
    framework_version="1.2-1",
    role=sagemaker_execution_role,
    instance_type="ml.m5.4xlarge",
    volume_size_in_gb=20,
    instance_count=1,
    base_job_name=processing_job_name
)

# Run processor 
sklearn_processor.run(
    inputs=[
            ProcessingInput(
                input_name="requirements",
                source="scripts/processing/processing_requirements.txt",
                destination="/opt/ml/processing/input/requirements/",
            )
    ],
    outputs=[
        ProcessingOutput(
            output_name="class_labels",
            source="/opt/ml/processing/h5ad_data/class_labels",
            destination=os.path.join(
                    "s3://{}".format(S3_BUCKET),
                    S3_PREFIX,
                    processing_job_name,
                    "class_labels",
            ),
        ),
        ProcessingOutput(
            output_name="train",
            source="/opt/ml/processing/h5ad_data/train",
            destination=os.path.join(
                    "s3://{}".format(S3_BUCKET),
                    S3_PREFIX,
                    processing_job_name,
                    "train",
            ),
        ),
        ProcessingOutput(
            output_name="validation",
            source="/opt/ml/processing/h5ad_data/val",
            destination=os.path.join(
                    "s3://{}".format(S3_BUCKET),
                    S3_PREFIX,
                    processing_job_name,
                    "validation",
            ),
        ),
        ProcessingOutput(
            output_name="test",
            source="/opt/ml/processing/h5ad_data/test",
            destination=os.path.join(
                    "s3://{}".format(S3_BUCKET),
                    S3_PREFIX,
                    processing_job_name,
                    "test",
            ),
        ),
    ],
    code="scripts/processing/process_10xpbmc3k.py",
    arguments=["--train_size", "0.8", "--split_by_group"]
)

# 3. ML training with SageMaker Training Jobs

## 3.1. Using MLflow to track model training experiments
Create an MLflow tracking server in SageMaker Studio. Copy the tracking server ARN to the variable `tracking_server_arn` below.

Update the SageMaker service role to have the following policy to enable ML flow tracking:

```
{
    "Version": "2012-10-17",    
    "Statement": [        
        {            
            "Effect": "Allow",            
            "Action": [
                "sagemaker-mlflow:*",
                "sagemaker:CreateMlflowTrackingServer",
                "sagemaker:UpdateMlflowTrackingServer",
                "sagemaker:DeleteMlflowTrackingServer",
                "sagemaker:StartMlflowTrackingServer",
                "sagemaker:StopMlflowTrackingServer",
                "sagemaker:CreatePresignedMlflowTrackingServerUrl"
            ],            
            "Resource": "*"        
        }        
    ]
}
```

In [None]:
tracking_server_arn = <YOUR TRACKING SERVER ARN>

## 3.2. Train a Logistic Regression model using normalized counts as the Baseline 

For comparing model performance, it is always good to have a simple baseline. We will use a simple logistic regression model, taking as input the normalized expression counts to predict cell type as the baseline.

### 3.2.1 Define and fit a SKLearn estimator, logging the run to an MLFlow experiment

In [None]:
lr_job_name = f"baseline-LR"
model_output_path = f"s3://{S3_BUCKET}/{S3_PREFIX}/trained_models/"
experiment_name = "scRNASeq-baseline"
# Uncomment for setting up MLflow exp
mlflow.set_tracking_uri(tracking_server_arn)
experiment = mlflow.set_experiment(experiment_name)

dataset_name = "pbmc3k"
processing_job_name = "sc-preprocess-hao2021"

lr_estimator = SKLearn(
    base_job_name=lr_job_name,
    enable_sagemaker_metrics=True,
    entry_point="baseline_lr_train_mlflow.py",
    framework_version="1.2-1",
    hyperparameters={
        'penalty': 'l2',
        'class_weight': 'balanced',
        'max_iter': 1000,
        'solver': 'saga', 
        'dataset_name': dataset_name
    },
    instance_count=1,
    instance_type="ml.c5.4xlarge",
    output_path=model_output_path,
    role=sagemaker_execution_role,
    sagemaker_session=sagemaker_session,
    source_dir="scripts/training/lr/",
    environment={
        "MLFLOW_TRACKING_URI": tracking_server_arn,
        "MLFLOW_EXPERIMENT_NAME": experiment.name
    }
)

train_s3_url      = f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/train/{dataset_name}_train.h5ad"
validation_s3_url = f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/validation/{dataset_name}_val.h5ad"
test_s3_url       = f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/test/{dataset_name}_test.h5ad"
labels_s3_url     = f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/class_labels/{dataset_name}_celltype_labels.pkl"

print(f"train s3 URL: {train_s3_url}")
print(f"validation s3 URL: {validation_s3_url}")
print(f"test s3 URL: {test_s3_url}")
print(f"labels s3 URL: {labels_s3_url}")

lr_estimator.fit(
    {'train': train_s3_url,
     'validation': validation_s3_url,
    'test': test_s3_url,
    'labels': labels_s3_url},
)

## 3.3. Fine tune Geneformer scRNA-Seq FM

As mentioned in the Introduction, Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
The pretrained model outputs dense vector embeddings of cells. We can fine tune it with a labeled dataset to perform cell type classification.

### 3.3.1. Pre-requisite

Build a docker image using the docker file in `scripts/Dockerfile` and push to your ECR repo. For example (if you are in the repo base directory) you could do something like:

```
cd geneformer/scripts
docker build -t geneformerft .
# These instructions are from the ECR repository web console:
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin <YOUR AWS ACCOUNT ID>
docker tag geneformerft:latest <YOUR AWS ACCOUNT ID>/geneformerft:latest
docker push <YOUR AWS ACCOUNT ID>/geneformerft:latest
```

Finally, copy the image uri below: 

In [None]:
training_img_uri = "<YOUR AWS ACCOUNT ID>/geneformerft:latest"

### 3.3.2 Define a Pytorch estimator with custom image, fit and track metrics

In [None]:
processing_job_name = "sc-preprocess"

In [None]:
if not processing_job_name.endswith("-hao2021"):
    processing_job_name = processing_job_name + "-hao2021"

In [None]:
S3_BUCKET, S3_PREFIX, processing_job_name

In [None]:
model_output_path = f"s3://{S3_BUCKET}/{S3_PREFIX}/trained_models/"
gf_training_job_name = 'geneformer-ft-testmlflow'
# Additional training parameters
hyperparameters = {
    'model_name': 'gf-12L-30M-i2048',
    'max_lr': 5e-05,
    'freeze_layers': 6,
    'num_gpus': 1,
    'num_proc': 16,
    'geneformer_batch_size': 20,
    'lr_schedule_fn': 'linear',
    'warmup_steps': 200,
    'epochs': 10,
    'optimizer': 'adamw'
}

# Set up MLflow tracking
experiment_name = "scRNASeq-fm"
mlflow.set_tracking_uri(tracking_server_arn)
experiment = mlflow.set_experiment(experiment_name)

# Create PyTorch estimator
geneformer_estimator = PyTorch(
    base_job_name=gf_training_job_name,
    entry_point="ft_geneformer_mlflow.py",
    source_dir="scripts/training/geneformer",
    output_path=model_output_path,
    instance_type="ml.g4dn.4xlarge",
    instance_count=1,
    image_uri=training_img_uri,
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    sagemaker_session=sagemaker_session,
    tags=[{"Key": "project", "Value": "scrnaseq-fm-finetune"}],
    environment={
        "MLFLOW_TRACKING_URI": tracking_server_arn,
        "MLFLOW_EXPERIMENT_NAME": experiment.name,
    }
)


base_prefix = "pbmc3k"
train_s3_url = f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/train/{base_prefix}_train.h5ad"
test_s3_url = f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/test/{base_prefix}_test.h5ad"
labels_s3_url = f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/class_labels/{base_prefix}_celltype_labels.pkl"

print(f"train S3 URL:  {train_s3_url}")
print(f"test S3 URL:   {test_s3_url}")
print(f"labels S3 URL: {labels_s3_url}")

geneformer_estimator.fit({
    'train': train_s3_url,
    'test': test_s3_url,
    'labels': labels_s3_url
})

### 3.3.3. Display experiment and run metrics logged by mlflow

In [None]:
mlflow.set_tracking_uri(tracking_server_arn)
runs = mlflow.search_runs(
    experiment_names=["scRNASeq-baseline", "scRNASeq-fm"],
    filter_string="attributes.status='FINISHED' and tags.mlflow.user='root'",
    max_results=3, 
)
display(runs)

In [None]:
runs[['run_id', 'experiment_id']+[c for c in runs.columns if 'metrics.eval' in c]].sort_values('experiment_id')

### 3.3.4. Compare baseline LR classifier and fine-tuned Geneformer

In [None]:
display(HTML("<table><tr><td><img src='./images/Logistic regression classifier prediction on test set.png', width=400></td><td><img src='./images/Finetuned mdl prediction on test set.png', width=400></td></tr></table>"))


### How to get better performance for FM on celltype classification

- Hyperparameter optimization for fine tuning task 
- Use larger finetuning datasets
- Benchmark baseline model and FM on complex datasets for OOD predictions (e.g. batch effects, different donors etc.)

### 3.3.4. Hyperparameter optimization for the fine tuning

In [None]:
processing_job_name = "sc-preprocess-hao2021"
experiment_name = "scRNASeq-fm-hpo"
mlflow.set_tracking_uri(tracking_server_arn)
experiment = mlflow.set_experiment(experiment_name)

hyperparameters = {
    'model_name': 'gf-12L-30M-i2048',
    'max_lr': 5e-05,
    'freeze_layers': 6,
    'num_gpus': 1,
    'num_proc': 16,
    'geneformer_batch_size': 20,
    'lr_schedule_fn': 'linear',
    'warmup_steps': 200,
    'epochs': 10,
    'optimizer': 'adamw'
}
metric_definitions = [
    {"Name": "loss", "Regex": "'loss': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "learning_rate", "Regex": "'learning_rate': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_loss", "Regex": "'eval_loss': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_accuracy", "Regex": "'eval_accuracy': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_f1", "Regex": "'eval_f1': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_precision", "Regex": "'eval_precision': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_recall", "Regex": "'eval_recall': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_runtime", "Regex": "'eval_runtime': ([0-9]+(.|e\-)[0-9]+),?"},
    {
        "Name": "eval_samples_per_second",
        "Regex": "'eval_samples_per_second': ([0-9]+(.|e\-)[0-9]+),?",
    },
    {"Name": "epoch", "Regex": "'epoch': ([0-9]+(.|e\-)[0-9]+),?"},
]
geneformer_estimator = PyTorch(
    base_job_name=gf_training_job_name,
    entry_point="ft_geneformer_mlflow.py",
    source_dir="scripts/training/geneformer",
    output_path=model_output_path,
    instance_type="ml.g4dn.4xlarge", 
    instance_count=1,
    image_uri=training_img_uri,
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    sagemaker_session=sagemaker_session,
    tags=[{"Key": "project", "Value": "scrnaseq-fm-finetune-hpo"}],
    environment={
        "MLFLOW_TRACKING_URI": tracking_server_arn,
        "MLFLOW_EXPERIMENT_NAME": experiment.name,
    },
    metric_definitions=metric_definitions
)

hyperparameter_ranges = {
    "max_lr": ContinuousParameter(1e-05, 1e-3, 'Logarithmic'),
    "freeze_layers": CategoricalParameter([2, 8, 12]),
    "epochs": CategoricalParameter([10, 15])
}

objective_metric_name = "loss"
objective_type = "Minimize"
metric_definitions = [{"Name": "loss", "Regex": "'loss': ([0-9]+(.|e\-)[0-9]+),?"}]
hsc = HyperbandStrategyConfig(max_resource=30, min_resource=1)
sc = StrategyConfig(hyperband_strategy_config=hsc)

with mlflow.start_run(run_name=sagemaker.utils.name_from_base("HPO")) as run:
    tuner = HyperparameterTuner(
        geneformer_estimator,
        objective_metric_name,
        hyperparameter_ranges,
        metric_definitions,
        max_jobs=36,
        max_parallel_jobs=6,
        objective_type=objective_type,
        strategy='Hyperband',
        strategy_config=sc,
        early_stopping_type='Off' # set to 'Off' to use hyperband internal early stopping
    )
    tuner.fit({'train': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/train/pbmc3k_train.h5ad",
               'test': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/test/pbmc3k_test.h5ad",
               'labels': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/class_labels/pbmc3k_celltype_labels.pkl"})

In [None]:
tuning_job_name = "geneformerft-241013-1800" # copied from output of previous cell
tuner_analytics = sagemaker.HyperparameterTuningJobAnalytics(tuning_job_name)

full_df = tuner_analytics.dataframe()
full_df.sort_values(by=["FinalObjectiveValue"], ascending=True).head()

# 4. Deploy a trained model as an inference endpoint

Deploy a the trained LR model using an inference script providing:
1. custom preprocessing to read h5ad file from s3, subset the data on genes in the trained model, normalize and transform the counts
2. use the trained logistic regression model to predict cell type

In [None]:
# model_id = "baseline-LR-2024-09-27-20-38-47-755"
model_id = "baseline-LR-2024-10-13-00-29-21-381"
model_data = f"s3://{S3_BUCKET}/scrnaseq-fm-finetune/trained_models/{model_id}/output/model.tar.gz"
lr_model = SKLearnModel(model_data, 
            sagemaker_execution_role, 
            entry_point="scrna_inference.py", 
            framework_version="1.2-1",
            py_version="py3",
            source_dir="scripts/inference",
            name="scRNASeq-celltype-lr-clf"
            )
predictor = lr_model.deploy(instance_type="local", #"ml.m5.xlarge", 
                            initial_instance_count=1,
                            endpoint_name='scRNASeq-celltype-lr-clf')

predictor.serializer = CSVSerializer()
predictor.deserializer = CSVDeserializer()

In [None]:
predicted_value = predictor.predict("s3://sagemaker-us-west-2-851725420776/scrnaseq-fm-finetune/sc-preprocess/test")

In [None]:
predicted_value