# Fine Tune and Benchmark Geneformer (Single cell RNA-Seq foundation model) For Cell Type/Cell State Classification

# 0. Import Dependencies


In [1]:
%pip install --disable-pip-version-check -q -U 'boto3==1.35.16' 'sagemaker==2.231.0' 'mlflow==2.13.2' 'sagemaker-mlflow==0.1.0'

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
from time import gmtime, strftime

import sagemaker
import boto3
import mlflow
from sagemaker.processing import FrameworkProcessor, ScriptProcessor, ProcessingInput, ProcessingOutput
from sagemaker.sklearn.estimator import SKLearn
from sagemaker.sklearn.processing import SKLearnProcessor
from sagemaker.experiments.run import Run
from sagemaker.inputs import TrainingInput
from sagemaker.pytorch import PyTorch
from sagemaker.tuner import (
    CategoricalParameter,
    ContinuousParameter,
    HyperparameterTuner,
    IntegerParameter,
    HyperbandStrategyConfig,
    StrategyConfig
)
from sagemaker.sklearn.model import SKLearnModel
from sagemaker.deserializers import CSVDeserializer
from sagemaker.serializers import CSVSerializer

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


# 1. Preparations

## 1.1 Create Some Necessary Clients

In [3]:
boto_session = boto3.session.Session()
region = boto_session.region_name
sagemaker_session = sagemaker.session.Session(boto_session)
sagemaker_execution_role = sagemaker.session.get_execution_role(sagemaker_session)
sagemaker_boto_client = boto_session.client("sagemaker")
s3_boto_client = boto_session.client("s3")
account_id = boto_session.client("sts").get_caller_identity().get("Account")
print(f"Assumed SageMaker role is {sagemaker_execution_role}")

Assumed SageMaker role is arn:aws:iam::851725420776:role/service-role/AmazonSageMakerServiceCatalogProductsUseRole


## 1.2. Specify S3 Bucket and Prefix

In [4]:
S3_BUCKET = sagemaker_session.default_bucket()
S3_PREFIX = "scrnaseq-fm-finetune"
S3_PATH = sagemaker.s3.s3_path_join(S3_BUCKET, S3_PREFIX)
print(f"S3 path is {S3_PATH}")

S3 path is sagemaker-us-west-2-851725420776/scrnaseq-fm-finetune


# 2. Data Preparation with Amazon SageMaker Processing


Here we download an example 10x scRNA-Seq dataset, and add cell type annotation using marker gene expression as the ground truth for our classification task.

## 2.1. Define parameters of the SageMaker Processing Job

In [35]:
processing_job_name = "sc-preprocess-hao2021"
print("Preparing and splitting scRNASeq dataset and saving adata", processing_job_name)

sklearn_processor = SKLearnProcessor(
    framework_version="1.2-1",
    role=sagemaker_execution_role,
    instance_type="ml.m5.4xlarge",
    volume_size_in_gb=20,
    instance_count=1,
    base_job_name=processing_job_name
)

# Run processor 
sklearn_processor.run(
    inputs=[
            ProcessingInput(
                input_name="requirements",
                source="scripts/processing/processing_requirements.txt", #"requirements/",
                destination="/opt/ml/processing/input/requirements/",
            )
    ],
    outputs=[
        ProcessingOutput(
            output_name="class_labels",
            source="/opt/ml/processing/h5ad_data/class_labels",
            destination=os.path.join(
                    "s3://{}".format(S3_BUCKET),
                    S3_PREFIX,
                    processing_job_name,
                    "class_labels",
            ),
        ),
        ProcessingOutput(
            output_name="train",
            source="/opt/ml/processing/h5ad_data/train",
            destination=os.path.join(
                    "s3://{}".format(S3_BUCKET),
                    S3_PREFIX,
                    processing_job_name,
                    "train",
            ),
        ),
        ProcessingOutput(
            output_name="validation",
            source="/opt/ml/processing/h5ad_data/val",
            destination=os.path.join(
                    "s3://{}".format(S3_BUCKET),
                    S3_PREFIX,
                    processing_job_name,
                    "validation",
            ),
        ),
        ProcessingOutput(
            output_name="test",
            source="/opt/ml/processing/h5ad_data/test",
            destination=os.path.join(
                    "s3://{}".format(S3_BUCKET),
                    S3_PREFIX,
                    processing_job_name,
                    "test",
            ),
        ),
    ],
    code="scripts/processing/process_hao2021.py",
    arguments=["--train_size", "0.8", "--split_by_group"]
)

INFO:sagemaker.image_uris:Defaulting to only available Python version: py3


Preparing and splitting scRNASeq dataset and saving adata sc-preprocess-hao2021


INFO:sagemaker:Creating processing-job with name sc-preprocess-hao2021-2024-09-25-20-56-41-979


.............[34mInstalling requirements[0m
[34mCollecting anndata>=0.9 (from -r /opt/ml/processing/input/requirements/processing_requirements.txt (line 1))
  Downloading anndata-0.9.2-py3-none-any.whl.metadata (6.1 kB)[0m
[34mCollecting scanpy>=1.9 (from -r /opt/ml/processing/input/requirements/processing_requirements.txt (line 2))
  Downloading scanpy-1.9.8-py3-none-any.whl.metadata (6.0 kB)[0m
[34mCollecting tdigest>=0.5.2 (from -r /opt/ml/processing/input/requirements/processing_requirements.txt (line 3))
  Downloading tdigest-0.5.2.2-py3-none-any.whl.metadata (4.9 kB)[0m
[34mCollecting tqdm>=4.65 (from -r /opt/ml/processing/input/requirements/processing_requirements.txt (line 4))
  Downloading tqdm-4.66.5-py3-none-any.whl.metadata (57 kB)[0m
[34mCollecting igraph (from -r /opt/ml/processing/input/requirements/processing_requirements.txt (line 6))
  Downloading igraph-0.11.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)[0m
[34mCollecting le

# 3. ML training with SageMaker Training Jobs

## 3.1. Using MLflow to track model training experiments
Create an MLflow tracking server in SageMaker Studio. Copy the tracking server ARN below.

Update the SageMaker service role to have the following policy to enable ML flow tracking:

```
{
    "Version": "2012-10-17",    
    "Statement": [        
        {            
            "Effect": "Allow",            
            "Action": [
                "sagemaker-mlflow:*",
                "sagemaker:CreateMlflowTrackingServer",
                "sagemaker:UpdateMlflowTrackingServer",
                "sagemaker:DeleteMlflowTrackingServer",
                "sagemaker:StartMlflowTrackingServer",
                "sagemaker:StopMlflowTrackingServer",
                "sagemaker:CreatePresignedMlflowTrackingServerUrl"
            ],            
            "Resource": "*"        
        }        
    ]
}
```

In [179]:
tracking_server_arn = "arn:aws:sagemaker:us-west-2:851725420776:mlflow-tracking-server/scrnaseq-ML"

## 3.2. Train a Logistic Regression model using normalized counts as the Baseline 

For comparing model performance, it is always good to have a simple baseline. We will use a simple logistic regression model, taking as input the normalized expression counts to predict cell type as the baseline.

### 3.2.1 Define and fit a SKLearn estimator, logging the run to an MLFlow experiment

In [180]:
lr_job_name = f"baseline-LR"
model_output_path = f"s3://{S3_BUCKET}/{S3_PREFIX}/trained_models/"
experiment_name = "scRNASeq-baseline"
# Uncomment for setting up MLflow exp
mlflow.set_tracking_uri(tracking_server_arn)
experiment = mlflow.set_experiment(experiment_name)

dataset_name = "pbmc3k" #"hao2021_pbmc"
processing_job_name = "sc-preprocess" # "sc-preprocess-hao2021"

lr_estimator = SKLearn(
    base_job_name=lr_job_name,
    enable_sagemaker_metrics=True,
    entry_point="baseline_lr_train_mlflow.py",
    framework_version="1.2-1",
    hyperparameters={
        'penalty': 'l2',
        'class_weight': 'balanced',
        'max_iter': 1000,
        'solver': 'saga', #'lbfgs',
        'dataset_name': dataset_name
    },
    instance_count=1,
    instance_type="ml.c5.4xlarge",
    output_path=model_output_path,
    role=sagemaker_execution_role,
    sagemaker_session=sagemaker_session,
    source_dir="scripts/training/lr/",
    environment={
        "MLFLOW_TRACKING_URI": tracking_server_arn,
        "MLFLOW_EXPERIMENT_NAME": experiment.name,
        #"MLFLOW_PARENT_RUN_ID": run.info.run_id,
    },
)


lr_estimator.fit(
    {'train': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/train/{dataset_name}_train.h5ad",
     'validation': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/validation/{dataset_name}_val.h5ad",
    'test': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/test/{dataset_name}_test.h5ad",
    'labels': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/class_labels/{dataset_name}_celltype_labels.pkl"},
    #wait=False,
)

INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole
INFO:sagemaker:Creating training-job with name: baseline-LR-2024-09-27-20-38-47-755


2024-09-27 20:38:48 Starting - Starting the training job...
2024-09-27 20:39:12 Starting - Preparing the instances for training...
2024-09-27 20:39:51 Downloading - Downloading the training image...
2024-09-27 20:40:22 Training - Training image download completed. Training in progress....[34m2024-09-27 20:40:42,895 sagemaker-containers INFO     Imported framework sagemaker_sklearn_container.training[0m
[34m2024-09-27 20:40:42,897 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2024-09-27 20:40:42,899 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2024-09-27 20:40:42,913 sagemaker_sklearn_container.training INFO     Invoking user training script.[0m
[34m2024-09-27 20:40:43,095 sagemaker-training-toolkit INFO     Installing dependencies from requirements.txt:[0m
[34m/miniconda3/bin/python -m pip install -r requirements.txt[0m
[34mCollecting scanpy==1.9.8 (from -r requirements.txt (line

## 3.3. Fine tune Geneformer scRNA-Seq FM

As mentioned in the Introduction, Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
The pretrained model outputs dense vector embeddings of cells. We can fine tune it with a labeled dataset to perform cell type classification.

### 3.3.1. Pre-requisite: Build a docker image using the docker file in `scripts/Dockerfile` and push to your ECR repo, copy the image uri below: 

In [105]:
training_img_uri = "851725420776.dkr.ecr.us-west-2.amazonaws.com/geneformerft:latest"

### 3.3.2 Define a Pytorch estimator with custom image, fit and track metrics

In [120]:
model_output_path = f"s3://{S3_BUCKET}/{S3_PREFIX}/trained_models/"
gf_training_job_name = 'geneformer-ft-testmlflow'
# Additional training parameters
hyperparameters = {
    'model_name': 'gf-12L-30M-i2048',
    'max_lr': 5e-05,
    'freeze_layers': 6,
    'num_gpus': 1,
    'num_proc': 16,
    'geneformer_batch_size': 20,
    'lr_schedule_fn': 'linear',
    'warmup_steps': 200,
    'epochs': 10,
    'optimizer': 'adamw'
}

experiment_name = "scRNASeq-fm"
mlflow.set_tracking_uri(tracking_server_arn)
experiment = mlflow.set_experiment(experiment_name)

geneformer_estimator = PyTorch(
    base_job_name=gf_training_job_name,
    entry_point="ft_geneformer_mlflow.py",
    source_dir="scripts/training/geneformer",
    output_path=model_output_path,
    instance_type="ml.g4dn.4xlarge", #"ml.g5.4xlarge", "local", "local-gpu"
    instance_count=1,
    image_uri=training_img_uri,
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    sagemaker_session=sagemaker_session,
    #distribution={"torch_distributed": {"enabled": True}},
    tags=[{"Key": "project", "Value": "scrnaseq-fm-finetune"}],
    environment={
        "MLFLOW_TRACKING_URI": tracking_server_arn,
        "MLFLOW_EXPERIMENT_NAME": experiment.name,
        #"MLFLOW_PARENT_RUN_ID": run.info.run_id,
    }
    #keep_alive_period_in_seconds=1800,  #Failed - Instances not retained as a result of warmpool resource limits being exceeded
)

processing_job_name = "sc-preprocess"
geneformer_estimator.fit({'train': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/train/pbmc3k_train.h5ad",
                   'test': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/test/pbmc3k_test.h5ad",
                     'labels': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/class_labels/pbmc3k_celltype_labels.pkl"})

INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole
INFO:sagemaker:Creating training-job with name: geneformer-ft-testmlflow-2024-09-26-22-12-42-887


2024-09-26 22:12:45 Starting - Starting the training job...
2024-09-26 22:12:58 Starting - Preparing the instances for training...
2024-09-26 22:13:39 Downloading - Downloading input data...
2024-09-26 22:13:54 Downloading - Downloading the training image.......................................
2024-09-26 22:20:37 Training - Training image download completed. Training in progress......[34m2024-09-26 22:21:22,505 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2024-09-26 22:21:22,506 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2024-09-26 22:21:22,546 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2024-09-26 22:21:22,547 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2024-09-26 22:21:22,581 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2024-09-26 22:21

### 3.3.3. Display experiment and run metrics logged by mlflow

In [121]:
mlflow.set_tracking_uri(tracking_server_arn)
runs = mlflow.search_runs(
    experiment_names=["scRNASeq-baseline", "scRNASeq-fm"],
    filter_string="attributes.status='FINISHED' and tags.mlflow.user='root'",
    max_results=3, 
)
display(runs)

INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole
INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole
INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole


Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.train_loss,metrics.grad_norm,metrics.train_samples_per_second,metrics.eval_global_f1,...,params.include_inputs_for_metrics,params.class_weight,params.dataset_name,params.max_iter,params.solver,params.penalty,tags.mlflow.user,tags.mlflow.source.name,tags.mlflow.runName,tags.mlflow.source.type
0,7f0919f381104694a6daf1e576cedfd2,34,FINISHED,s3://sagemaker-studio-851725420776-tx0da5flyzo...,2024-09-26 22:23:16.124000+00:00,2024-09-26 22:53:32.768000+00:00,1.101465,5.827277,13.012,0.633333,...,False,,,,,,root,ft_geneformer_mlflow.py,caring-hound-246,LOCAL
1,3499df45a33d4f788903312de24689ee,1,FINISHED,s3://sagemaker-studio-851725420776-tx0da5flyzo...,2024-09-26 21:51:05.500000+00:00,2024-09-26 21:52:16.591000+00:00,,,,0.814815,...,,balanced,pbmc3k,1000.0,saga,l2,root,baseline_lr_train_mlflow.py,adaptable-gull-947,LOCAL
2,d614805549464fed8f3d364191dff7d9,34,FINISHED,s3://sagemaker-studio-851725420776-tx0da5flyzo...,2024-09-26 05:41:09.954000+00:00,2024-09-26 06:11:18.372000+00:00,1.469274,6.025624,13.767,0.559259,...,False,,,,,,root,ft_geneformer_mlflow.py,intelligent-boar-55,LOCAL


In [125]:
runs[['run_id', 'experiment_id']+[c for c in runs.columns if 'metrics.eval' in c]].sort_values('experiment_id')

Unnamed: 0,run_id,experiment_id,metrics.eval_global_f1,metrics.eval_samples_per_second,metrics.eval_accuracy,metrics.eval_runtime,metrics.eval_loss,metrics.eval_class_weighted_f1,metrics.eval_class_averaged_accuracy,metrics.eval_macro_f1,metrics.eval_steps_per_second
1,3499df45a33d4f788903312de24689ee,1,0.814815,205640.472126,0.814815,0.001313,,0.813833,0.75801,0.772846,
0,7f0919f381104694a6daf1e576cedfd2,34,0.633333,25.128,0.633333,10.745,0.843233,0.61381,0.422594,0.41954,1.303
2,d614805549464fed8f3d364191dff7d9,34,0.559259,25.331,0.559259,10.6587,1.154671,0.498224,0.323888,0.29748,1.313


### 3.3.4. Compare baseline LR classifier and fine-tuned Geneformer

In [2]:
display(HTML("<table><tr><td><img src='./images/Logistic regression classifier prediction on test set.png', width=400></td><td><img src='./images/Finetuned mdl prediction on test set.png', width=400></td></tr></table>"))


### How to get better performance for FM on celltype classification

- Hyperparameter optimization for fine tuning task 
- Use larger finetuning datasets
- Benchmark baseline model and FM on complex datasets for OOD predictions (e.g. batch effects, different donors etc.)

### 3.3.4. Hyperparameter optimization for the fine tuning

In [None]:
processing_job_name = "sc-preprocess"
experiment_name = "scRNASeq-fm-hpo"
mlflow.set_tracking_uri(tracking_server_arn)
experiment = mlflow.set_experiment(experiment_name)

hyperparameters = {
    'model_name': 'gf-12L-30M-i2048',
    'max_lr': 5e-05,
    'freeze_layers': 6,
    'num_gpus': 1,
    'num_proc': 16,
    'geneformer_batch_size': 20,
    'lr_schedule_fn': 'linear',
    'warmup_steps': 200,
    'epochs': 10,
    'optimizer': 'adamw'
}
metric_definitions = [
    {"Name": "loss", "Regex": "'loss': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "learning_rate", "Regex": "'learning_rate': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_loss", "Regex": "'eval_loss': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_accuracy", "Regex": "'eval_accuracy': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_f1", "Regex": "'eval_f1': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_precision", "Regex": "'eval_precision': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_recall", "Regex": "'eval_recall': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_runtime", "Regex": "'eval_runtime': ([0-9]+(.|e\-)[0-9]+),?"},
    {
        "Name": "eval_samples_per_second",
        "Regex": "'eval_samples_per_second': ([0-9]+(.|e\-)[0-9]+),?",
    },
    {"Name": "epoch", "Regex": "'epoch': ([0-9]+(.|e\-)[0-9]+),?"},
]
geneformer_estimator = PyTorch(
    base_job_name=gf_training_job_name,
    entry_point="ft_geneformer_mlflow.py",
    source_dir="scripts/training/geneformer",
    output_path=model_output_path,
    instance_type="ml.g4dn.4xlarge", #"ml.c5.4xlarge",  #"ml.g4dn.4xlarge",
    instance_count=1,
    image_uri=training_img_uri,
    role=sagemaker_execution_role,
    hyperparameters=hyperparameters,
    sagemaker_session=sagemaker_session,
    #distribution={"torch_distributed": {"enabled": True}},
    tags=[{"Key": "project", "Value": "scrnaseq-fm-finetune-hpo"}],
    environment={
        "MLFLOW_TRACKING_URI": tracking_server_arn,
        "MLFLOW_EXPERIMENT_NAME": experiment.name,
        #"MLFLOW_PARENT_RUN_ID": run.info.run_id,
    },
    metric_definitions=metric_definitions
    #keep_alive_period_in_seconds=1800,  #Failed - Instances not retained as a result of warmpool resource limits being exceeded
)

hyperparameter_ranges = {
    "max_lr": ContinuousParameter(1e-05, 1e-3, 'Logarithmic'),
    "freeze_layers": CategoricalParameter([2, 8, 12]),
    "epochs": CategoricalParameter([10, 15])
}

objective_metric_name = "loss"
objective_type = "Minimize"
metric_definitions = [{"Name": "loss", "Regex": "'loss': ([0-9]+(.|e\-)[0-9]+),?"}]
hsc = HyperbandStrategyConfig(max_resource=30, min_resource=1)
sc = StrategyConfig(hyperband_strategy_config=hsc)

with mlflow.start_run(run_name=sagemaker.utils.name_from_base("HPO")) as run:
    tuner = HyperparameterTuner(
        geneformer_estimator,
        objective_metric_name,
        hyperparameter_ranges,
        metric_definitions,
        max_jobs=36,
        max_parallel_jobs=6,
        objective_type=objective_type,
        strategy='Hyperband',
        strategy_config=sc,
        early_stopping_type='Off' # set to 'Off' to use hyperband internal early stopping
    )
    tuner.fit({'train': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/train/pbmc3k_train.h5ad",
               'test': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/test/pbmc3k_test.h5ad",
                 'labels': f"s3://{S3_BUCKET}/{S3_PREFIX}/{processing_job_name}/class_labels/pbmc3k_celltype_labels.pkl"})

INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole
INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole
INFO:sagemaker:Creating hyperparameter tuning job with name: geneformerft-240927-0539


.........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

In [151]:
tuning_job_name = "geneformerft-240927-0539"
tuner_analytics = sagemaker.HyperparameterTuningJobAnalytics(tuning_job_name)

full_df = tuner_analytics.dataframe()
full_df.sort_values(by=["FinalObjectiveValue"], ascending=True).head()

Unnamed: 0,epochs,freeze_layers,max_lr,TrainingJobName,TrainingJobStatus,FinalObjectiveValue,TrainingStartTime,TrainingEndTime,TrainingElapsedTimeSeconds
24,"""15""","""12""",0.000546,geneformerft-240927-0539-012-1b04da5d,Stopped,0.4162,2024-09-27 05:58:43+00:00,2024-09-27 06:16:43+00:00,1080.0
33,"""10""","""12""",0.000396,geneformerft-240927-0539-003-4c2dca82,Stopped,0.4598,2024-09-27 05:39:58+00:00,2024-09-27 05:57:46+00:00,1068.0
16,"""10""","""2""",0.000707,geneformerft-240927-0539-020-d100c82e,Stopped,1.7518,2024-09-27 06:12:35+00:00,2024-09-27 06:23:14+00:00,639.0
32,"""10""","""2""",0.000549,geneformerft-240927-0539-004-79641a86,Stopped,1.7543,2024-09-27 05:40:04+00:00,2024-09-27 05:50:12+00:00,608.0
21,"""10""","""8""",0.000431,geneformerft-240927-0539-015-bea7d1e6,Stopped,1.7643,2024-09-27 06:02:02+00:00,2024-09-27 06:12:00+00:00,598.0


# 4. Deploy a trained model as an inference endpoint

Deploy a the trained LR model using an inference script providing:
1. custom preprocessing to read h5ad file from s3, subset the data on genes in the trained model, normalize and transform the counts
2. use the trained logistic regression model to predict cell type

In [58]:
model_data = "s3://sagemaker-us-west-2-851725420776/scrnaseq-fm-finetune/trained_models/baseline-LR-2024-09-27-20-38-47-755/output/model.tar.gz"
lr_model = SKLearnModel(model_data, 
            sagemaker_execution_role, 
            entry_point="scrna_inference.py", 
            framework_version="1.2-1",
            py_version="py3",
            source_dir="scripts/inference",
            name="scRNASeq-celltype-lr-clf"
            )
predictor = lr_model.deploy(instance_type="local", #"ml.m5.xlarge", 
                            initial_instance_count=1,
                           endpoint_name='scRNASeq-celltype-lr-clf')

predictor.serializer = CSVSerializer()
predictor.deserializer = CSVDeserializer()

Attaching to 9ft5dcngdd-algo-1-4534h
9ft5dcngdd-algo-1-4534h  | 2024-09-28 20:50:35,430 INFO - sagemaker-containers - No GPUs detected (normal if no gpus installed)
9ft5dcngdd-algo-1-4534h  | 2024-09-28 20:50:35,433 INFO - sagemaker-containers - No GPUs detected (normal if no gpus installed)
9ft5dcngdd-algo-1-4534h  | 2024-09-28 20:50:35,434 INFO - sagemaker-containers - nginx config: 
9ft5dcngdd-algo-1-4534h  | worker_processes auto;
9ft5dcngdd-algo-1-4534h  | daemon off;
9ft5dcngdd-algo-1-4534h  | pid /tmp/nginx.pid;
9ft5dcngdd-algo-1-4534h  | error_log  /dev/stderr;
9ft5dcngdd-algo-1-4534h  | 
9ft5dcngdd-algo-1-4534h  | worker_rlimit_nofile 4096;
9ft5dcngdd-algo-1-4534h  | 
9ft5dcngdd-algo-1-4534h  | events {
9ft5dcngdd-algo-1-4534h  |   worker_connections 2048;
9ft5dcngdd-algo-1-4534h  | }
9ft5dcngdd-algo-1-4534h  | 
9ft5dcngdd-algo-1-4534h  | http {
9ft5dcngdd-algo-1-4534h  |   include /etc/nginx/mime.types;
9ft5dcngdd-algo-1-4534h  |   default_type application/octet-stream;
9ft5d

In [59]:
predicted_value = predictor.predict("s3://sagemaker-us-west-2-851725420776/scrnaseq-fm-finetune/sc-preprocess/test")

9ft5dcngdd-algo-1-4534h  | 2024-09-28 20:51:21,739 INFO - sagemaker-containers - No GPUs detected (normal if no gpus installed)
9ft5dcngdd-algo-1-4534h  | sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
9ft5dcngdd-algo-1-4534h  | sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
9ft5dcngdd-algo-1-4534h  | <sagemaker.session.Session object at 0x7f581b3d93a0>
9ft5dcngdd-algo-1-4534h  | model_dir: ['model.tar.gz', 'model.joblib', 'feature_names.joblib']
9ft5dcngdd-algo-1-4534h  | s3://sagemaker-us-west-2-851725420776/scrnaseq-fm-finetune/sc-preprocess/test
9ft5dcngdd-algo-1-4534h  | ['pbmc3k_test.h5ad']
9ft5dcngdd-algo-1-4534h  | /tmp/09282024205122/pbmc3k_test.h5ad
9ft5dcngdd-algo-1-4534h  | /tmp/09282024205122 removed successfully!
9ft5dcngdd-algo-1-4534h  | 172.18.0.1 - - [28/Sep/2024:20:51:23 +0000] "POST /invocations HTTP/1.1" 200 540 "-" "python-urllib3/2.2.2"


In [60]:
predicted_value

[['4'],
 ['2'],
 ['2'],
 ['1'],
 ['4'],
 ['2'],
 ['3'],
 ['3'],
 ['0'],
 ['2'],
 ['2'],
 ['2'],
 ['2'],
 ['1'],
 ['2'],
 ['4'],
 ['0'],
 ['6'],
 ['2'],
 ['1'],
 ['3'],
 ['3'],
 ['3'],
 ['1'],
 ['2'],
 ['1'],
 ['3'],
 ['1'],
 ['0'],
 ['2'],
 ['3'],
 ['0'],
 ['1'],
 ['2'],
 ['1'],
 ['2'],
 ['0'],
 ['5'],
 ['2'],
 ['2'],
 ['2'],
 ['0'],
 ['6'],
 ['2'],
 ['2'],
 ['0'],
 ['1'],
 ['2'],
 ['2'],
 ['3'],
 ['2'],
 ['2'],
 ['2'],
 ['3'],
 ['2'],
 ['5'],
 ['2'],
 ['2'],
 ['3'],
 ['2'],
 ['2'],
 ['2'],
 ['1'],
 ['5'],
 ['1'],
 ['2'],
 ['0'],
 ['4'],
 ['2'],
 ['2'],
 ['2'],
 ['2'],
 ['0'],
 ['2'],
 ['2'],
 ['4'],
 ['1'],
 ['2'],
 ['3'],
 ['3'],
 ['2'],
 ['0'],
 ['3'],
 ['4'],
 ['2'],
 ['2'],
 ['2'],
 ['0'],
 ['5'],
 ['2'],
 ['2'],
 ['1'],
 ['2'],
 ['3'],
 ['0'],
 ['3'],
 ['2'],
 ['2'],
 ['6'],
 ['2'],
 ['2'],
 ['2'],
 ['3'],
 ['4'],
 ['1'],
 ['2'],
 ['6'],
 ['2'],
 ['2'],
 ['5'],
 ['7'],
 ['2'],
 ['2'],
 ['0'],
 ['3'],
 ['2'],
 ['1'],
 ['2'],
 ['6'],
 ['5'],
 ['2'],
 ['2'],
 ['6'],
 ['3'],
 ['5'],
