### Load required packages

In [None]:
import os
import warnings
import scvi
import scipy
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
%matplotlib inline
from helper_functions import *

sc.settings.n_jobs = 32
warnings.filterwarnings("ignore")

pwd = os.getcwd()

### Load and prepare data

In [None]:
adata_mvi = sc.read_h5ad(filename="multivi_AD_3cohorts_RNA_ATAC_Multiome_v1.2_trained.h5ad")

ad_obs = adata_mvi.obs.copy()
ad_obs.reset_index(inplace=True)
column_select = ['donor_name', 'external_donor_name', 'age', 'sex', 'medical_conditions', 'method', 'library_prep']
ad_obs_1 = ad_obs[['level_0', 'sample_id', 'batch_id', 'modality', '_scvi_batch', '_scvi_labels'] + column_select]

rna_mult_anno = pd.read_csv("/allen/programs/celltypes/workgroups/hct/SEA-AD/RNAseq/scANVI/output/MTG_AD/metadata/MTG_combined_metadata_full.2022-04-13.csv")
rna_mult_anno_1 = rna_mult_anno.drop(column_select + ["index_name"], axis=1)
ad_obs_2 = pd.merge(ad_obs_1, rna_mult_anno_1, how='left', on='sample_id')

b = np.array(ad_obs_2['sample_id'].str.split("-", expand=True))
ad_obs_2['barcodes_out'] = b[:, 0] + "-1"

atac_only_anno = pd.read_csv("/allen/programs/celltypes/workgroups/rnaseqanalysis/ATAC-Seq/Arrow/ATAC_AD_Center_Grant_complete/ATAC_AD_Center_Grant_complete.csv")
atac_only_anno["path_to_fragments"] = atac_only_anno["ar_directory"] + atac_only_anno["ar_id"].astype(str) + "/outs/fragments.tsv.gz"
atac_only_anno = atac_only_anno.loc[:, ["library_prep", "path_to_fragments"]]

mult_only_anno = pd.read_csv("/allen/programs/celltypes/workgroups/rnaseqanalysis/10x/ARC/Human/ARC_AD_Center_Grant_MTG/ARC_AD_Center_Grant_MTG_MTX-2036_human.csv")
mult_only_anno["path_to_fragments"] = mult_only_anno["ar_directory"] + mult_only_anno["ar_id"].astype(str) + "/outs/atac_fragments.tsv.gz"
mult_only_anno = mult_only_anno.loc[:, ["library_prep", "path_to_fragments"]]

atac_anno = pd.concat([mult_only_anno, atac_only_anno])
ad_obs_3 = pd.merge(ad_obs_2, atac_anno, how='left', on='library_prep')

ad_obs_3.set_index("level_0", inplace=True)
ad_obs_3.index.name = None
adata_mvi.obs = ad_obs_3

### Calculate QC metrics and perform label transfer

In [None]:
adata_mvi.obs["Doublet_or_LowQuality"] = np.where(adata_mvi.obs.for_analysis == False,
                                                  "RNA doublet or LQ cells", 
                                                  "RNA high quality cells or ATAC cells")


In [None]:
cells = adata_mvi.uns['neighbors']['connectivities'].shape[0]
qc_ratios = np.zeros(cells)
for i in np.arange(cells):
    ratio = compute_cell_quality(adata_mvi, i)
    qc_ratios[i] = ratio
    
adata_mvi.obs["rna_neighbors_qc_ratio"] = qc_ratios
sc.pl.umap(adata_mvi, color="rna_neighbors_qc_ratio")

label = adata_mvi.obs['modality']
# label: paired, accessibility or expression

_, c = np.unique(label, return_counts=True)
theoretic_score = ((c / c.sum()) ** 2).sum()

modality_mixing = np.zeros(cells)
for i in np.arange(cells):
    ratio = compute_cell_mixing(adata_mvi, i)
    modality_mixing[i] = ratio
    
adata_mvi.obs["modality_mixing_ratio"] = modality_mixing
sc.pl.umap(adata_mvi, color='modality_mixing_ratio')

cells = adata_mvi.shape[0]
purity_ratios = np.zeros(cells)
labels = adata_mvi.obs["subclass_scANVI"].to_numpy()
for i in np.arange(cells):
    ratio, label_ = compute_label_purity(adata_mvi, i)
    purity_ratios[i] = ratio
    labels[i] = label_
    
adata_mvi.obs["subclass_purity_ratio"] = purity_ratios
adata_mvi.obs["label_transfer"] = labels
sc.pl.umap(adata_mvi, color='subclass_purity_ratio')

sc.pl.umap(adata_mvi, color='label_transfer', legend_loc="on data")

In [None]:
adata_mvi.obs["rna_neighbors_qc_ratio_new"] = adata_mvi.obs["rna_neighbors_qc_ratio"]

def compute_cell_quality_all_neighbors(adata_mvi, cell_idx):
    idx = np.where(adata_mvi.uns['neighbors']['connectivities'][cell_idx].todense()>0)[1]
    df = adata_mvi.obs[["rna_neighbors_qc_ratio_new", "modality"]].iloc[idx]
    ratio = df["rna_neighbors_qc_ratio_new"].mean(skipna=True)
    return ratio

for i in np.where(adata_mvi.obs["rna_neighbors_qc_ratio_new"].isnull())[0]:
    adata_mvi.obs["rna_neighbors_qc_ratio_new"][i] = compute_cell_quality_all_neighbors(adata_mvi, i)

adata_mvi.obs["all_neighbors_NA"] = adata_mvi.obs["rna_neighbors_qc_ratio_new"].isnull() * 1.0

def compute_label_purity_slot(adata_mvi, cell_idx, slot="subclass_scANVI"):
    idx = np.where(adata_mvi.uns['neighbors']['connectivities'][cell_idx].todense()>0)[1]
    df = adata_mvi.obs[[slot]].iloc[idx]
    df = df.loc[~df[slot].isnull()]
    u, c = np.unique(df, return_counts=True)
    if np.size(c) == 0:
        ratio = 0
        label = np.nan
    else:
        ratio = c[np.argmax(c)] / c.sum()
        label = u[np.argmax(c)]
    
    return ratio, label

cells = adata_mvi.shape[0]
purity_ratios = np.zeros(cells)
labels = adata_mvi.obs["subclass_scANVI"].to_numpy()
for i in np.arange(cells):
    ratio, label_ = compute_label_purity_slot(adata_mvi, i, slot="subclass_scANVI")
    purity_ratios[i] = ratio
    labels[i] = label_

adata_mvi.obs["subclass_purity_ratio"] = purity_ratios
adata_mvi.obs["label_transfer"] = labels

for i in np.where(adata_mvi.obs["label_transfer"].isnull())[0]:
    ratio, label_ = compute_label_purity_slot(adata_mvi, i, slot="label_transfer")
    adata_mvi.obs["subclass_purity_ratio"].loc[i] = ratio
    adata_mvi.obs["label_transfer"].loc[i] = label_


In [None]:
from joblib import parallel_backend

sc.settings.n_jobs=32
with parallel_backend('threading', n_jobs=32):
    sc.tl.leiden(adata_mvi, key_added = "leiden_1.0")

In [None]:
adata_mvi.obs.experiment_component_failed = adata_mvi.obs.experiment_component_failed.astype(str)

convert_columns = ['class',
 'neighborhood',
 'subclass',
 'subclass_color',
 'cluster',
 'cluster_color',
 'class_held',
 'subclass_held',
 'cluster_held',
 'supertype',
 'for_analysis']

adata_mvi.obs[convert_columns] = adata_mvi.obs[convert_columns].astype(str)
adata_mvi.write("multivi_AD_3cohorts_RNA_ATAC_Multiome_v2.0_annotation_updated.h5ad")

### Filter bad quality cells

In [None]:
bad_clusters = ['33', '22', '21', '6', '29', '10', '25', '36', '39', '47', '32']
cut_off = (~adata_mvi.obs['leiden_1.0'].isin(bad_clusters)) & (adata_mvi.obs["rna_neighbors_qc_ratio_new"].to_numpy() <= 0.2)
adata = adata_mvi[cut_off].copy()

adata.obs["subclass_purity_ratio_new"] = adata.obs["subclass_purity_ratio"]
adata.obs["label_transfer_new"] = adata.obs["label_transfer"]

for i in np.where(adata.obs["label_transfer_new"].isnull())[0]:
    ratio, label_ = compute_label_purity_slot(ad, i, slot="label_transfer_new")
    adata.obs["subclass_purity_ratio_new"].loc[i] = ratio
    adata.obs["label_transfer_new"].loc[i] = label_
    
adata.write("multivi_AD_3cohorts_RNA_ATAC_Multiome_v2.0_annotation_updated_refined_1.0.h5ad")