In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections
from natsort import natsorted

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP
from statsmodels.stats.multitest import multipletests

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu

from scroutines.gene_modules import GeneModules  

# import importlib
# import scroutines
# importlib.reload(scroutines)
# from scroutines.gene_modules import GeneModules  


In [None]:
outfigdir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures/250409"
!mkdir -p $outfigdir

# load gene annotation and data

In [None]:
gene_modules = GeneModules()
g, gs, ms = gene_modules.check_genes('Cdh13')
print("\t".join(g))
print("\t".join(gs))
print("\t".join(ms))

In [None]:
scores_abc = pd.read_csv("/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/scores_l23abc_cheng22_250904.csv", 
                         index_col=0,
                        )
scores_abc['scores_c-a'] = scores_abc['scores_c'] - scores_abc['scores_a']
scores_abc

In [None]:
adata = anndata.read("/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/superdupermegaRNA_hasraw_cheng22_l23.h5ad") # , backed='r')
adata = adata[adata.obs['Age'].str.contains(r'^P28')]
adata

In [None]:
adata.X = adata.raw.X

In [None]:
adata.obs['scores_a'] = scores_abc.loc[adata.obs.index,'scores_a'].copy()
adata.obs['scores_b'] = scores_abc.loc[adata.obs.index,'scores_b'].copy()
adata.obs['scores_c'] = scores_abc.loc[adata.obs.index,'scores_c'].copy()
adata.obs['scores_c-a'] = scores_abc.loc[adata.obs.index,'scores_c-a'].copy()

In [None]:
adata.obs['cond'] = adata.obs['Age']

sample_labels = adata.obs['Sample'].values
time_labels = [s[:-1].replace('DR', '') for s in sample_labels]

adata.obs['sample'] = sample_labels #
adata.obs['time']   = time_labels

uniq_samples = natsorted(np.unique(sample_labels))
nr_samples = [s for s in uniq_samples if "DR" not in s]
dr_samples = [s for s in uniq_samples if "DR" in s]

uniq_conds = np.array(natsorted(np.unique(adata.obs['cond'].values)))
print(uniq_conds)

In [None]:
# remove mitocondria genes
adata = adata[:,~adata.var.index.str.contains(r'^mt-')]
# remove sex genes
adata = adata[:,~adata.var.index.str.contains(r'^Xist$')]

# filter genes
cond = np.ravel((adata.X>0).sum(axis=0)) > 10 # expressed in more than 10 cells
adata = adata[:,cond].copy()

In [None]:
adata

In [None]:
# counts
x = adata.X
cov = np.ravel(np.sum(x, axis=1))
genes = adata.var.index.values

# CP10k
xn = (sparse.diags(1/cov).dot(x))*1e4

# log2(CP10k+1)
xln = xn.copy()
xln.data = np.log2(xln.data+1)

adata.layers[    'norm'] = np.array(xn.todense())
adata.layers[ 'lognorm'] = np.array(xln.todense())

# annotate samples

In [None]:
# adata.obs['sample'].unique()
import re

todo_conds = ['P28', 'P28DL', 'P28DR',]
todo_samps = [
    'P28_1a', 'P28_1b', 
    'P28_2a', 'P28_2b', 

    'P28_dl_1a', 'P28_dl_1b',
    'P28_dl_2a', 'P28_dl_2b', 

    'P28_dr_1a', 'P28_dr_1b', 
    'P28_dr_3a', 'P28_dr_3b',
]

todo_conds_t = np.array([int(re.sub(r'[a-zA-Z]', '', a.split('_')[0])) for a in todo_conds])
todo_samps_t = np.array([int(re.sub(r'[a-zA-Z]', '', a.split('_')[0])) for a in todo_samps])
print(todo_conds_t)
print(todo_samps_t)

def mean_over_samples(mmat_res_samp):
    """12 samples to 3 conditions
    """
    assert mmat_res_samp.shape[0] == 12 
    
    mmat_res_samp_mean = np.zeros(mmat_res_samp.shape)[:3]
    mmat_res_samp_mean[0] = np.mean(mmat_res_samp[ :4], axis=0)
    mmat_res_samp_mean[1] = np.mean(mmat_res_samp[4:8], axis=0)
    mmat_res_samp_mean[2] = np.mean(mmat_res_samp[8:], axis=0)

    return mmat_res_samp_mean

def transform_bigredmat(bigmat, n_type):
    """bigmat or redmat
    to fmat and zmat
    """
    fmat = bigmat.reshape(-1, bigmat.shape[-1]).T # merge first two dimensions - move it to last dim
    zmat = zscore(fmat, axis=1)
    
    return fmat, zmat

In [None]:
%%time

offset = 1
mat = adata.layers['norm'][...]
gexp_l23baseline = np.log2(np.mean(mat, axis=0)*1e2+offset) # CP10k -> CPM

n_type = 10
frac_archetypal_cells_viz = 0.2
bigmat_nfd = np.zeros((len(todo_samps), n_type, mat.shape[1]))
bigmat_abc = np.zeros((len(todo_samps),      3, mat.shape[1]))

for i, samp in enumerate(todo_samps):
    # get sub
    adatasub = adata[adata.obs['sample']==samp]
    n_cells = adatasub.shape[0]
    
    print(samp, n_cells)
    
    # get A vs C 
    ranks_ac = adatasub.obs['scores_c-a'].rank()
    ranks_b  = adatasub.obs['scores_b'].rank()
    
    # per type
    cells_type_nfd = pd.qcut(ranks_ac, n_type, labels=False)
    for j in range(n_type):
        mat_j = adatasub[cells_type_nfd==j].layers['norm'][...]
        mmat_j = np.log2(np.mean(mat_j, axis=0)*1e2+offset)-gexp_l23baseline # CP10k -> CPM
        bigmat_nfd[i,j] = mmat_j
    
    # A, B, C
    num_archetypal_cells_viz = int(n_cells*frac_archetypal_cells_viz)
    
    precond_a = ranks_ac <= num_archetypal_cells_viz
    precond_c = ranks_ac > adatasub.shape[0] - num_archetypal_cells_viz
    precond_b = ranks_b  > adatasub.shape[0] - num_archetypal_cells_viz
    
    cond_a = np.all([ precond_a, ~precond_b, ~precond_c], axis=0)
    cond_b = np.all([~precond_a,  precond_b, ~precond_c], axis=0)
    cond_c = np.all([~precond_a, ~precond_b,  precond_c], axis=0)
    
    for j, cond in enumerate([cond_a, cond_b, cond_c]):
        mat_j = adatasub[cond].layers['norm'][...]
        mmat_j = np.log2(np.mean(mat_j, axis=0)*1e2+offset)-gexp_l23baseline # CP10k -> CPM
        bigmat_abc[i,j] = mmat_j


In [None]:
redmat_nfd = mean_over_samples(bigmat_nfd)
fmat_nfd, zmat_nfd = transform_bigredmat(redmat_nfd, n_type)
print(bigmat_nfd.shape, redmat_nfd.shape)
print(fmat_nfd.shape)   # gene, cond*type
print(zmat_nfd.shape)   # gene, cond*type


# Profile vision-dependent ABC genes (identified from multiome)

In [None]:
# streamline 2 heatmaps
def order_zmat(zmat, metric0=None):
    """order rows - assuming first 10 cols are NR, and second 10 cols are DR
    """
    zmat_nr = zmat[:,  :10]
    zmat_dl = zmat[:,10:20]
    zmat_dr = zmat[:,20:  ]
    
    # metric0: categorical (provided)
    if metric0 is None:
        # metric0 = [0]*len(zmat)
        metric0 = (np.mean(zmat_dr, axis=1) - np.mean(zmat_nr, axis=1)) > 0
    
    # metric2: continuous (where in NR)
    # metric2 = np.argmax(zmat_nr, axis=1) # peak location
    pmat_nr = np.exp(zmat_nr)
    pmat_nr = pmat_nr/np.sum(pmat_nr, axis=1).reshape(-1,1)
    metric2 = pmat_nr.dot(np.arange(10))                     # centroid location
    
    
    # first by metric 0 then by metric 2
    dforder = pd.DataFrame()
    dforder['m0'] = metric0
    dforder['m2'] = metric2
    
    gene_order = dforder.sort_values(['m0', 'm2']).index.values
    return dforder, gene_order


def mark_ticklabels(highlights, color='red'):
    """
    """
    # Get the tick labels
    tick_labels = plt.gca().get_yticklabels()

    # Mark labels at x = 2 and x = 8 in red
    for label in tick_labels:
        tick_val = label.get_text()
        if tick_val in highlights:
            label.set_color(color)
    return 

In [None]:
# import a gene list 
# indices from the big matrix


genes_vision_abc = [
    'Xylt1','March4','C1ql3','Ptprg','Shc3',
    'Cdh4','Fnbp1l','Slit3','Btg2','Npas4',
    'Nrp1','Ifrd1','Tnfaip6','1700016P03Rik','Nptx2',
    'Ctnna3','Mapk4','Ppm1h','Bdnf','Mir670hg',
    'Baz1a','Homer1','Phf21b','Plcl1','Dusp14',
    'Megf11','Maml3','Gm15398','Mthfd1l','Inhba',
    'Zbtb16','Ell2','Ppme1','Epha10','Rph3a',
    'Igsf9b','Sntb2','Tmtc2','Cpne8','Sgcd',
    'Eda','Thsd7a','Syt10','Col23a1','Epha6',
    'Glis3','Igfbp5','Hdac9','Cdh7','Zfp536',
    'Sparcl1','Gm15155','Kctd8','L3mbtl4','Cdh18',
    'Pcdh10','Pth2r','Cdh20','Gm28175','Parm1',
    'Lcorl','Elavl2','Kcnh5','Ntng1','Cntnap4',
    'Rorb','Rora','Mgll']

genes_csm_abc = [
    'Tenm1','Epha5','Cd200','Ptpru','Ptpro',
    'Vcan','Cdh13','Cntn5','Pcdh19','Cntnap5a',
    'Neo1','Chl1','Robo1','Lsamp','Sema6a',
    'Ptprk','Dscaml1','Tnr','Igsf11','Cadm1',
    'Fstl5','Kirrel3','Sorl1','Nell2','Epha3',
    'Lrfn5','Pcdh17','Efna5','Ntrk3','Tenm2',
    'Clstn2','Slit2','Nrg1','Cd47','Ghr',
    'Lingo2','Cntn3','Nrxn3','Sdk1','Pcdh15',
    'Fat3','Dscam','Pcdh11x','Sema3a','Unc5c',
    'Cntn4','Cdh8','Cdh12','Cdh6','Abi3bp',
    'Il1rap','Cntn6','Tenm4','Dcc','Pcdh9',
    'Alcam','Astn2','Flrt2','Sema7a','Robo3',
    'Ptprt','Lrrc4c','Cntnap2','Unc5d','Cntnap5b',
    'Ncam2','Epha7','Sema6d','Fstl4','Ptprg',
    'Cdh4','Slit3','Nrp1','Ctnna3','Megf11',
    'Epha10','Igsf9b','Epha6','Cdh7','Cdh18',
    'Pcdh10','Cdh20','Ntng1','Cntnap4',
]

In [None]:
selected_conditions = [0,1,2]
titles = ['P28NR', 'P28DL', 'P28DR']
columns = np.hstack([
    0*n_type+np.arange(n_type), 
    1*n_type+np.arange(n_type),
    2*n_type+np.arange(n_type),
])

In [None]:
genes_this = genes_vision_abc
genes_idx = basicu.get_index_from_array(genes, genes_this)
# genes_highlights1 = vision_abc_csm
# genes_highlights2 = vision_abc_synaptic

fig, ax = plt.subplots(1,1,figsize=(7,14))
zmat = zmat_nfd[genes_idx, :][:, columns]
zmat = stats.zscore(zmat, axis=1)

# dforder, gene_order = order_zmat(zmat)
# zmat = zmat[gene_order]
# genes_this = genes_this[gene_order]

sns.heatmap(zmat, 
            cmap='coolwarm', cbar_kws=dict(shrink=0.3), 
            xticklabels=False,
            # vmax=2.5, vmin=-2.5,
            vmax=2, vmin=-2,
            rasterized=True,
            ax=ax)
ax.set_yticks(0.5+np.arange(len(zmat)))
ax.set_yticklabels(genes_this, fontsize=12, rotation=0)
# mark_ticklabels(genes_highlights1, color='magenta')
# mark_ticklabels(genes_highlights2, color='green')

ax.vlines(10, 0, len(zmat), color='white', linewidth=1)
ax.vlines(20, 0, len(zmat), color='white', linewidth=1)
for i, cond in enumerate(titles):
    ax.text(i*10, -0.5, f'{cond}', fontsize=12, va='bottom')
    
output = os.path.join(outfigdir, 'vision_abc_heatmap.pdf')
powerplots.savefig_autodate(fig, output)

plt.show()

In [None]:
genes_this = genes_csm_abc
genes_idx = basicu.get_index_from_array(genes, genes_this)
# genes_highlights1 = vision_abc_csm
# genes_highlights2 = vision_abc_synaptic

fig, ax = plt.subplots(1,1,figsize=(7,14))
zmat = zmat_nfd[genes_idx, :][:, columns]
zmat = stats.zscore(zmat, axis=1)

# dforder, gene_order = order_zmat(zmat)
# zmat = zmat[gene_order]
# genes_this = genes_this[gene_order]

sns.heatmap(zmat, 
            cmap='coolwarm', cbar_kws=dict(shrink=0.3), 
            xticklabels=False,
            # vmax=2.5, vmin=-2.5,
            vmax=2, vmin=-2,
            rasterized=True,
            ax=ax)
ax.set_yticks(0.5+np.arange(len(zmat)))
ax.set_yticklabels(genes_this, fontsize=12, rotation=0)
# mark_ticklabels(genes_highlights1, color='magenta')
# mark_ticklabels(genes_highlights2, color='green')

ax.vlines(10, 0, len(zmat), color='white', linewidth=1)
ax.vlines(20, 0, len(zmat), color='white', linewidth=1)
for i, cond in enumerate(titles):
    ax.text(i*10, -0.5, f'{cond}', fontsize=12, va='bottom')
    
output = os.path.join(outfigdir, 'vision_abc_heatmap.pdf')
powerplots.savefig_autodate(fig, output)

plt.show()

# check ARG regulons