In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections
from natsort import natsorted

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP
from statsmodels.stats.multitest import multipletests

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu
from scroutines.gene_modules import GeneModules  


In [None]:
outdir_fig = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures"

# load gene annotation and data

In [None]:
gene_modules = GeneModules()
g, gs, ms = gene_modules.check_genes('Cdh13')
print("\t".join(g))
print("\t".join(gs))
print("\t".join(ms))

In [None]:
# use those 286 genes
df = pd.read_csv("/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/cheng21_cell_scrna/res/L23-ABC-genes-n288-n286unq-annot_v3_july8_2024.csv")
genes_l23 = df['gene'].astype(str).values
genes_l23a = df[df['P17on']=='A']['gene'].astype(str).values
genes_l23b = df[df['P17on']=='B']['gene'].astype(str).values
genes_l23c = df[df['P17on']=='C']['gene'].astype(str).values

print(genes_l23a.shape, genes_l23b.shape, genes_l23c.shape)
genes_grp = df['P17on'].astype(str).values
assert len(genes_l23) == len(np.unique(genes_l23))

genes_l23.shape

In [None]:
scores_abc = pd.read_csv("/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/scores_l23abc.csv", 
                         index_col=0,
                        )
scores_abc

In [None]:
adata = anndata.read("/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/superdupermegaRNA_hasraw_multiome_l23.h5ad")
adata

In [None]:
adata.X = adata.raw.X

In [None]:
adata.obs['scores_a'] = scores_abc.loc[adata.obs.index,'scores_a'].copy()
adata.obs['scores_b'] = scores_abc.loc[adata.obs.index,'scores_b'].copy()
adata.obs['scores_c'] = scores_abc.loc[adata.obs.index,'scores_c'].copy()

In [None]:
sample_labels = adata.obs['Sample'].values
time_labels = [s[:-1].replace('DR', '') for s in sample_labels]

adata.obs['sample'] = sample_labels #
adata.obs['time']   = time_labels

uniq_samples = natsorted(np.unique(sample_labels))
uniq_times = natsorted(np.unique(time_labels))

nr_samples = [s for s in uniq_samples if "DR" not in s]
dr_samples = [s for s in uniq_samples if "DR" in s]
print(uniq_times)
print(nr_samples)
print(dr_samples)

adata.obs['sample'].unique(), adata.obs['cond'].unique()

In [None]:
# remove mitocondria genes
adata = adata[:,~adata.var.index.str.contains(r'^mt-')]

# filter genes
cond = np.ravel((adata.X>0).sum(axis=0)) > 10 # expressed in more than 10 cells
adata = adata[:,cond].copy()


In [None]:

# counts
x = adata.X
cov = np.ravel(np.sum(x, axis=1))
genes = adata.var.index.values

# CP10k
xn = (sparse.diags(1/cov).dot(x))*1e4

# log2(CP10k+1)
xln = xn.copy()
xln.data = np.log2(xln.data+1)

adata.layers[    'norm'] = np.array(xn.todense())
adata.layers[ 'lognorm'] = np.array(xln.todense())

In [None]:
qs_avc = []
l2fc_avc = []

for cond, obssub in adata.obs.groupby('cond'):
    # get sub
    adatasub = adata[obssub.index]
    
    # get A vs C 
    precond_a = adatasub.obs['scores_a'].rank(ascending=False)<=200
    precond_b = adatasub.obs['scores_b'].rank(ascending=False)<=200
    precond_c = adatasub.obs['scores_c'].rank(ascending=False)<=200
    
    cond_a = np.all([ precond_a, ~precond_b, ~precond_c], axis=0)
    cond_b = np.all([~precond_a,  precond_b, ~precond_c], axis=0)
    cond_c = np.all([~precond_a, ~precond_b,  precond_c], axis=0)
    
    print(precond_a.sum(), 
          precond_b.sum(), 
          precond_c.sum(),)
    print(cond_a.sum(), cond_b.sum(), cond_c.sum())
    
    adatasub_a = adatasub[cond_a]
    adatasub_c = adatasub[cond_c]
    
    # DEGs
    mat_a = adatasub_a.layers['lognorm'][...]
    mat_c = adatasub_c.layers['lognorm'][...]
    ts, ps = stats.ttest_ind(mat_a, mat_c)
    rs, qs, _, _ = multipletests(np.nan_to_num(ps, nan=1).reshape(-1,), method='fdr_bh') # why nan in ps -- not expressed
    lfc = np.mean(mat_c, axis=0) - np.mean(mat_a, axis=0) # log2FC (log2CP10k as raw counts)
    
    num_sig = np.sum(np.logical_and(qs < 0.05, np.abs(lfc) > 1))
    print(cond, adatasub_a.shape, num_sig)
    
    # save this
    qs_avc.append(qs)
    l2fc_avc.append(lfc)
    
    
qs_avc = np.array(qs_avc)
l2fc_avc = np.array(l2fc_avc)

In [None]:
adata.uns['qs_avc'] = qs_avc
adata.uns['l2fc_avc'] = l2fc_avc

In [None]:
l2fc_th = 1
alpha_th = 0.05
cond_sig   = np.logical_and(qs_avc < alpha_th, np.abs(l2fc_avc) > l2fc_th)
cond_sig_c = np.logical_and(qs_avc < alpha_th, l2fc_avc >  l2fc_th)
cond_sig_a = np.logical_and(qs_avc < alpha_th, l2fc_avc < -l2fc_th)

print('num AvsC-DEGs per cond:\t', cond_sig.sum(axis=1))
print('num A-DEGs per cond:\t',    cond_sig_a.sum(axis=1))
print('num C-DEGs per cond:\t',    cond_sig_c.sum(axis=1))

In [None]:
gene_sig_instances   = cond_sig.sum(axis=0)
gene_sig_instances_a = cond_sig_a.sum(axis=0)
gene_sig_instances_c = cond_sig_c.sum(axis=0)

instances, counts = np.unique(gene_sig_instances, return_counts=True)
instances_a, counts_a = np.unique(gene_sig_instances_a, return_counts=True)
instances_c, counts_c = np.unique(gene_sig_instances_c, return_counts=True)

print('num AvsC-DEGs in num conds:\t',  counts  , np.any(cond_sig, axis=0).sum())
print('num A-DEGs in num conds:\t',     counts_a, np.any(cond_sig_a, axis=0).sum())
print('num C-DEGs in num conds:\t',     counts_c, np.any(cond_sig_c, axis=0).sum())

In [None]:
a_any = np.sort(adata.var[np.any(cond_sig_a, axis=0)].index.values)
c_any = np.sort(adata.var[np.any(cond_sig_c, axis=0)].index.values)
a_all = np.sort(adata.var[np.all(cond_sig_a, axis=0)].index.values)
c_all = np.sort(adata.var[np.all(cond_sig_c, axis=0)].index.values)
ac_overlap = np.sort(adata.var[np.logical_and(np.any(cond_sig_a, axis=0), np.any(cond_sig_c, axis=0))].index.values)

print('a any', a_any)
print('c any', c_any)
print('a all', a_all)
print('c all', c_all)
print('ac overlap', ac_overlap)

In [None]:
a_all_annots, a_all_styled, a_all_annots_styled = gene_modules.check_genes(a_all)
c_all_annots, c_all_styled, c_all_annots_styled = gene_modules.check_genes(c_all)

print("\t".join(a_all_annots_styled)) # _styled))
print("---"*10) # _styled))
print("\t".join(c_all_annots_styled)) # _styled))

In [None]:
fig, ax = plt.subplots(figsize=(4,4))
ax.plot(np.arange(1,1+11)[::-1], np.cumsum(counts_a[1:][::-1]), '-o', label='A', color='C0')
ax.plot(np.arange(1,1+11)[::-1], np.cumsum(counts_c[1:][::-1]), '-o', label='C', color='C2')
ax.set_ylim(ymin=0)
ax.legend()
ax.set_ylabel('number of TFs')
ax.set_xlabel('number of time points')
sns.despine(ax=ax)
plt.show()

# Volcano

In [None]:
def show_volcano_v2(thetypeidx, thetype, lfc, qs,
                    cond1, cond2up, cond2dn, 
                    querygenes_idx=None, 
                    gene_annots=None,
                    ax=None, bbox_to_anchor=(1,1), loc=None,
                   ): 
    """
    """
    eff = lfc[:,thetypeidx]
    pvl = -np.log10(np.clip(qs[:,thetypeidx], 1e-50, None)) # +1e-10)
    cnd_up = np.all([cond1[:,thetypeidx], 
                     cond2up[:,thetypeidx]], axis=0) 
    cnd_dn = np.all([cond1[:,thetypeidx], 
                     cond2dn[:,thetypeidx]], axis=0) 

    if ax is None: 
        fig, ax = plt.subplots()
    
    # all genes
    ax.scatter(eff, pvl, s=1, color='lightgray', rasterized=True)

    # up genes
    ax.scatter(eff[cnd_up], pvl[cnd_up], s=3, facecolors='C0', rasterized=True)
    # dn genes
    ax.scatter(eff[cnd_dn], pvl[cnd_dn], s=3, facecolors='C1', rasterized=True)
    
    # # query genes
    # add text
    if querygenes_idx is not None:
        ax.scatter(eff[querygenes_idx], pvl[querygenes_idx], s=15, 
                   # label=f'type-specific (n={len(querygenes_idx):,})',
                   facecolors='none', edgecolors='k', linewidth=1, rasterized=True)
        for idx in querygenes_idx:
            ax.text(eff[idx], pvl[idx], gene_annots[idx], fontsize=10)

            
    # ax.grid(axis='y')
    sns.despine(ax=ax)
    ax.set_xlabel('log2(FC) (C/A in CP10k)')
    ax.set_ylabel('-log10(adj. p)')
    ax.set_title(f'{thetype}')
    ax.text(1,0.1,
            f'up (n={cnd_up.sum():,})\ndown (n={cnd_dn.sum():,})', 
            ha='right',
            fontsize=10, transform=ax.transAxes)
    return 

In [None]:
genes_comm = adata.var.index.values
lfc = adata.uns['l2fc_avc'].T
qs  = adata.uns['qs_avc'].T
lfc_th, qs_th = 1, 0.05

In [None]:
cond1   = qs  <  qs_th
cond2up = lfc >  lfc_th
cond2dn = lfc < -lfc_th

In [None]:
unq_conds = natsorted(np.unique(adata.obs['cond']))
unq_conds

In [None]:
n = len(unq_conds)
fig, axs = plt.subplots(1,n,figsize=(4*n,4*1), sharex=True, sharey=True)
for cond_idx, thecond in enumerate(unq_conds):
    ax = axs.flat[cond_idx]
    show_volcano_v2(cond_idx, thecond, lfc, qs, cond1, cond2up, cond2dn, 
                    # typegenes_idx, 
                    ax=ax, bbox_to_anchor=(0.5, -0.3), loc='upper center')
    sns.despine(ax=ax)
fig.tight_layout()

# output = os.path.join(outfigdir, "volcano.pdf")
# powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
querygenes = ['Meis2','Foxp1','Cdh13','Cdh12']
querygenes_idx = basicu.get_index_from_array(genes_comm, querygenes) 
gene_annots = genes_comm

n = len(unq_conds)
fig, axs = plt.subplots(3,4,figsize=(4*4,4*3), sharex=True, sharey=True)
for cond_idx, thecond in enumerate(unq_conds):
    ax = axs.flat[cond_idx]
    show_volcano_v2(cond_idx, thecond, lfc, qs, cond1, cond2up, cond2dn, 
                    querygenes_idx=querygenes_idx, 
                    gene_annots=genes_comm,
                    ax=ax, bbox_to_anchor=(0.5, -0.3), loc='upper center')
    sns.despine(ax=ax)
fig.tight_layout()
plt.show()

In [None]:
genes_sig = genes_comm[np.any(np.abs(lfc) > 1, axis=1)]
querygenes = np.intersect1d(gene_modules.annots['tf'], genes_sig) #['Meis2','Foxp1','Cdh13','Cdh12']
querygenes_idx = basicu.get_index_from_array(genes_comm, querygenes) 
gene_annots = genes_comm

n = len(unq_conds)
fig, axs = plt.subplots(3,4,figsize=(4*4,4*3), sharex=True, sharey=True)
for cond_idx, thecond in enumerate(unq_conds):
    ax = axs.flat[cond_idx]
    show_volcano_v2(cond_idx, thecond, lfc, qs, cond1, cond2up, cond2dn, 
                    querygenes_idx=querygenes_idx, 
                    gene_annots=genes_comm,
                    ax=ax, bbox_to_anchor=(0.5, -0.3), loc='upper center')
    sns.despine(ax=ax)
fig.tight_layout()
plt.show()