### Load needed libraries

In [None]:
import importlib
import iterative_scANVI
importlib.reload(iterative_scANVI)
from iterative_scANVI import *
%matplotlib inline

sc.settings.n_jobs = 32

pwd = os.getcwd()

### Self projection with reference data

In [None]:
## Load the MTG reference dataset
dataset = "reference"
technology = "singleomeCR6"
region = "MTG"
date = "2022-04-08"

adata_ref_mtg = sc.read_h5ad(filename=os.path.join(pwd, "input", region + "_" + dataset + "_" + technology + "." + date + ".h5ad"))

# Define MTG supertypes based on MTG self-projection

low_confidence = ["L2/3 IT_4", "L2/3 IT_9", "L2/3 IT_11",
                  "L5 IT_4",
                  "L5/6 NP_5",
                  "Micro-PVM_3",
                  "Pvalb_4", "Pvalb_11",
                  "Sncg_7",
                  "Sst_6", "Sst_8", "Sst_14", "Sst_15", "Sst_16", "Sst_17", "Sst_18", "Sst_21", "Sst_24", "Sst_26", 
                  "Vip_3", "Vip_7",  "Vip_8", "Vip_10", "Vip_17", "Vip_20", "Vip_22"]

adata_ref_mtg.obs["supertype"] = adata_ref_mtg.obs["cluster"].copy()
adata_ref_mtg.obs["supertype"] = adata_ref_mtg.obs["supertype"].astype("object")

for i in low_confidence:
    adata_ref_mtg.obs.loc[adata_ref_mtg.obs["cluster"] == i, "supertype"] = "Unknown"

adata_ref_mtg.obs["supertype"] = adata_ref_mtg.obs["supertype"].astype("category")

## Load the A9 reference dataset
dataset = "reference"
technology = "singleomeCR6"
region = "A9"
date = "2022-08-19"

adata_ref = sc.read_h5ad(filename=os.path.join(pwd, "input", region + "_" + dataset + "_" + technology + "." + date + ".h5ad"))

adata_ref.obs["class"] = "Unknown"
adata_ref.obs["subclass"] = "Unknown"
adata_ref.obs["supertype"] = "Unknown"

iterative_scANVI(
    adata_ref,
    adata_ref_mtg,
    output_dir=os.path.join(pwd, "output", region + "_" + dataset + "_"  + technology + "_MTG_liftover"),
    labels_keys=["class", "subclass", "supertype"],
    **{
        "categorical_covariate_keys": ["library_prep"],
        "continuous_covariate_keys": ["nFeature_RNA"]
      }
)

### Save subclass AnnData objects for reference

In [None]:
save_anndata(
    adata_query=adata_ref,
    adata_ref=adata_ref_mtg,
    split_key="subclass_scANVI",
    groupby="supertype",
    output_dir=os.path.join(pwd, "output", region + "_" + dataset + "_"  + technology + "_MTG_liftover"),
    date = "2023-06-08",
    diagnostic_plots=["sex", "age_at_death", "donor_name", "roi", "doublet_score", "nFeature_RNA", "fraction_mito"],
    model_args={
        "layer": "UMIs",
        "batch_key": None,
        "categorical_covariate_keys": ["library_prep"],
        "continuous_covariate_keys": ["nFeature_RNA"]
    },
    **{
        "n_cores": 32,
        "cluster_cells": False,
    }
)

### Concatenate A9 and MTG reference and lift over labels onto A9

In [None]:
adata_ref = adata_ref.concatenate(adata_ref_mtg, index_unique=None)
adata_ref.uns["Great Apes Metadata"] = adata_ref_mtg.uns["Great Apes Metadata"]
del adata_ref_mtg

output_dir=os.path.join(pwd, "output", region + "_" + dataset + "_"  + technology + "_MTG_liftover")
results_file = "iterative_scANVI_results.2023-06-08.csv"
scANVI_results = pd.read_csv(os.path.join(output_dir, results_file), index_col=0)
adata_ref.obs.loc[scANVI_results[scANVI_results["class"] == "Unknown"].index, "class"] = scANVI_results.loc[scANVI_results["class"] == "Unknown", "class_scANVI"].copy()
adata_ref.obs.loc[scANVI_results[scANVI_results["class"] == "Unknown"].index, "subclass"] = scANVI_results.loc[scANVI_results["class"] == "Unknown", "subclass_scANVI"].copy()
adata_ref.obs.loc[scANVI_results[scANVI_results["class"] == "Unknown"].index, "supertype"] = scANVI_results.loc[scANVI_results["class"] == "Unknown", "supertype_scANVI"].copy()

### Run on singleome and multiome combined data

In [None]:
dataset = "AD"
region = "A9"
date = "2022-08-19"

adata_query = sc.read_h5ad(filename=os.path.join(pwd, "input", region + "_" + dataset + "_combined." + date + ".h5ad"))

iterative_scANVI(
    adata_query,
    adata_ref,
    output_dir=os.path.join(pwd, "output", region + "_" + dataset),
    labels_keys=["class", "subclass", "supertype"],
    **{
        "categorical_covariate_keys": ["library_prep"],
        "continuous_covariate_keys": ["nFeature_RNA"]
      }
)

#### Save subclass AnnData objects for manual curation

In [None]:
save_anndata(
    adata=adata,
    adata_ref=adata_ref,
    split_key="subclass_scANVI",
    groupby="supertype",
    output_dir=os.path.join(pwd, "output", region + "_" + dataset),
    date = "2023-06-14",
    model_args={
        "layer": "UMIs",
        "batch_key": None,
        "categorical_covariate_keys": ["library_prep"],
        "continuous_covariate_keys": ["nFeature_RNA"]
    },
    **{
        "n_cores": 32,
        "cluster_cells": True,
    }
)

# These files were manually curated in cellxgene

#### Leiden clustering refinement

In [None]:
region = "A9"
dataset = "AD"
groups = {"Astro":
            {"type": "glia"},
          "Endo":
            {"type": "glia"},
          "Micro-PVM":
            {"type": "glia",
             "cutoffs": {"nFeature_RNA": (1000, "gt")}
            },
          "Oligo":
            {"type": "glia"},
          "OPC":
            {"type": "glia"},
          "VLMC":
            {"type": "glia"},
         }

clean_taxonomies(
    groups=groups,
    splitby="supertype_scANVI",
    reference_key="reference_cell",
    object_dir=os.path.join(pwd, "output", region + "_" + dataset, "objects"),
    **{
        "layer": "UMIs",
        "categorical_covariate_keys": ["donor_name", "sex", "ch_race___1", "method"],
        "continuous_covariate_keys": ["nFeature_RNA", "age_at_death", "fraction_mito"],
        "diagnostic_plots": ["ch_cognitivestatus_binary", "adneurochange", "reference_cell", "sex", "roi", "age_at_death", "donor_name", "method", "doublet_score", "nFeature_RNA", "fraction_mito"],
        "use_markers": True,
        "refine_supertypes": False
    }
)

In [None]:
region = "A9"
dataset = "AD"
groups = {"L2/3 IT":
            {"type": "excitatory"},
          "L4 IT":
            {"type": "excitatory"},
          "L5 ET":
            {"type": "excitatory",
             "cutoffs": {"doublet_score": (0.5, "lt")}
            },
          "L5 IT":
            {"type": "excitatory"},
          "L5/6 NP":
            {"type": "excitatory"},
          "L6 CT":
            {"type": "excitatory"},
          "L6 IT":
            {"type": "excitatory"},
          "L6 IT Car3":
            {"type": "excitatory"},
          "L6b":
            {"type": "excitatory"},
          "Lamp5":
            {"type": "inhibitory"},
          "Lamp5_Lhx6":
            {"type": "inhibitory"},
          "Pax6":
            {"type": "inhibitory"},
          "Pvalb":
            {"type": "inhibitory"},
          "Chandelier":
            {"type": "inhibitory"},
          "Sncg":
            {"type": "inhibitory"},
          "Sst":
            {"type": "inhibitory"},
          "Sst Chodl":
            {"type": "inhibitory"},
          "Vip":
            {"type": "inhibitory"},
         }

clean_taxonomies(
    groups=groups,
    splitby="supertype_scANVI",
    reference_key="reference_cell",
    object_dir=os.path.join(pwd, "output", region + "_" + dataset, "objects"),
    **{
        "layer": "UMIs",
        "categorical_covariate_keys": ["donor_name", "sex", "ch_race___1", "method"],
        "continuous_covariate_keys": ["nFeature_RNA", "age_at_death", "fraction_mito"],
        "diagnostic_plots": ["ch_cognitivestatus_binary", "adneurochange", "reference_cell", "sex", "roi", "age_at_death", "donor_name", "method", "doublet_score", "nFeature_RNA", "fraction_mito"],
        "use_markers": True,
        "refine_supertypes": False
    }
)

### Expand the non-neuronal taxonomy and pull final QC together

In [None]:
region = "A9"
dataset = "AD"
output_dir = os.path.join(pwd, "output", region + "_" + dataset)
corrected_label = "supertype_scANVI"
results_file = "iterative_scANVI_results.2023-06-14.csv"
split_key = "subclass_scANVI"

scANVI_results = pd.read_csv(os.path.join(output_dir, results_file), index_col=0)
corrected = pd.DataFrame(columns=[corrected_label])

for n,i in enumerate(os.listdir(os.path.join(output_dir, "objects"))):
    if os.path.isdir(os.path.join(output_dir, "objects", i)) is False or i.startswith("."):
        continue

    if i in scANVI_results[split_key].astype("category").cat.categories is False:
        continue

    print(str(datetime.now()) + " -- " + i)
    cell_labels = pd.read_csv(os.path.join(output_dir, "objects", i, "corrections.csv"), index_col=0)
    corrected = pd.concat([corrected, cell_labels])

corrected[corrected_label + "_leiden"] = corrected[corrected_label].copy()
corrected.drop(corrected_label, axis=1, inplace=True)
scANVI_results = pd.concat([scANVI_results, corrected.loc[scANVI_results.index, :]], axis=1)
scANVI_results.to_csv(os.path.join(output_dir, "iterative_scANVI_results_refined." + str(datetime.date(datetime.now())) + ".csv"))

### Add in finalized scANVI results to the AnnData

In [None]:
dataset = "AD"
region = "A9"

adata = sc.read_h5ad("A9_combined.2023-06-09.h5ad")

output_dir = os.path.join(pwd, "output", region + "_" + dataset)
results_file = "iterative_scANVI_results_refined.2023-07-10.csv"

scANVI_results = pd.read_csv(os.path.join(output_dir, results_file), index_col=0)
scANVI_results = scANVI_results.loc[:, np.setdiff1d(scANVI_results.columns, adata.obs.columns)]
    
if scANVI_results.shape[0] != adata.shape[0]:
    common_cells = np.intersect1d(adata.obs_names, scANVI_results.index)
    adata = adata[common_cells].copy()
    print("WARNING: Mismatch between cells in scANVI results and merged AnnData object, using " + str(len(common_cells)) + " common cells. Was this expected?") 

adata.obs = pd.concat([adata.obs, scANVI_results.loc[adata.obs_names, :]], axis=1)

### Write out final AnnData object/CSVs that includes all nuclei

In [None]:
to_keep = ~(
    ([i not in adata.obs["supertype"].cat.categories for i in adata.obs["supertype_scANVI_leiden"]]) & 
    (~adata.obs["supertype_scANVI_leiden"].str.contains("-SEAAD")) &
    (~adata.obs["supertype_scANVI_leiden"].str.contains("Endo_"))
)
adata.obs["for_analysis"] = to_keep

adata.write(os.path.join(pwd, "output", region + "_" + dataset, "raw." + str(datetime.date(datetime.now())) + ".h5ad"))

### Remove low quality cells, train model for final representation and write out

In [None]:
adata = adata[to_keep].copy()

final_model_args = {
    "n_layers": 2,
    "n_latent": 20,
    "dispersion": "gene-label"
}
print("Setting up AnnData...")
scvi.model.SCVI.setup_anndata(
    adata,
    layer="UMIs",
    categorical_covariate_keys=["library_prep"],
    continuous_covariate_keys=["nFeature_RNA"],
    labels_key="supertype_scANVI_leiden"
)
if os.path.exists(os.path.join(pwd, "output", region + "_" + dataset, "final_model")) is False:
    print("Creating model...")
    final_model = scvi.model.SCVI(adata, **final_model_args)
    print("Training model...")
    final_model.train(max_epochs=200, early_stopping=True)
    final_model.save(os.path.join(pwd, "output", region + "_" + dataset, "final_model"))
else:
    print("Loading model...")
    final_model = scvi.model.SCVI.load(os.path.join(pwd, "output", region + "_" + dataset, "final_model"), adata)

print("Calculcating latent representation and UMAP...")
with parallel_backend('threading', n_jobs=32):
    adata.obsm["X_scVI"] = final_model.get_latent_representation()
    sc.pp.neighbors(adata, use_rep="X_scVI")
    sc.tl.umap(adata)
    
colors = pd.read_csv(os.path.join(pwd, "input", "cluster_order_and_colors.csv"))
subclass_colors = colors.loc[:, ["subclass_label", "subclass_color"]].drop_duplicates()
subclass_colors.index = subclass_colors["subclass_label"].copy()
subclass_colors = subclass_colors["subclass_color"].to_dict()

supertype_colors = colors.loc[:, ["cluster_label", "cluster_color"]]
supertype_colors.index = supertype_colors["cluster_label"].copy()
supertype_colors = supertype_colors["cluster_color"].to_dict()

plt.rcParams["figure.figsize"] = (10, 10)
sc.pl.umap(adata, color="subclass_scANVI", palette=subclass_colors, legend_loc="on data", frameon=False, size=3)
sc.pl.umap(adata, color="supertype_scANVI_leiden", palette=supertype_colors, legend_loc="on data", frameon=False, size=3, legend_fontoutline=3)

adata.write(os.path.join(pwd, "output", region + "_" + dataset, "final." + str(datetime.date(datetime.now())) + ".h5ad"))

### Write out Subclass specific latent spaces and UMAP coordinates

In [None]:
for i in adata.obs["subclass_scANVI"].cat.categories:

    markers = pd.read_csv(os.path.join(pwd, "output", region + "_" + dataset, "objects", "models", split_value.replace("/", " "), "scVI_model", "var_names.csv"), header=None)
    markers = markers[0].to_list()

    sub = adata[(adata.obs["subclass_scANVI"] == i), markers].copy()
    
    model = scvi.model.SCVI.load(os.path.join(pwd, "output", region + "_" + dataset, "objects", "models", split_value.replace("/", " "), "scVI_model"), sub)
    
    with parallel_backend('threading', n_jobs=32):
        sub.obsm["X_scVI"] = model.get_latent_representation()
        sc.pp.neighbors(sub, use_rep="X_scVI")
        sc.tl.umap(sub)

    np.save(os.path.join(pwd, "output", region + "_" + dataset, "objects", "models", split_value.replace("/", " "), "scVI_model", "X_scVI.npy"), sub.obsm["X_scVI"])
    np.save(os.path.join(pwd, "output", region + "_" + dataset, "objects", "models", split_value.replace("/", " "), "scVI_model" "X_umap.npy"), sub.obsm["X_umap"])
    pd.DataFrame(sub.obs_names).to_csv(os.path.join(pwd, "output", region + "_" + dataset, "objects", "models", split_value.replace("/", " "), "scVI_model" "obs_names.csv"), index=False)    