In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections
from natsort import natsorted

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP
from statsmodels.stats.multitest import multipletests

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu
from scroutines.gene_modules import GeneModules  


In [None]:
outfigdir = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures/250409"
!mkdir -p $outfigdir

# load gene annotation and data

In [None]:
gene_modules = GeneModules()
g, gs, ms = gene_modules.check_genes('Cdh13')
print("\t".join(g))
print("\t".join(gs))
print("\t".join(ms))

In [None]:
genes_alltime_hvgs = np.loadtxt('/u/home/f/f7xiesnm/v1_multiome/l23_alltime_hvgs_n4940.txt', dtype='str')
genes_alltime_hvgs

In [None]:
# use those 286 genes
df = pd.read_csv("/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/cheng21_cell_scrna/res/L23-ABC-genes-n288-n286unq-annot_v3_july8_2024.csv")
genes_l23 = df['gene'].astype(str).values
genes_l23a = df[df['P17on']=='A']['gene'].astype(str).values
genes_l23b = df[df['P17on']=='B']['gene'].astype(str).values
genes_l23c = df[df['P17on']=='C']['gene'].astype(str).values

print(genes_l23a.shape, genes_l23b.shape, genes_l23c.shape)
genes_grp = df['P17on'].astype(str).values
assert len(genes_l23) == len(np.unique(genes_l23))

genes_l23.shape

In [None]:
genes_alltime_hvgs_rm_l23 = genes_alltime_hvgs[~np.isin(genes_alltime_hvgs, genes_l23)]
genes_alltime_hvgs_rm_l23.shape

In [None]:
scores_abc = pd.read_csv("/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/scores_l23abc.csv", 
                         index_col=0,
                        )
scores_abc['scores_c-a'] = scores_abc['scores_c'] - scores_abc['scores_a']
scores_abc

In [None]:
adata = anndata.read("/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/superdupermegaRNA_hasraw_multiome_l23.h5ad")
adata

In [None]:
adata.X = adata.raw.X

In [None]:
adata.obs['scores_a'] = scores_abc.loc[adata.obs.index,'scores_a'].copy()
adata.obs['scores_b'] = scores_abc.loc[adata.obs.index,'scores_b'].copy()
adata.obs['scores_c'] = scores_abc.loc[adata.obs.index,'scores_c'].copy()
adata.obs['scores_c-a'] = scores_abc.loc[adata.obs.index,'scores_c-a'].copy()

In [None]:
sample_labels = adata.obs['Sample'].values
time_labels = [s[:-1].replace('DR', '') for s in sample_labels]

adata.obs['sample'] = sample_labels #
adata.obs['time']   = time_labels

uniq_samples = natsorted(np.unique(sample_labels))
nr_samples = [s for s in uniq_samples if "DR" not in s]
dr_samples = [s for s in uniq_samples if "DR" in s]

uniq_conds = np.array(natsorted(np.unique(adata.obs['cond'].values)))
print(uniq_conds)

In [None]:
nr_idx = np.array([0,1,2,4,6,8,10])
dr_idx = np.array([3,5,7,9])

nr_times = np.array([6,8,10,12,14,17,21])
dr_times = np.array(       [12,14,17,21])

In [None]:
# remove mitocondria genes
adata = adata[:,~adata.var.index.str.contains(r'^mt-')]
# remove sex genes
adata = adata[:,~adata.var.index.str.contains(r'^Xist$')]

# filter genes
cond = np.ravel((adata.X>0).sum(axis=0)) > 10 # expressed in more than 10 cells
adata = adata[:,cond].copy()

In [None]:
adata

In [None]:
# output = '/u/home/f/f7xiesnm/v1_multiome/multiome_l23_allgenes.txt'

# all_genes = adata.var.index.values
# print(all_genes.shape)
# np.savetxt(output, all_genes, fmt='%s')

In [None]:
# counts
x = adata.X
cov = np.ravel(np.sum(x, axis=1))
genes = adata.var.index.values

# CP10k
xn = (sparse.diags(1/cov).dot(x))*1e4

# log2(CP10k+1)
xln = xn.copy()
xln.data = np.log2(xln.data+1)

adata.layers[    'norm'] = np.array(xn.todense())
adata.layers[ 'lognorm'] = np.array(xln.todense())

In [None]:
genes_idx_alltime_hvgs_rm_l23 = basicu.get_index_from_array(adata.var.index.values, genes_alltime_hvgs_rm_l23)
genes_idx_alltime_hvgs_rm_l23

In [None]:
np.random.rand(0)

num_archetypal_cells = 100
offset = 1 # CP10k + offset (CPM + 100*offset)
SHUFFLE = False #False

n_pseudo_genes = 35

n_cond = len(uniq_conds)
n_gene = adata.shape[1] 

qs_tensor   = np.zeros((n_cond,3,n_gene))  # 3 represents 3 pairwise comparisons (ca, ba, bc)
l2fc_tensor = np.zeros((n_cond,3,n_gene))


for cond_code, cond in enumerate(uniq_conds):
    # get sub
    adatasub = adata[adata.obs['cond']==cond]
    n_cells = adatasub.shape[0]
    
    # get A vs C 
    ranks_ac = adatasub.obs['scores_c-a'].rank()
    ranks_b  = adatasub.obs['scores_b'].rank()
    
    precond_a = ranks_ac <= num_archetypal_cells
    precond_c = ranks_ac > adatasub.shape[0] - num_archetypal_cells
    precond_b = ranks_b  > adatasub.shape[0] - num_archetypal_cells
    
    cond_a = np.all([ precond_a, ~precond_b, ~precond_c], axis=0)
    cond_b = np.all([~precond_a,  precond_b, ~precond_c], axis=0)
    cond_c = np.all([~precond_a, ~precond_b,  precond_c], axis=0)
    
    # SHUFFLE
    if SHUFFLE:
        adatasub = adatasub[np.random.choice(n_cells, size=n_cells, replace=False)]
    
    # print(precond_a.sum(), 
    #       precond_b.sum(), 
    #       precond_c.sum(),)
    print(cond, cond_a.sum(), cond_b.sum(), cond_c.sum())
    
    adatasub_a = adatasub[cond_a]
    adatasub_b = adatasub[cond_b]
    adatasub_c = adatasub[cond_c]
    
    # DEGs
    mat_a = adatasub_a.layers['norm'][...]
    mat_b = adatasub_b.layers['norm'][...]
    mat_c = adatasub_c.layers['norm'][...]
    
    logmat_a = adatasub_a.layers['lognorm'][...]
    logmat_b = adatasub_b.layers['lognorm'][...]
    logmat_c = adatasub_c.layers['lognorm'][...]
    
    ts_ca, ps_ca = stats.ttest_ind(logmat_c, logmat_a)
    ts_ba, ps_ba = stats.ttest_ind(logmat_b, logmat_a)
    ts_bc, ps_bc = stats.ttest_ind(logmat_b, logmat_c)
    
    _, qs_ca, _, _ = multipletests(np.nan_to_num(ps_ca, nan=1).reshape(-1,), method='fdr_bh') # why nan in ps -- not expressed
    _, qs_ba, _, _ = multipletests(np.nan_to_num(ps_ba, nan=1).reshape(-1,), method='fdr_bh') # why nan in ps -- not expressed
    _, qs_bc, _, _ = multipletests(np.nan_to_num(ps_bc, nan=1).reshape(-1,), method='fdr_bh') # why nan in ps -- not expressed
    
    l2fc_ca = np.log2(np.mean(mat_c, axis=0)+offset) - np.log2(np.mean(mat_a, axis=0)+offset) # log2FC (CP10k as raw counts)
    l2fc_ba = np.log2(np.mean(mat_b, axis=0)+offset) - np.log2(np.mean(mat_a, axis=0)+offset) # log2FC (CP10k as raw counts)
    l2fc_bc = np.log2(np.mean(mat_b, axis=0)+offset) - np.log2(np.mean(mat_c, axis=0)+offset) # log2FC (CP10k as raw counts)
    
    qs_a   = np.minimum(qs_ca, qs_ba) # the better of the two
    qs_c   = np.minimum(qs_ca, qs_bc) # the better of the two
    qs_b   = np.minimum(qs_ba, qs_bc) # the better of the two

    l2fc_a = np.max([-l2fc_ca, -l2fc_ba], axis=0) # mean fold change
    l2fc_c = np.max([ l2fc_ca, -l2fc_bc], axis=0) # mean fold change
    l2fc_b = np.max([ l2fc_ba,  l2fc_bc], axis=0) # mean fold change

    cond_sig_a = np.all([-l2fc_ca > 0, -l2fc_ba > 0, l2fc_a > 1, qs_a < 0.05], axis=0)
    cond_sig_c = np.all([ l2fc_ca > 0, -l2fc_bc > 0, l2fc_c > 1, qs_c < 0.05], axis=0)
    cond_sig_b = np.all([ l2fc_ba > 0,  l2fc_bc > 0, l2fc_b > 1, qs_b < 0.05], axis=0)
    
    # save this
    l2fc_tensor[cond_code, 0] = l2fc_ca
    l2fc_tensor[cond_code, 1] = l2fc_ba
    l2fc_tensor[cond_code, 2] = l2fc_bc
    
    qs_tensor[cond_code, 0] = qs_ca
    qs_tensor[cond_code, 1] = qs_ba
    qs_tensor[cond_code, 2] = qs_bc
    
    # [0,1] scaled scores
    mat = adatasub.layers['lognorm'][...]
    mins = np.min(mat, axis=0)
    maxs = np.max(mat, axis=0)
    nmat = (mat - mins)/(maxs-mins+1e-10)
    
    # print(cond, cond_sig_a.sum(), cond_sig_c.sum(), cond_sig_b.sum()) 

# output, check results and stats

In [None]:
# %%time
# fout1 = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/DEG_l23abc_qs_250409.npy'
# fout2 = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/DEG_l23abc_l2fc_250409.npy'
# fout3 = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/DEG_l23abc_gene_list_250409.csv'

# np.save(fout1, qs_tensor)
# np.save(fout2, l2fc_tensor)

In [None]:
# qs_tensor = np.load(fout1)
# l2fc_tensor = np.load(fout2)

l2fc_th = np.log2(2)
l2fc_th_s = np.log2(1.2)
alpha_th = 0.05

In [None]:
qs_ca   = qs_tensor[:,0,:]
qs_ba   = qs_tensor[:,1,:]
qs_bc   = qs_tensor[:,2,:]

l2fc_ca = l2fc_tensor[:,0,:]
l2fc_ba = l2fc_tensor[:,1,:]
l2fc_bc = l2fc_tensor[:,2,:]

In [None]:
qs_a   = np.minimum(qs_ca, qs_ba) # the better of the two
qs_c   = np.minimum(qs_ca, qs_bc) # the better of the two
qs_b   = np.minimum(qs_ba, qs_bc) # the better of the two

l2fc_a = np.max([-l2fc_ca, -l2fc_ba], axis=0) # mean fold change
l2fc_c = np.max([ l2fc_ca, -l2fc_bc], axis=0) # mean fold change
l2fc_b = np.max([ l2fc_ba,  l2fc_bc], axis=0) # mean fold change


l2fc_as = np.min([-l2fc_ca, -l2fc_ba], axis=0) # mean fold change
l2fc_cs = np.min([ l2fc_ca, -l2fc_bc], axis=0) # mean fold change
l2fc_bs = np.min([ l2fc_ba,  l2fc_bc], axis=0) # mean fold change


cond_sig_a = np.all([-l2fc_ca > 0, -l2fc_ba > 0, l2fc_a > l2fc_th, qs_a < alpha_th], axis=0)
cond_sig_c = np.all([ l2fc_ca > 0, -l2fc_bc > 0, l2fc_c > l2fc_th, qs_c < alpha_th], axis=0)
# cond_sig_a = np.all([-l2fc_ca > 0, -l2fc_ba > 0, l2fc_a > l2fc_th, qs_a < alpha_th, l2fc_as > l2fc_th_s], axis=0)
# cond_sig_c = np.all([ l2fc_ca > 0, -l2fc_bc > 0, l2fc_c > l2fc_th, qs_c < alpha_th, l2fc_cs > l2fc_th_s], axis=0)
cond_sig_b = np.all([ l2fc_ba > 0,  l2fc_bc > 0, l2fc_b > l2fc_th, qs_b < alpha_th, l2fc_bs > l2fc_th_s], axis=0)


instances, counts_a = np.unique(cond_sig_a.sum(axis=0), return_counts=True)
instances, counts_c = np.unique(cond_sig_c.sum(axis=0), return_counts=True)
instances, counts_b = np.unique(cond_sig_b.sum(axis=0), return_counts=True)

print('num A genes for each cond:\t', cond_sig_a.sum(axis=1), np.any(cond_sig_a, axis=0).sum())
print('num C genes for each cond:\t', cond_sig_c.sum(axis=1), np.any(cond_sig_c, axis=0).sum())
print('num B genes for each cond:\t', cond_sig_b.sum(axis=1), np.any(cond_sig_b, axis=0).sum())

print('num A genes in num conds:\t',  counts_a[1:])
print('num C genes in num conds:\t',  counts_c[1:])
print('num B genes in num conds:\t',  counts_b[1:])

In [None]:
df_res_all = []

for label, cond_sig in zip(['A', 'C', 'B'], 
                           [cond_sig_a, cond_sig_c, cond_sig_b]):
    
    cond_idx, gene_idx = np.nonzero(cond_sig.astype(int))
    
    df_res = pd.DataFrame()
    df_res['cond'] = uniq_conds[cond_idx]
    df_res['gene'] = genes[gene_idx]
    df_res['archetype'] = label
    df_res_all.append(df_res)
    
df_res_all = pd.concat(df_res_all)
df_res_all
    

In [None]:
df_res_all.groupby(['cond', 'archetype']).size().unstack()

In [None]:
df_res_all.groupby('gene').sum()

In [None]:
# df_res_all.to_csv(fout3, header=True, index=False)

# further check

In [None]:
a_any = np.sort(adata.var[np.any(cond_sig_a, axis=0)].index.values)
a_all = np.sort(adata.var[np.all(cond_sig_a, axis=0)].index.values)

c_any = np.sort(adata.var[np.any(cond_sig_c, axis=0)].index.values)
c_all = np.sort(adata.var[np.all(cond_sig_c, axis=0)].index.values)
ac_overlap = np.sort(adata.var[np.logical_and(np.any(cond_sig_a, axis=0), np.any(cond_sig_c, axis=0))].index.values)

b_any = np.sort(adata.var[np.any(cond_sig_b, axis=0)].index.values)
b_all = np.sort(adata.var[np.all(cond_sig_b, axis=0)].index.values)

print('a any', a_any.shape)
print('a all', a_all.shape)

print('c any', c_any.shape)
print('c all', c_all.shape)
print('ac overlap', ac_overlap.shape)

print('b any', b_any.shape)
print('b all', b_all.shape)

In [None]:
a_all_annots, a_all_styled, a_all_annots_styled = gene_modules.check_genes(a_all)
c_all_annots, c_all_styled, c_all_annots_styled = gene_modules.check_genes(c_all)
b_all_annots, b_all_styled, b_all_annots_styled = gene_modules.check_genes(b_all)

ac_overlap_annots, ac_overlap_styled, ac_overlap_annots_styled = gene_modules.check_genes(ac_overlap)

print("\t".join(a_all_annots_styled)) # _styled))
print("---"*10) # _styled))
print("\t".join(c_all_annots_styled)) # _styled))
print("---"*10) # _styled))
print("\t".join(b_all_annots_styled)) # _styled))
print("---"*10) # _styled))
print("\t".join(ac_overlap_annots_styled)) # _styled))

In [None]:
nums_a = cond_sig_a.sum(axis=1)
nums_c = cond_sig_c.sum(axis=1)
nums_b = cond_sig_b.sum(axis=1)

fig, axs = plt.subplots(1,2,figsize=(4*2,4))
ax = axs[0]
ax.plot(nr_times, nums_a[nr_idx], '-o' , fillstyle='none', label='A NR', color='C0')
ax.plot(dr_times, nums_a[dr_idx], '--s', fillstyle='none', label='A DR', color='C0')
ax.plot(nr_times, nums_c[nr_idx], '-o' , fillstyle='none', label='C NR', color='C2')
ax.plot(dr_times, nums_c[dr_idx], '--s', fillstyle='none', label='C DR', color='C2')
ax.set_xticks(nr_times)
ax.grid(False, axis='x')
ax.set_ylim(ymin=0) # , ymax=120)
ax.legend()
ax.set_ylabel('num. of gene')
ax.set_xlabel('time (P)')
sns.despine(ax=ax)

ax = axs[1]
ax.plot(nr_times, nums_b[nr_idx], '-o' , fillstyle='none', label='B NR', color='C1')
ax.plot(dr_times, nums_b[dr_idx], '--s', fillstyle='none', label='B DR', color='C1')
ax.set_xticks(nr_times)
ax.grid(False, axis='x')
ax.set_ylim(ymin=0) #, ymax=80)
ax.legend()
sns.despine(ax=ax)


output = os.path.join(outfigdir, 'num_degs_abc_1.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
nums_a = cond_sig_a.sum(axis=1)
nums_b = cond_sig_b.sum(axis=1)
nums_c = cond_sig_c.sum(axis=1)

fig, axs = plt.subplots(1,3,figsize=(3*3,4))
ax = axs[0]
ax.plot(nr_times, nums_a[nr_idx], '-o' , fillstyle='none', label='A', color='C0')
ax.set_xticks([6,10,14,17,21])
ax.grid(False, axis='x')
ax.set_ylim(ymin=0, ymax=110)
# ax.legend()
ax.set_ylabel('num. of gene')
ax.set_xlabel('time (P)')
ax.set_title('A genes')
sns.despine(ax=ax)

ax = axs[1]
ax.plot(nr_times, nums_c[nr_idx], '-o' , fillstyle='none', label='C', color='C2')
ax.set_xticks([6,10,14,17,21])
ax.grid(False, axis='x')
ax.set_ylim(ymin=0, ymax=110)
ax.set_title('C genes')
# ax.legend()
sns.despine(ax=ax)

ax = axs[2]
ax.plot(nr_times, nums_b[nr_idx], '-o' , fillstyle='none', label='B', color='C1')
ax.set_xticks([6,10,14,17,21])
ax.grid(False, axis='x')
ax.set_ylim(ymin=0, ymax=70)
# ax.legend()
sns.despine(ax=ax)
ax.set_title('B genes')
fig.tight_layout()
output = os.path.join(outfigdir, 'num_degs_abc_2.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
num_uniq_a = np.sum(np.any(cond_sig_a, axis=0))
num_uniq_b = np.sum(np.any(cond_sig_b, axis=0))
num_uniq_c = np.sum(np.any(cond_sig_c, axis=0))

# check effect size 

In [None]:
cond_sig_a_any = np.any(cond_sig_a, axis=0)
cond_sig_b_any = np.any(cond_sig_b, axis=0)
cond_sig_c_any = np.any(cond_sig_c, axis=0)

In [None]:
# adata.obs['sample'].unique()
import re

todo_conds = [
    'P12DR', 'P14DR', 'P17DR', 'P21DR',
    'P6', 'P8', 'P10', 'P12', 'P14', 'P17', 'P21', 
]
todo_samps = [
    'P12DRa', 'P12DRb',
    'P14DRa', 'P14DRb',
    'P17DRa', 'P17DRb',
    'P21DRa', 'P21DRb',
    'P6a', 'P6b', 'P6c', 
    'P8a', 'P8b', 'P8c', 
    'P10a', 'P10b', 
    'P12a', 'P12b', 'P12c', 
    'P14a', 'P14b',
    'P17a', 'P17b', 
    'P21a', 'P21b', 
]
todo_conds_t = np.array([int(re.sub(r'[a-zA-Z]', '', a)) for a in todo_conds])
todo_samps_t = np.array([int(re.sub(r'[a-zA-Z]', '', a)) for a in todo_samps])
print(todo_conds_t)
print(todo_samps_t)

def mean_over_samples(mmat_res_samp):
    """25 samples to 11 conditions
    """
    assert mmat_res_samp.shape[0] == 25
    
    mmat_res_samp_mean = np.zeros(mmat_res_samp.shape)[:11]
    mmat_res_samp_mean[0] = np.mean(mmat_res_samp[ :2], axis=0)
    mmat_res_samp_mean[1] = np.mean(mmat_res_samp[2:4], axis=0)
    mmat_res_samp_mean[2] = np.mean(mmat_res_samp[4:6], axis=0)
    mmat_res_samp_mean[3] = np.mean(mmat_res_samp[6:8], axis=0)

    mmat_res_samp_mean[4] = np.mean(mmat_res_samp[8:11], axis=0)
    mmat_res_samp_mean[5] = np.mean(mmat_res_samp[11:14], axis=0)
    mmat_res_samp_mean[6] = np.mean(mmat_res_samp[14:16], axis=0)
    mmat_res_samp_mean[7] = np.mean(mmat_res_samp[16:19], axis=0)
    mmat_res_samp_mean[8] = np.mean(mmat_res_samp[19:21], axis=0)
    mmat_res_samp_mean[9] = np.mean(mmat_res_samp[21:23], axis=0)
    mmat_res_samp_mean[10] = np.mean(mmat_res_samp[23:  ], axis=0)
    
    return mmat_res_samp_mean

def transform_bigredmat(bigmat):
    """bigmat or redmat
    to fmat and zmat
    """
    fmat = bigmat.reshape(-1, bigmat.shape[-1]).T
    fmat = np.hstack([fmat[:,4*5:], fmat[:,:4*5]]) # CHANGED COLUMN ORDER!!
    zmat = zscore(fmat, axis=1)
    
    return fmat, zmat

In [None]:
%%time

mat = adata.layers['norm'][...]
gexp_l23baseline = np.log2(np.mean(mat, axis=0)*1e2+offset) # CP10k -> CPM

n_type = 5
frac_archetypal_cells_viz = 0.2
bigmat_nfd = np.zeros((len(todo_samps), n_type, mat.shape[1]))
bigmat_abc = np.zeros((len(todo_samps),      3, mat.shape[1]))

for i, samp in enumerate(todo_samps):
    print(samp)
    
    # get sub
    adatasub = adata[adata.obs['sample']==samp]
    n_cells = adatasub.shape[0]
    
    # get A vs C 
    ranks_ac = adatasub.obs['scores_c-a'].rank()
    ranks_b  = adatasub.obs['scores_b'].rank()
    
    # per type
    cells_type_nfd = pd.qcut(ranks_ac, n_type, labels=False)
    for j in range(n_type):
        mat_j = adatasub[cells_type_nfd==j].layers['norm'][...]
        mmat_j = np.log2(np.mean(mat_j, axis=0)*1e2+offset)-gexp_l23baseline # CP10k -> CPM
        bigmat_nfd[i,j] = mmat_j
    
    # A, B, C
    num_archetypal_cells_viz = int(n_cells*frac_archetypal_cells_viz)
    
    precond_a = ranks_ac <= num_archetypal_cells_viz
    precond_c = ranks_ac > adatasub.shape[0] - num_archetypal_cells_viz
    precond_b = ranks_b  > adatasub.shape[0] - num_archetypal_cells_viz
    
    cond_a = np.all([ precond_a, ~precond_b, ~precond_c], axis=0)
    cond_b = np.all([~precond_a,  precond_b, ~precond_c], axis=0)
    cond_c = np.all([~precond_a, ~precond_b,  precond_c], axis=0)
    
    for j, cond in enumerate([cond_a, cond_b, cond_c]):
        mat_j = adatasub[cond].layers['norm'][...]
        mmat_j = np.log2(np.mean(mat_j, axis=0)*1e2+offset)-gexp_l23baseline # CP10k -> CPM
        bigmat_abc[i,j] = mmat_j


In [None]:
print(bigmat_abc.shape) # cond, type, gene

In [None]:
bigmat_abc_ig_list = [
    np.mean(bigmat_abc[:,:,cond_sig_a_any], axis=-1),
    np.mean(bigmat_abc[:,:,cond_sig_b_any], axis=-1),
    np.mean(bigmat_abc[:,:,cond_sig_c_any], axis=-1),
]

redmat_abc_ig_list = [mean_over_samples(x) for x in bigmat_abc_ig_list]

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(3*3,1*3), sharex=True, sharey=True)
for i in range(3):
    ax = axs[i]
    bigmat_mean_ig = bigmat_abc_ig_list[i]
    redmat_mean_ig = redmat_abc_ig_list[i]
    
    ax.plot(todo_samps_t[8:], bigmat_mean_ig[8:,0], 'o', markersize=5, fillstyle='none', color='C0')
    ax.plot(todo_samps_t[8:], bigmat_mean_ig[8:,1], 'o', markersize=5, fillstyle='none', color='C1')
    ax.plot(todo_samps_t[8:], bigmat_mean_ig[8:,2], 'o', markersize=5, fillstyle='none', color='C2')
    
    ax.plot(todo_samps_t[:8], bigmat_mean_ig[:8,0], 's', markersize=5, fillstyle='none', color='C0', alpha=0.5)
    ax.plot(todo_samps_t[:8], bigmat_mean_ig[:8,1], 's', markersize=5, fillstyle='none', color='C1', alpha=0.5)
    ax.plot(todo_samps_t[:8], bigmat_mean_ig[:8,2], 's', markersize=5, fillstyle='none', color='C2', alpha=0.5)
    
    ax.plot(todo_conds_t[4:], redmat_mean_ig[4:,0], '-', color='C0')
    ax.plot(todo_conds_t[4:], redmat_mean_ig[4:,1], '-', color='C1')
    ax.plot(todo_conds_t[4:], redmat_mean_ig[4:,2], '-', color='C2')
    
    ax.plot(todo_conds_t[:4], redmat_mean_ig[:4,0], '-', color='C0', alpha=0.5)
    ax.plot(todo_conds_t[:4], redmat_mean_ig[:4,1], '-', color='C1', alpha=0.5)
    ax.plot(todo_conds_t[:4], redmat_mean_ig[:4,2], '-', color='C2', alpha=0.5)

    ax.grid(False)
    ax.set_xticks([6,10,14,17,21])
    sns.despine(ax=ax)

axs[0].set_xlabel('Postnatal day (P)')
axs[0].set_ylabel('Gene expr.\nlog2(archetype / baseline)')
axs[0].set_title(f'A genes\nn={num_uniq_a:,}')
axs[1].set_title(f'B genes\nn={num_uniq_b:,}')
axs[2].set_title(f'C genes\nn={num_uniq_c:,}')
output = os.path.join(outfigdir, 'abc_degs_signals_over_time_withDR.pdf') 
powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
redmat_nfd = mean_over_samples(bigmat_nfd)
fmat_nfd, zmat_nfd = transform_bigredmat(redmat_nfd)
print(redmat_nfd.shape) # cond, type, gene
print(fmat_nfd.shape)   # gene, cond*type
print(zmat_nfd.shape)   # gene, cond*type

fmat_nfd_ag = fmat_nfd[cond_sig_a_any]
fmat_nfd_bg = fmat_nfd[cond_sig_b_any]
fmat_nfd_cg = fmat_nfd[cond_sig_c_any]
fmat_nfd_abcg = fmat_nfd[cond_sig_abc_any]

zmat_nfd_ag = zmat_nfd[cond_sig_a_any]
zmat_nfd_bg = zmat_nfd[cond_sig_b_any]
zmat_nfd_cg = zmat_nfd[cond_sig_c_any]
zmat_nfd_abcg = zmat_nfd[cond_sig_abc_any]

print(fmat_nfd_ag.shape, zmat_nfd_ag.shape, zmat_nfd_abcg.shape)

In [None]:
redmat_nfd_ag = redmat_nfd[:,:,cond_sig_a_any] #.shape
redmat_nfd_bg = redmat_nfd[:,:,cond_sig_b_any] #.shape
redmat_nfd_cg = redmat_nfd[:,:,cond_sig_c_any] #.shape
redmat_nfd_abcg = redmat_nfd[:,:,cond_sig_abc_any] #.shape
print(redmat_nfd_ag.shape)

In [None]:
from sklearn.cluster import KMeans
def mean_shape(vec):
    """
    """
    loc = np.arange(len(vec))
    
    # vec_n = (vec-np.min(vec))/(np.max(vec)-np.min(vec))
    vec_n = np.clip(vec, 0, None)
    vec_n = vec_n/np.sum(vec_n)
    
    ctrd = loc.dot(vec_n)
    return ctrd

def organize_zmat(zmat, fmat, redmat, title='', n_geneset_clsts=5, genes=None):
    """NOTE THAT THE ORDER OF COND is DIFFERRENT BETWEEN (zmat, fmat) - DR first) and (redmat) - NR first)
    """
    method = KMeans(n_clusters=n_geneset_clsts, n_init=10, random_state=0)
    geneset_clst = method.fit_predict(zmat)

    # average over genes per geneset and cell clusters - leave genesets and conditions there
    time_sketches = []
    for i in range(n_geneset_clsts):
        time_sketch = np.mean(redmat[:,:,geneset_clst==i], axis=2) # mean over genes
        time_sketch = np.max(time_sketch, axis=1) # max over cell types
        time_sketches.append(time_sketch)
    time_sketches = np.vstack(time_sketches)[:,4:] # n_geneset_clsts, n_cond (select NR only)

    # clst_order = [2,1,3,4,0]
    # clst_order = np.argsort(np.argmax(ctrds, axis=1)) 
    clst_order = np.argsort([mean_shape(time_sketch) for time_sketch in time_sketches]) 
    geneset_clst_renamed = pd.Series({clst: i for i, clst in enumerate(clst_order)}).reindex(geneset_clst).values
    geneset_order = np.argsort(geneset_clst_renamed)
    
    # reorder 
    genes_ordered = genes[geneset_order]
    clsts_ordered = geneset_clst_renamed[geneset_order]
    zmat_ordered = zmat[geneset_order] 
    fmat_ordered = fmat[geneset_order] 
    
    # gene list per group
    geneset_list = []
    for i in range(n_geneset_clsts):
        geneset_list.append(genes_ordered[clsts_ordered==i])
    
    res = {
        'title': title,
        'order': geneset_order,
        'zmat':  zmat_ordered,
        'fmat':  fmat_ordered,
        'genes': genes_ordered,
        'clst':  clsts_ordered,
        'time_sketches':  time_sketches[clst_order],
        'geneset_list': geneset_list,
    }
    return res

In [None]:
# ctrds = organize_zmat(zmat_nfd_ag, fmat_nfd_ag, redmat_nfd_ag, title='A genes', genes=genes[cond_sig_a_any])

res_a = organize_zmat(zmat_nfd_ag, fmat_nfd_ag, redmat_nfd_ag, title='A genes', genes=genes[cond_sig_a_any])
res_b = organize_zmat(zmat_nfd_bg, fmat_nfd_bg, redmat_nfd_bg, title='B genes', genes=genes[cond_sig_b_any])
res_c = organize_zmat(zmat_nfd_cg, fmat_nfd_cg, redmat_nfd_cg, title='C genes', genes=genes[cond_sig_c_any])

res_abc = organize_zmat(zmat_nfd_abcg, fmat_nfd_abcg, redmat_nfd_abcg, title='ABC genes', genes=genes[cond_sig_abc_any])

# Profile these modules

# Quantify time vs DR effect
- late are DR sensitive
- (P21NR-P10NR) vs (P21NR vs DR)
- refine this as the average time effect vs average DR effect

In [None]:
times = np.array([6,8,10,12,14,17,21])
dr_times = np.array([12,14,17,21])

In [None]:
from matplotlib.colors import LinearSegmentedColormap

colors_a = [(0.0, 'black'), (1.0, 'C0')]      
colors_b = [(0.0, 'black'), (1.0, 'C1')]      
colors_c = [(0.0, 'black'), (1.0, 'C2')]      

# Create a custom colormap using LinearSegmentedColormap
cmap_a = LinearSegmentedColormap.from_list('cmap_a', colors_a)
cmap_b = LinearSegmentedColormap.from_list('cmap_b', colors_b)
cmap_c = LinearSegmentedColormap.from_list('cmap_c', colors_c)

colors_l23 = [
    np.array(cmap_a(1.0)),
    0.7*np.array(cmap_a(1.0))+0.3*np.array(cmap_b(1.0)),
    np.array(cmap_b(1.0)),
    0.7*np.array(cmap_b(1.0))+0.3*np.array(cmap_c(1.0)),
    np.array(cmap_c(1.0)),
]

In [None]:
from sklearn.metrics import r2_score

In [None]:
def calc_del(red_y):
    del_t = np.mean(red_y[7+3]-red_y[3+3], axis=0) # mean over ABC
    del_v21 = np.mean(-red_y[7+3]+red_y[3], axis=0) # mean over ABC
    del_v17 = np.mean(-red_y[7+2]+red_y[2], axis=0)
    del_v14 = np.mean(-red_y[7+1]+red_y[1], axis=0)
    del_v12 = np.mean(-red_y[7+0]+red_y[0], axis=0)

    del_varr = np.array([del_v21, del_v17, del_v14, del_v12])
    del_v = np.mean(del_varr, axis=0) # mean over time
    del_v1721 = np.mean(del_varr[:2], axis=0) # mean over time
    
    return del_t, del_v, del_v21

def calc_del_typespec(red_y):
    del_t   = red_y[7+3]-red_y[3+3] # (n_type, n_gene)
    del_v21 = -red_y[7+3]+red_y[3]  #
    del_v17 = -red_y[7+2]+red_y[2]
    del_v14 = -red_y[7+1]+red_y[1]
    del_v12 = -red_y[7+0]+red_y[0]

    del_varr = np.array([del_v21, del_v17, del_v14, del_v12])
    del_v = np.mean(del_varr, axis=0) # mean over time
    del_v1721 = np.mean(del_varr[:2], axis=0) # mean over time
    
    return del_t, del_v, del_v21

In [None]:
cond_sig_abc_any = np.any([
    cond_sig_a_any, 
    cond_sig_b_any, 
    cond_sig_c_any, 
], axis=0)
print(cond_sig_abc_any.sum())

genes_abc = genes[cond_sig_abc_any]
big_y = bigmat_abc[:,:,cond_sig_abc_any]# .shape
red_y = mean_over_samples(big_y)
del_t_abc, del_v_abc, del_v21_abc = calc_del(red_y)
del_t_abc_typespec, del_v_abc_typespec, del_v21_abc_typespec = calc_del_typespec(red_y)

big_y_a = bigmat_abc[:,:,cond_sig_a_any]# .shape
big_y_b = bigmat_abc[:,:,cond_sig_b_any]# .shape
big_y_c = bigmat_abc[:,:,cond_sig_c_any]# .shape

red_y_a = mean_over_samples(big_y_a)
red_y_b = mean_over_samples(big_y_b)
red_y_c = mean_over_samples(big_y_c)

del_t_a, del_v_a, del_v21_a = calc_del(red_y_a)
del_t_b, del_v_b, del_v21_b = calc_del(red_y_b)
del_t_c, del_v_c, del_v21_c = calc_del(red_y_c)

# del_t_a_typespec, del_v_a_typespec = calc_del_typespec(red_y_a)
# del_t_b_typespec, del_v_b_typespec = calc_del_typespec(red_y_b)
# del_t_c_typespec, del_v_c_typespec = calc_del_typespec(red_y_c)

In [None]:
genes_a = genes[cond_sig_a_any]
genes_b = genes[cond_sig_b_any]
genes_c = genes[cond_sig_c_any]

In [None]:
_x = del_t_abc
_y = del_v_abc

_xa = del_t_a
_xb = del_t_b
_xc = del_t_c

_ya = del_v_a
_yb = del_v_b
_yc = del_v_c

n = len(_x)

r, _ = stats.pearsonr(_x, _y)
slope, intercept = np.polyfit(_x, _y, 1)
xbase = np.linspace(-4,4,5) 
ybase = slope*xbase + intercept
r2 = r2_score(_y, _x*slope+intercept)
assert r**2 - r2 < 1e-3

fig, ax = plt.subplots(figsize=(5,4))
ax.scatter(_x, _y, s=5, color='k')#s=10, facecolors='none', edgecolors='C0', linewidths=1)
    
ax.plot(xbase, ybase, '--r', linewidth=1) #, zorder=0)
ax.axvline(0, color='gray',  linewidth=1, zorder=0)
ax.axhline(0, color='gray',  linewidth=1, zorder=0)
ax.grid(False)
sns.despine(ax=ax)
ax.set_ylabel('log2(DR/NR)')
ax.set_xlabel('log2(P21/P10)')
ax.set_title(f'y={slope:.2f}x{intercept:.2f}; r={r:.2f}; n={n}', fontsize=15)

# output = os.path.join(outfigdir, f'time_vs_dr_linear.pdf') 
# powerplots.savefig_autodate(fig, output)
plt.show()

fig, axs = plt.subplots(1,3,figsize=(5*3,4), sharex=True, sharey=True)
for j in range(3):
    _xj = [_xa, _xb, _xc][j]
    _yj = [_ya, _yb, _yc][j]
    _gj = [genes_a, genes_b, genes_c][j]
    ax = axs[j]
    ax.scatter(_x, _y, s=5, color='lightgray')
    ax.scatter(_xj, _yj, s=5, color=f'C{j}')
    
    idx = np.argsort(np.abs(_yj))[::-1][:5]
    for idx_i in idx: 
        xt = _xj[idx_i]
        yt = _yj[idx_i]
        tt = _gj[idx_i]
        ax.text(xt, yt+0.01, tt, fontsize=8, va='bottom', ha='center',)

    ax.plot(xbase, ybase, '--k', linewidth=1) #, zorder=1)
    ax.axvline(0, color='gray',  linewidth=1, zorder=0)
    ax.axhline(0, color='gray',  linewidth=1, zorder=0)
    ax.grid(False)
    sns.despine(ax=ax)
    ax.set_ylabel('log2(DR/NR)')
    ax.set_xlabel('log2(P21/P10)')
    ax.set_title(f'n = {len(_gj)}', fontsize=15)

# output = os.path.join(outfigdir, f'time_vs_dr_linear.pdf') 
# powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
csm_annots = [
    'igsf',    
    'cad',     
    'fbrn',    
    'eph',     
    'sema',    
    'tene',    
    'astn',    
    'cntnap',  
    'nrxn',    
    'axon',    
    'wiring',
]

In [None]:
gene_annots_abc = gene_modules.check_genes(genes_abc)[0]
cond_csm_abc = np.array([_g in csm_annots for _g in gene_annots_abc])
print(genes_abc[cond_csm_abc], np.sum(cond_csm_abc))

gene_annots_a = gene_modules.check_genes(genes_a)[0]
cond_csm_a = np.array([_g in csm_annots for _g in gene_annots_a])
print(genes_a[cond_csm_a], np.sum(cond_csm_a))

gene_annots_b = gene_modules.check_genes(genes_b)[0]
cond_csm_b = np.array([_g in csm_annots for _g in gene_annots_b])
print(genes_b[cond_csm_b], np.sum(cond_csm_b))

gene_annots_c = gene_modules.check_genes(genes_c)[0]
cond_csm_c = np.array([_g in csm_annots for _g in gene_annots_c])
print(genes_c[cond_csm_c], np.sum(cond_csm_c))

In [None]:
# combine all of them and show the changes are more in B an C
n_abc_csm = np.sum(cond_csm_abc)
genes_abc_csm = genes_abc[cond_csm_abc]
del_v21_abc_csm = del_v21_abc[cond_csm_abc]
del_v21_abc_typespec_csm = del_v21_abc_typespec[:,cond_csm_abc]

support = np.arange(n_abc_csm)
fig, ax = plt.subplots(1,1,figsize=(1*4,1*15), sharey=True)
idx = np.argsort(del_v21_abc_csm)[::-1]

for i, isort in enumerate(idx):
    vals = del_v21_abc_typespec_csm[:,isort]
    # val_max = vals[np.argmax(np.abs(vals))]
    # ax.plot([0,val_max], [i,i], '-k')
    val_min = vals[np.argmin(vals)]
    val_max = vals[np.argmax(vals)]
    ax.plot([val_min,val_max], [i,i], '-k', zorder=1)

ax.scatter(del_v21_abc_typespec_csm[0,idx], support, color='C0', s=20)#facecolor='none') 
ax.scatter(del_v21_abc_typespec_csm[1,idx], support, color='C1', s=20) #facecolor='none') 
ax.scatter(del_v21_abc_typespec_csm[2,idx], support, color='C2', s=20) #facecolor='none') 
    
ax.set_yticks(support)
ax.set_yticklabels(genes_abc_csm[idx], rotation=0, fontsize=10)
ax.set_ylim([-1, n_abc_csm])
# ax.set_xlim([-1,1])
ax.axvline(0, color='k', linestyle='-', zorder=1, linewidth=1)
sns.despine(ax=ax)
ax.set_ylabel(f'Archetype-enriched CSMs (n={n_abc_csm})')
plt.show()

In [None]:
mat = np.abs(del_v21_abc_typespec_csm).T
print(mat.shape)
_, p_ab = stats.wilcoxon(mat[:,0], mat[:,1])
_, p_ac = stats.wilcoxon(mat[:,0], mat[:,2])
_, p_bc = stats.wilcoxon(mat[:,1], mat[:,2])
print(p_ab, p_ac, p_bc)

fig, ax = plt.subplots(figsize=(3,4))
sns.violinplot(mat, color='white', cut=0, ax=ax)
sns.stripplot(mat, ax=ax)
sns.despine(ax=ax)
ax.set_xticklabels(['A cells', 'B cells', 'C cells'])
ax.set_ylabel('|log2(DR/NR)|')
ax.set_title('archetypal cell surface molecules')
plt.show()


# vision-dependent CSMs - examples

In [None]:
def plot_genes(query, query_idx):
    n = len(query)
    fig, axs = plt.subplots(1, n, figsize=(n*3,1*4), sharex=True) # , sharey=True)
    for i in range(n):
        ax = axs[i]
        gn = query[i]
        gidx = query_idx[i]

        big_y = bigmat_abc[:,:,gidx]
        red_y = mean_over_samples(big_y) 

        ax.plot(todo_samps_t[8:], big_y[8:,0], 'o', markersize=5, fillstyle='none', color='C0')
        ax.plot(todo_samps_t[8:], big_y[8:,1], 'o', markersize=5, fillstyle='none', color='C1')
        ax.plot(todo_samps_t[8:], big_y[8:,2], 'o', markersize=5, fillstyle='none', color='C2')

        ax.plot(todo_conds_t[4:], red_y[4:,0], '-', color='C0')
        ax.plot(todo_conds_t[4:], red_y[4:,1], '-', color='C1')
        ax.plot(todo_conds_t[4:], red_y[4:,2], '-', color='C2')

        # DR
        plot_offset = 12
        ax.axvspan(12-2+1            , 21+1            , color='orange', alpha=0.1, linewidth=0, zorder=0)
        ax.axvspan(12-2+1+plot_offset, 21+1+plot_offset, color='lightgray', alpha=0.3, linewidth=0, zorder=0)

        ax.plot(todo_samps_t[:8]+plot_offset, big_y[:8,0], 's', markersize=5, fillstyle='none', color='C0')
        ax.plot(todo_samps_t[:8]+plot_offset, big_y[:8,1], 's', markersize=5, fillstyle='none', color='C1')
        ax.plot(todo_samps_t[:8]+plot_offset, big_y[:8,2], 's', markersize=5, fillstyle='none', color='C2')

        ax.plot(todo_conds_t[:4]+plot_offset, red_y[:4,0], '--', color='C0', alpha=1)
        ax.plot(todo_conds_t[:4]+plot_offset, red_y[:4,1], '--', color='C1', alpha=1)
        ax.plot(todo_conds_t[:4]+plot_offset, red_y[:4,2], '--', color='C2', alpha=1)

        ax.grid(False, axis='x')
        ax.set_xticks([6, 12, 21, 
                          12+plot_offset, 21+plot_offset])
        ax.set_xticklabels([6, 12, 21, 
                               12, 21])
        sns.despine(ax=ax)
        ax.set_title(gn)

    axs[0].set_xlabel('Postnatal day (P)')
    axs[0].set_ylabel('Gene expr.\nlog2(archetype / baseline)')
    fig.tight_layout()

    # output = os.path.join(outfigdir, f'gene_groups_abc_v4_{suptitle}_DR2.pdf') 
    # powerplots.savefig_autodate(fig, output)
    plt.show()
    
def plot_genes_nr(query, query_idx, output=None):
    n = len(query)
    fig, axs = plt.subplots(1, n, figsize=(n*3,1*4), sharex=True) # , sharey=True)
    for i in range(n):
        ax = axs[i]
        gn = query[i]
        gidx = query_idx[i]

        big_y = bigmat_abc[:,:,gidx]
        red_y = mean_over_samples(big_y) 

        ax.plot(todo_samps_t[8:], big_y[8:,0], 'o', markersize=5, fillstyle='none', color='C0')
        ax.plot(todo_samps_t[8:], big_y[8:,1], 'o', markersize=5, fillstyle='none', color='C1')
        ax.plot(todo_samps_t[8:], big_y[8:,2], 'o', markersize=5, fillstyle='none', color='C2')

        ax.plot(todo_conds_t[4:], red_y[4:,0], '-', color='C0')
        ax.plot(todo_conds_t[4:], red_y[4:,1], '-', color='C1')
        ax.plot(todo_conds_t[4:], red_y[4:,2], '-', color='C2')


        ax.grid(False, axis='x')
        ax.set_xticks([6, 10, 14, 21])
        sns.despine(ax=ax)
        ax.set_title(gn)

    axs[0].set_xlabel('Postnatal day (P)')
    axs[0].set_ylabel('Gene expr.\nlog2(archetype / baseline)')
    fig.tight_layout()

    if output is not None:
        powerplots.savefig_autodate(fig, output)
    plt.show()

In [None]:
roboslit_genes = [g for g in genes_abc if g.startswith('Robo')]+[g for g in genes_abc if g.startswith('Slit')]
query = roboslit_genes
query_idx = basicu.get_index_from_array(genes, query)
plot_genes_nr(query, query_idx)


In [None]:
query = ['Cdh13', 'Sema6a', 'Igsf9b', 'Megf11', 'Cdh4', 'Cdh7', 'Cntnap4', 'Cdh20', ]
query_idx = basicu.get_index_from_array(genes, query)
plot_genes(query, query_idx)

In [None]:
query = ['Amph', 'Dtna', 'Plcl1', 'Bdnf', 'Nptx2', 'Scg2', ]
query_idx = basicu.get_index_from_array(genes, query)
plot_genes(query, query_idx)

In [None]:
_x = del_t_abc
_y = del_v_abc
_g = genes[cond_sig_abc_any]
query_idx = basicu.get_index_from_array(genes[cond_sig_abc_any], query)

n = len(_x)

r, _ = stats.pearsonr(_x, _y)

fig, ax = plt.subplots(figsize=(4,4))
ax.scatter(_x, _y, s=5, color='lightgray')#s=10, facecolors='none', edgecolors='C0', linewidths=1)
    
ax.plot(xbase, ybase, '--k', linewidth=1) #, zorder=0)
ax.axvline(0, color='gray',  linewidth=1, zorder=0)
ax.axhline(0, color='gray',  linewidth=1, zorder=0)
ax.grid(False)
sns.despine(ax=ax)
ax.set_ylabel('log2(DR/NR)')
ax.set_xlabel('log2(P21/P10)')
ax.set_title(f'r={r:.2f}; n={n}', fontsize=15)

# idx = np.argsort(np.abs(_yj))[::-1][:10]
ax.scatter(_x[query_idx], 
           _y[query_idx], s=5, color='C1')#s=10, facecolors='none', edgecolors='C0', linewidths=1)
for idx_i in query_idx: 
    xt = _x[idx_i]
    yt = _y[idx_i]
    tt = _g[idx_i]
    ax.text(xt+0.1, yt, tt, fontsize=12, va='bottom',) # ha='center',)
        
# output = os.path.join(outfigdir, f'time_vs_dr_linear.pdf') 
# powerplots.savefig_autodate(fig, output)
plt.show()

    



In [None]:
fig, axs = plt.subplots(3,1,figsize=(6,4*3))
for ax_idx, res_this in enumerate([res_a, res_c, res_b]):
    ax = axs[ax_idx]
    
    order = res_this['order']
    title = res_this['title']
    zmat  = res_this['zmat']
    clsts = res_this['clst']
    genes_this = res_this['genes']
    
    gene_annots_this = gene_modules.check_genes(genes_this)[0]
    cond_csm_this = np.array([_g in csm_annots for _g in gene_annots_this])
    zmat = zmat[cond_csm_this]
    clsts = clsts[cond_csm_this]
    genes_this = genes_this[cond_csm_this] #res_this['genes']

    sns.heatmap(zmat[:,:7*5], cmap='coolwarm', cbar_kws=dict(shrink=0.5), 
                xticklabels=False,
                vmax=3, vmin=-3,
                rasterized=True,
                ax=ax)
    ax.set_yticks(0.5+np.arange(len(zmat)))
    ax.set_yticklabels(genes_this, fontsize=4, rotation=0)
    
    ax.hlines(np.cumsum(np.unique(clsts, return_counts=True)[1]), 0, 55, color='white', linewidth=1)
    ax.vlines(np.arange(0,35,5), 0, len(zmat), color='white', linewidth=1)
    ax.vlines(7*5, 0, len(zmat), color='black', linewidth=1)

    ax.grid(False)
    
ax = axs[0]
ax.set_xticks(0.5+np.arange(n_type))
ax.set_xticklabels(['A', '<-', '-', '->', 'C'], fontsize=10, rotation=0)
for i, cond in enumerate(np.hstack([todo_conds[4:]])):
    # ax.axvline(condcode*5, color='k', linestyle='--', linewidth=1)
    ax.text(i*5, -0.5, f'{cond}', fontsize=10, va='bottom')

# output = os.path.join(outfigdir, f'heatmap_csm_all.pdf')
# powerplots.savefig_autodate(fig, output)
plt.show()

# break

In [None]:
fig, ax = plt.subplots(1,1,figsize=(6,10))
    
res_this = res_abc

order = res_this['order']
title = res_this['title']
zmat  = res_this['zmat']
clsts = res_this['clst']
genes_this = res_this['genes']

gene_annots_this = gene_modules.check_genes(genes_this)[0]
cond_csm_this = np.array([_g in csm_annots for _g in gene_annots_this])
zmat = zmat[cond_csm_this]
clsts = clsts[cond_csm_this]
genes_this = genes_this[cond_csm_this] #res_this['genes']

sns.heatmap(zmat[:,:7*5], cmap='coolwarm', cbar_kws=dict(shrink=0.5), 
            xticklabels=False,
            vmax=3, vmin=-3,
            rasterized=True,
            ax=ax)
ax.set_yticks(0.5+np.arange(len(zmat)))
ax.set_yticklabels(genes_this, fontsize=4, rotation=0)

ax.hlines(np.cumsum(np.unique(clsts, return_counts=True)[1]), 0, 55, color='white', linewidth=1)
ax.vlines(np.arange(0,35,5), 0, len(zmat), color='white', linewidth=1)
ax.vlines(7*5, 0, len(zmat), color='black', linewidth=1)

ax.grid(False)
    
ax.set_xticks(0.5+np.arange(n_type))
ax.set_xticklabels(['A', '<-', '-', '->', 'C'], fontsize=10, rotation=0)
for i, cond in enumerate(np.hstack([todo_conds[4:]])):
    # ax.axvline(condcode*5, color='k', linestyle='--', linewidth=1)
    ax.text(i*5, -0.5, f'{cond}', fontsize=10, va='bottom')

# output = os.path.join(outfigdir, f'heatmap_csm_all.pdf')
# powerplots.savefig_autodate(fig, output)
plt.show()

# break