### Load required packages

In [None]:
import os
import warnings
import scvi
import anndata
import scipy
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from joblib import parallel_backend
%matplotlib inline
from helper_functions import *

sc.settings.n_jobs = 32
warnings.filterwarnings("ignore")

pwd = os.getcwd()

### Load and prepare data

In [None]:
adata_rna = sc.read_h5ad(os.path.join("input", "MTG_RNAseq", "final.2022-04-14.h5ad")
adata_rna.obs.groupby(["studies", "method"]).size()

cohorts = ["AD_Center_Grant, AD_Mult_Cohort", "AD_Center_Grant, AD_Mult_Cohort2", "AD_Center_Grant, AD_Mult_Cohort3", "AD_Mult_Cohort2, AD_Center_Grant", "AD_Mult_Cohort3, AD_Center_Grant", "AD_Mult_Cohort, AD_Center_Grant"]
adata_rna_sub = adata_rna[[i in cohorts for i in adata_rna.obs["studies"]]].copy()

adata_mult_rna = adata_rna_sub[[i in ["10xMulti"] for i in adata_rna_sub.obs["method"]]].copy()

adata_mult_rna.var["modality"] = "Gene Expression"

In [None]:
library_prep_ar_id = adata_mult_rna.obs['library_prep'].astype(str) + "-" + adata_mult_rna.obs['ar_id'].astype(str)
library_prep_ar_id = library_prep_ar_id.unique()

base_path = "/allen/programs/celltypes/workgroups/hct/SEA-AD/ATACseq/Human_AD_MTG_peakCounts/multiome_ATAC/"
mtx_ext = ".multiome.atac.tsv_grps13-6.merged_counts.mtx"
peak_ext = ".multiome.atac.tsv_grps13-6.merged_peaks.bed"
barcode_ext = ".multiome.atac.tsv_grps13-6.merged_barcodes.tsv"
adata_mult_atac_list = []

for lib in library_prep_ar_id:
    adata_atac = create_atac_anndata(base_path, lib, mtx_ext, peak_ext, barcode_ext)
    adata_mult_atac_list.append(adata_atac)

adata_mult_atac = anndata.concat(adata_mult_atac_list, merge = "same")


In [None]:
cells = np.intersect1d(adata_mult_rna.obs.index, adata_mult_atac.obs.index, assume_unique=True)

adata_mult_atac = adata_mult_atac[cells, :].copy()
adata_mult_rna = adata_mult_rna[cells, :].copy()

adata_mult_rna = adata_mult_rna[adata_mult_rna.obs.index.argsort(), :].copy()
adata_mult_atac = adata_mult_atac[adata_mult_atac.obs.index.argsort(), :].copy()

adata_mult = anndata.concat([adata_mult_rna, adata_mult_atac], axis=1, join="inner", merge="unique")


### Prepare singleome ATAC anndata

In [None]:
base_path = "/allen/programs/celltypes/workgroups/hct/SEA-AD/ATACseq/Human_AD_MTG_peakCounts/singleome_ATAC/"
mtx_ext = ".tsv_grps13-6.merged_counts.mtx"
peak_ext = ".tsv_grps13-6.merged_peaks.bed"
barcode_ext = ".tsv_grps13-6.merged_barcodes.tsv"
adata_atac_list = []

file_names = [fn for fn in os.listdir(base_path) if fn.endswith(mtx_ext)]
print(file_names[:3])
print("there are " + str(len(file_names)) + " libraries in this folder")
library_prep_ar_id = [lib.split(mtx_ext)[0] for lib in file_names]

for lib in library_prep_ar_id:
    adata_atac = create_atac_anndata(base_path, lib, mtx_ext, peak_ext, barcode_ext)
    adata_atac_list.append(adata_atac)

adata_atac_only = anndata.concat(adata_atac_list, merge = "same")


### Prepare MVI datasets and run MultiVI

In [None]:
atac_only_anno = pd.read_csv("/allen/programs/celltypes/workgroups/rnaseqanalysis/ATAC-Seq/Arrow/ATAC_AD_Center_Grant_complete/ATAC_AD_Center_Grant_complete.csv")
atac_only_anno = atac_only_anno.loc[:, ['age', 'sex', 'medical_conditions', 'donor_name', 'method', 'library_prep']]
atac_only_anno['sex'] = np.where(atac_only_anno['sex'] == 'M', 'Male', 'Female')
atac_only_anno['sex'] = atac_only_anno['sex'].astype('category')

df = adata_atac_only.obs
df.reset_index(inplace=True)
df_new = pd.merge(df, atac_only_anno, how='left', on='library_prep')
adata_atac_only.obs = df_new
adata_atac_only.obs.set_index("barcode", inplace=True)

min_cells = int(adata_atac_only.shape[0] * 0.01)
sc.pp.filter_genes(adata_atac_only, min_cells=min_cells)

peaks = adata_atac_only.var_names

adata_rna_only = anndata.read_h5ad(filename="AD_cohort2_RNA_only_anndata.h5ad")
min_cells = int(adata_rna_only.shape[0] * 0.01)
sc.pp.filter_genes(adata_rna_only, min_cells=min_cells)
genes = adata_rna_only.var_names
features = genes.union(peaks)
adata_mult = adata_mult[:, features].copy()

adata_mvi = scvi.data.organize_multiome_anndatas(adata_mult, adata_rna_only, adata_atac_only)

adata_mvi = adata_mvi[:, adata_mvi.var["modality"].argsort()].copy()
sc.pp.filter_cells(adata_mvi, min_genes=1000)
adata_mvi.write(filename='multivi_AD_3cohorts_RNA_ATAC_Multiome_v1.2_anndata.h5ad')

### Run MultiVI

In [None]:
scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key='modality', categorical_covariate_keys=["donor_name", "sex"])

mvi = scvi.model.MULTIVI(
    adata_mvi,
    n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
)

mvi.train()
mvi.save("trained_multivi_AD_3cohorts_RNA_ATAC_Multiome_v1.2", save_anndata=True)

### Reload trained data and construct latent space

In [None]:
adata_mvi = anndata.read_h5ad(filename='multivi_AD_3cohorts_RNA_ATAC_Multiome_v1.2_anndata.h5ad')
mvi = scvi.model.MULTIVI.load("trained_multivi_AD_3cohorts_RNA_ATAC_Multiome_v1.2", adata=adata_mvi)

cds1 = mvi.get_latent_representation(adata_mvi[0:100000])
left = 0
step = 100000
cds_ls = []
while left < adata_mvi.shape[0]:
    right = min(left + step, adata_mvi.shape[0])
    print(left, right)
    cds_ls.append(mvi.get_latent_representation(adata_mvi[left:right]))
    left += step
    
adata_mvi.obsm["X_MultiVI"] = np.concatenate(tuple(cds_ls), axis=0)
sc.pp.neighbors(adata_mvi, use_rep="X_MultiVI")

sc.settings.n_jobs=32
with parallel_backend('threading', n_jobs=32):
    sc.tl.umap(adata_mvi, min_dist=0.2)

adata_mvi.write("multivi_AD_3cohorts_RNA_ATAC_Multiome_v1.2_trained.h5ad")