In [None]:
import mira
import anndata
import scanpy as sc
import optuna
mira.utils.pretty_sderr()

In [None]:
# check to make sure we have GPU availability :) 
import torch
torch.cuda.is_available()

In [None]:
data = anndata.read_h5ad('multiome.h5ad')

In [None]:
sc.pp.filter_genes(data, min_cells=15)
data.raw = data

In [None]:
sc.pp.normalize_total(data, target_sum=1e4)
sc.pp.log1p(data)

In [None]:
sc.pp.highly_variable_genes(data, min_disp=0.2)
data.var['exog'] = data.var.highly_variable.copy()

In [None]:
data.var['endog'] = data.var.exog & (data.var.dispersions_norm > 0.7)

In [None]:
data.layers['counts'] = data.raw.to_adata().X.copy()

## Train and tune RNA topic model

In [None]:
model = mira.topics.ExpressionTopicModel(
    endogenous_key = 'endog',
    exogenous_key = 'exog',
    counts_layer = 'counts',
    seed = 42,
    hidden = 128)

In [None]:
model.get_learning_rate_bounds(data, eval_every = 1, upper_bound_lr = 5)

In [None]:
model.trim_learning_rate_bounds(2.5, 1)
_ = model.plot_learning_rate_bounds()

In [None]:
tuner = mira.topics.TopicModelTuner(
    model,
    save_name = 'E4AD_1yr_tuner',
    seed = 42,
    iters = 64,
    max_topics = 55)

In [None]:
tuner.train_test_split(data)

In [None]:
tuner.tune(data, n_workers = 1)

In [None]:
!mkdir -p data
model.save("data/rna_topic_model.pth")

# Trying now with separated RNA and ATAC h5 files

In [None]:
data_RNA = anndata.read_h5ad('multiome_RNA.h5ad')

In [None]:
sc.pp.filter_genes(data_RNA, min_cells=15)
data_RNA.raw = data_RNA

In [None]:
sc.pp.normalize_total(data_RNA, target_sum=1e4)
sc.pp.log1p(data_RNA)

In [None]:
sc.pp.highly_variable_genes(data_RNA, min_disp=0.2)
data_RNA.var['exog'] = data_RNA.var.highly_variable.copy()

In [None]:
data_RNA.var['endog'] = data_RNA.var.exog & (data_RNA.var.dispersions_norm > 0.7)

In [None]:
data_RNA.layers['counts'] = data_RNA.raw.to_adata().X.copy()

### Train model

In [None]:
model_RNA = mira.topics.ExpressionTopicModel(
    endogenous_key = 'endog',
    exogenous_key = 'exog',
    counts_layer = 'counts',
    seed = 42,
    hidden = 128)

In [None]:
model_RNA.get_learning_rate_bounds(data_RNA, eval_every = 1, upper_bound_lr = 5)

In [None]:
model_RNA.trim_learning_rate_bounds(7.5, 0.5)
_ = model_RNA.plot_learning_rate_bounds()

In [None]:
tuner_RNA = mira.topics.TopicModelTuner(
    model_RNA,
    save_name = 'E4AD_1yr_tuner_RNA',
    seed = 42,
    iters = 64,
    max_topics = 55)

In [None]:
tuner_RNA.train_test_split(data_RNA)

To view training on Tensorboard, ssh into the cluster with ssh -L 16006:127.0.0.1:6006 amillet@login04-hpc.rockefeller.edu. Then conda activate tensorboard followed by tensorboard serve --logdir /path/to/MIRA/runs/file. Open http://127.0.0.1:16006/ in browser to view.

In [None]:
tuner_RNA.tune(data_RNA, n_workers = 1)

In [None]:
tuner_RNA.select_best_model(data_RNA, record_umaps=True)

In [None]:
model_RNA.save('data/topic_model_rna_h5.pth')

In [None]:
model_RNA.predict(data_RNA)
model_RNA.get_umap_features(data_RNA, box_cox=0.5)
sc.pp.neighbors(data_RNA, use_rep = 'X_umap_features', metric = 'manhattan')
sc.tl.umap(data_RNA, min_dist=0.1, negative_sample_rate=0.05)

sc.pl.umap(data_RNA, frameon=False, size = 1200, alpha = 0.5, add_outline=True,
          outline_width=(0.1,0))

In [None]:
study_RNA = mira.topics.TopicModelTuner.load_study('E4AD_1yr_tuner_RNA')

In [None]:
optuna.visualization.plot_optimization_history(study_RNA)

In [None]:
optuna.visualization.plot_parallel_coordinate(study_RNA)

# Now with ATAC.

In [None]:
data_ATAC = anndata.read_h5ad('multiome_ATAC.h5ad')

In [None]:
sc.pp.filter_genes(data_ATAC, min_cells=15)
data_ATAC.raw = data_ATAC

In [None]:
sc.pp.normalize_total(data_ATAC, target_sum=1e4)
sc.pp.log1p(data_ATAC)

In [None]:
data_ATAC.layers['counts'] = data_ATAC.raw.to_adata().X.copy()

### Training model:

In [None]:
model_ATAC = mira.topics.AccessibilityTopicModel(counts_layer='counts',
                                                 seed = 42,
                                                 dataset_loader_workers = 3)

In [None]:
model_ATAC.get_learning_rate_bounds(data_ATAC, eval_every=1, upper_bound_lr=5)

In [None]:
model_ATAC.trim_learning_rate_bounds(5, 0.5)
_ = model_ATAC.plot_learning_rate_bounds()

In [None]:
tuner_ATAC = mira.topics.TopicModelTuner(
    model_ATAC,
    save_name = 'E4AD_1yr_tuner_ATAC',
    seed = 42,
    iters = 64,
    max_topics = 55)

In [None]:
tuner_ATAC.train_test_split(data_ATAC)

To view training on Tensorboard, ssh into the cluster with ssh -L 16006:127.0.0.1:6006 amillet@login04-hpc.rockefeller.edu. Then conda activate tensorboard followed by tensorboard serve --logdir /path/to/MIRA/runs/file. Open http://127.0.0.1:16006/ in browser to view.

In [None]:
tuner_ATAC.tune(data_ATAC, n_workers = 1)

In [None]:
tuner_ATAC.select_best_model(data_ATAC, record_umaps=True)

In [None]:
model_ATAC.save('data/topic_model_atac_h5.pth')

In [None]:
model_ATAC.predict(data_ATAC)
model_ATAC.get_umap_features(data_ATAC, box_cox=0.5)
sc.pp.neighbors(data_ATAC, use_rep = 'X_umap_features', metric = 'manhattan')
sc.tl.umap(data_ATAC, min_dist=0.1, negative_sample_rate=0.05)

sc.pl.umap(data_ATAC, frameon=False, size = 1200, alpha = 0.5, add_outline=True,
          outline_width=(0.1,0))

In [None]:
study_ATAC = mira.topics.TopicModelTuner.load_study('E4AD_1yr_tuner_ATAC')

In [None]:
optuna.visualization.plot_optimization_history(study_ATAC)

In [None]:
optuna.visualization.plot_parallel_coordinate(study_ATAC)

# Joint Representation

In [None]:
import mira
import anndata
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import logging
import seaborn as sns
mira.logging.getLogger().setLevel(logging.INFO)
import warnings
warnings.simplefilter("ignore")

umap_kwargs = dict(
    add_outline=True, outline_width=(0.1,0), outline_color=('grey', 'white'),
    legend_fontweight=350, frameon = False, legend_fontsize=12
)
print(mira.__version__)
mira.utils.pretty_sderr()

Re-prep our datasets and our models.

In [None]:
data_RNA = anndata.read_h5ad('multiome_RNA.h5ad')
sc.pp.filter_genes(data_RNA, min_cells=15)
data_RNA.raw = data_RNA
sc.pp.normalize_total(data_RNA, target_sum=1e4)
sc.pp.log1p(data_RNA)
sc.pp.highly_variable_genes(data_RNA, min_disp=0.2)
data_RNA.var['exog'] = data_RNA.var.highly_variable.copy()
data_RNA.var['endog'] = data_RNA.var.exog & (data_RNA.var.dispersions_norm > 0.7)
data_RNA.layers['counts'] = data_RNA.raw.to_adata().X.copy()

In [None]:
data_ATAC = anndata.read_h5ad('multiome_ATAC.h5ad')
sc.pp.filter_genes(data_ATAC, min_cells=15)
data_ATAC.raw = data_ATAC
sc.pp.normalize_total(data_ATAC, target_sum=1e4)
sc.pp.log1p(data_ATAC)
data_ATAC.layers['counts'] = data_ATAC.raw.to_adata().X.copy()

In [None]:
# reload our models
model_RNA = mira.topics.ExpressionTopicModel.load('data/topic_model_rna_h5.pth')
model_ATAC = mira.topics.AccessibilityTopicModel.load('data/topic_model_atac_h5.pth')

In [None]:
model_RNA.predict(data_RNA)
model_ATAC.predict(data_ATAC)

In [None]:
# box-cox of 0.33 looks best
model_RNA.get_umap_features(data_RNA, box_cox=0.33)
model_ATAC.get_umap_features(data_ATAC, box_cox=0.33)

In [None]:
sc.pp.neighbors(data_RNA, use_rep = 'X_umap_features', metric = 'manhattan', n_neighbors = 21)
sc.tl.umap(data_RNA, min_dist = 0.1)
data_RNA.obsm['X_umap'] = data_RNA.obsm['X_umap']*np.array([-1,-1]) # flip for consistency
sc.pp.neighbors(data_ATAC, use_rep = 'X_umap_features', metric = 'manhattan', n_neighbors = 21)
sc.tl.umap(data_ATAC, min_dist = 0.1)
data_ATAC.obsm['X_umap'] = data_ATAC.obsm['X_umap']*np.array([1,-1]) # flip for consistency

In [None]:
# remap our clust_idents to their actual names instead of just level #s
data_RNA.obs.clust_ident = data_RNA.obs.clust_ident.astype(str)
data_ATAC.obs.clust_ident = data_ATAC.obs.clust_ident.astype(str)
mapping_dict = {"0" : "Homeostatic Microglia" ,
                "1" : "Arhgap15-hi Homeostatic Microglia",
                "2" : "mt-Enriched Microglia",
                "3" : "DAM-1",
                "4" : "DAM-2",
                "5" : "TIMs",
                "6" : "Siglech-hi Microglia",
                "7" : "Inteferon Induced Microglia",
                "8" : "Monocytes",
                "9" : "F13a1+ Monocytes",
                "10" : "Macrophages",
                "11" : "Early Neutrophils",
                "12" : "Inflammatory Neutrophils",
                "13" : "B Cells 1",
                "14" : "B Cells 2",
                "15" : "B Cells 3",
                "16" : "B Cells 4",
                "17" : "IgM+ B Cells",
                "18" : "Naive CD4s",
                "19" : "Treg CD4s",
                "20" : "Tem CD4s",
                "21" : "Trm CD4s",
                "22" : "Astrocytes"
               }
data_RNA.obs = data_RNA.obs.replace({"clust_ident":mapping_dict})
data_ATAC.obs = data_ATAC.obs.replace({"clust_ident":mapping_dict})

In [None]:
palette = dict(zip(
    data_ATAC.obs.clust_ident.unique(), [sns.color_palette('Set3')[(i+1)%12] for i in range(30)]
))

fig, ax = plt.subplots(2,1,figsize=(10,15))
sc.pl.umap(data_RNA, color = 'clust_ident', legend_loc = 'on data', ax = ax[0], size = 20,
          **umap_kwargs, title = 'Expression Only', show = False, palette=palette)

sc.pl.umap(data_ATAC, color = 'clust_ident', legend_loc = 'on data', ax = ax[1], size = 20,
          **umap_kwargs, title = 'Accessibility Only', show = False, na_color = 'lightgrey',
          palette=palette)
plt.tight_layout()
plt.show()

In [None]:
data_RNA, data_ATAC = mira.utils.make_joint_representation(data_RNA, data_ATAC)

In [None]:
sc.pp.neighbors(data_RNA, use_rep = 'X_joint_umap_features', metric = 'manhattan',
               n_neighbors = 20)

In [None]:
sc.tl.leiden(data_RNA, resolution = 1.2)
sc.tl.paga(data_RNA)
sc.pl.paga(data_RNA, plot=False)
sc.tl.umap(data_RNA, init_pos='paga', min_dist = 0.1)

In [None]:
fig, ax = plt.subplots(1,2,figsize=(10,15))
sc.pl.umap(data_RNA, color = 'leiden', legend_loc = 'on data', ax = ax[0], show = False, size = 20,
          **umap_kwargs, title = 'MIRA_clusts')
sc.pl.umap(data_RNA, color = 'clust_ident', legend_loc = 'on data', ax = ax[1], show = False, size = 20,
          **umap_kwargs, title = 'seurat_clusts')
plt.tight_layout()
plt.show()

In [None]:
# transfer metadata over so we can just use data_RNA for plotting from here on out
data_RNA.obs = data_RNA.obs.join(
    data_ATAC.obs.add_prefix('ATAC_'))

data_ATAC.obsm['X_umap'] = data_RNA.obsm['X_umap']

In [None]:
# save our files:
data_RNA.write('multiome_RNA_processed.h5ad')
data_ATAC.write('multiome_ATAC_processed.h5ad')

In [None]:
mira.tl.get_cell_pointwise_mutual_information(data_RNA, data_ATAC)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8,5))
sc.pl.umap(data_RNA, color = 'pointwise_mutual_information', ax = ax, vmin = 0,
          color_map='magma', frameon=False, add_outline=True, vmax = 4, size = 7)

In [None]:
mira.tl.summarize_mutual_information(data_RNA, data_ATAC)

This is a very high level of concordance between RNA and ATAC modalities. 0.5 is considered the threshold for high :)

In [None]:
cross_correlation = mira.tl.get_topic_cross_correlation(data_RNA, data_ATAC)

In [None]:
sns.clustermap(cross_correlation, vmin = 0,
               cmap = 'magma', method='ward',
               dendrogram_ratio=0.05, cbar_pos=None, figsize=(7,7))

## Let's dive a bit deeper into these topics.

First, we'll check which genes are most strongly activated by each topic and do simple Enrichr analysis to correlate each with a process or pathway.

In [None]:
sc.pl.umap(data_RNA, color  = ['topic_0', 'topic_1', 'topic_2','topic_3','topic_4','topic_5','topic_6',
                               'topic_7','topic_8','topic_9','topic_10'], frameon=False, ncols=4,color_map = 'viridis')

In [None]:
# we post the top 5% genes (our model took ~4000 genes) from each of our microglial-enriched topics
model_RNA.post_topic(0, top_n=200)
model_RNA.post_topic(3, top_n=200)
model_RNA.post_topic(6, top_n=200)
model_RNA.post_topic(7, top_n=200)
model_RNA.post_topic(8, top_n=200)
model_RNA.post_topic(9, top_n=200)

In [None]:
model_RNA.fetch_topic_enrichments(0, ontologies= ['WikiPathways_2019_Mouse','GO_Biological_Process_2021'])
model_RNA.fetch_topic_enrichments(3, ontologies= ['WikiPathways_2019_Mouse','GO_Biological_Process_2021'])
model_RNA.fetch_topic_enrichments(6, ontologies= ['WikiPathways_2019_Mouse','GO_Biological_Process_2021'])
model_RNA.fetch_topic_enrichments(7, ontologies= ['WikiPathways_2019_Mouse','GO_Biological_Process_2021'])
model_RNA.fetch_topic_enrichments(8, ontologies= ['WikiPathways_2019_Mouse','GO_Biological_Process_2021'])
model_RNA.fetch_topic_enrichments(9, ontologies= ['WikiPathways_2019_Mouse','GO_Biological_Process_2021'])

In [None]:
model_RNA.plot_enrichments(0, show_top=10)
model_RNA.plot_enrichments(3, show_top=10)
model_RNA.plot_enrichments(6, show_top=10)
model_RNA.plot_enrichments(7, show_top=10)
model_RNA.plot_enrichments(8, show_top=10)
model_RNA.plot_enrichments(9, show_top=10)

Now, ATAC.

In [None]:
sc.pl.umap(data_RNA, color  = ['ATAC_topic_0', 'ATAC_topic_1', 'ATAC_topic_2','ATAC_topic_3','ATAC_topic_4','ATAC_topic_5',
                               'ATAC_topic_6','ATAC_topic_7','ATAC_topic_8','ATAC_topic_9','ATAC_topic_10',
                              'ATAC_topic_11','ATAC_topic_12','ATAC_topic_13','ATAC_topic_14','ATAC_topic_15',
                              'ATAC_topic_16','ATAC_topic_17','ATAC_topic_18','ATAC_topic_19'], frameon=False, ncols=4,color_map = 'magma')

In [None]:
sc.pl.umap(data_RNA, color  = ['ATAC_topic_17'], frameon=False, ncols=4,color_map = 'magma')

In [None]:
# download the mus musculus genome for motif scanning analysis over our ATAC frags
!wget https://hgdownload.soe.ucsc.edu/goldenPath/mm10/bigZips/mm10.fa.gz
!gzip -d -f mm10.fa.gz

In [None]:
# prep data_ATAC to have chromosome, start, and end metadata easily accessable
ids = data_ATAC.var.index.to_series().str.split("-")
data_ATAC.var['chr'] = Extract(ids,0)
data_ATAC.var['start'] = Extract(ids,1)
data_ATAC.var['end'] = Extract(ids,2)
data_ATAC.var['chr'] = data_ATAC.var['chr'].astype(str)
data_ATAC.var['start'] = data_ATAC.var['start'].astype(str)
data_ATAC.var['end'] = data_ATAC.var['end'].astype(str)
data_ATAC.var

In [None]:
mira.tools.motif_scan.logger.setLevel(logging.INFO) # make sure progress messages are displayed
mira.tl.get_motif_hits_in_peaks(data_ATAC,
                    genome_fasta='mm10.fa',
                    chrom = 'chr', start = 'start', end = 'end') # use our metadata we just populated

In [None]:
mira.utils.fetch_factor_meta(data_ATAC).head(3)

In [None]:
#mira.utils.subset_factors(data_ATAC,
                          #use_factors=[factor for factor in data_RNA.var_names])
# filter out TFs that don't have matching RNA data
# however atm these are all human names (see above) so we skip this for now
# asked on github, will see if i get a response
# https://github.com/cistrome/MIRA/issues/16

In [None]:
A = mira.utils.fetch_factor_hits(data_ATAC).X
sparsity = 1.0 - ( A.count_nonzero() / float(A.toarray().size) )
sparsity

In [None]:
data_atac = anndata.read_h5ad('multiome_ATAC_processed.h5ad')
model_ATAC = mira.topics.AccessibilityTopicModel.load('data/topic_model_atac_h5_processed.pth')

In [None]:
model_ATAC.get_enriched_TFs(data_ATAC, topic_num=2, top_quantile=0.2)
model_ATAC.get_enriched_TFs(data_ATAC, topic_num=3, top_quantile=0.2)
model_ATAC.get_enriched_TFs(data_ATAC, topic_num=15, top_quantile=0.2)

In [None]:
import pandas as pd
# ATAC signature 2 correlates highly with TIMs
df = pd.DataFrame(model_ATAC.get_enrichments(16)).sort_values(by=['pval'], ascending = True).head(50)
df.to_csv("atac_topic_16_tfs.csv")

In [None]:
model_ATAC.plot_compare_topic_enrichments(3, 4,
            fontsize=10, label_closeness=3, figsize=(6,6), pval_threshold = (1e-45,1e-300))

In [None]:
motif_scores = model_ATAC.get_motif_scores(data_ATAC)

In [None]:
motif_scores.var = motif_scores.var.set_index('parsed_name')
motif_scores.var_names_make_unique()
motif_scores.obsm['X_umap'] = data_ATAC.obsm['X_umap']

In [None]:
import pandas as pd
df = pd.DataFrame(motif_scores.X, columns = motif_scores.var.index.map(str))
df.index = motif_scores.obs.index
df.to_csv("~/scratch/R_dir/1yr_multiome/MIRA_motif_scores.csv")

In [None]:
sc.pl.umap(motif_scores, color = ['RUNX2', 'RUNX3','JUN','FOS','CEBPA',"CEBPE"],
           frameon=False, color_map='viridis', ncols=2)

In [None]:
# save our files again:
data_RNA.write('multiome_RNA_processed.h5ad')
data_ATAC.write('multiome_ATAC_processed.h5ad')
model_RNA.save('data/topic_model_rna_h5_processed.pth')
model_ATAC.save('data/topic_model_atac_h5_processed.pth')