# Finetune GET Model on PBMC 10k Multiome

This notebook demonstrates finetuning a GET model from a pretrained GET checkpoint (without diffusion).

The checkpoint was generated using:
- Script: `scripts/run_pretrain.py`
- Config: `tutorials/yamls/pretrain_without_diffusion.yaml`
- Model: GETRegionPretrain

## Setup

First, let's import the necessary modules and set up our configuration.

Note:
If you run from a Mac, make sure you use the jupyter notebook rather than the VSCode interactive python editor as the later seems to have issue with multiple workers.
If you run from Linux, both should work fine.


In [1]:
from pathlib import Path
import os
import sys

# Add project root to Python path
PROJECT_ROOT = "/home/yoyomanzoor/Documents/get_multimodel"
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
    os.chdir(PROJECT_ROOT)

import matplotlib.pyplot as plt
import seaborn as sns
from gcell.utils.causal_lib import get_subnet, plot_comm, preprocess_net

from get_model.config.config import load_config, export_config, load_config_from_yaml
from get_model.run_region import run_zarr as run


  import pkg_resources


In [2]:
# Set up data paths
data_dir = "/scratch/bioinf593f25_class_root/bioinf593f25_class/shared_data/themanifolds/tutorial_data"
annotation_dir = data_dir + '/annotation_dir'

# Checkpoint path
checkpoint_path = os.path.expanduser("~/greatlakes/GET.ckpt")
print(f"Using checkpoint: {checkpoint_path}")
assert Path(checkpoint_path).exists(), f"Checkpoint not found at {checkpoint_path}"


Using checkpoint: /home/yoyomanzoor/greatlakes/GET.ckpt


In [3]:
# Configure celltypes for modeling
celltype_for_modeling = [
    'memory_b',
    'cd14_mono',
    'gdt',
    'cd8_tem_1',
    'naive_b',
    'mait',
    'intermediate_b',
    'cd4_naive',
    'cd8_tem_2',
    'cd8_naive',
    'cd4_tem',
    'cd4_tcm',
    'cd16_mono',
    'nk',
    'cdc',
    'treg'
]

# Load the predefined finetune tutorial config
cfg = load_config('finetune_tutorial_pbmc')
cfg.stage = 'fit'
cfg.run.project_name = 'finetune_pbmc10k_multiome_GET'
cfg.run.run_name = 'finetune_from_GET_checkpoint'
cfg.dataset.quantitative_atac = False  # We use binary ATAC signal for motif interpretation analysis
cfg.dataset.zarr_path = f"{annotation_dir}/pbmc10k_multiome.zarr"
cfg.dataset.celltypes = ','.join(celltype_for_modeling)
cfg.dataset.leave_out_celltypes = 'cd4_tcm'  # Leave out celltype for evaluation
cfg.finetune.checkpoint = checkpoint_path
cfg.training.epochs = 20
cfg.machine.codebase = PROJECT_ROOT
cfg.machine.num_devices = 1  # use 0 for cpu training; >=1 for gpu training
cfg.machine.batch_size = 8  # batch size for training
cfg.machine.output_dir = f"{data_dir}/output"

print(f"Output path: {cfg.machine.output_dir}/{cfg.run.project_name}/{cfg.run.run_name}")
print(f"Training for {cfg.training.epochs} epochs")
print(f"Checkpoint: {cfg.finetune.checkpoint}")


Output path: /scratch/bioinf593f25_class_root/bioinf593f25_class/shared_data/themanifolds/tutorial_data/output/finetune_pbmc10k_multiome_GET/finetune_from_GET_checkpoint
Training for 20 epochs
Checkpoint: /home/yoyomanzoor/greatlakes/GET.ckpt


In [None]:
# Export the config to a yaml file
export_config(cfg, "exported_finetune_GET_config.yaml")
print("Configuration exported to exported_finetune_GET_config.yaml")


In [None]:
# Load the config from the yaml file
cfg = load_config_from_yaml("exported_finetune_GET_config.yaml")


In [None]:
print(f"Default checkpoint path is at: {cfg.machine.output_dir}/{cfg.run.project_name}/{cfg.run.run_name}/checkpoints/best.ckpt")
print("The `trainer.checkpoint_callback.best_model_path` variable will be updated to the checkpoint path after training")


Now we can start the finetuning


In [None]:
# Run the finetuning
trainer = run(cfg)
print("Checkpoint path:", trainer.checkpoint_callback.best_model_path)
