In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import numpy as np
import pandas as pd
import anndata
import scanpy as sc
from scipy import sparse
import seaborn as sns
sns.set(style = 'ticks')
import matplotlib.pyplot as plt
import os
import glob
import os
import sys
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import scale
from lisa.core.utils import LoadingBar
import importlib
from multiome_models import RPModel
import tqdm
from scipy.stats import rankdata, kendalltau, spearmanr
import pickle
from skbio.stats.composition import ilr

In [None]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.debug("test")

In [None]:
def clr(x):
    return np.log(x) - np.log(x).mean(axis = 1, keepdims = True)

def standardize(x):
    return np.clip((x - x.mean()) / x.std(), -3, 3)

def min_max_standardize(x, axis = -1):
    return (x - x.min(axis, keepdims = True)) / (x.max(axis, keepdims = True) - x.min(axis, keepdims = True))

def plot_umap(andata,*,color_key, key = 'X_umap', quantitative=True, ax = None, center=True, legend = False,
             hue_order = None, hue_norm = None, size = 0.5, palette = 'vlag'):
    data = andata.obsm[key]
    if type(color_key) == str:
        color_var = andata.obs[color_key].values
    else:
        color_var = color_key
        
    if not quantitative:
        ax = sns.scatterplot(x = data[:,0], y = data[:,1], size = size, 
                       hue = color_var.astype('str'), palette=palette,
                       hue_order = hue_order, legend = legend, ax = ax)
    else:
        ax = sns.scatterplot(x = -data[:,0], y = data[:,1], hue = standardize(color_var) if center else color_var,
                   palette = palette, size = size, legend = legend, ax = ax, hue_norm = hue_norm)
    ax.set(xticklabels = [], xticks = [], yticks = [])
    sns.despine()
    return ax

def map_plot(rows, columns, plot_fn, figsize = (15,10), **kwargs):
    
    fig, ax = plt.subplots(rows, columns, figsize = figsize)
    if columns == 1:
        ax = ax.reshape(-1,1)
    elif rows == 1:
        ax = ax.reshape(1,-1)
        
    for k in range(rows*columns):
        i = k //columns
        j = (k-(columns*i))%columns
        plot_fn(k, ax[i,j], **kwargs)
    return ax

def get_cell_relavance(cell_topic_dist, region_topic_dist, query):
    
    query_topic_dist = region_topic_dist[:, query]
    
    modeled_word_probs = np.dot(cell_topic_dist, query_topic_dist)
    
    document_relavance = np.sum(np.log(modeled_word_probs), axis = 1)
    
    return document_relavance

def topic_relevance_corr(cell_topic_dist, document_relevance):
    
    cell_clr = cell_topic_dist.copy()
    
    cell_clr -= cell_clr.mean(0)
    document_relevance -= document_relevance.mean(0)
    
    cov = np.dot(document_relevance.T, cell_clr)/document_relevance.shape[0]
    
    corr = cov/(document_relevance.std(0)[:, np.newaxis] * cell_clr.std(0)[np.newaxis, :])
    
    return corr

def get_TF_corr(cell_topic_dist, region_topic_dist, queries):
    
    relevance = []
    for query in tqdm.tqdm(queries):
        cell_relevance = get_cell_relavance(cell_topic_dist, region_topic_dist, query)
        relevance.append(cell_relevance[:,np.newaxis])
        
    relevance = np.hstack(relevance)
        
    return relevance, topic_relevance_corr(clr(cell_topic_dist), relevance)

def fishers_inclusiveness(*,overlaps, chip_hits, num_peaks, genome_size, peak_size = 300):
    neither_regions = int(genome_size/peak_size - (num_peaks + chip_hits - overlaps))
    return fisher_exact(np.array(
        [[overlaps, chip_hits - overlaps],
         [num_peaks - overlaps, int(neither_regions)]])
    )[1]

In [None]:
atac_data = anndata.read_h5ad('2021-02-01_atac_data.h5ad')
gene_expr = anndata.read_h5ad('2021-02-01_gene_expr.h5ad')