### Load required packages

In [None]:
import os
import warnings
import scvi
import anndata
import scipy
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
%matplotlib inline

sc.settings.n_jobs = 32
warnings.filterwarnings("ignore")

pwd = os.getcwd()

### Load singleome ATAC dataset and filter peak features

In [None]:
adata_atac_only = anndata.read_h5ad(filename="atac_Pvalb_12.06.2022.h5ad")

In [None]:
min_cells = int(adata_atac_only.shape[0] * 0.01)
sc.pp.filter_genes(adata_atac_only, min_cells=min_cells)

In [None]:
peaks = adata_atac_only.var_names

### Prepare singleome RNA dataset

In [None]:
adata_rna_only = anndata.read_h5ad(filename='/allen/programs/celltypes/workgroups/hct/SEA-AD/Integration/multivi_subclasses/singleomeRNA_MTG_Full/Pvalb.h5ad')

In [None]:
min_cells = int(adata_rna_only.shape[0] * 0.01)
sc.pp.filter_genes(adata_rna_only, min_cells=min_cells)

In [None]:
genes = adata_rna_only.var_names
features = genes.union(peaks)

In [None]:
adata_rna_only.write(filename='rna_Pvalb_all_02.23.2023.h5ad')

## Prepare multiome dataset and filter features based on singleome ATAC and RNA

In [None]:
adata_mult = anndata.read_h5ad(filename='mult_Pvalb_12.06.2022.h5ad')
adata_mult = adata_mult[:, features].copy()

### Prepare MVI dataset and run MultiVI

In [None]:
mvi_anno = pd.read_csv("/allen/programs/celltypes/workgroups/hct/SEA-AD/Integration/multivi_AD_3cohorts_RNA_ATAC_Multiome_v2.0_annotation_updated_refined_1.0_obs_meta.csv")
mvi_anno = mvi_anno[(mvi_anno['modality'] == 'accessibility') & (mvi_anno['label_transfer'] == 'Pvalb')]
mvi_anno = mvi_anno[['sample_id', 'age', 'sex', 'medical_conditions', 'donor_name', 'method']]
mvi_anno['sex'] = np.where(mvi_anno['sex'] == 'Male', 'M', 'F')
mvi_anno['sex'] = mvi_anno['sex'].astype('category')

In [None]:
df = adata_atac_only.obs.copy()
df.reset_index(inplace=True)
df_new = pd.merge(df, mvi_anno, how='left', on='sample_id')

In [None]:
adata_atac_only.obs = df_new
adata_atac_only.obs.set_index("barcode", inplace=True)

In [None]:
adata_rna_only.var["modality"] = "Gene Expression"

In [None]:
del adata_mult.obsm['_scvi_extra_categoricals']
del adata_mult.obsm['_scvi_extra_continuous']
del adata_rna_only.obsm['_scvi_extra_categoricals']
del adata_rna_only.obsm['_scvi_extra_continuous']

In [None]:
# We can now use the organizing method from scvi to concatenate these anndata
adata_mvi = scvi.data.organize_multiome_anndatas(adata_mult, adata_rna_only, adata_atac_only)

In [None]:
adata_mvi = adata_mvi[:, adata_mvi.var["modality"].argsort()].copy()

In [None]:
scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key='modality', categorical_covariate_keys=["donor_name", "sex"])

mvi = scvi.model.MULTIVI(
    adata_mvi,
    n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
)

In [None]:
mvi.train()

In [None]:
adata_mvi.obsm["X_MultiVI"] = mvi.get_latent_representation()
sc.pp.neighbors(adata_mvi, use_rep="X_MultiVI")
sc.tl.umap(adata_mvi, min_dist=0.2)

In [None]:
plt.rcParams["figure.figsize"] = (5, 5)
sc.pl.umap(adata_mvi, color='modality')

In [None]:
sc.pl.umap(adata_mvi, color='sex')

In [None]:
sc.pl.umap(adata_mvi, color='donor_name')

In [None]:
def compute_label_purity_slot(adata_mvi, cell_idx, slot="subclass_scANVI"):
    idx = np.where(adata_mvi.uns['neighbors']['connectivities'][cell_idx].todense()>0)[1]
    df = adata_mvi.obs[[slot]].iloc[idx]
    df = df.loc[~df[slot].isnull()]
    u, c = np.unique(df, return_counts=True)
    if np.size(c) == 0:
        ratio = 0
        label = np.nan
    else:
        ratio = c[np.argmax(c)] / c.sum()
        label = u[np.argmax(c)]
    
    return ratio, label

In [None]:
cells = adata_mvi.shape[0]
purity_ratios = np.zeros(cells)
labels = adata_mvi.obs["supertype_scANVI"].to_numpy()
for i in np.arange(cells):
    ratio, label_ = compute_label_purity_slot(adata_mvi, i, slot="supertype_scANVI")
    purity_ratios[i] = ratio
    labels[i] = label_

adata_mvi.obs["supertype_scANVI_purity_ratio"] = purity_ratios
adata_mvi.obs["supertype_scANVI_label_transfer"] = labels

In [None]:
for i in np.where(adata_mvi.obs["supertype_scANVI_label_transfer"].isnull())[0]:
    ratio, label_ = compute_label_purity_slot(adata_mvi, i, slot="supertype_scANVI_label_transfer")
    adata_mvi.obs["supertype_scANVI_purity_ratio"].loc[i] = ratio
    adata_mvi.obs["supertype_scANVI_label_transfer"].loc[i] = label_

In [None]:
print(adata_mvi.obs["supertype_scANVI_label_transfer"].isnull().sum())

In [None]:
cells = adata_mvi.shape[0]
purity_ratios = np.zeros(cells)
labels = adata_mvi.obs["supertype_scANVI_leiden"].to_numpy()
for i in np.arange(cells):
    ratio, label_ = compute_label_purity_slot(adata_mvi, i, slot="supertype_scANVI_leiden")
    purity_ratios[i] = ratio
    labels[i] = label_

adata_mvi.obs["supertype_scANVI_leiden_purity_ratio"] = purity_ratios
adata_mvi.obs["supertype_scANVI_leiden_label_transfer"] = labels

In [None]:
for i in np.where(adata_mvi.obs["supertype_scANVI_leiden_label_transfer"].isnull())[0]:
    ratio, label_ = compute_label_purity_slot(adata_mvi, i, slot="supertype_scANVI_leiden_label_transfer")
    adata_mvi.obs["supertype_scANVI_leiden_purity_ratio"].loc[i] = ratio
    adata_mvi.obs["supertype_scANVI_leiden_label_transfer"].loc[i] = label_

In [None]:
print(adata_mvi.obs["supertype_scANVI_leiden_label_transfer"].isnull().sum())

In [None]:
plt.rcParams["figure.figsize"] = (6, 6)
sc.pl.umap(adata_mvi, color='supertype_scANVI_leiden_label_transfer', legend_loc="on data")

In [None]:
sc.pl.umap(adata_mvi, color='supertype_scANVI_leiden_label_transfer')

In [None]:
from joblib import parallel_backend
sc.settings.n_jobs=32
with parallel_backend('threading', n_jobs=32):
    sc.tl.leiden(adata_mvi, key_added = "leiden_1.0") # default resolution in 1.0


In [None]:
sc.pl.umap(adata_mvi, color=['leiden_1.0'], legend_loc='on data') ## this is default resolution

In [None]:
total_peaks = np.asarray(np.sum(adata_mvi[:, adata_mvi.var.modality == 'Peaks'].X, axis=1)).reshape(-1)

In [None]:
adata_mvi.obs['total_peak_count'] = total_peaks

In [None]:
dtypes = adata_mvi.obs.dtypes
non_string_cols = dtypes[dtypes == 'object'].index.tolist()
adata_mvi.obs[non_string_cols] = adata_mvi.obs[non_string_cols].astype(str)

In [None]:
adata_mvi.write("multivi_AD_Pvalb_02.23.23.h5ad")

In [None]:
mvi.save("trained_multivi_AD_Pvalb_02.23.23", save_anndata=True)