In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections
from natsort import natsorted

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu


In [None]:
outdir_fig = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures"

# load gene annotation and data

In [None]:
import importlib
import scroutines
import scroutines.gene_modules
importlib.reload(scroutines)
importlib.reload(scroutines.gene_modules)

from scroutines.gene_modules import GeneModules  
gene_modules = GeneModules()
g_anno, g_color, g_styled = gene_modules.check_genes('Cdh13')
print("\t".join(g_anno))
print("\t".join(g_color))
print("\t".join(g_styled))

In [None]:
# AC genes
f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/Saumya_P6-21_AC_genes.csv'
df = pd.read_csv(f)

df_a = df.iloc[:25]
df_c = df.iloc[25:]

alltime_a = np.unique(df_a.values)
alltime_c = np.unique(df_c.values)
alltime_ac = np.hstack([alltime_a, alltime_c])

ac_overlap = np.intersect1d(alltime_a, alltime_c)

print(df_a.shape, df_c.shape, alltime_a.shape, alltime_c.shape, alltime_ac.shape, ac_overlap.shape)
df.head()

In [None]:
# use those 286 genes
# df = pd.read_csv("../../data/cheng21_cell_scrna/res/candidate_genes_vincent_0503_v2.csv")
df = pd.read_csv("../../data/cheng21_cell_scrna/res/L23-ABC-genes-n288-n286unq-annot.csv")
genes_l23 = df['gene'].astype(str).values
genes_l23a = df[df['P17on']=='A']['gene'].astype(str).values
genes_l23b = df[df['P17on']=='B']['gene'].astype(str).values
genes_l23c = df[df['P17on']=='C']['gene'].astype(str).values

print(genes_l23a.shape, genes_l23b.shape, genes_l23c.shape)
genes_grp = df['P17on'].astype(str).values
assert len(genes_l23) == len(np.unique(genes_l23))

genes_l23.shape

In [None]:
%%time
adata0 = anndata.read("../../data/v1_multiome/L23_allmultiome_proc_P6toP21.h5ad", backed='r')
adata = anndata.read("../../data/v1_multiome/L23_allmultiome_proc_P6toP21_NRDR.h5ad")
adata

In [None]:
# # define
genes0 = adata0.var.index.values
genes = adata.var.index.values
conds = adata.obs['cond'].values
types = adata.obs['Type'].values
samps = adata.obs['sample'].values

# switch back to float64
adata.layers[    'norm'] = adata.layers['norm'][...].astype(np.float64)
adata.layers[ 'lognorm'] = np.log10(1+adata.layers['norm'][...]) # np.array(xln.todense())
adata.layers['zlognorm'] = zscore(adata.layers['lognorm'][...], axis=0)

In [None]:
f_rna1 = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/rna_qs_avc_p6to21.txt'
f_rna2 = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/rna_l2fc_avc_p6to21.txt'

rna_qs_avc = np.loadtxt(f_rna1)
rna_l2fc_avc = np.loadtxt(f_rna2)

rna_qs_avc.shape, rna_l2fc_avc.shape

In [None]:
time_series = [6,8,10,12,14,17,21]

In [None]:
# get avsc genes

# genes = adata_rna.var.index.values
degs_a = []
degs_c = []
for i, t in enumerate(time_series):
    cond_a = np.logical_and(rna_qs_avc[i]<0.05, rna_l2fc_avc[i] < -1)
    cond_c = np.logical_and(rna_qs_avc[i]<0.05, rna_l2fc_avc[i] >  1)
    deg_a = np.sort(genes0[cond_a])
    deg_c = np.sort(genes0[cond_c])
    degs_a.append(deg_a)
    degs_c.append(deg_c)
    
    print(f"{t}")
    _, _, gs = gene_modules.check_genes(deg_a)
    print("\t".join(gs))
    print("\033[0;m ---")
    _, _, gs = gene_modules.check_genes(deg_c)
    print("\t".join(gs))
    print("\033[0;m ---")



In [None]:
degs_ac = np.unique(np.hstack([np.hstack(degs_a), np.hstack(degs_c)]))
degs_ac, degs_ac.shape

In [None]:
pcs_p8 = adata.obsm['pca_p8']
pcs_p17on = adata.obsm['pca_p17on']

# res0 = pd.DataFrame(adata.layers['zlognorm'][...], columns=genes)
# res0['cond'] = conds
# res0['type'] = types
# res0['samp'] = samps
# res0['rep']  = [samp[-1] for samp in samps]
# res0['type'] = np.char.add('c', res0['type'].values.astype(str))

# res1 = pd.DataFrame(pcs_p8, columns=np.char.add("p8PC", ((1+np.arange(pcs_p8.shape[1])).astype(str))))
# res2 = pd.DataFrame(pcs_p17on, columns=np.char.add("p17onPC", ((1+np.arange(pcs_p17on.shape[1])).astype(str))))
# res = pd.concat([res0, res1, res2], axis=1)

In [None]:
allcolors = sns.color_palette('tab20c', 20)
allcolors

In [None]:
allcolors2 = sns.color_palette('tab10', 20)
allcolors2

In [None]:
palette = collections.OrderedDict({
     "P6": allcolors[2],
     "P8": allcolors[1],
    "P10": allcolors[0],
    "P12": allcolors[4+2],
    "P14": allcolors[4+0],
    
    "P17": allcolors[8+2],
    "P21": allcolors[8+0],
    
    "P12DR": allcolors[8+2],
    "P14DR": allcolors[8+0],
    "P17DR": allcolors[8+0],
    "P21DR": allcolors[8+0],
    
})
cases = np.array(list(palette.keys()))

cond_order_dict = {
    'P6':  0,
    'P8':  1,
    'P10': 2,
    'P12': 3,
    'P14': 4,
    'P17': 5,
    'P21': 6,
    
    'P12DR': 7,
    'P14DR': 8,
    'P17DR': 9,
    'P21DR': 10,
}
unq_conds = np.array(list(cond_order_dict.keys()))

adata.obs['cond_order'] = adata.obs['cond'].apply(lambda x: cond_order_dict[x]).astype(int)

palette_types = collections.OrderedDict({
    'L2/3_A': allcolors2[0],
    'L2/3_B': allcolors2[1],
    'L2/3_C': allcolors2[2],
})             

palette_types = {
    'c14': 'C0', 
    'c18': 'C1',
    'c16': 'C2', 
    
    'c13': 'C0', 
    'c15': 'C1', 
    'c17': 'C2',
}
type_order = [key for key, val in palette_types.items()]
type_order

In [None]:
# plot('p17onPC1', 'p17onPC2', aspect_equal=True)
# plot('p8PC1', 'p8PC2', aspect_equal=True)

In [None]:
np.unique(adata.obs['cond_order']) #.unique())

# get top 20% A vs C in each case

In [None]:
from statsmodels.stats.multitest import multipletests

In [None]:

### select C genes - no
# adata = adata[:,degs_ac]
# genes = adata.var.index.values
# adata

In [None]:
genes_comm = adata.var.index.values
# lfc_th, qs_th = np.log2(1.3), 0.05
lfc_th, qs_th = np.log2(2), 0.05
alpha_th = qs_th
l2fc_th = lfc_th
times = ['P12', 'P14', 'P17', 'P21']
conds = ['P12', 'P14', 'P17', 'P21']

qs_nrdr_a = []
qs_nrdr_c = []
l2fc_nrdr_a = []
l2fc_nrdr_c = []

for t in times: 
    
    adatasub1 = adata[adata.obs['cond']==t]
    adatasub2 = adata[adata.obs['cond']==t+'DR']

    # get A vs C 
    x1 = adatasub1.obsm['pca_p17on'][...,0]
    cond_1a = x1 < np.percentile(x1,20)
    cond_1c = x1 > np.percentile(x1,80)
    adatasub_1a = adatasub1[cond_1a]
    adatasub_1c = adatasub1[cond_1c]
    
    # get A vs C 
    x2 = adatasub2.obsm['pca_p17on'][...,0]
    cond_2a = x2 < np.percentile(x2,20)
    cond_2c = x2 > np.percentile(x2,80)
    adatasub_2a = adatasub2[cond_2a]
    adatasub_2c = adatasub2[cond_2c]
    
    # DEGs
    mat_1a = adatasub_1a.layers['lognorm'][...]
    mat_1c = adatasub_1c.layers['lognorm'][...]
    mat_2a = adatasub_2a.layers['lognorm'][...]
    mat_2c = adatasub_2c.layers['lognorm'][...]
    
    ts_a, ps_a = stats.ttest_ind(mat_1a, mat_2a)
    ts_c, ps_c = stats.ttest_ind(mat_1c, mat_2c)
    
    rs_a, qs_a, _, _ = multipletests(np.nan_to_num(ps_a, nan=1).reshape(-1,), method='fdr_bh') # why nan in ps -- not expressed
    rs_c, qs_c, _, _ = multipletests(np.nan_to_num(ps_c, nan=1).reshape(-1,), method='fdr_bh') # why nan in ps -- not expressed
    
    lfc_a = np.log2(10)*(np.mean(mat_2a, axis=0) - np.mean(mat_1a, axis=0)) # log2FC (log10CPM as raw counts)
    lfc_c = np.log2(10)*(np.mean(mat_2c, axis=0) - np.mean(mat_1c, axis=0)) # log2FC (log10CPM as raw counts)
    
    num_sig_a = np.sum(np.logical_and(qs_a < alpha_th, np.abs(lfc_a) > l2fc_th))
    num_sig_c = np.sum(np.logical_and(qs_c < alpha_th, np.abs(lfc_c) > l2fc_th))
    print(t, adatasub_1a.shape, adatasub_1c.shape, num_sig_a, num_sig_c)
    
    # save this
    qs_nrdr_a.append(qs_a)
    qs_nrdr_c.append(qs_c)
    l2fc_nrdr_a.append(lfc_a)
    l2fc_nrdr_c.append(lfc_c)
    
qs_nrdr_a = np.array(qs_nrdr_a)
qs_nrdr_c = np.array(qs_nrdr_c)

l2fc_nrdr_a = np.array(l2fc_nrdr_a)
l2fc_nrdr_c = np.array(l2fc_nrdr_c)

In [None]:
adata.uns['qs_nrdr_a'] = qs_nrdr_a
adata.uns['qs_nrdr_c'] = qs_nrdr_c

adata.uns['l2fc_nrdr_a'] = l2fc_nrdr_a
adata.uns['l2fc_nrdr_c'] = l2fc_nrdr_c
adata

In [None]:
cond_sig_a = np.logical_and(qs_nrdr_a < alpha_th, np.abs(l2fc_nrdr_a) > l2fc_th)
cond_sig_c = np.logical_and(qs_nrdr_c < alpha_th, np.abs(l2fc_nrdr_c) > l2fc_th)

# print('num AvsC-DEGs per cond:\t', cond_sig.sum(axis=1))
print('num A-DEGs per cond:\t',    cond_sig_a.sum(axis=1))
print('num C-DEGs per cond:\t',    cond_sig_c.sum(axis=1))

In [None]:
# gene_sig_instances   = cond_sig.sum(axis=0)
gene_sig_instances_a = cond_sig_a.sum(axis=0)
gene_sig_instances_c = cond_sig_c.sum(axis=0)

# instances, counts = np.unique(gene_sig_instances, return_counts=True)
instances_a, counts_a = np.unique(gene_sig_instances_a, return_counts=True)
instances_c, counts_c = np.unique(gene_sig_instances_c, return_counts=True)

# print('num AvsC-DEGs in num conds:\t',  counts  , np.any(cond_sig, axis=0).sum())
print('num A-DEGs in num conds:\t', instances_a, counts_a, np.any(cond_sig_a, axis=0).sum())
print('num C-DEGs in num conds:\t', instances_c, counts_c, np.any(cond_sig_c, axis=0).sum())

In [None]:
a_any = np.sort(adata.var[np.any(cond_sig_a, axis=0)].index.values)
c_any = np.sort(adata.var[np.any(cond_sig_c, axis=0)].index.values)
a_all = np.sort(adata.var[np.all(cond_sig_a, axis=0)].index.values)
c_all = np.sort(adata.var[np.all(cond_sig_c, axis=0)].index.values)
ac_overlap = np.sort(adata.var[np.logical_and(np.any(cond_sig_a, axis=0), np.any(cond_sig_c, axis=0))].index.values)

print('a any', a_any)
print('c any', c_any)
print('a all', a_all)
print('c all', c_all)
print('ac overlap', ac_overlap)

In [None]:
_, _, a_any_styled = gene_modules.check_genes(a_any)
_, _, c_any_styled = gene_modules.check_genes(c_any)

print("\033[0;m anytime A:", "\t".join(a_any_styled))
print("\033[0;m anytime C:", "\t".join(c_any_styled))

In [None]:
# %%time
# fout1 = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/rna_qs_avc_p6to21.txt'
# fout2 = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/rna_l2fc_avc_p6to21.txt'

# np.savetxt(fout1, qs_avc)
# np.savetxt(fout2, l2fc_avc)

# Volcano

In [None]:
def show_volcano_v2(thetypeidx, thetype, lfc, qs,
                    cond1, cond2up, cond2dn, 
                    querygenes_idx=None, 
                    gene_annots=None,
                    ax=None, bbox_to_anchor=(1,1), loc=None,
                    clip_p=None,
                   ): 
    """
    """
    eff = lfc[:,thetypeidx]
    pvl = -np.log10(qs[:,thetypeidx])
    if clip_p is not None:
        pvl = np.clip(pvl, None, clip_p)
    cnd_up = np.all([cond1[:,thetypeidx], 
                     cond2up[:,thetypeidx]], axis=0) 
    cnd_dn = np.all([cond1[:,thetypeidx], 
                     cond2dn[:,thetypeidx]], axis=0) 

    if ax is None: 
        fig, ax = plt.subplots()
    
    # all genes
    ax.scatter(eff, pvl, s=1, color='lightgray', rasterized=True)
    print(cnd_up.sum())

    # up genes
    ax.scatter(eff[cnd_up], pvl[cnd_up], s=3, color='C0', rasterized=True)
    # dn genes
    ax.scatter(eff[cnd_dn], pvl[cnd_dn], s=3, color='C1', rasterized=True)
    
    # # query genes
    # add text
    if querygenes_idx is not None:
        ax.scatter(eff[querygenes_idx], pvl[querygenes_idx], s=15, 
                   # label=f'type-specific (n={len(querygenes_idx):,})',
                   facecolors='none', edgecolors='k', linewidth=1, rasterized=True)
        for idx in querygenes_idx:
            ax.text(eff[idx], pvl[idx], gene_annots[idx], fontsize=10)

            
    # ax.grid(axis='y')
    sns.despine(ax=ax)
    ax.set_xlabel('log2(FC) (C/A in CP10k)')
    ax.set_ylabel('-log10(adj. p)')
    ax.set_title(f'{thetype}')
    ax.text(1,0.1,
            f'up (n={cnd_up.sum():,})\ndown (n={cnd_dn.sum():,})', 
            ha='right',
            fontsize=10, transform=ax.transAxes)
    return 

In [None]:
unq_condidx = np.unique(adata.obs['cond_order']) #.unique()
unq_condidx # , unq_conds

In [None]:
lfc = l2fc_nrdr_a.T #adata.uns['l2fc_avc'].T
qs  = qs_nrdr_a.T # adata.uns['qs_avc'].T
cond1   = qs  <  qs_th
cond2up = lfc >  lfc_th
cond2dn = lfc < -lfc_th

querygenes = ['Fos', 'Arc', 'Nr4a2', 'Nr4a3']
querygenes_idx = basicu.get_index_from_array(genes_comm, querygenes) 
gene_annots = genes_comm

n = len(unq_conds)
fig, axs = plt.subplots(1,4,figsize=(4*4,4*1)) # sharex=True, sharey=True)
for cond_idx in range(4): #[3,4,5,6]: # unq_condidx:
    ax = axs[cond_idx]
    thecond = conds[cond_idx]
    show_volcano_v2(cond_idx, thecond, lfc, qs, cond1, cond2up, cond2dn, 
                    clip_p=None,
                    querygenes_idx=querygenes_idx, 
                    gene_annots=genes_comm,
                    ax=ax, bbox_to_anchor=(0.5, -0.3), loc='upper center')
    sns.despine(ax=ax)
fig.tight_layout()
plt.show()

In [None]:
# genes = adata_rna.var.index.values

degs_a_dn = []
degs_a_up = []
for i, t in enumerate(conds):
    cond_a_dn = np.logical_and(qs_nrdr_a[i]<0.05, l2fc_nrdr_a[i] < -lfc_th)
    cond_a_up = np.logical_and(qs_nrdr_a[i]<0.05, l2fc_nrdr_a[i] >  lfc_th)
    
    deg_a_dn = np.sort(genes[cond_a_dn])
    deg_a_up = np.sort(genes[cond_a_up])
    degs_a_dn.append(deg_a_dn)
    degs_a_up.append(deg_a_up)
    
    print(f"{t}")
    if len(deg_a_dn) > 0:
        _, _, gs = gene_modules.check_genes(deg_a_dn)
        print("\t".join(gs))
    print("\033[0;m ---")
    if len(deg_a_up) > 0:
        _, _, gs = gene_modules.check_genes(deg_a_up)
        print("\t".join(gs))
    print("\033[0;m ---")
    

In [None]:
# genes = adata_rna.var.index.values

degs_c_dn = []
degs_c_up = []
for i, t in enumerate(conds):
    cond_c_dn = np.logical_and(qs_nrdr_c[i]<0.05, l2fc_nrdr_c[i] < -lfc_th)
    cond_c_up = np.logical_and(qs_nrdr_c[i]<0.05, l2fc_nrdr_c[i] >  lfc_th)
    
    deg_c_dn = np.sort(genes[cond_c_dn])
    deg_c_up = np.sort(genes[cond_c_up])
    degs_c_dn.append(deg_c_dn)
    degs_c_up.append(deg_c_up)
    
    print(f"{t}")
    if len(deg_c_dn) > 0:
        _, _, gs = gene_modules.check_genes(deg_c_dn)
        print("\t".join(gs))
    print("\033[0;m ---")
    if len(deg_c_up) > 0:
        _, _, gs = gene_modules.check_genes(deg_c_up)
        print("\t".join(gs))
    print("\033[0;m ---")
    

In [None]:
lfc = l2fc_nrdr_c.T #adata.uns['l2fc_avc'].T
qs  = qs_nrdr_c.T # adata.uns['qs_avc'].T
cond1   = qs  <  qs_th
cond2up = lfc >  lfc_th
cond2dn = lfc < -lfc_th

# querygenes = ['Ptprg', 'Sorcs3', 'Cdh20', 'Egfem1', 'Kcnh5', 'Rorb', 'Epha10', 'Pcdh15']
querygenes = ['Fos', 'Arc', 'Nr4a2', 'Nr4a3', 'Xist']
querygenes_idx = basicu.get_index_from_array(genes_comm, querygenes) 
gene_annots = genes_comm

n = len(unq_conds)
fig, axs = plt.subplots(1,4,figsize=(4*4,4*1)) # , sharex=True, sharey=True)
for cond_idx in range(4): #[3,4,5,6]: # unq_condidx:
    ax = axs[cond_idx]
    thecond = conds[cond_idx]
    show_volcano_v2(cond_idx, thecond, lfc, qs, cond1, cond2up, cond2dn, 
                    clip_p=None,
                    querygenes_idx=querygenes_idx, 
                    gene_annots=genes_comm,
                    ax=ax, bbox_to_anchor=(0.5, -0.3), loc='upper center')
    ax.set_xticks([-1,0,1])
    sns.despine(ax=ax)
fig.tight_layout()
plt.show()