In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### data loading and preparation

In [3]:
import concord as ccd
import scanpy as sc
import torch
import warnings
from pathlib import Path
warnings.filterwarnings('ignore')
import time
from pathlib import Path
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
seed = 0

In [4]:
proj_name = "cel_packerN2_predict"
file_name = "cel_packerN2_predict"
file_suffix = time.strftime('%b%d-%H%M')
seed = 0
ccd.ul.set_seed(seed)
save_dir = Path(f"../save/{proj_name}")
save_dir.mkdir(parents=True, exist_ok=True)

data_dir = Path(f"../data/{proj_name}")
data_dir.mkdir(parents=True, exist_ok=True)


### Predict with CONCORD and scVI

In [5]:

from pathlib import Path
import scanpy as sc
from scvi.model import SCVI

def predict_with_scvi(run_dir: Path, adata: sc.AnnData, out_key: str):
    """
    Parameters
    ----------
    run_dir : Path
        Folder like “…/scvi_0705-0140”. The actual checkpoint is in run_dir / "scvi_model.pt".
    adata : AnnData
        Full dataset to embed.
    out_key : str
        Where to store the latent representation in adata.obsm.
    """
    run_dir = Path(run_dir).resolve()

    inner_dir = run_dir / "scvi_model.pt"     # <── directory that holds model.pt
    if not (inner_dir / "model.pt").is_file():
        raise FileNotFoundError(f"{inner_dir}/model.pt not found")

    # 1️⃣  harmonise genes / categories
    SCVI.prepare_query_anndata(adata, str(inner_dir))

    # 2️⃣  stitch query cells onto the trained weights
    vae_q = SCVI.load_query_data(
        adata,
        str(inner_dir),        # folder that contains model.pt
    )
    vae_q.is_trained = True

    # 3️⃣  forward pass
    adata.obsm[out_key] = vae_q.get_latent_representation()
    return adata.obsm[out_key]



def predict_with_concord(model_dir: Path, adata, out_key: str):
    model = ccd.Concord.load(model_dir=model_dir)
    model.predict_adata(adata, output_key=out_key)
    return adata.obsm[out_key]


In [6]:
data_dir

PosixPath('../data/cel_packerN2_predict')

In [7]:
import scanpy as sc
from pathlib import Path
import concord as ccd
from scvi.model import SCVI            # make sure scvi-tools is ≥ 1.1
from benchmark_utils import latest_run_dir

methods = ["scvi", "concord_hcl", "concord_knn"]
fractions = [1.0]
ds_proj_name = "cel_packerN2_downsample"
adata = sc.read_h5ad(data_dir / f"{file_name}_preprocessed.h5ad")
for frac in fractions:
    tag        = f"ds{int(frac * 100)}"                         # ds100, ds10, …
    cur_proj   = f"{ds_proj_name}_{tag}"                       # e.g. cel_packerN2_downsample_ds10
    proj_save  = Path(f"../save/") / cur_proj
    print(proj_save)
    # fetch latest run directories for each method
    paths = {m: latest_run_dir(proj_save, m) for m in methods}
    for method, model_path in paths.items():
        if model_path is None:
            print(f"[{method:12}]  {tag:<5}  –  no model directory found, skipping.")
            continue

        out_key = f"{method}_{tag}"
        if out_key in adata.obsm:
            print(f"[{method:12}]  {tag:<5}  –  already exists, skipping.")
            continue

        try:
            if method == "scvi":
                latent = predict_with_scvi(model_path, adata, out_key)
            else:                          # concord_hcl / concord_knn
                latent = predict_with_concord(model_path, adata, out_key)

            print(f"[{method:12}]  {tag:<5}  →  stored in .obsm['{out_key}']  "
                  f"({latent.shape[0]}×{latent.shape[1]})")

        except Exception as e:
            print(f"[{method:12}]  {tag:<5}  –  FAILED: {e}")

../save/cel_packerN2_downsample_ds100
[34mINFO    [0m File                                                                                                      
         [35m/Users/QZhu/Documents/CONCORD/Concord_benchmark/save/cel_packerN2_downsample_ds100/scvi_0705-0140/scvi_mod[0m
         [35mel.pt/[0m[95mmodel.pt[0m already downloaded                                                                         
[34mINFO    [0m Found [1;36m100.0[0m% reference vars in query data.                                                                
[34mINFO    [0m File                                                                                                      
         [35m/Users/QZhu/Documents/CONCORD/Concord_benchmark/save/cel_packerN2_downsample_ds100/scvi_0705-0140/scvi_mod[0m
         [35mel.pt/[0m[95mmodel.pt[0m already downloaded                                                                         
[scvi        ]  ds100  –  FAILED: 'pyro_param_store'
conco

#### Benchmark cell type

In [30]:
import numpy as np
bad_annotation = [np.nan, '', 'unknown', 'None', 'nan', 'NaN', 'NA', 'na', 'unannotated']
bad_cells = adata.obs['cell_type'].isin(bad_annotation) 
adata_ct = adata[~bad_cells].copy()
print(f"✅ Filtered adata to remove bad annotations, new shape: {adata_ct.shape}")


✅ Filtered adata to remove bad annotations, new shape: (43686, 10000)


In [29]:
adata.obsm

AxisArrays with keys: Concord, Concord-decoder, Concord-decoder_UMAP, Concord-decoder_UMAP_3D, Concord_UMAP, Concord_UMAP_3D, X_pca, scvi_ds100, concord_hcl_ds100, concord_knn_ds100

In [None]:
# Probe only version
state_key = 'cell_type'
batch_key = 'batch'
methods = ["scvi_ds100", "concord_hcl_ds100", "concord_knn_ds100"]
out = ccd.bm.run_benchmark_pipeline(
    adata_ct,
    embedding_keys=methods,
    state_key=state_key,
    batch_key=batch_key,
    save_dir=save_dir / "benchmarks_celltype_probe",
    file_suffix=file_suffix,  # e.g. "2025-06-25"
    run=("probe"),          # run only these blocks
    plot_individual=False,          # skip the intermediate PDFs
)
combined_celltype = out["combined"]

# Save the benchmark results
import pickle
with open(save_dir / f"benchmark_probe_{state_key}_{file_suffix}.pkl", "wb") as f:
    pickle.dump(out, f)

print(f"✅ Benchmark results saved to: {save_dir / f'benchmark_{state_key}_{file_suffix}.pkl'}")

concord.benchmarking.benchmark - INFO - Running Probe benchmark
concord.benchmarking.benchmark - INFO - Running linear probe for state with keys ['scvi_ds100', 'concord_hcl_ds100', 'concord_knn_ds100']
Detected task: classification


#### Benchmark lineage

In [None]:
import numpy as np
bad_annotation = [np.nan, '', 'unknown', 'None', 'nan', 'NaN', 'NA', 'na', 'unannotated']

state_benchmarks = {}
for frac in fractions:  # reverse order to process larger fractions first
    adata_name = f"{file_name}_downsampled_{int(frac * 100)}_final.h5ad"
    tag        = f"ds{int(frac * 100)}"                   # keeps job names unique
    cur_proj = f"{proj_name}_{tag}"
    cur_dir = Path("../data") / cur_proj
    cur_adata = sc.read_h5ad(cur_dir / adata_name)
    bad_cells = cur_adata.obs['lineage_complete'].isin(bad_annotation)

    print(f"Filtering {cur_proj} to remove bad annotations: {bad_cells.sum()} cells out of {len(cur_adata)}")
    adata_ct = cur_adata[~bad_cells].copy()

    print(f"✅ Filtered adata to remove bad annotations, new shape: {adata_ct.shape}")
    state_counts = len(adata_ct.obs['lineage_complete'].value_counts())
    batch_counts = len(adata_ct.obs['batch'].value_counts())
    print(f"Cell types: {state_counts}, Batches: {batch_counts}")
    state_key = 'lineage_complete' if state_counts > 1 else None
    batch_key = 'batch' if batch_counts > 1 else None
    out = ccd.bm.run_benchmark_pipeline(
        adata_ct,
        embedding_keys=methods,
        state_key=state_key,
        batch_key=batch_key,
        save_dir=save_dir / f"{cur_proj}_benchmarks_{state_key}",
        file_suffix=file_suffix,  # e.g. "2025-06-25"
        run=("probe"),          # run only these blocks
        plot_individual=False,          # skip the intermediate PDFs
    )
    combined_celltype = out["combined"]
    state_benchmarks[tag] = combined_celltype

    # Save the benchmark results
    import pickle
    with open(save_dir / f"{cur_proj}_benchmark_{state_key}_{file_suffix}.pkl", "wb") as f:
        pickle.dump(out, f)
    print(f"✅ Benchmark results saved to: {save_dir / f'{cur_proj}_benchmark_{state_key}_{file_suffix}.pkl'}")

with open(save_dir / f"{proj_name}_{state_key}_benchmarks_{file_suffix}.pkl", "wb") as f:
    pickle.dump(state_benchmarks, f)
print(f"✅ State benchmarks saved to: {save_dir / f'{proj_name}_{state_key}_benchmarks_{file_suffix}.pkl'}")