# Tahoe Inference - Chemical Perturbation Prediction
This notebook runs State model inference to predict chemical perturbation effects on single cells. We'll test the reproducibility of Tahoe results using our prepared dataset.
## input h5ad 

In [1]:
import h5py
import os
os.environ['MPLBACKEND'] = 'Agg'

In [7]:
# Check the structure of our sample H5 files
with h5py.File('virtual_cell/predicted_only_10.h5ad', 'r') as f:
    print(type(f))
    print("Keys in root:", list(f.keys()))
    if 'obsm' in f:
        print("Keys in obsm:", list(f['obsm'].keys()))
        print("Keys in X:", list(f['X'].keys()))
    else:
        print("No 'obsm' group found")

<class 'h5py._hl.files.File'>
Keys in root: ['X', 'layers', 'obs', 'obsm', 'obsp', 'uns', 'var', 'varm', 'varp']
Keys in obsm: []


AttributeError: 'Dataset' object has no attribute 'keys'

In [3]:
! ls virtural_cell

predicted_only_10.h5ad


In [8]:

# Check the structure of one of COLAB H5 files
with h5py.File('/workspace/training_dataset/competition_support_set/competition_train.h5', 'r') as f:
    print(type(f))
    print("Keys in root:", list(f.keys()))
    if 'obsm' in f:
        print("Keys in obsm:", list(f['obsm'].keys()))
        print("Keys in X:", list(f['X'].keys()))
    else:
        print("No 'obsm' group found")

<class 'h5py._hl.files.File'>
Keys in root: ['X', 'layers', 'obs', 'obsm', 'obsp', 'uns', 'var', 'varm', 'varp']
Keys in obsm: []
Keys in X: ['data', 'indices', 'indptr']


## Model download
### 1. SE600M Download trained parameteres 
### 2. var_dims.pkl from ST-Tahoe [reference](https://github.com/ArcInstitute/state/issues/133)

In [2]:
! pwd # $(pwd):/workspace where I run jupyter locally /Users/ermin/PycharmProjects

/workspace


In [5]:
! pip install huggingface_hub # install huggingface_hub

Collecting huggingface_hub
  Downloading huggingface_hub-0.34.4-py3-none-any.whl.metadata (14 kB)
Collecting hf-xet<2.0.0,>=1.1.3 (from huggingface_hub)
  Downloading hf_xet-1.1.9-cp37-abi3-manylinux_2_28_aarch64.whl.metadata (4.7 kB)
Downloading huggingface_hub-0.34.4-py3-none-any.whl (561 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m561.5/561.5 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading hf_xet-1.1.9-cp37-abi3-manylinux_2_28_aarch64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: hf-xet, huggingface_hub
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [huggingface_hub] [huggingface_hub]
[1A[2KSuccessfully installed hf-xet-1.1.9 huggingface_hub-0.34.4
[0m

In [6]:
import os
from huggingface_hub import snapshot_download

# Define the repository and local directory
repo_id = "arcinstitute/ST-Tahoe"
local_dir = "virtual_cell/ST-Tahoe"

# Download all files from the repository
print(f"Downloading all files from {repo_id}...")
local_path = snapshot_download(
    repo_id=repo_id,
    local_dir=local_dir,
    local_dir_use_symlinks=False  # This ensures actual files are downloaded, not symlinks
)

Downloading all files from arcinstitute/ST-Tahoe...


For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

cell_type_onehot_map.pkl:   0%|          | 0.00/518k [00:00<?, ?B/s]

MODEL_ACCEPTABLE_USE_POLICY.md: 0.00B [00:00, ?B/s]

batch_onehot_map.pkl:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

MODEL_LICENSE.md: 0.00B [00:00, ?B/s]

config.yaml: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

LICENSE.md: 0.00B [00:00, ?B/s]

data_module.torch:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

pert_onehot_map.pt:   0%|          | 0.00/5.50M [00:00<?, ?B/s]

final.ckpt:   0%|          | 0.00/3.01G [00:00<?, ?B/s]

final_from_preprint.ckpt:   0%|          | 0.00/3.07G [00:00<?, ?B/s]

var_dims.pkl:   0%|          | 0.00/206k [00:00<?, ?B/s]

wandb_path.txt:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

In [7]:
! ls virtual_cell/ST-Tahoe/

LICENSE.md			cell_type_onehot_map.pkl  pert_onehot_map.pt
MODEL_ACCEPTABLE_USE_POLICY.md	config.yaml		  var_dims.pkl
MODEL_LICENSE.md		data_module.torch	  wandb_path.txt
README.md			final.ckpt
batch_onehot_map.pkl		final_from_preprint.ckpt


## Inferecne our sample
### Killed reason: 1.My MacPro (16G memory) might not enough memory for inference run. 

In [8]:
! state tx infer \
  --output "virtual_cell/prediction_ST-Tahoe_250902/prediction_only_10.h5ad" \
  --model_dir virtual_cell/ST-Tahoe \
  --checkpoint virtual_cell/ST-Tahoe/final.ckpt \
  --adata "virtual_cell/predicted_only_10.h5ad" \
  --pert_col "chemical"

INFO:state._cli._tx._infer:Loaded config from virtual_cell/ST-Tahoe/config.yaml
INFO:state._cli._tx._infer:Loading model from checkpoint: virtual_cell/ST-Tahoe/final.ckpt
PertSetsPerturbationModel(
  (loss_fn): SamplesLoss()
  (gene_decoder): LatentToGeneDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=2000, out_features=1024, bias=True)
      (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (2): GELU(approximate='none')
      (3): Dropout(p=0.1, inplace=False)
      (4): Linear(in_features=1024, out_features=1024, bias=True)
      (5): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (6): GELU(approximate='none')
      (7): Dropout(p=0.1, inplace=False)
      (8): Linear(in_features=1024, out_features=512, bias=True)
      (9): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (10): GELU(approximate='none')
      (11): Dropout(p=0.1, inplace=False)
      (12): Linear(in_features=512, out_features=2000, bias=True)
      (13): ReLU

In [9]:
# embedding first, but encounter the killed error. The moment memory used is 13G+, less than total 16 G. 
! state emb transform \
  --model-folder virtual_cell/SE-600M \
  --input virtual_cell/predicted_only_10.h5ad \
  --output virtual_cell/predicted_only_10_emb.h5ad

INFO:state._cli._emb._transform:Using model checkpoint: virtual_cell/SE-600M/se600m_epoch4.ckpt
INFO:state._cli._emb._transform:Creating inference object
INFO:state._cli._emb._transform:Loading model from checkpoint: virtual_cell/SE-600M/se600m_epoch4.ckpt
Killed


In [10]:
# try a different input. NG official data
! state tx infer \
  --output "virtual_cell/prediction_ST-Tahoe_250902/prediction_competition_val_template.h5ad" \
  --model_dir virtual_cell/ST-Tahoe \
  --checkpoint virtual_cell/ST-Tahoe/final.ckpt \
  --adata "training_dataset/competition_support_set/competition_val_template.h5ad" \
  --pert_col "chemical"

INFO:state._cli._tx._infer:Loaded config from virtual_cell/ST-Tahoe/config.yaml
INFO:state._cli._tx._infer:Loading model from checkpoint: virtual_cell/ST-Tahoe/final.ckpt
PertSetsPerturbationModel(
  (loss_fn): SamplesLoss()
  (gene_decoder): LatentToGeneDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=2000, out_features=1024, bias=True)
      (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (2): GELU(approximate='none')
      (3): Dropout(p=0.1, inplace=False)
      (4): Linear(in_features=1024, out_features=1024, bias=True)
      (5): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (6): GELU(approximate='none')
      (7): Dropout(p=0.1, inplace=False)
      (8): Linear(in_features=1024, out_features=512, bias=True)
      (9): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (10): GELU(approximate='none')
      (11): Dropout(p=0.1, inplace=False)
      (12): Linear(in_features=512, out_features=2000, bias=True)
      (13): ReLU

## Wrong model ST-Parse logs for learning structure for ours and COLAB

In [39]:
! state tx infer \
  --output "virtual_cell/prediction_250617/prediction_only_10.h5ad" \
  --model_dir virtual_cell/ST_Parse \
  --checkpoint virtual_cell/ST_Parse/final.ckpt \
  --adata "virtual_cell/predicted_only_10.h5ad" \
  --pert_col "target_gene"

INFO:state._cli._tx._infer:Loaded config from virtual_cell/ST_Parse/config.yaml
INFO:state._cli._tx._infer:Loading model from checkpoint: virtual_cell/ST_Parse/final.ckpt
PertSetsPerturbationModel(
  (loss_fn): SamplesLoss()
  (gene_decoder): LatentToGeneDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=2000, out_features=1024, bias=True)
      (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (2): GELU(approximate='none')
      (3): Dropout(p=0.1, inplace=False)
      (4): Linear(in_features=1024, out_features=1024, bias=True)
      (5): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (6): GELU(approximate='none')
      (7): Dropout(p=0.1, inplace=False)
      (8): Linear(in_features=1024, out_features=512, bias=True)
      (9): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (10): GELU(approximate='none')
      (11): Dropout(p=0.1, inplace=False)
      (12): Linear(in_features=512, out_features=2000, bias=True)
      (13): ReLU

In [40]:
# using COLAB input for testing
! state tx infer \
  --output "virtual_cell/prediction_250617/prediction_only_10.h5ad" \
  --model_dir virtual_cell/ST_Parse \
  --checkpoint virtual_cell/ST_Parse/final.ckpt \
  --adata "training_dataset/competition_support_set/competition_val_template.h5ad" \
  --pert_col "target_gene"

INFO:state._cli._tx._infer:Loaded config from virtual_cell/ST_Parse/config.yaml
INFO:state._cli._tx._infer:Loading model from checkpoint: virtual_cell/ST_Parse/final.ckpt
PertSetsPerturbationModel(
  (loss_fn): SamplesLoss()
  (gene_decoder): LatentToGeneDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=2000, out_features=1024, bias=True)
      (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (2): GELU(approximate='none')
      (3): Dropout(p=0.1, inplace=False)
      (4): Linear(in_features=1024, out_features=1024, bias=True)
      (5): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (6): GELU(approximate='none')
      (7): Dropout(p=0.1, inplace=False)
      (8): Linear(in_features=1024, out_features=512, bias=True)
      (9): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (10): GELU(approximate='none')
      (11): Dropout(p=0.1, inplace=False)
      (12): Linear(in_features=512, out_features=2000, bias=True)
      (13): ReLU