In [None]:
%load_ext autoreload
%autoreload 2

### data loading and preparation

In [None]:
import concord as ccd
import scanpy as sc
import torch
import warnings
from pathlib import Path
warnings.filterwarnings('ignore')
import time
from pathlib import Path
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
seed = 0

In [None]:
proj_name = "cel_packerN2_downsample"
file_name = "cel_packerN2_downsample"
file_suffix = time.strftime('%b%d-%H%M')
seed = 0
ccd.ul.set_seed(seed)
save_dir = Path(f"../save/{proj_name}")
save_dir.mkdir(parents=True, exist_ok=True)

data_dir = Path(f"../data/{proj_name}")
data_dir.mkdir(parents=True, exist_ok=True)


In [None]:
adata = sc.read_h5ad(Path('../data/CBCEcombineN2/') / 'adata_celsub_Jun26-1610.h5ad')
sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=10000, subset=False)
sc.tl.pca(adata, n_comps=300, svd_solver='arpack', use_highly_variable=True)
adata = adata[:, adata.var.highly_variable].copy()
adata.write_h5ad(data_dir / f"{file_name}_preprocessed.h5ad")
print(f"✅ Preprocessed data saved to {data_dir / f'{file_name}_preprocessed.h5ad'}")

In [None]:
DATA_FILE     = Path(f"{file_name}_preprocessed.h5ad")
adata = sc.read_h5ad(data_dir / DATA_FILE)

### Create jobs

In [None]:
import json, math, subprocess
from pathlib import Path

fractions = [1.0, 0.5, 0.25, 0.1, 0.05, 0.01]
concord_args = {
    "latent_dim": 300,
    "batch_size": 256,
    "encoder_dims": [1000],
    "p_intra_domain": 1.0,
    "augmentation_mask_prob": 0.30,
    "clr_temperature": 0.30,
    "sampler_knn": 1000,
    "n_epochs": 15,        # gets overwritten below
    "lr":       1e-2,   # gets overwritten below
}

py_methods = ["scvi", "harmony", "scanorama", "liger", "unintegrated", "concord_hcl", "concord_knn", "contrastive"]
output_dir = '../jobs'
device = 'auto'
conda_env = 'concord'
batch_key = 'batch'
state_key = 'None'
latent_dim = '300'  # Adjust as needed, but should match the encoder_dims in concord_args


for frac in fractions:
    adata_name = f"{file_name}_downsampled_{int(frac * 100)}.h5ad"
    tag        = f"ds{int(frac * 100)}"                   # keeps job names unique
    cur_proj = f"{proj_name}_{tag}"
    cur_dir = Path("../data") / cur_proj
    cur_dir.mkdir(parents=True, exist_ok=True)
    if frac < 1.0:
        downsampled_adata = adata.copy()
        n_cells = int(len(downsampled_adata) * frac)
        print(f"Downsampling to {frac * 100:.1f}% of the original data ({n_cells} cells)")
        downsampled_adata = downsampled_adata[:n_cells, :].copy()
        downsampled_adata.obs['downsample_fraction'] = frac
        downsampled_adata.write_h5ad(cur_dir / f"{file_name}_downsampled_{int(frac * 100)}.h5ad")
        print(f"✅ Downsampled data saved to {cur_dir / f'{file_name}_downsampled_{int(frac * 100)}.h5ad'}")
    else:
        downsampled_adata = adata.copy()
        downsampled_adata.write_h5ad(cur_dir / f"{file_name}_downsampled_{int(frac * 100)}.h5ad")

    subprocess.run([
        "python", "./generate_py_jobs.py",
        "--proj_name", f"{proj_name}_{tag}",
        "--adata_filename", adata_name,
        "--methods", *py_methods,
        "--batch_key", batch_key,
        "--state_key", state_key,
        "--latent_dim", latent_dim,
        "--output_dir", output_dir,
        "--device", device,
        "--conda_env", conda_env,
        "--runtime", "02:00:00",
        "--concord_kwargs", json.dumps(concord_args),
        "--root_save_dir", "../save",
        "--root_data_dir", "../data",
    ], check=True)

In [None]:
# ------------------------------------------------------------------
# create submit_sequential_<proj>.sh  (runs each *.py job in order)
# ------------------------------------------------------------------
sequential_submit = Path("../jobs") / f"submit_sequential_{proj_name}.sh"

# NOTE: only the for-loop glob changed ↓↓↓
sequential_template = f"""#!/usr/bin/env bash
set -euo pipefail
cd "$(dirname "$0")"            # work inside this folder (../jobs)
shopt -s nullglob

for folder in benchmark_{proj_name}_*; do
  [[ -d "$folder" ]] || continue
  echo "===== entering $folder  $(date) ====="

  for job in "$folder"/*.py; do
    [[ -e "$job" ]] || continue

    base=${{job%.py}}
    log="${{base}}.log"

    # ───────────────────────────────────────────────────────────────
    # skip if a previous run finished successfully
    #   • If you only care that the log exists (no success check),
    #     drop the grep clause.
    # ───────────────────────────────────────────────────────────────
    if [[ -f "$log" ]] && grep -q "finished OK" "$log"; then
        echo ">>> SKIP $job  — already completed"
        continue
    fi

    echo ">>> $job   $(date)" | tee -a "$log"
    if python "$job" >>"$log" 2>&1; then
        echo ">>> finished OK" | tee -a "$log"
    else
        echo ">>> FAILED"      | tee -a "$log"
    fi
  done
done
"""

sequential_submit.write_text(sequential_template)
sequential_submit.chmod(0o755)
print(f"📌  Run “{sequential_submit}” to queue the down-sample benchmarks sequentially.")


### Collect results

In [None]:
from benchmark_utils import add_embeddings
methods = ["scvi", "harmony", "scanorama", "liger", "unintegrated", "concord_hcl", "concord_knn", "contrastive"]
fractions = [1.0, 0.5, 0.25, 0.1, 0.05, 0.01]

for frac in fractions:
    adata_name = f"{file_name}_downsampled_{int(frac * 100)}.h5ad"
    tag        = f"ds{int(frac * 100)}"                   # keeps job names unique
    cur_proj = f"{proj_name}_{tag}"
    cur_dir = Path("../data") / cur_proj
    cur_adata = sc.read_h5ad(cur_dir / adata_name)
    print(f"Running benchmarks for {cur_proj} with {len(cur_adata)} cells and {len(cur_adata.var)} genes.")
    cur_adata = add_embeddings(cur_adata, proj_name=cur_proj, methods=methods)
    cur_adata.write_h5ad(cur_dir / f"{file_name}_downsampled_{int(frac * 100)}_final.h5ad")
    

### Benchmarking

In [None]:
from benchmark_utils import collect_benchmark_logs

for frac in fractions:
    adata_name = f"{file_name}_downsampled_{int(frac * 100)}.h5ad"
    tag        = f"ds{int(frac * 100)}"                   # keeps job names unique
    cur_proj = f"{proj_name}_{tag}"
    bench_df = collect_benchmark_logs(cur_proj, methods)
    # Save the benchmark results
    bench_df.to_csv(save_dir / f"{cur_proj}_benchmark_summary_{file_suffix}.tsv", sep="\t", index=False)
    print(f"✅ Benchmark summary saved to: {save_dir / f'{cur_proj}_benchmark_summary_{file_suffix}.tsv'}")
    # Plot benchmark results
    from benchmark_utils import plot_benchmark_performance
    import matplotlib.pyplot as plt
    custom_rc = {
        'font.family': 'Arial',  # Set the desired font for this plot
    }
    with plt.rc_context(rc=custom_rc):
        plot_benchmark_performance(bench_df, figsize=(8,2), dpi=300, save_path = save_dir / f"{cur_proj}_benchmark_plot_{file_suffix}.pdf")


#### Benchmark cell type

In [None]:
fractions[::-1]

In [None]:
adata_ct.obs['cell_type'].value_counts()

In [None]:
import numpy as np
bad_annotation = [np.nan, '', 'unknown', 'None', 'nan', 'NaN', 'NA', 'na', 'unannotated']

state_benchmarks = {}
for frac in fractions:  # reverse order to process larger fractions first
    adata_name = f"{file_name}_downsampled_{int(frac * 100)}_final.h5ad"
    tag        = f"ds{int(frac * 100)}"                   # keeps job names unique
    cur_proj = f"{proj_name}_{tag}"
    cur_dir = Path("../data") / cur_proj
    cur_adata = sc.read_h5ad(cur_dir / adata_name)
    bad_cells = cur_adata.obs['cell_type'].isin(bad_annotation)
    # Also remove cells with classes with less than 2 cells
    # cell_counts = cur_adata.obs['cell_type'].value_counts()
    # bad_cells |= cur_adata.obs['cell_type'].isin(cell_counts[cell_counts < 3].index)

    print(f"Filtering {cur_proj} to remove bad annotations: {bad_cells.sum()} cells out of {len(cur_adata)}")
    adata_ct = cur_adata[~bad_cells].copy()

    print(f"✅ Filtered adata to remove bad annotations, new shape: {adata_ct.shape}")
    state_counts = len(adata_ct.obs['cell_type'].value_counts())
    batch_counts = len(adata_ct.obs['batch'].value_counts())
    print(f"Cell types: {state_counts}, Batches: {batch_counts}")
    state_key = 'cell_type' if state_counts > 1 else None
    batch_key = 'batch' if batch_counts > 1 else None
    out = ccd.bm.run_benchmark_pipeline(
        adata_ct,
        embedding_keys=methods,
        state_key=state_key,
        batch_key=batch_key,
        save_dir=save_dir / f"{cur_proj}_benchmarks_{state_key}",
        file_suffix=file_suffix,  # e.g. "2025-06-25"
        run=("probe"),          # run only these blocks
        plot_individual=False,          # skip the intermediate PDFs
    )
    combined_celltype = out["combined"]
    state_benchmarks[tag] = combined_celltype

    # Save the benchmark results
    import pickle
    with open(save_dir / f"{cur_proj}_benchmark_{state_key}_{file_suffix}.pkl", "wb") as f:
        pickle.dump(out, f)
    print(f"✅ Benchmark results saved to: {save_dir / f'{cur_proj}_benchmark_{state_key}_{file_suffix}.pkl'}")

with open(save_dir / f"{proj_name}_{state_key}_benchmarks_{file_suffix}.pkl", "wb") as f:
    pickle.dump(state_benchmarks, f)
print(f"✅ State benchmarks saved to: {save_dir / f'{proj_name}_{state_key}_benchmarks_{file_suffix}.pkl'}")

#### Benchmark lineage

In [None]:
import numpy as np
bad_annotation = [np.nan, '', 'unknown', 'None', 'nan', 'NaN', 'NA', 'na', 'unannotated']

state_benchmarks = {}
for frac in fractions:  # reverse order to process larger fractions first
    adata_name = f"{file_name}_downsampled_{int(frac * 100)}_final.h5ad"
    tag        = f"ds{int(frac * 100)}"                   # keeps job names unique
    cur_proj = f"{proj_name}_{tag}"
    cur_dir = Path("../data") / cur_proj
    cur_adata = sc.read_h5ad(cur_dir / adata_name)
    bad_cells = cur_adata.obs['lineage_complete'].isin(bad_annotation)

    print(f"Filtering {cur_proj} to remove bad annotations: {bad_cells.sum()} cells out of {len(cur_adata)}")
    adata_ct = cur_adata[~bad_cells].copy()

    print(f"✅ Filtered adata to remove bad annotations, new shape: {adata_ct.shape}")
    state_counts = len(adata_ct.obs['lineage_complete'].value_counts())
    batch_counts = len(adata_ct.obs['batch'].value_counts())
    print(f"Cell types: {state_counts}, Batches: {batch_counts}")
    state_key = 'lineage_complete' if state_counts > 1 else None
    batch_key = 'batch' if batch_counts > 1 else None
    out = ccd.bm.run_benchmark_pipeline(
        adata_ct,
        embedding_keys=methods,
        state_key=state_key,
        batch_key=batch_key,
        save_dir=save_dir / f"{cur_proj}_benchmarks_{state_key}",
        file_suffix=file_suffix,  # e.g. "2025-06-25"
        run=("probe"),          # run only these blocks
        plot_individual=False,          # skip the intermediate PDFs
    )
    combined_celltype = out["combined"]
    state_benchmarks[tag] = combined_celltype

    # Save the benchmark results
    import pickle
    with open(save_dir / f"{cur_proj}_benchmark_{state_key}_{file_suffix}.pkl", "wb") as f:
        pickle.dump(out, f)
    print(f"✅ Benchmark results saved to: {save_dir / f'{cur_proj}_benchmark_{state_key}_{file_suffix}.pkl'}")

with open(save_dir / f"{proj_name}_{state_key}_benchmarks_{file_suffix}.pkl", "wb") as f:
    pickle.dump(state_benchmarks, f)
print(f"✅ State benchmarks saved to: {save_dir / f'{proj_name}_{state_key}_benchmarks_{file_suffix}.pkl'}")