In [None]:
%load_ext autoreload
%autoreload 2

## Basic setup

In [None]:
import concord as ccd
import scanpy as sc
import torch
import warnings
from pathlib import Path
warnings.filterwarnings('ignore')
import time
from pathlib import Path
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
seed = 0

In [None]:
import time
from pathlib import Path
proj_name = "lungv1_5k"
file_name = proj_name
file_suffix = time.strftime('%b%d-%H%M')
seed = 0

save_dir = Path(f"../save/{proj_name}")
save_dir.mkdir(parents=True, exist_ok=True)

data_dir = Path(f"../data/{proj_name}")
data_dir.mkdir(parents=True, exist_ok=True)


In [None]:
data_path = data_dir / 'Lung_atlas_public.h5ad'
adata = sc.read(
    data_path
)

In [None]:
# Find highly variable genes
sc.pp.highly_variable_genes(adata, n_top_genes=5000, flavor='seurat_v3', subset=False)

In [None]:
adata = adata[:, adata.var.highly_variable].copy()
adata.write_h5ad(data_dir / f"{file_name}_preprocessed_HVG.h5ad")
print(f"✅ Preprocessed data saved to {data_dir / f'{file_name}_preprocessed_HVG.h5ad'}")

In [None]:
import subprocess, json
py_methods = ["scvi", "harmony", "scanorama", "liger", "unintegrated", "concord_hcl", "concord_knn", "contrastive"]
output_dir = '../jobs'
device = 'auto'
conda_env = 'cellpath'
batch_key = "batch"
state_key = 'None'
latent_dim = '50'  # Adjust as needed, but should match the encoder_dims in concord_args
subprocess.run([
    "python", "./generate_py_jobs.py",
    "--proj_name", proj_name,
    "--adata_filename", f"{file_name}_preprocessed_HVG.h5ad",
    "--methods", *py_methods,
    "--batch_key", batch_key,
    "--state_key", state_key,
    "--latent_dim", latent_dim,
    "--output_dir", output_dir,
    "--device", device,
    "--mem", "32G",  # Adjust memory as needed
    "--conda_env", conda_env,
    "--runtime", "2:00:00",
    "--mode", "wynton"
])


In [None]:
ccd.ul.anndata_to_viscello(adata,
                        output_dir=data_dir / f"viscello_{proj_name}",
                        project_name=proj_name,
                        organism='mmu')

In [None]:
# Generate script for Seurat
import subprocess
r_methods = ["seurat_cca", "seurat_rpca"]
output_dir = '../jobs'
device = 'auto'
subprocess.run([
    "python", "./generate_seurat_script.py",
    "--proj_name", proj_name,
    "--eset_dir", '../'+ str(data_dir / f"viscello_{proj_name}"),   # <- folder w/ eset.rds
    "--methods", *r_methods,
    "--batch_key", batch_key,
    "--state_key", state_key,
    "--latent_dim", latent_dim,
    "--mem", "80G",  # Adjust memory as needed
    "--runtime", "10:00:00",
    "--output_dir", output_dir,
    "--device", device,
    "--conda_env", conda_env
])

In [None]:
proj_folder = Path(output_dir) / f"benchmark_{proj_name}"   # ../jobs/benchmark_<proj>
proj_folder.mkdir(exist_ok=True)                      # defensive

submit_all = proj_folder / f"submit_all_{proj_name}.sh"
with submit_all.open("w") as f:
    f.write("#!/bin/bash\n")
    f.write("# Auto-generated — submits every job for this project\n")
    f.write("# Run from this folder, or let the script cd into it.\n\n")
    f.write('cd "$(dirname "$0")"\n\n')          # ensures we’re in the right dir
    for sh_file in sorted(proj_folder.glob(f"benchmark_{proj_name}_*.sh")):
        f.write(f'qsub "{sh_file.name}"\n')

submit_all.chmod(0o755)
print(f"📌  Run “{submit_all}” to queue every job.")

### Collect results

In [None]:
from benchmark_utils import collect_benchmark_logs
methods = ["scvi", "harmony", "scanorama", "liger", "unintegrated", "concord_hcl", "concord_knn", "contrastive", "seurat_cca", "seurat_rpca"]

bench_df = collect_benchmark_logs(file_name, methods)
# Save the benchmark results
bench_df.to_csv(save_dir / f"benchmark_summary_{file_suffix}.tsv", sep="\t", index=False)
print(f"✅ Benchmark summary saved to: {save_dir / f'benchmark_summary_{file_suffix}.tsv'}")
# Plot benchmark results
from benchmark_utils import plot_benchmark_performance
import matplotlib.pyplot as plt
custom_rc = {
    'font.family': 'Arial',  # Set the desired font for this plot
}
with plt.rc_context(rc=custom_rc):
    plot_benchmark_performance(bench_df, figsize=(8,2), dpi=300, save_path = save_dir / f"benchmark_plot_{file_suffix}.pdf")


In [None]:
data_path = data_dir / 'Lung_atlas_public.h5ad'
adata = sc.read(
    data_path
)

In [None]:
adata.obsm.keys()

In [None]:
from benchmark_utils import add_embeddings
adata = add_embeddings(adata, proj_name=proj_name, methods=methods)

In [None]:
# Run umap for all latent embeddings
for basis in methods:
    print("Running UMAP for", basis)
    if basis not in adata.obsm:
        print(f"{basis} not found.")
        continue
    #if 'UMAP' not in basis:
    ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
    ccd.ul.run_umap(adata, source_key=basis, result_key=f'{basis}_UMAP_3D', n_components=3, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)

# Save obsm 
ccd.ul.save_obsm_to_hdf5(adata, save_dir / f"{file_name}_obsm_final.h5")
adata.write_h5ad(data_dir / f"{file_name}_adata_final.h5ad")
print(f"✅ Saved adata with embeddings to {data_dir / f'{file_name}_adata_final.h5ad'}")

In [None]:
# plot everything
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rcParams

# Set Arial as the default font
custom_rc = {
    'font.family': 'Arial',  # Set the desired font for this plot
}

show_keys = methods
show_cols = ['dataset', 'nGene', 'protocol', 'batch', 'donor', 'cell_type']
basis_types = ['UMAP']

font_size=10
point_size=.1
alpha=0.8
ncols = len(show_keys)
figsize=(ncols * 1.5,1.5)
nrows = int(np.ceil(len(show_keys) / ncols))

with plt.rc_context(rc=custom_rc):
    ccd.pl.plot_all_embeddings(
        adata,
        show_keys,
        color_bys=show_cols,
        basis_types=basis_types,
        #pal=pal,
        font_size=font_size,
        point_size=point_size,
        alpha=alpha,
        figsize=figsize,
        ncols=ncols,
        seed=seed,
        save_dir=save_dir,
        file_suffix=file_suffix,
        dpi=600,
        save_format='svg'
    )


In [None]:
basis = 'concord_knn'
show_basis = basis + '_UMAP'
ccd.pl.plot_embedding(
    adata, show_basis, show_cols, figsize=(13,8), dpi=600, ncols=3, font_size=10, point_size=1, legend_loc="on data",
    save_path=save_dir / f"{show_basis}_{file_suffix}.pdf"
)

### Separately run CONCORD

In [None]:
data_path = data_dir / 'Lung_atlas_public.h5ad'
adata = sc.read(
    data_path
)

In [None]:
sc.pp.highly_variable_genes(adata, n_top_genes=3000, flavor="cell_ranger", batch_key="batch")
sc.tl.pca(adata, n_comps=30, use_highly_variable=True)

In [None]:
concord_args = {
        'adata': adata,
        'input_feature': adata.var[adata.var['highly_variable']].index.tolist(),
        'latent_dim': 30,
        'batch_size': 256,
        'augmentation_mask_prob': 0.3, 
        'clr_temperature': 0.5,
        'p_intra_knn': 0.0,
        'clr_beta': 0.2,
        'sampler_knn': 300,
        'sampler_emb': 'X_pca',
        'p_intra_domain': 1.0,
        'n_epochs': 10,
        'domain_key': 'batch',
        'verbose': True,
        'seed': seed,
        'device': device,
        'save_dir': save_dir,
        'load_data_into_memory': True,
    }

output_key = f'Concord_knn'
cur_ccd = ccd.Concord(**concord_args)
cur_ccd.fit_transform(output_key=output_key)

In [None]:
ccd.ul.run_umap(adata, source_key=output_key, result_key=f'{output_key}_UMAP', n_components=2, n_neighbors=30, min_dist=0.1, metric='euclidean', random_state=seed)
#output_key = 'Concord'
show_cols = ['protocol', 'batch', 'cell_type']
show_basis = f'{output_key}_UMAP'
ccd.pl.plot_embedding(
    adata, show_basis, show_cols, figsize=(10,3), dpi=300, ncols=3, font_size=5, point_size=3, legend_loc='on data',
    save_path=save_dir / f"embeddings_{show_basis}_{file_suffix}.png"
)