# Finetune GET Model on PBMC 10k Multiome

This notebook demonstrates finetuning a GET model from a pretrained GET checkpoint (without diffusion).

The checkpoint was generated using:
- Script: `scripts/run_pretrain.py`
- Config: `tutorials/yamls/pretrain_without_diffusion.yaml`
- Model: GETRegionPretrain

## Setup

First, let's import the necessary modules and set up our configuration.

Note:
If you run from a Mac, make sure you use the jupyter notebook rather than the VSCode interactive python editor as the later seems to have issue with multiple workers.
If you run from Linux, both should work fine.


In [6]:
from pathlib import Path
import os
import sys

# Add project root to Python path
PROJECT_ROOT = "/home/yoyomanzoor/Documents/get_multimodel"
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
    os.chdir(PROJECT_ROOT)

import matplotlib.pyplot as plt
import seaborn as sns
from gcell.utils.causal_lib import get_subnet, plot_comm, preprocess_net

from get_model.config.config import load_config, export_config, load_config_from_yaml
from get_model.run_region import run_zarr as run

yaml_path = "get_model/config/finetune_GET_config.yaml"


In [10]:
# Set up data paths
data_dir = "/home/yoyomanzoor/Crucial/get_data"
annotation_dir = data_dir + '/annotation_dir'

# Checkpoint path
checkpoint_path = os.path.expanduser("~/greatlakes/transformer-best-v1.ckpt")
print(f"Using checkpoint: {checkpoint_path}")
assert Path(checkpoint_path).exists(), f"Checkpoint not found at {checkpoint_path}"


Using checkpoint: /home/yoyomanzoor/greatlakes/transformer-best-v1.ckpt


In [11]:
# Configure celltypes for modeling
celltype_for_modeling = [
    'memory_b',
    'cd14_mono',
    'gdt',
    'cd8_tem_1',
    'naive_b',
    'mait',
    'intermediate_b',
    'cd4_naive',
    'cd8_tem_2',
    'cd8_naive',
    'cd4_tem',
    'cd4_tcm',
    'cd16_mono',
    'nk',
    'cdc',
    'treg'
]

# Load the predefined finetune tutorial config
cfg = load_config('finetune_tutorial_pbmc')
cfg.stage = 'fit'
cfg.run.project_name = 'finetune_pbmc10k_multiome_GET'
cfg.run.run_name = 'finetune_from_GET_checkpoint'
cfg.dataset.quantitative_atac = False  # We use binary ATAC signal for motif interpretation analysis
cfg.dataset.zarr_path = f"{annotation_dir}/pbmc10k_multiome.zarr"
cfg.dataset.celltypes = ','.join(celltype_for_modeling)
cfg.dataset.leave_out_celltypes = 'cd4_tcm'  # Leave out celltype for evaluation
cfg.finetune.checkpoint = checkpoint_path
cfg.training.epochs = 20
cfg.machine.codebase = PROJECT_ROOT
cfg.machine.num_devices = 1  # use 0 for cpu training; >=1 for gpu training
cfg.machine.batch_size = 8  # batch size for training
cfg.machine.output_dir = f"{data_dir}/output"

print(f"Output path: {cfg.machine.output_dir}/{cfg.run.project_name}/{cfg.run.run_name}")
print(f"Training for {cfg.training.epochs} epochs")
print(f"Checkpoint: {cfg.finetune.checkpoint}")


Output path: /home/yoyomanzoor/Crucial/get_data/output/finetune_pbmc10k_multiome_GET/finetune_from_GET_checkpoint
Training for 20 epochs
Checkpoint: /home/yoyomanzoor/greatlakes/transformer-best-v1.ckpt


In [12]:
# Export the config to a yaml file
export_config(cfg, yaml_path)
print(f"Configuration exported to {yaml_path}")


Configuration exported to get_model/config/finetune_GET_config.yaml


In [13]:
# Load the config from the yaml file
cfg = load_config_from_yaml(yaml_path)


In [None]:
# Configure finetune settings for loading pretrain checkpoint
# The pretrain checkpoint (GETRegionPretrain) has head_mask, but finetune model (GETRegionFinetune) needs head_exp
cfg.finetune.model_key = "model"  # Key for model in checkpoint (as set in run_pretrain.py)
cfg.finetune.strict = False  # Allow missing head_exp weights (pretrain has head_mask instead)
cfg.finetune.patterns_to_drop = ["head_mask."]  # Drop head_mask weights (not used in finetune model)
# Set rename_config to map pretrain checkpoint structure to finetune model structure
if cfg.finetune.rename_config is None:
    cfg.finetune.rename_config = {}
# Ensure region_embed and encoder mappings are correct (if needed)
cfg.finetune.rename_config.setdefault("encoder.region_embed", "region_embed")
cfg.finetune.rename_config.setdefault("region_embed.proj.", "region_embed.embed.")
cfg.finetune.rename_config.setdefault("encoder.cls_token", "cls_token")

print(f"Finetune config:")
print(f"  Model key: {cfg.finetune.model_key}")
print(f"  Strict loading: {cfg.finetune.strict}")
print(f"  Patterns to drop: {cfg.finetune.patterns_to_drop}")
print(f"  Rename config: {cfg.finetune.rename_config}")


In [14]:
print(f"Default checkpoint path is at: {cfg.machine.output_dir}/{cfg.run.project_name}/{cfg.run.run_name}/checkpoints/best.ckpt")
print("The `trainer.checkpoint_callback.best_model_path` variable will be updated to the checkpoint path after training")


Default checkpoint path is at: /home/yoyomanzoor/Crucial/get_data/output/finetune_pbmc10k_multiome_GET/finetune_from_GET_checkpoint/checkpoints/best.ckpt
The `trainer.checkpoint_callback.best_model_path` variable will be updated to the checkpoint path after training


Now we can start the finetuning


In [15]:
# Run the finetuning
trainer = run(cfg)
print("Checkpoint path:", trainer.checkpoint_callback.best_model_path)




Load ckpt from /home/yoyomanzoor/greatlakes/transformer-best-v1.ckpt


RuntimeError: Error(s) in loading state_dict for GETRegionFinetune:
	Missing key(s) in state_dict: "cls_token", "region_embed.embed.weight", "region_embed.embed.bias", "encoder.blocks.0.norm1.weight", "encoder.blocks.0.norm1.bias", "encoder.blocks.0.attn.q_bias", "encoder.blocks.0.attn.v_bias", "encoder.blocks.0.attn.qkv.weight", "encoder.blocks.0.attn.proj.weight", "encoder.blocks.0.attn.proj.bias", "encoder.blocks.0.norm2.weight", "encoder.blocks.0.norm2.bias", "encoder.blocks.0.mlp.fc1.weight", "encoder.blocks.0.mlp.fc1.bias", "encoder.blocks.0.mlp.fc2.weight", "encoder.blocks.0.mlp.fc2.bias", "encoder.blocks.1.norm1.weight", "encoder.blocks.1.norm1.bias", "encoder.blocks.1.attn.q_bias", "encoder.blocks.1.attn.v_bias", "encoder.blocks.1.attn.qkv.weight", "encoder.blocks.1.attn.proj.weight", "encoder.blocks.1.attn.proj.bias", "encoder.blocks.1.norm2.weight", "encoder.blocks.1.norm2.bias", "encoder.blocks.1.mlp.fc1.weight", "encoder.blocks.1.mlp.fc1.bias", "encoder.blocks.1.mlp.fc2.weight", "encoder.blocks.1.mlp.fc2.bias", "encoder.blocks.2.norm1.weight", "encoder.blocks.2.norm1.bias", "encoder.blocks.2.attn.q_bias", "encoder.blocks.2.attn.v_bias", "encoder.blocks.2.attn.qkv.weight", "encoder.blocks.2.attn.proj.weight", "encoder.blocks.2.attn.proj.bias", "encoder.blocks.2.norm2.weight", "encoder.blocks.2.norm2.bias", "encoder.blocks.2.mlp.fc1.weight", "encoder.blocks.2.mlp.fc1.bias", "encoder.blocks.2.mlp.fc2.weight", "encoder.blocks.2.mlp.fc2.bias", "encoder.blocks.3.norm1.weight", "encoder.blocks.3.norm1.bias", "encoder.blocks.3.attn.q_bias", "encoder.blocks.3.attn.v_bias", "encoder.blocks.3.attn.qkv.weight", "encoder.blocks.3.attn.proj.weight", "encoder.blocks.3.attn.proj.bias", "encoder.blocks.3.norm2.weight", "encoder.blocks.3.norm2.bias", "encoder.blocks.3.mlp.fc1.weight", "encoder.blocks.3.mlp.fc1.bias", "encoder.blocks.3.mlp.fc2.weight", "encoder.blocks.3.mlp.fc2.bias", "encoder.blocks.4.norm1.weight", "encoder.blocks.4.norm1.bias", "encoder.blocks.4.attn.q_bias", "encoder.blocks.4.attn.v_bias", "encoder.blocks.4.attn.qkv.weight", "encoder.blocks.4.attn.proj.weight", "encoder.blocks.4.attn.proj.bias", "encoder.blocks.4.norm2.weight", "encoder.blocks.4.norm2.bias", "encoder.blocks.4.mlp.fc1.weight", "encoder.blocks.4.mlp.fc1.bias", "encoder.blocks.4.mlp.fc2.weight", "encoder.blocks.4.mlp.fc2.bias", "encoder.blocks.5.norm1.weight", "encoder.blocks.5.norm1.bias", "encoder.blocks.5.attn.q_bias", "encoder.blocks.5.attn.v_bias", "encoder.blocks.5.attn.qkv.weight", "encoder.blocks.5.attn.proj.weight", "encoder.blocks.5.attn.proj.bias", "encoder.blocks.5.norm2.weight", "encoder.blocks.5.norm2.bias", "encoder.blocks.5.mlp.fc1.weight", "encoder.blocks.5.mlp.fc1.bias", "encoder.blocks.5.mlp.fc2.weight", "encoder.blocks.5.mlp.fc2.bias", "encoder.blocks.6.norm1.weight", "encoder.blocks.6.norm1.bias", "encoder.blocks.6.attn.q_bias", "encoder.blocks.6.attn.v_bias", "encoder.blocks.6.attn.qkv.weight", "encoder.blocks.6.attn.proj.weight", "encoder.blocks.6.attn.proj.bias", "encoder.blocks.6.norm2.weight", "encoder.blocks.6.norm2.bias", "encoder.blocks.6.mlp.fc1.weight", "encoder.blocks.6.mlp.fc1.bias", "encoder.blocks.6.mlp.fc2.weight", "encoder.blocks.6.mlp.fc2.bias", "encoder.blocks.7.norm1.weight", "encoder.blocks.7.norm1.bias", "encoder.blocks.7.attn.q_bias", "encoder.blocks.7.attn.v_bias", "encoder.blocks.7.attn.qkv.weight", "encoder.blocks.7.attn.proj.weight", "encoder.blocks.7.attn.proj.bias", "encoder.blocks.7.norm2.weight", "encoder.blocks.7.norm2.bias", "encoder.blocks.7.mlp.fc1.weight", "encoder.blocks.7.mlp.fc1.bias", "encoder.blocks.7.mlp.fc2.weight", "encoder.blocks.7.mlp.fc2.bias", "encoder.blocks.8.norm1.weight", "encoder.blocks.8.norm1.bias", "encoder.blocks.8.attn.q_bias", "encoder.blocks.8.attn.v_bias", "encoder.blocks.8.attn.qkv.weight", "encoder.blocks.8.attn.proj.weight", "encoder.blocks.8.attn.proj.bias", "encoder.blocks.8.norm2.weight", "encoder.blocks.8.norm2.bias", "encoder.blocks.8.mlp.fc1.weight", "encoder.blocks.8.mlp.fc1.bias", "encoder.blocks.8.mlp.fc2.weight", "encoder.blocks.8.mlp.fc2.bias", "encoder.blocks.9.norm1.weight", "encoder.blocks.9.norm1.bias", "encoder.blocks.9.attn.q_bias", "encoder.blocks.9.attn.v_bias", "encoder.blocks.9.attn.qkv.weight", "encoder.blocks.9.attn.proj.weight", "encoder.blocks.9.attn.proj.bias", "encoder.blocks.9.norm2.weight", "encoder.blocks.9.norm2.bias", "encoder.blocks.9.mlp.fc1.weight", "encoder.blocks.9.mlp.fc1.bias", "encoder.blocks.9.mlp.fc2.weight", "encoder.blocks.9.mlp.fc2.bias", "encoder.blocks.10.norm1.weight", "encoder.blocks.10.norm1.bias", "encoder.blocks.10.attn.q_bias", "encoder.blocks.10.attn.v_bias", "encoder.blocks.10.attn.qkv.weight", "encoder.blocks.10.attn.proj.weight", "encoder.blocks.10.attn.proj.bias", "encoder.blocks.10.norm2.weight", "encoder.blocks.10.norm2.bias", "encoder.blocks.10.mlp.fc1.weight", "encoder.blocks.10.mlp.fc1.bias", "encoder.blocks.10.mlp.fc2.weight", "encoder.blocks.10.mlp.fc2.bias", "encoder.blocks.11.norm1.weight", "encoder.blocks.11.norm1.bias", "encoder.blocks.11.attn.q_bias", "encoder.blocks.11.attn.v_bias", "encoder.blocks.11.attn.qkv.weight", "encoder.blocks.11.attn.proj.weight", "encoder.blocks.11.attn.proj.bias", "encoder.blocks.11.norm2.weight", "encoder.blocks.11.norm2.bias", "encoder.blocks.11.mlp.fc1.weight", "encoder.blocks.11.mlp.fc1.bias", "encoder.blocks.11.mlp.fc2.weight", "encoder.blocks.11.mlp.fc2.bias", "encoder.norm.weight", "encoder.norm.bias", "head_exp.head.weight", "head_exp.head.bias". 
	Unexpected key(s) in state_dict: "model.mask_token", "model.cls_token", "model.region_embed.embed.weight", "model.region_embed.embed.bias", "model.encoder.encoder.blocks.0.norm1.weight", "model.encoder.encoder.blocks.0.norm1.bias", "model.encoder.encoder.blocks.0.attn.q_bias", "model.encoder.encoder.blocks.0.attn.v_bias", "model.encoder.encoder.blocks.0.attn.qkv.weight", "model.encoder.encoder.blocks.0.attn.proj.weight", "model.encoder.encoder.blocks.0.attn.proj.bias", "model.encoder.encoder.blocks.0.norm2.weight", "model.encoder.encoder.blocks.0.norm2.bias", "model.encoder.encoder.blocks.0.mlp.fc1.weight", "model.encoder.encoder.blocks.0.mlp.fc1.bias", "model.encoder.encoder.blocks.0.mlp.fc2.weight", "model.encoder.encoder.blocks.0.mlp.fc2.bias", "model.encoder.encoder.blocks.1.norm1.weight", "model.encoder.encoder.blocks.1.norm1.bias", "model.encoder.encoder.blocks.1.attn.q_bias", "model.encoder.encoder.blocks.1.attn.v_bias", "model.encoder.encoder.blocks.1.attn.qkv.weight", "model.encoder.encoder.blocks.1.attn.proj.weight", "model.encoder.encoder.blocks.1.attn.proj.bias", "model.encoder.encoder.blocks.1.norm2.weight", "model.encoder.encoder.blocks.1.norm2.bias", "model.encoder.encoder.blocks.1.mlp.fc1.weight", "model.encoder.encoder.blocks.1.mlp.fc1.bias", "model.encoder.encoder.blocks.1.mlp.fc2.weight", "model.encoder.encoder.blocks.1.mlp.fc2.bias", "model.encoder.encoder.blocks.2.norm1.weight", "model.encoder.encoder.blocks.2.norm1.bias", "model.encoder.encoder.blocks.2.attn.q_bias", "model.encoder.encoder.blocks.2.attn.v_bias", "model.encoder.encoder.blocks.2.attn.qkv.weight", "model.encoder.encoder.blocks.2.attn.proj.weight", "model.encoder.encoder.blocks.2.attn.proj.bias", "model.encoder.encoder.blocks.2.norm2.weight", "model.encoder.encoder.blocks.2.norm2.bias", "model.encoder.encoder.blocks.2.mlp.fc1.weight", "model.encoder.encoder.blocks.2.mlp.fc1.bias", "model.encoder.encoder.blocks.2.mlp.fc2.weight", "model.encoder.encoder.blocks.2.mlp.fc2.bias", "model.encoder.encoder.blocks.3.norm1.weight", "model.encoder.encoder.blocks.3.norm1.bias", "model.encoder.encoder.blocks.3.attn.q_bias", "model.encoder.encoder.blocks.3.attn.v_bias", "model.encoder.encoder.blocks.3.attn.qkv.weight", "model.encoder.encoder.blocks.3.attn.proj.weight", "model.encoder.encoder.blocks.3.attn.proj.bias", "model.encoder.encoder.blocks.3.norm2.weight", "model.encoder.encoder.blocks.3.norm2.bias", "model.encoder.encoder.blocks.3.mlp.fc1.weight", "model.encoder.encoder.blocks.3.mlp.fc1.bias", "model.encoder.encoder.blocks.3.mlp.fc2.weight", "model.encoder.encoder.blocks.3.mlp.fc2.bias", "model.encoder.encoder.blocks.4.norm1.weight", "model.encoder.encoder.blocks.4.norm1.bias", "model.encoder.encoder.blocks.4.attn.q_bias", "model.encoder.encoder.blocks.4.attn.v_bias", "model.encoder.encoder.blocks.4.attn.qkv.weight", "model.encoder.encoder.blocks.4.attn.proj.weight", "model.encoder.encoder.blocks.4.attn.proj.bias", "model.encoder.encoder.blocks.4.norm2.weight", "model.encoder.encoder.blocks.4.norm2.bias", "model.encoder.encoder.blocks.4.mlp.fc1.weight", "model.encoder.encoder.blocks.4.mlp.fc1.bias", "model.encoder.encoder.blocks.4.mlp.fc2.weight", "model.encoder.encoder.blocks.4.mlp.fc2.bias", "model.encoder.encoder.blocks.5.norm1.weight", "model.encoder.encoder.blocks.5.norm1.bias", "model.encoder.encoder.blocks.5.attn.q_bias", "model.encoder.encoder.blocks.5.attn.v_bias", "model.encoder.encoder.blocks.5.attn.qkv.weight", "model.encoder.encoder.blocks.5.attn.proj.weight", "model.encoder.encoder.blocks.5.attn.proj.bias", "model.encoder.encoder.blocks.5.norm2.weight", "model.encoder.encoder.blocks.5.norm2.bias", "model.encoder.encoder.blocks.5.mlp.fc1.weight", "model.encoder.encoder.blocks.5.mlp.fc1.bias", "model.encoder.encoder.blocks.5.mlp.fc2.weight", "model.encoder.encoder.blocks.5.mlp.fc2.bias", "model.encoder.encoder.blocks.6.norm1.weight", "model.encoder.encoder.blocks.6.norm1.bias", "model.encoder.encoder.blocks.6.attn.q_bias", "model.encoder.encoder.blocks.6.attn.v_bias", "model.encoder.encoder.blocks.6.attn.qkv.weight", "model.encoder.encoder.blocks.6.attn.proj.weight", "model.encoder.encoder.blocks.6.attn.proj.bias", "model.encoder.encoder.blocks.6.norm2.weight", "model.encoder.encoder.blocks.6.norm2.bias", "model.encoder.encoder.blocks.6.mlp.fc1.weight", "model.encoder.encoder.blocks.6.mlp.fc1.bias", "model.encoder.encoder.blocks.6.mlp.fc2.weight", "model.encoder.encoder.blocks.6.mlp.fc2.bias", "model.encoder.encoder.blocks.7.norm1.weight", "model.encoder.encoder.blocks.7.norm1.bias", "model.encoder.encoder.blocks.7.attn.q_bias", "model.encoder.encoder.blocks.7.attn.v_bias", "model.encoder.encoder.blocks.7.attn.qkv.weight", "model.encoder.encoder.blocks.7.attn.proj.weight", "model.encoder.encoder.blocks.7.attn.proj.bias", "model.encoder.encoder.blocks.7.norm2.weight", "model.encoder.encoder.blocks.7.norm2.bias", "model.encoder.encoder.blocks.7.mlp.fc1.weight", "model.encoder.encoder.blocks.7.mlp.fc1.bias", "model.encoder.encoder.blocks.7.mlp.fc2.weight", "model.encoder.encoder.blocks.7.mlp.fc2.bias", "model.encoder.encoder.blocks.8.norm1.weight", "model.encoder.encoder.blocks.8.norm1.bias", "model.encoder.encoder.blocks.8.attn.q_bias", "model.encoder.encoder.blocks.8.attn.v_bias", "model.encoder.encoder.blocks.8.attn.qkv.weight", "model.encoder.encoder.blocks.8.attn.proj.weight", "model.encoder.encoder.blocks.8.attn.proj.bias", "model.encoder.encoder.blocks.8.norm2.weight", "model.encoder.encoder.blocks.8.norm2.bias", "model.encoder.encoder.blocks.8.mlp.fc1.weight", "model.encoder.encoder.blocks.8.mlp.fc1.bias", "model.encoder.encoder.blocks.8.mlp.fc2.weight", "model.encoder.encoder.blocks.8.mlp.fc2.bias", "model.encoder.encoder.blocks.9.norm1.weight", "model.encoder.encoder.blocks.9.norm1.bias", "model.encoder.encoder.blocks.9.attn.q_bias", "model.encoder.encoder.blocks.9.attn.v_bias", "model.encoder.encoder.blocks.9.attn.qkv.weight", "model.encoder.encoder.blocks.9.attn.proj.weight", "model.encoder.encoder.blocks.9.attn.proj.bias", "model.encoder.encoder.blocks.9.norm2.weight", "model.encoder.encoder.blocks.9.norm2.bias", "model.encoder.encoder.blocks.9.mlp.fc1.weight", "model.encoder.encoder.blocks.9.mlp.fc1.bias", "model.encoder.encoder.blocks.9.mlp.fc2.weight", "model.encoder.encoder.blocks.9.mlp.fc2.bias", "model.encoder.encoder.blocks.10.norm1.weight", "model.encoder.encoder.blocks.10.norm1.bias", "model.encoder.encoder.blocks.10.attn.q_bias", "model.encoder.encoder.blocks.10.attn.v_bias", "model.encoder.encoder.blocks.10.attn.qkv.weight", "model.encoder.encoder.blocks.10.attn.proj.weight", "model.encoder.encoder.blocks.10.attn.proj.bias", "model.encoder.encoder.blocks.10.norm2.weight", "model.encoder.encoder.blocks.10.norm2.bias", "model.encoder.encoder.blocks.10.mlp.fc1.weight", "model.encoder.encoder.blocks.10.mlp.fc1.bias", "model.encoder.encoder.blocks.10.mlp.fc2.weight", "model.encoder.encoder.blocks.10.mlp.fc2.bias", "model.encoder.encoder.blocks.11.norm1.weight", "model.encoder.encoder.blocks.11.norm1.bias", "model.encoder.encoder.blocks.11.attn.q_bias", "model.encoder.encoder.blocks.11.attn.v_bias", "model.encoder.encoder.blocks.11.attn.qkv.weight", "model.encoder.encoder.blocks.11.attn.proj.weight", "model.encoder.encoder.blocks.11.attn.proj.bias", "model.encoder.encoder.blocks.11.norm2.weight", "model.encoder.encoder.blocks.11.norm2.bias", "model.encoder.encoder.blocks.11.mlp.fc1.weight", "model.encoder.encoder.blocks.11.mlp.fc1.bias", "model.encoder.encoder.blocks.11.mlp.fc2.weight", "model.encoder.encoder.blocks.11.mlp.fc2.bias", "model.encoder.norm.weight", "model.encoder.norm.bias", "model.head_mask.weight", "model.head_mask.bias". 