### Load required packages

In [None]:
import os
import warnings
import scvi
import anndata
import scipy
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
%matplotlib inline

sc.settings.n_jobs = 32
warnings.filterwarnings("ignore")

pwd = os.getcwd()


### Prepare singleome ATAC dataset 

In [None]:
mvi_anno = pd.read_csv("/allen/programs/celltypes/workgroups/hct/SEA-AD/Integration/multivi_AD_3cohorts_RNA_ATAC_Multiome_v2.0_annotation_updated_refined_1.0_obs_meta.csv")
df = mvi_anno[['sample_id', 'path_to_fragments', 'donor_name', 'library_prep', 'modality']]
df['ar_id'] = df.sample_id.str.split("-").str[-1]
df.drop(columns=['sample_id'], inplace=True)
df.drop_duplicates(inplace=True)
df = df[df.modality == 'accessibility']
mapping = pd.read_csv("/allen/programs/celltypes/workgroups/hct/SEA-AD/Integration/subclasses_files/filter_files/mapping_filter_atac_libraries.csv")
df_new = pd.merge(df, mapping, how='left', left_on='path_to_fragments', right_on='1')
df_new = df_new.drop(columns=['Unnamed: 0', '1'])
df_new['prefix_frags'] = df_new['library_prep'] + '-' + df_new['ar_id']
df_new['frags'] = 'fragments.tsv.gz_filter' + df_new['0'].astype(str) + 'Sst.tsv'
name_change = pd.Series(df_new.prefix_frags.values,index=df_new.frags).to_dict()


In [None]:
# Preprocess and load the ATAC datasets from singleome (file processed by Mariano).
def create_atac_anndata(base_path, mtx_ext, peak_ext, barcode_ext, name_change_dict, dict_key):
    adata = anndata.read_mtx(os.path.join(base_path, dict_key + mtx_ext))
    coords = pd.read_csv(
        os.path.join(base_path, dict_key + peak_ext),
        sep="\t",
        header=None,
        index_col=None,
    )
    coords.rename({0: "chr", 1: "start", 2: "end"}, axis="columns", inplace=True)
    coords.set_index(
        coords.chr.astype(str)
        + ":"
        + coords.start.astype(str)
        + "-"
        + coords.end.astype(str),
        inplace=True,
    )
    coords.index = coords.index.astype(str)
    
    cell_annot = pd.read_csv(
        os.path.join(base_path, dict_key + barcode_ext), 
        sep="-", 
        header=None, 
        index_col=None
    )
    cell_annot.rename({0: "barcode", 1: "batch_id"}, axis="columns", inplace=True)
    
    ## pay atention to changes below.
    cell_annot["library_prep"] = name_change_dict[dict_key].split("-")[0]
    cell_annot["sample_id"] = cell_annot["barcode"] + "-" + name_change_dict[dict_key]
    cell_annot["barcode"] = cell_annot["barcode"] + "-" + cell_annot["library_prep"]
    cell_annot.set_index("barcode", inplace=True)
    cell_annot.index = cell_annot.index.astype(str)
    
    adata.obs = cell_annot
    # adata.obs["modality"] = "accessibility"
    adata.var = coords
    adata.var["modality"] = "Peaks"
    return adata.copy()

In [None]:
# Preprocess and merge multiome-rna and multiome-atac datasets.
base_path = "/allen/programs/celltypes/workgroups/hct/SEA-AD/Integration/subclasses_files/filter_files/Sst/"
mtx_ext = "_SST_ADNC.concat.merged_counts.mtx"
peak_ext = "_SST_ADNC.concat.merged_peaks.bed"
barcode_ext = "_SST_ADNC.concat.merged_barcodes.tsv"

In [None]:
## https://stackoverflow.com/questions/2225564/get-a-filtered-list-of-files-in-a-directory
included_extensions = ["_SST_ADNC.concat.merged_barcodes.tsv"]
file_names = [fn for fn in os.listdir(base_path)
              if any(fn.endswith(ext) for ext in included_extensions)]

In [None]:
adata_atac_list = []

for fl in file_names:
    adata_atac = create_atac_anndata(base_path, mtx_ext, peak_ext, barcode_ext, name_change, fl.split("_SST_")[0])
    adata_atac_list.append(adata_atac)

adata_atac_only = anndata.concat(adata_atac_list, merge = "same")

In [None]:
adata_atac_only.write(filename="atac_Sst_12.08.2022.h5ad")

In [None]:
## add filter steps here to reduce memory usage.
# We also filter features to remove those that appear in fewer than 1% of the cells
# This instruction removes genes or peaks that are not expressed/accessible in more than 1% of the cells.
print(adata_atac_only.shape)
min_cells = int(adata_atac_only.shape[0] * 0.01)
sc.pp.filter_genes(adata_atac_only, min_cells=min_cells)
print(adata_atac_only.shape)

peaks = adata_atac_only.var_names

## Prepare singleome RNA datasets

In [None]:
adata_rna_only = anndata.read_h5ad(filename='/allen/programs/celltypes/workgroups/hct/SEA-AD/Integration/multivi_subclasses/singleomeRNA_MTG_Full/Sst.h5ad')

min_cells = int(adata_rna_only.shape[0] * 0.01)
sc.pp.filter_genes(adata_rna_only, min_cells=min_cells)

genes = adata_rna_only.var_names
features = genes.union(peaks)

adata_rna_only.write(filename='rna_Sst_all_02.23.2023.h5ad')

In [None]:
adata_mult = anndata.read_h5ad(filename='mult_Sst_12.08.2022.h5ad')
adata_mult = adata_mult[:, features].copy()

### Prepare MVI dataset and run MultiVI

In [None]:
mvi_anno = pd.read_csv("/allen/programs/celltypes/workgroups/hct/SEA-AD/Integration/multivi_AD_3cohorts_RNA_ATAC_Multiome_v2.0_annotation_updated_refined_1.0_obs_meta.csv")
mvi_anno = mvi_anno[(mvi_anno['modality'] == 'accessibility') & (mvi_anno['label_transfer'] == 'Sst')]
mvi_anno = mvi_anno[['sample_id', 'age', 'sex', 'medical_conditions', 'donor_name', 'method']]
mvi_anno['sex'] = np.where(mvi_anno['sex'] == 'Male', 'M', 'F')
mvi_anno['sex'] = mvi_anno['sex'].astype('category')

In [None]:
df = adata_atac_only.obs.copy()
df.reset_index(inplace=True)
df_new = pd.merge(df, mvi_anno, how='left', on='sample_id')

In [None]:
adata_atac_only.obs = df_new
adata_atac_only.obs.set_index("barcode", inplace=True)

In [None]:
adata_rna_only.var["modality"] = "Gene Expression"

In [None]:
del adata_mult.obsm['_scvi_extra_categoricals']
del adata_mult.obsm['_scvi_extra_continuous']
del adata_rna_only.obsm['_scvi_extra_categoricals']
del adata_rna_only.obsm['_scvi_extra_continuous']

In [None]:
# We can now use the organizing method from scvi to concatenate these anndata
adata_mvi = scvi.data.organize_multiome_anndatas(adata_mult, adata_rna_only, adata_atac_only)

In [None]:
adata_mvi = adata_mvi[:, adata_mvi.var["modality"].argsort()].copy()

In [None]:
scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key='modality', categorical_covariate_keys=["donor_name", "sex"])

mvi = scvi.model.MULTIVI(
    adata_mvi,
    n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
)

In [None]:
mvi.train()

In [None]:
adata_mvi.obsm["X_MultiVI"] = mvi.get_latent_representation()
sc.pp.neighbors(adata_mvi, use_rep="X_MultiVI")
sc.tl.umap(adata_mvi, min_dist=0.2)

In [None]:
plt.rcParams["figure.figsize"] = (5, 5)
sc.pl.umap(adata_mvi, color='modality')

In [None]:
sc.pl.umap(adata_mvi, color='sex')

In [None]:
sc.pl.umap(adata_mvi, color='donor_name')

In [None]:
sc.pl.umap(adata_mvi, color='supertype_scANVI')

In [None]:
def compute_label_purity_slot(adata_mvi, cell_idx, slot="subclass_scANVI"):
    idx = np.where(adata_mvi.uns['neighbors']['connectivities'][cell_idx].todense()>0)[1]
    df = adata_mvi.obs[[slot]].iloc[idx]
    df = df.loc[~df[slot].isnull()]
    u, c = np.unique(df, return_counts=True)
    if np.size(c) == 0:
        ratio = 0
        label = np.nan
    else:
        ratio = c[np.argmax(c)] / c.sum()
        label = u[np.argmax(c)]
    
    return ratio, label

In [None]:
cells = adata_mvi.shape[0]
purity_ratios = np.zeros(cells)
labels = adata_mvi.obs["supertype_scANVI"].to_numpy()
for i in np.arange(cells):
    ratio, label_ = compute_label_purity_slot(adata_mvi, i, slot="supertype_scANVI")
    purity_ratios[i] = ratio
    labels[i] = label_

adata_mvi.obs["supertype_scANVI_purity_ratio"] = purity_ratios
adata_mvi.obs["supertype_scANVI_label_transfer"] = labels

In [None]:
for i in np.where(adata_mvi.obs["supertype_scANVI_label_transfer"].isnull())[0]:
    ratio, label_ = compute_label_purity_slot(adata_mvi, i, slot="supertype_scANVI_label_transfer")
    adata_mvi.obs["supertype_scANVI_purity_ratio"].loc[i] = ratio
    adata_mvi.obs["supertype_scANVI_label_transfer"].loc[i] = label_

In [None]:
print(adata_mvi.obs["supertype_scANVI_label_transfer"].isnull().sum())

In [None]:
cells = adata_mvi.shape[0]
purity_ratios = np.zeros(cells)
labels = adata_mvi.obs["supertype_scANVI_leiden"].to_numpy()
for i in np.arange(cells):
    ratio, label_ = compute_label_purity_slot(adata_mvi, i, slot="supertype_scANVI_leiden")
    purity_ratios[i] = ratio
    labels[i] = label_

adata_mvi.obs["supertype_scANVI_leiden_purity_ratio"] = purity_ratios
adata_mvi.obs["supertype_scANVI_leiden_label_transfer"] = labels

In [None]:
for i in np.where(adata_mvi.obs["supertype_scANVI_leiden_label_transfer"].isnull())[0]:
    ratio, label_ = compute_label_purity_slot(adata_mvi, i, slot="supertype_scANVI_leiden_label_transfer")
    adata_mvi.obs["supertype_scANVI_leiden_purity_ratio"].loc[i] = ratio
    adata_mvi.obs["supertype_scANVI_leiden_label_transfer"].loc[i] = label_

In [None]:
print(adata_mvi.obs["supertype_scANVI_leiden_label_transfer"].isnull().sum())

In [None]:
sc.pl.umap(adata_mvi, color='supertype_scANVI_leiden_label_transfer', legend_loc="on data")

In [None]:
from joblib import parallel_backend
sc.settings.n_jobs=32
with parallel_backend('threading', n_jobs=32):
    sc.tl.leiden(adata_mvi, key_added = "leiden_1.0") # default resolution in 1.0


In [None]:
sc.pl.umap(adata_mvi, color=['leiden_1.0'], legend_loc='on data') ## this is default resolution

In [None]:
total_peaks = np.asarray(np.sum(adata_mvi[:, adata_mvi.var.modality == 'Peaks'].X, axis=1)).reshape(-1)
adata_mvi.obs['total_peak_count'] = total_peaks

In [None]:
dtypes = adata_mvi.obs.dtypes
non_string_cols = dtypes[dtypes == 'object'].index.tolist()
adata_mvi.obs[non_string_cols] = adata_mvi.obs[non_string_cols].astype(str)

In [None]:
adata_mvi.write("multivi_AD_Sst_02.23.23.h5ad")

In [None]:
mvi.save("trained_multivi_AD_Sst_02.23.23", save_anndata=True)