<!--  -->
# Preprocessing - Joint Embedding & Doublet Removal with MultiVI
adapted from Michael Sterr

2024-02-09 09:28:15 


# Setup

In [None]:
# General
import scipy as sci
import numpy as np
import pandas as pd
import logging
import time
import pickle
from itertools import chain
import session_info
import gc # Free memory #gc.collect()
import scipy.stats as stats

# Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import rcParams
from matplotlib.pyplot import rc_context
from matplotlib import cm
import seaborn as sb

# Analysis
import muon as mu
from muon import atac as ac # Import a module with ATAC-seq-related functions
import scanpy as sc
import anndata as ad

In [None]:
# Settings

import warnings
warnings.filterwarnings("ignore")

## Directory
base_dir = '/mnt/hdd/'
data_dir = 'data/Healthy/'
nb_dir = 'Notebooks/Gut_project/'
sc.settings.figdir = base_dir + nb_dir + 'Figures'
sc.settings.cachedir = base_dir + 'Cache'

## Scanpy settings
sc.settings.verbosity = 3
sc.logging.print_versions()
session_info.show()

In [None]:
%run utils.ipynb

In [None]:
mymap = load_RdOrYl_cmap_settings()

# Setup R

In [None]:
#R
import rpy2
import rpy2.robjects as ro
import rpy2.rinterface_lib.callbacks
from rpy2.robjects import pandas2ri
import anndata2ri
setup_R('/home/scanalysis/mnt/envs/scUV/lib/R')

In [None]:
%%R

.libPaths()

In [None]:
%%R
# Parallelization
library(BiocParallel)
register(MulticoreParam(20, progressbar = TRUE))

library(future)
plan("multicore", workers = 20)
options(future.globals.maxSize = 64 * 1024^2)
plan()

library(doParallel)
registerDoParallel(20)

sessionInfo()

# Load Data

## MuData

In [None]:
multiome_samples = ['597_NVF_Crypts_Rep1', '598_FVF_Crypts_Rep1','599_FVF_Crypts_Rep2','604_NVF_Crypts_Rep2', 'FVF-high','FVF-low']

In [None]:
atac = sc.read_h5ad('/mnt/hdd/data/Healthy/atac_normalized_concatednated.h5ad')

#### load adata

In [None]:
adata = sc.read_h5ad('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_woimmune.h5ad')

In [None]:
adata

In [None]:
# Clean up .obs
adata.obs = adata.obs.loc[:,['sample', 'n_counts', 'log_counts', 'n_counts_rank', 'n_genes', 'log_genes', 'mt_frac', 'rp_frac', 'ambi_frac', 'final_doublets', 'final_doublets_cat', 'doublet_calls', 'cells_remain', 'batch', 'leiden', 'size_factors', 'S_score', 'G2M_score', 'phase', 'proliferation','leiden_wnn', 'initial_cell_type', 'is_paneth']].copy()
# delete all uns/obsm/varm/layers/obsp/raw
del adata.uns
del adata.obsm
del adata.varm
del adata.obsp
del adata.raw
gc.collect()

In [None]:
## add metadata
metadata_df =read_excel_metadata(f'/mnt/hdd/data/metadata_mouse_gut.xlsx')
# Ensure folder name is the index in metadata for easier access
metadata_df.drop(metadata_df[metadata_df['kit'] == 'Multiome_ATAC_v1'].index, inplace=True)
metadata_df.drop(metadata_df[~metadata_df['condition'].isin(['Ctr','Ctr/WT'])].index, inplace=True)
metadata_df.set_index('folder name', inplace=True)
metadata_df.drop(['Sample Pooling - confounded with Project?','date','Project Name','Link_id','sample name','Cell Count [cells/µl]','Viable Cells [%]','Lib. Concentration [ng/µl]','Lib. Molarity [nM]','Average Lib. Size [bp]','cDNA Cycles','Lib. Cycles','10x Sample Index','Sequencing Depth [reads/cell]','MUC ID','exclusion, reason'], axis=1, inplace=True)

In [None]:
metadata_df

In [None]:
# Function to update adata.obs with metadata using a lambda function
for col in metadata_df.columns:
    try:
        adata.obs[col] = adata.obs['sample'].apply(lambda x: metadata_df.at[x, col])
    except KeyError as err:
        print(f'no such key: {err} in col {col}')

#### extract multiome normailzed gex data

In [None]:
adata_gex = adata[adata.obs['sample'].isin(multiome_samples)]

In [None]:
barcodes_all = [name.split('_')[0] for name in adata_gex.obs_names]

In [None]:
atac= atac[barcodes_all] #get rid of immune cells

In [None]:
# Get obs data from gex
atac.obs = adata_gex.obs.copy() # for concatenation

In [None]:
adata_gex.obs['sample'].value_counts()

In [None]:
adata_gex

#### clean up gex data from muliome samples

In [None]:
# Set raw count as X
adata_gex.X = adata_gex.layers['raw_counts'].copy()

# Set feature type
adata_gex.var['feature_types'] = 'Gene Expression'
adata_gex.var['modality'] = 'Gene Expression'


# Remove all unneccessary var data
adata_gex.var = adata_gex.var.loc[:,['feature_types','genome']].copy()

In [None]:
# Remove all unneccessary var data
adata_gex.var = adata_gex.var.loc[:,['feature_types','genome']].copy()

In [None]:
# delete all uns/obsm/varm/layers/obsp/raw
del adata_gex.layers
del adata_gex.raw
gc.collect()

In [None]:
adata_gex

In [None]:
gc.collect()

#### concat normalized gex and atac

In [None]:
atac

In [None]:
atac.var

In [None]:
atac.var['feature_types'] = 'Peaks'

In [None]:
# Filter variable peaks
atac = atac[:,atac.var['highly_variable'] == True]

In [None]:
# Joint data
adata_multi = adata_gex.copy().transpose().concatenate(atac.copy().transpose()).transpose()

# Add modality to .obs
adata_multi.obs['modality'] = 'Multiome'

# Fix var_names
adata_multi.var_names = list(adata_gex.var_names) + list(atac.var_names)

In [None]:
# Clean up .obs
adata_multi.obs = adata_multi.obs.loc[:,['sample', 'doublet_calls', 'final_doublets', 'final_doublets_cat', 'phase', 'proliferation', 'initial_cell_type','Project','sequencing','condition','kit','linienhintergrund','strain','enriched','enrichment proportion','diet','Index Type','sequencing machine']].copy()

In [None]:
adata_multi.var = adata_multi.var.iloc[:,0:2].copy()

In [None]:
adata_multi.var.feature_types.value_counts()

In [None]:
adata_multi.var

# Prepare for MultiVI

In [None]:
# extract non-multiome gex data
adata_rna = adata[~adata.obs['sample'].isin(multiome_samples)]
adata_rna.obs['modality'] = 'expression'

In [None]:
adata_rna

In [None]:
adata_rna.obs['sample'].value_counts()

In [None]:
# Set raw count as X
adata_rna.X = adata_rna.layers['raw_counts'].copy()

# Set feature type
adata_rna.var['feature_types'] = 'Gene Expression'
adata_rna.var['modality'] = 'Gene Expression'

# Remove all unneccessary var data
adata_rna.var = adata_rna.var.loc[:,['feature_types','genome']].copy()

In [None]:
# delete all uns/obsm/varm/layers/obsp/raw
del adata_rna.layers
del adata_rna.raw
gc.collect()

In [None]:
# Clean up .obs
adata_rna.obs = adata_rna.obs.loc[:,['sample', 'doublet_calls', 'final_doublets', 'final_doublets_cat', 'phase', 'proliferation', 'initial_cell_type','Project','sequencing','condition','kit','linienhintergrund','strain','enriched','enrichment proportion','diet','Index Type','sequencing machine']].copy()

#### from github HLCA: visualize main covariates

In [None]:
sc.tl.pca(adata)

n_pcs = 50

#specifiy covariates we want to check (we will quantify their correlation with the 1st 50 PCs, to see how much variance they can each explain):

covariates = [
    "sample",
'doublet_calls', 'final_doublets', 'final_doublets_cat', 'phase', 'proliferation', 'initial_cell_type','Project','sequencing','condition','kit','linienhintergrund','strain','enriched','enrichment proportion','diet','Index Type','sequencing machine'
]

Create shuffled assignment of single cell platform (and processing site if included), to compare actual variance explained to variance explained expected by random. We will assign all cells of the same sample to the same value.

In [None]:
include_processing_site =True

In [None]:
# create shuffled version of single cell platform, and of Processing_site:
if include_processing_site:
    sample_to_scplatform = adata.obs.groupby("sample").agg(
        {"Project": "first", "sequencing machine": "first"}
    )
else:
    sample_to_scplatform = adata.obs.groupby("sample").agg(
        {"Project": "first"}
    )
for i in range(10):
    np.random.shuffle(sample_to_scplatform.Project)
    adata.obs["Project_shuffled_" + str(i)] = adata.obs["sample"].map(
        dict(
            zip(
                sample_to_scplatform.index,
                sample_to_scplatform.Project,
            )
        )
    )
    covariates.append("Project_shuffled_" + str(i))
    if include_processing_site:
        np.random.shuffle(sample_to_scplatform['sequencing machine'])
        adata.obs["sequencing machine_shuffled_" + str(i)] = adata.obs["sample"].map(
            dict(zip(sample_to_scplatform.index, sample_to_scplatform['sequencing machine']))
        )
        covariates.append("sequencing machine_shuffled_" + str(i))

Now check for every covariate, for every PC how much variance among the cells' PC scores the covariate can explain. Add this variance explained per PC up across PCs for every covariate. This will give us the total amount of variance explained per covariate.

In [None]:
from sklearn.linear_model import LinearRegression

def check_if_nan(value):
    """return Boolean version of value that is True if value is
    some type of NaN (e.g. np.nan, None, "nan" etc). 
    Example use:
    none_entries = subadata.obs.applymap(check_if_nan)
    subadata.obs = subadata.obs.mask(none_entries.values)
    """
    if value == "nan":
        return True
    elif value == None:
        return True
    if isinstance(value, float):
        if np.isnan(value):
            return True
    if value == "ND":
        return True
    return False

In [None]:
var_explained = pd.DataFrame(index=range(n_pcs), columns=covariates + ["overall"])
for pc in range(n_pcs):
    y_true_unfiltered = adata.obsm["X_pca"][:, pc]
    var_explained.loc[pc, "overall"] = np.var(y_true_unfiltered)
    for cov in covariates:
        x = adata.obs[cov].values.copy()
        x_nans = np.vectorize(check_if_nan)(x)
        x = x[~x_nans]
        if len(x) != 0:
            y_true = y_true_unfiltered[~x_nans].reshape(-1, 1)
            if x.dtype in ["float32", "float", "float64"]:
                x = x.reshape(-1, 1)
            else:
                if len(set(x)) == 1:
                    var_explained.loc[pc, cov] = np.nan
                    continue
                x = pd.get_dummies(x)
            x.columns = x.columns.astype(str)
            lrf = LinearRegression(fit_intercept=True).fit(
                x,
                y_true,
            )
            y_pred = lrf.predict(x)
            var_explained.loc[pc, cov] = np.var(y_pred)
total_variance_explained = np.sum(var_explained, axis=0).sort_values(ascending=False)
total_variance_explained_fractions = (
    total_variance_explained / total_variance_explained["overall"]
)

Do the same for the shuffled covariates. Calculate mean over shuffling instances, add as one value to clean fractions:

In [None]:
total_variance_explained_clean = total_variance_explained_fractions[
    [
        x
        for x in total_variance_explained_fractions.index
        if not x.startswith("sequencing machine_shuffled")
        and not x.startswith("Project_shuffled")
    ]
]
total_variance_explained_clean["Project_shuffled"] = np.mean(
    total_variance_explained_fractions[
        [
            x
            for x in total_variance_explained_fractions.index
            if x.startswith("Project_")
        ]
    ]
)
stdev_Project_shuffled = np.std(
    total_variance_explained_fractions[
        [
            x
            for x in total_variance_explained_fractions.index
            if x.startswith("Project_")
        ]
    ]
)
if include_processing_site:
    total_variance_explained_clean["sequencing machine_shuffled"] = np.mean(
        total_variance_explained_fractions[
            [
                x
                for x in total_variance_explained_fractions.index
                if x.startswith("sequencing machine_shuffled")
            ]
        ]
    )
    stdev_processing_site_shuffled = np.std(
        total_variance_explained_fractions[
            [
                x
                for x in total_variance_explained_fractions.index
                if x.startswith("sequencing machine_shuffled")
            ]
        ]
    )


Sort results:

In [None]:
total_variance_explained_clean.sort_values(ascending=False, inplace=True)

Plot:

In [None]:
plt.figure(figsize=(8, 4))
plt.bar(
    total_variance_explained_clean[::-1].index,
    total_variance_explained_clean[::-1].values,
)
plt.title(
    f"covariate correlation with first 50 PCs",
    fontsize=14,
)  # \n({dominant_type})
plt.xticks(rotation=90)
plt.show()

In [None]:
del adata
gc.collect()

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
import scvi
gc.collect()

In [None]:
del adata_gex
del atac
gc.collect()

In [None]:
gc.collect()

In [None]:
adata_multi

In [None]:
adata_multi.X.shape

In [None]:
adata_rna

In [None]:
adata_mvi = scvi.data.organize_multiome_anndatas(adata_multi, adata_rna)

OR

In [None]:
# Order features, such that genes appear before genomic regions
adata_mvi = adata_mvi[:, adata_mvi.var["feature_types"].argsort()].copy()
adata_mvi.var

In [None]:
# filter features present in less than 1% of cells
print(adata_mvi.shape)
sc.pp.filter_genes(adata_mvi, min_cells=int(adata_mvi.shape[0] * 0.01))
print(adata_mvi.shape)

# Run MultiVI

n_hidden: 1024, layers: 4, inject: True

This model does a very good job in clustering doublets together.

## Setup Model

In [None]:
n_hidden=1024
n_latent=50
n_layers=4

batch_key = 'modality'
# labels_key = None

categorical_covariate_keys = ['sample','kit']
continuous_covariate_keys = None

deeply_inject_covariates = True

modality_weights = 'cell'
model_depth = True

In [None]:
scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key=batch_key, categorical_covariate_keys=categorical_covariate_keys, continuous_covariate_keys=continuous_covariate_keys)

In [None]:
model_mvi = scvi.model.MULTIVI(
    adata_mvi,
    n_genes=(adata_mvi.var['feature_types']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['feature_types']=='Peaks').sum(),
    n_hidden=n_hidden,
    n_latent=n_latent, 
    n_layers_encoder=n_layers,
    n_layers_decoder=n_layers,
    deeply_inject_covariates=deeply_inject_covariates,
    #modality_weights=modality_weights,
    #model_depth=model_depth,
    #gene_dispersion='gene-batch'
)
model_mvi.view_anndata_setup()

## Train

In [None]:
torch.cuda.empty_cache()

In [None]:
model_mvi.train()

In [None]:
# plot reconstruction loss
plt.plot(model_mvi.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
plt.plot(model_mvi.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

## Save

In [None]:
import datetime
file_path = '/mnt/hdd/data/Healthy'
directory_path = file_path + '/Models/'
file_base_name = 'healthy_atlas'
base_name = file_base_name + '_MultiVI-Doublet-Removal'
date = str(datetime.date.today()) + '_'

try:
    covarCat = '_covarCat' + ''.join(' '.join('_'.join(categorical_covariate_keys).split('_')).title().split(' '))
except:
    covarCat = '_covarCatNone'
    
try:
    covarCont = '_covarCont' + ''.join(' '.join('_'.join(continuous_covariate_keys).split('_')).title().split(' '))
except:
    covarCont = '_covarContNone' 

# try:
#     labels = '_labels' + ''.join(' '.join(''.join(labels_key).split('_')).title().split(' '))
# except:
#     labels = '_labelsNone'

deep = '_inject' + str(deeply_inject_covariates)
layers = '_layers' + str(n_layers)
hidden = '_hidden' + str(n_hidden)
latent = '_latent' + str(n_latent)

model_type = '_MultiVI'

model_path = ''.join([
    directory_path,
    date,
    base_name,
#     labels,
    covarCat,
    covarCont,
    deep,
    layers,
    hidden,
    latent,
    model_type
])

model_path

In [None]:
model_mvi.save(model_path, overwrite=True, save_anndata=True)

In [None]:
gc.collect()

## Results

In [None]:
adata_mvi.obsm['X_MultiVI'] = model_mvi.get_latent_representation(adata_mvi)

In [None]:
adata_mvi.write('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_MutliVI_meta.h5ad')

In [None]:
sc.pp.neighbors(adata_mvi, use_rep='X_MultiVI')

In [None]:
sc.tl.umap(adata_mvi, min_dist=0.3)

In [None]:
sc.tl.leiden(adata_mvi, resolution=1.5)

In [None]:
sc.pl.umap(adata_mvi, color=['modality','sample','leiden','initial_cell_type','doublet_calls','final_doublets_cat'], size=7, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=2, wspace=0.65, color_map=mymap)

# Remove Doublets

In [None]:
adata_mvi.obs['doublet_calls_cat'] = [str(x) for x in adata_mvi.obs['doublet_calls']]

In [None]:
adata_mvi.uns['doublet_calls_cat_colors'] = np.array([mpl.colors.to_hex(color, keep_alpha=True) for color in mymap(np.linspace(0,1,8))])

In [None]:
sc.pl.umap(adata_mvi, color=['doublet_calls_cat'], size=15, add_outline=True, alpha=1, outline_width=(0.3, 0.0), ncols=5, color_map=mymap)

In [None]:
plot_composition(adata_mvi, x_key='initial_cell_type', y_key='doublet_calls_cat', x_rotation=90, figsize=(8,4))

In [None]:
sc.tl.leiden(adata_mvi, resolution=2.5)

In [None]:
sc.pl.umap(adata_mvi, color=['leiden','doublet_calls','final_doublets_cat'], size=12, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=4, wspace=0.65,color_map=mymap)

In [None]:
pct_doublets = plot_composition(adata_mvi, x_key='leiden', y_key='doublet_calls_cat', x_rotation=90, figsize=(8,4))
pct_doublets

In [None]:
sc.tl.leiden(adata_mvi, restrict_to=('leiden', ['20','21']), resolution=0.5, key_added='leiden_sub1')

In [None]:
sc.pl.umap(adata_mvi, color=['leiden_sub1','final_doublets_cat','doublet_calls'], size=12, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=4, wspace=0.65,color_map=mymap)

In [None]:
sc.tl.leiden(adata_mvi, restrict_to=('leiden_sub1', ['20-21,0']), resolution=0.25, key_added='leiden_sub2')

In [None]:
sc.pl.umap(adata_mvi, color=['leiden_sub2','final_doublets_cat','doublet_calls'], size=12, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=4, wspace=0.65,color_map=mymap)

In [None]:
plot_composition(adata_mvi, x_key='leiden_sub2', y_key='doublet_calls_cat', x_rotation=90, figsize=(8,4))

In [None]:
sc.tl.leiden(adata_mvi, restrict_to=('leiden_sub2', ['20-21,0,2']), resolution=0.25, key_added='leiden_sub3')

In [None]:
plot_composition(adata_mvi, x_key='leiden_sub3', y_key='doublet_calls_cat', x_rotation=90, figsize=(8,4))

In [None]:
# remove doublet clusters
adata_mvi = adata_mvi[~adata_mvi.obs.leiden_sub3.isin(list(pct_doublets['x_labels'][pct_doublets['0'] < 4]))].copy()

In [None]:
sc.pp.neighbors(adata_mvi, use_rep="X_MultiVI", metric='correlation', n_pcs=50, n_neighbors=20)
sc.tl.umap(adata_mvi, min_dist=0.3)

In [None]:
sc.tl.leiden(adata_mvi, resolution=3)

In [None]:
sc.pl.umap(adata_mvi, color=['leiden','doublet_calls','final_doublets_cat'], size=12, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=4,wspace=0.55, color_map=mymap)

In [None]:
pct_doublets = plot_composition(adata_mvi, x_key='leiden', y_key='doublet_calls_cat', x_rotation=90, figsize=(8,4))
pct_doublets

In [None]:
# remove doublet clusters
adata_mvi = adata_mvi[~adata_mvi.obs.leiden.isin(list(pct_doublets['x_labels'][pct_doublets['0'] < 5]))].copy()

In [None]:
# remove cells with > 3 doublet calls
adata_mvi = adata_mvi[adata_mvi.obs.doublet_calls < 4 ].copy()

In [None]:
sc.pp.neighbors(adata_mvi, use_rep="X_MultiVI", n_pcs=50, n_neighbors=20)
sc.tl.umap(adata_mvi, min_dist=0.2, spread=0.8, negative_sample_rate=1, gamma=0.5)

In [None]:
sc.tl.leiden(adata_mvi, resolution=1)

In [None]:
sc.pl.umap(adata_mvi, color=['sample','leiden','initial_cell_type','doublet_calls','final_doublets_cat'], size=12, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=3, wspace=0.95, color_map=mymap)

In [None]:
sc.tl.paga(adata_mvi, groups='initial_cell_type')
sc.pl.paga(adata_mvi)

In [None]:
sc.tl.umap(adata_mvi, min_dist=0.2, spread=0.8, negative_sample_rate=1, gamma=1, init_pos='paga') #0.2,1,1,0.5

In [None]:
sc.pl.umap(adata_mvi, color=['sample','leiden','initial_cell_type','doublet_calls'], size=7, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=2, wspace=0.95, color_map=mymap)

# Update Original Files

In [None]:
adata_mvi.write('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_noimmune_multivi.h5ad')

In [None]:
gc.collect()

In [None]:
adata = sc.read_h5ad('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_woimmune.h5ad')


In [None]:
# Clean up .obs
adata.obs = adata.obs.loc[:,['sample', 'n_counts', 'log_counts', 'n_counts_rank', 'n_genes', 'log_genes', 'mt_frac', 'rp_frac', 'ambi_frac', 'is_paneth', 'doublet_calls', 'final_doublets', 'final_doublets_cat', 'phase', 'proliferation', 'initial_cell_type']].copy()
# delete all uns/obsm/varm/layers/obsp/raw
del adata.uns
del adata.obsm
del adata.varm
del adata.obsp
del adata.raw
gc.collect()

In [None]:
barcodes_multiome = ['_'.join(name.split('_')[0:2]) for name in adata_mvi[adata_mvi.obs.modality.isin(['paired'])].obs_names]
barcodes_gex = ['_'.join(name.split('_')[0:2]) for name in adata_mvi[~adata_mvi.obs.modality.isin(['paired'])].obs_names]
barcodes_all = ['_'.join(name.split('_')[0:2]) for name in adata_mvi.obs_names]

In [None]:
adata_multi = adata_multi[barcodes_multiome]
adata_rna = adata_rna[barcodes_gex]

In [None]:
adata = adata[barcodes_all]

In [None]:
adata.obsm['X_MultiVI_meta'] = adata_mvi.obsm['X_MultiVI'].copy()

In [None]:
del model_mvi
del adata_mvi
gc.collect()

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
adata.write('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_noimmune_multivi_orig_wodblts_meta.h5ad')

In [None]:
del adata
gc.collect()

# Run MultiVI without Doublets

n_hidden: 1024, layers: 4, inject: True


In [None]:
torch.cuda.empty_cache()

## Setup Data

In [None]:
import torch
import scvi

In [None]:
adata_mvi = scvi.data.organize_multiome_anndatas(adata_multi, adata_rna)

In [None]:
del adata_multi
del adata_rna
gc.collect()

In [None]:
# Order features, such that genes appear before genomic regions
adata_mvi = adata_mvi[:, adata_mvi.var["feature_types"].argsort()].copy()
adata_mvi.var

In [None]:
# filter features present in less than 1% of cells
print(adata_mvi.shape)
sc.pp.filter_genes(adata_mvi, min_cells=int(adata_mvi.shape[0] * 0.01))
print(adata_mvi.shape)

## Setup Model

In [None]:
n_hidden=1024
n_latent=50
n_layers=4

batch_key = 'modality'
# labels_key = None

categorical_covariate_keys = ['sample','kit']
continuous_covariate_keys = None

deeply_inject_covariates = True

modality_weights = 'cell'
model_depth = True

In [None]:
adata_mvi.obs = adata_mvi.obs.astype({'enrichment proportion': str})
scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key=batch_key, categorical_covariate_keys=categorical_covariate_keys, continuous_covariate_keys=continuous_covariate_keys)

In [None]:
model_mvi = scvi.model.MULTIVI(
    adata_mvi,
    n_genes=(adata_mvi.var['feature_types']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['feature_types']=='Peaks').sum(),
    n_hidden=n_hidden,
    n_latent=n_latent, 
    n_layers_encoder=n_layers,
    n_layers_decoder=n_layers,
    deeply_inject_covariates=deeply_inject_covariates,
    #modality_weights=modality_weights,
    #model_depth=model_depth,
    #gene_dispersion='gene-batch'
)
model_mvi.view_anndata_setup()

## Train

In [None]:
gc.collect()

In [None]:
torch.cuda.empty_cache()

In [None]:
torch.cuda.memory_allocated()

In [None]:
model_mvi.train()

In [None]:
# plot reconstruction loss
plt.plot(model_mvi.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
plt.plot(model_mvi.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.legend()

## Save

In [None]:
directory_path = file_path + '/Models/'
base_name = file_base_name + '_MultiVI-Integration_v6'
date = str(datetime.date.today()) + '_'

try:
    covarCat = '_covarCat' + ''.join(' '.join('_'.join(categorical_covariate_keys).split('_')).title().split(' '))
except:
    covarCat = '_covarCatNone'
    
try:
    covarCont = '_covarCont' + ''.join(' '.join('_'.join(continuous_covariate_keys).split('_')).title().split(' '))
except:
    covarCont = '_covarContNone' 

# try:
#     labels = '_labels' + ''.join(' '.join(''.join(labels_key).split('_')).title().split(' '))
# except:
#     labels = '_labelsNone'

deep = '_inject' + str(deeply_inject_covariates)
layers = '_layers' + str(n_layers)
hidden = '_hidden' + str(n_hidden)
latent = '_latent' + str(n_latent)

model_type = '_MultiVI'

model_path = ''.join([
    directory_path,
    date,
    base_name,
#     labels,
    covarCat,
    covarCont,
    deep,
    layers,
    hidden,
    latent,
    model_type
])

model_path

In [None]:
model_mvi.save(model_path, overwrite=True, save_anndata=True)

## Results

In [None]:
adata_mvi.obsm['X_MultiVI_rmDoublets_meta'] = model_mvi.get_latent_representation(adata_mvi)

In [None]:
del model_mvi
gc.collect()

In [None]:
sc.pp.neighbors(adata_mvi, use_rep="X_MultiVI_rmDoublets_meta", n_pcs=50, n_neighbors=20)
sc.tl.umap(adata_mvi, min_dist=0.2, spread=0.8, negative_sample_rate=1, gamma=0.5)

In [None]:
sc.tl.leiden(adata_mvi, resolution=1)

In [None]:
sc.pl.umap(adata_mvi, color=['sample','leiden','initial_cell_type','doublet_calls','final_doublets_cat'], size=7, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=2, wspace =0.65, color_map=mymap)

In [None]:
sc.tl.paga(adata_mvi, groups='initial_cell_type')
sc.pl.paga(adata_mvi)

In [None]:
sc.tl.umap(adata_mvi, min_dist=0.2, spread=0.8, negative_sample_rate=1, gamma=0.5, init_pos='paga')

In [None]:
sc.pl.umap(adata_mvi, color=['sample','leiden','initial_cell_type','doublet_calls','final_doublets_cat'], size=7, add_outline=True, alpha=0.7, outline_width=(0.3, 0.0), ncols=3,wspace=0.99, color_map=mymap)

# Update Original Files & Save

## Multiome

In [None]:
adata_mvi.write('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_noimmune_multivi_wodoublets.h5ad')

In [None]:
adata_mvi = sc.read_h5ad('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_noimmune_multivi_wodoublets.h5ad')

In [None]:
barcodes_all = ['_'.join(name.split('_')[0:2]) for name in adata_mvi.obs_names]

In [None]:
adata = sc.read_h5ad('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_woimmune.h5ad')

In [None]:
# Clean up .obs
adata.obs = adata.obs.loc[:,['sample', 'n_counts', 'log_counts', 'n_counts_rank', 'n_genes', 'log_genes', 'mt_frac', 'rp_frac', 'ambi_frac', 'is_paneth', 'doublet_calls', 'final_doublets', 'final_doublets_cat', 'phase', 'proliferation', 'initial_cell_type']].copy()
# delete all uns/obsm/varm/layers/obsp/raw
del adata.uns
del adata.obsm
del adata.varm
del adata.obsp
del adata.raw
gc.collect()

In [None]:
adata = adata[barcodes_all]

In [None]:
adata.obsm['X_MultiVI_rmDoublets_meta'] = adata_mvi.obsm['X_MultiVI_rmDoublets_meta']

In [None]:
del adata_mvi

In [None]:
adata.write('/mnt/hdd/data/Healthy/adata_markedDoublets_normalized_initialAnno_noimmune_multivi_orig_wodblts_2_meta.h5ad')

In [None]:
adata.obs['sample'].value_counts()