# infer with a pre-trained st

## 1.fetch a file from HuggingFace

In [None]:
from huggingface_hub import hf_hub_download

file_path = hf_hub_download(
    repo_id="arcinstitute/State-Tahoe-Filtered",
    repo_type="dataset",
    filename="c40.h5ad",
    local_dir="/data3/fanpeishan/cache", 
    local_dir_use_symlinks=False
)

print(f"Downloaded to: {file_path}")

## 2.prepare for inference

In [None]:
import anndata as ad
import pickle
adata_holdout = ad.read_h5ad("c37.h5ad")
hvg_names = pickle.load(open('ST-Tahoe/var_dims.pkl', 'rb'))['gene_names']
adata_holdout.var.index = hvg_names
adata_holdout.X = adata_holdout.obsm['X_hvg']

# Save it back out
adata_holdout.write_h5ad("c37_real.h5ad")

## 3.infer

In [None]:
state tx infer \
    --model-dir ST-Tahoe \
    --checkpoint ST-Tahoe/final.ckpt \
    --pert-col drugname_drugconc \
    --batch-col plate \
    --control-pert "[('DMSO_TF', 0.0, 'uM')]" \
    --adata c37_real.h5ad \
    --output c37_simulated.h5ad

In [None]:
==> STATE: tx infer (virtual experiment)
Loaded config: ST-Tahoe/config.yaml
Control perturbation: [('DMSO_TF', 0.0, 'uM')]
Grouping by cell type column: cell_name
StateTransitionPerturbationModel(略)
Model device: cuda:0
Model cell_set_len (max sequence length): 256
Model uses batch encoder: True
Model output space: gene
Using adata.X as input features: shape (1835947, 2000)
Cells: total=1835947, control=45150, non-control=1790797
Running virtual experiment (homogeneous per-perturbation forward passes; controls included)...
Group NCI-H23: 100% 1137/1137 [03:51<00:00,  4.91it/s, Pert: [('γ-Oryzanol', 5.0, 'uM]

=== Inference complete ===
Input cells:         1835947
Controls simulated:  45150
Treated simulated:   1790797
Wrote predictions to adata.X
Saved:               c37_simulated.h5ad

## 4.compare two data

In [None]:
cell-eval run \
    -ap c37_simulated.h5ad \
    -ar c37_real.h5ad \
    -o . \
    --control-pert "[('DMSO_TF', 0.0, 'uM')]" \
    --pert-col drugname_drugconc \
    --profile minimal \
    --celltype-col cell_name \
    --batch-size 1024 \
    --num-threads 64

In [None]:
import pandas as pd
results = pd.read_csv('/data3/fanpeishan/state/run_results/run12/run_results/NCI-H2122_agg_results.csv')
results[results.statistic == 'mean']