In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections
from natsort import natsorted

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu


In [None]:
outdir_fig = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures"

# load gene annotation and data

In [None]:
import importlib
import scroutines
import scroutines.gene_modules
importlib.reload(scroutines)
importlib.reload(scroutines.gene_modules)

from scroutines.gene_modules import GeneModules  
gene_modules = GeneModules()
g_anno, g_color, g_styled = gene_modules.check_genes('Cdh13')
print("\t".join(g_anno))
print("\t".join(g_color))
print("\t".join(g_styled))

In [None]:
# AC genes
f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/Saumya_P6-21_AC_genes.csv'
df = pd.read_csv(f)

df_a = df.iloc[:25]
df_c = df.iloc[25:]

alltime_a = np.unique(df_a.values)
alltime_c = np.unique(df_c.values)
alltime_ac = np.hstack([alltime_a, alltime_c])

ac_overlap = np.intersect1d(alltime_a, alltime_c)

print(df_a.shape, df_c.shape, alltime_a.shape, alltime_c.shape, alltime_ac.shape, ac_overlap.shape)
df.head()

In [None]:
# use those 286 genes
# df = pd.read_csv("../../data/cheng21_cell_scrna/res/candidate_genes_vincent_0503_v2.csv")
df = pd.read_csv("../../data/cheng21_cell_scrna/res/L23-ABC-genes-n288-n286unq-annot.csv")
genes_l23 = df['gene'].astype(str).values
genes_l23a = df[df['P17on']=='A']['gene'].astype(str).values
genes_l23b = df[df['P17on']=='B']['gene'].astype(str).values
genes_l23c = df[df['P17on']=='C']['gene'].astype(str).values

print(genes_l23a.shape, genes_l23b.shape, genes_l23c.shape)
genes_grp = df['P17on'].astype(str).values
assert len(genes_l23) == len(np.unique(genes_l23))

genes_l23.shape

In [None]:
%%time
adata = anndata.read("../../data/v1_multiome/L23_allmultiome_proc_P6toP21.h5ad")
adata

In [None]:
# # define
genes = adata.var.index.values
conds = adata.obs['cond'].values
types = adata.obs['Type'].values
samps = adata.obs['sample'].values

# switch back to float64
adata.layers[    'norm'] = adata.layers['norm'][...].astype(np.float64)
adata.layers[ 'lognorm'] = np.log10(1+adata.layers['norm'][...]) # np.array(xln.todense())
adata.layers['zlognorm'] = zscore(adata.layers['lognorm'][...], axis=0)

In [None]:
pcs_p8 = adata.obsm['pca_p8']
pcs_p17on = adata.obsm['pca_p17on']

res0 = pd.DataFrame(adata.layers['zlognorm'][...], columns=genes)
res0['cond'] = conds
res0['type'] = types
res0['samp'] = samps
res0['rep']  = [samp[-1] for samp in samps]
res0['type'] = np.char.add('c', res0['type'].values.astype(str))

res1 = pd.DataFrame(pcs_p8, columns=np.char.add("p8PC", ((1+np.arange(pcs_p8.shape[1])).astype(str))))
res2 = pd.DataFrame(pcs_p17on, columns=np.char.add("p17onPC", ((1+np.arange(pcs_p17on.shape[1])).astype(str))))
res = pd.concat([res0, res1, res2], axis=1)

In [None]:
allcolors = sns.color_palette('tab20c', 20)
allcolors

In [None]:
allcolors2 = sns.color_palette('tab10', 20)
allcolors2

In [None]:
palette = collections.OrderedDict({
     "P6": allcolors[2],
     "P8": allcolors[1],
    "P10": allcolors[0],
    "P12": allcolors[4+2],
    "P14": allcolors[4+0],
    
    "P17": allcolors[8+2],
    "P21": allcolors[8+0],
    
})
cases = np.array(list(palette.keys()))

cond_order_dict = {
    'P6':  0,
    'P8':  1,
    'P10': 2,
    'P12': 3,
    'P14': 4,
    'P17': 5,
    'P21': 6,
}
unq_conds = np.array(list(cond_order_dict.keys()))
adata.obs['cond_order'] = adata.obs['cond'].apply(lambda x: cond_order_dict[x])

palette_types = collections.OrderedDict({
    'L2/3_A': allcolors2[0],
    'L2/3_B': allcolors2[1],
    'L2/3_C': allcolors2[2],
})             

palette_types = {
    'c14': 'C0', 
    'c18': 'C1',
    'c16': 'C2', 
    
    'c13': 'C0', 
    'c15': 'C1', 
    'c17': 'C2',
}
type_order = [key for key, val in palette_types.items()]
type_order

In [None]:
from scipy import stats
from matplotlib.ticker import MaxNLocator

def plot(x, y, aspect_equal=False, density=False, hue='type'):
    n = 7
    fig, axs = plt.subplots(1,n,figsize=(4*n,4*1), sharex=True, sharey=True)
    for i, (ax, cond) in enumerate(zip(axs.flat, cases)):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        if hue == 'type':
            sns.scatterplot(data=res[res['cond']==cond].sample(frac=1, replace=False),
                            x=x, y=y, 
                            hue='type',
                            hue_order=list(palette_types.keys()),
                            palette=palette_types,
                            s=3, edgecolor='none', 
                            legend=False,
                            ax=ax,
                           )
        else:
            sns.scatterplot(data=res[res['cond']==cond].sample(frac=1, replace=False),
                            x=x, y=y, 
                            hue='rep',
                            # hue_order=list(palette_types.keys()),
                            # palette=palette_types,
                            s=3, edgecolor='none', 
                            legend=False,
                            ax=ax,
                           )
            
        if density:
            sns.histplot(data=res[res['cond']==cond],
                            x=x, y=y, 
                            legend=False,
                            ax=ax,
                           )
        sns.despine(ax=ax)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=3))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
        if aspect_equal:
            ax.set_aspect('equal')
        if i > 0:
            ax.set_xlabel('')
            ax.set_ylabel('')
    plt.show()
    
def plot2(x, y, hue=None, aspect_equal=False):
    n = 7
    fig, axs = plt.subplots(1,n,figsize=(4*n,4*1), sharex=True, sharey=True)
    fig.suptitle(hue, x=0, ha='left')
    for i, (ax, cond) in enumerate(zip(axs.flat, cases)):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        alpha=0.3,
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        show = res[res['cond']==cond]
        if hue:
            ax.scatter(show[x], show[y], c=show[hue], 
                       cmap='coolwarm',
                       vmin=-3, vmax=3,
                       s=5, 
                       edgecolor='none', 
                      )
        else:
            r, p = stats.spearmanr(show[x], show[y])
            ax.scatter(show[x], show[y],  
                       s=5, 
                       edgecolor='none', 
                      )
            ax.set_title(f'{cond}\n r={r:.2f}')
        sns.despine(ax=ax)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=3))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
        if aspect_equal:
            ax.set_aspect('equal')
        if i > 0:
            ax.set_xlabel('')
            ax.set_ylabel('')
        ax.grid(False)
    fig.tight_layout()
    plt.show()

# Plot A vs C genes aligning cells along early vs late PCs

In [None]:
plot('p17onPC1', 'p17onPC2', aspect_equal=True)
plot('p8PC1', 'p8PC2', aspect_equal=True)

# get top 20% A vs C in each case

In [None]:
from statsmodels.stats.multitest import multipletests

In [None]:
qs_avc = []
l2fc_avc = []

for cond_order, obssub in adata.obs.groupby('cond_order'):
    # get sub
    adatasub = adata[obssub.index]
    
    # get A vs C 
    x = adatasub.obsm['pca_p17on'][...,0]
    cond_a = x < np.percentile(x,20)
    cond_c = x > np.percentile(x,80)
    adatasub_a = adatasub[cond_a]
    adatasub_c = adatasub[cond_c]
    
    # DEGs
    mat_a = adatasub_a.layers['lognorm'][...]
    mat_c = adatasub_c.layers['lognorm'][...]
    ts, ps = stats.ttest_ind(mat_a, mat_c)
    rs, qs, _, _ = multipletests(np.nan_to_num(ps, nan=1).reshape(-1,), method='fdr_bh') # why nan in ps -- not expressed
    lfc = np.log2(10)*(np.mean(mat_c, axis=0) - np.mean(mat_a, axis=0)) # log2FC (log10CPM as raw counts)
    
    num_sig = np.sum(np.logical_and(qs < 0.05, np.abs(lfc) > 1))
    print(cond_order, adatasub_a.shape, num_sig)
    
    # save this
    qs_avc.append(qs)
    l2fc_avc.append(lfc)
    
qs_avc = np.array(qs_avc)
l2fc_avc = np.array(l2fc_avc)

In [None]:
adata.uns['qs_avc'] = qs_avc
adata.uns['l2fc_avc'] = l2fc_avc
adata

In [None]:
l2fc_th = 1
alpha_th = 0.05
cond_sig   = np.logical_and(qs_avc < alpha_th, np.abs(l2fc_avc) > l2fc_th)
cond_sig_c = np.logical_and(qs_avc < alpha_th, l2fc_avc >  l2fc_th)
cond_sig_a = np.logical_and(qs_avc < alpha_th, l2fc_avc < -l2fc_th)

print('num AvsC-DEGs per cond:\t', cond_sig.sum(axis=1))
print('num A-DEGs per cond:\t',    cond_sig_a.sum(axis=1))
print('num C-DEGs per cond:\t',    cond_sig_c.sum(axis=1))

In [None]:
gene_sig_instances   = cond_sig.sum(axis=0)
gene_sig_instances_a = cond_sig_a.sum(axis=0)
gene_sig_instances_c = cond_sig_c.sum(axis=0)

instances, counts = np.unique(gene_sig_instances, return_counts=True)
instances_a, counts_a = np.unique(gene_sig_instances_a, return_counts=True)
instances_c, counts_c = np.unique(gene_sig_instances_c, return_counts=True)

print('num AvsC-DEGs in num conds:\t',  counts  , np.any(cond_sig, axis=0).sum())
print('num A-DEGs in num conds:\t',     counts_a, np.any(cond_sig_a, axis=0).sum())
print('num C-DEGs in num conds:\t',     counts_c, np.any(cond_sig_c, axis=0).sum())

In [None]:
a_any = np.sort(adata.var[np.any(cond_sig_a, axis=0)].index.values)
c_any = np.sort(adata.var[np.any(cond_sig_c, axis=0)].index.values)
a_all = np.sort(adata.var[np.all(cond_sig_a, axis=0)].index.values)
c_all = np.sort(adata.var[np.all(cond_sig_c, axis=0)].index.values)
ac_overlap = np.sort(adata.var[np.logical_and(np.any(cond_sig_a, axis=0), np.any(cond_sig_c, axis=0))].index.values)

print('a any', a_any)
print('c any', c_any)
print('a all', a_all)
print('c all', c_all)
print('ac overlap', ac_overlap)

In [None]:
a_all_annots, a_all_styled, a_all_annots_styled = gene_modules.check_genes(a_all)
c_all_annots, c_all_styled, c_all_annots_styled = gene_modules.check_genes(c_all)

print("\t".join(a_all))
print("\t".join(a_all_annots_styled)) # _styled))

In [None]:
print("\t".join(c_all))
print("\t".join(c_all_annots_styled)) # _styled))

In [None]:
print("\033[0;m alltime A:", "\t".join(a_all_styled))
print("\033[0;m alltime C:", "\t".join(c_all_styled))

In [None]:
fig, ax = plt.subplots(1,1,figsize=(4,4))
ax.bar(np.arange(1,1+7)-0.2, cond_sig_a.sum(axis=1), label='A', color='C0', width=0.4, edgecolor='none')
ax.bar(np.arange(1,1+7)+0.2, cond_sig_c.sum(axis=1), label='C', color='C2', width=0.4, edgecolor='none')
ax.set_xticklabels(['6', '8', '10', '12', '14', '17', '21'])
ax.set_xticks(np.arange(1,1+7))
ax.grid(False, axis='x')

ax.set_ylim(ymin=0)
ax.legend()
ax.set_ylabel('number of gene instances')
ax.set_xlabel('time (P)')
sns.despine(ax=ax)
plt.show()

In [None]:
cond_sig_a.sum(axis=1).sum()

In [None]:
cond_sig_c.sum(axis=1).sum()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(4,4))
ax.bar(np.arange(1,1+7)-0.2, counts_a[1:], label='A', color='C0', width=0.4, edgecolor='none')
ax.bar(np.arange(1,1+7)+0.2, counts_c[1:], label='C', color='C2', width=0.4, edgecolor='none')
ax.set_xticks(np.arange(1,1+7))
ax.grid(False, axis='x')

ax.set_ylim(ymin=0)
ax.legend()
ax.set_ylabel('number of genes')
ax.set_xlabel('number of time points')
sns.despine(ax=ax)
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(4,4))
ax.bar(np.arange(3)-0.2, [np.sum(counts_a[1:4]), np.sum(counts_a[4:7]), counts_a[7]], label='A', color='C0', width=0.4, edgecolor='none')
ax.bar(np.arange(3)+0.2, [np.sum(counts_c[1:4]), np.sum(counts_c[4:7]), counts_c[7]], label='C', color='C2', width=0.4, edgecolor='none')
ax.set_xticks(np.arange(3))
ax.set_xticklabels(['1~3', '4~6', '7 (all)'])
ax.grid(False, axis='x')

ax.set_ylim(ymin=0)
ax.legend()
ax.set_ylabel('number of genes')
ax.set_xlabel('number of time points')
sns.despine(ax=ax)
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(4,4))
ax.bar(np.arange(1)-0.2, [np.sum(counts_a[1:])], label='A', color='C0', width=0.4, edgecolor='none')
ax.bar(np.arange(1)+0.2, [np.sum(counts_c[1:])], label='C', color='C2', width=0.4, edgecolor='none')
ax.set_xticks(np.arange(3))
ax.grid(False, axis='x')

ax.set_ylim(ymin=0)
ax.legend()
ax.set_ylabel('number of genes')
sns.despine(ax=ax)
plt.show()

In [None]:
# %%time
# fout1 = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/rna_qs_avc_p6to21.txt'
# fout2 = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/rna_l2fc_avc_p6to21.txt'

# np.savetxt(fout1, qs_avc)
# np.savetxt(fout2, l2fc_avc)

# Volcano

In [None]:
def show_volcano_v2(thetypeidx, thetype, lfc, qs,
                    cond1, cond2up, cond2dn, 
                    querygenes_idx=None, 
                    gene_annots=None,
                    ax=None, bbox_to_anchor=(1,1), loc=None,
                    clip_p=None,
                   ): 
    """
    """
    eff = lfc[:,thetypeidx]
    pvl = -np.log10(qs[:,thetypeidx])
    if clip_p is not None:
        pvl = np.clip(pvl, None, clip_p)
    cnd_up = np.all([cond1[:,thetypeidx], 
                     cond2up[:,thetypeidx]], axis=0) 
    cnd_dn = np.all([cond1[:,thetypeidx], 
                     cond2dn[:,thetypeidx]], axis=0) 

    if ax is None: 
        fig, ax = plt.subplots()
    
    # all genes
    ax.scatter(eff, pvl, s=1, color='lightgray', rasterized=True)
    print(cnd_up.sum())

    # up genes
    ax.scatter(eff[cnd_up], pvl[cnd_up], s=3, color='C0', rasterized=True)
    # dn genes
    ax.scatter(eff[cnd_dn], pvl[cnd_dn], s=3, color='C1', rasterized=True)
    
    # # query genes
    # add text
    if querygenes_idx is not None:
        ax.scatter(eff[querygenes_idx], pvl[querygenes_idx], s=15, 
                   # label=f'type-specific (n={len(querygenes_idx):,})',
                   facecolors='none', edgecolors='k', linewidth=1, rasterized=True)
        for idx in querygenes_idx:
            ax.text(eff[idx], pvl[idx], gene_annots[idx], fontsize=10)

            
    # ax.grid(axis='y')
    sns.despine(ax=ax)
    ax.set_xlabel('log2(FC) (C/A in CP10k)')
    ax.set_ylabel('-log10(adj. p)')
    ax.set_title(f'{thetype}')
    ax.text(1,0.1,
            f'up (n={cnd_up.sum():,})\ndown (n={cnd_dn.sum():,})', 
            ha='right',
            fontsize=10, transform=ax.transAxes)
    return 

In [None]:
unq_condidx = adata.obs['cond_order'].unique()
unq_condidx, unq_conds

In [None]:
genes_comm = adata.var.index.values
lfc = l2fc_avc.T #adata.uns['l2fc_avc'].T
qs  = qs_avc.T # adata.uns['qs_avc'].T
lfc_th, qs_th = 1, 0.05

In [None]:
cond1   = qs  <  qs_th
cond2up = lfc >  lfc_th
cond2dn = lfc < -lfc_th

In [None]:
thetypeidx = 0 # type A 
thetype = 'L2/3_A'
show_volcano_v2(thetypeidx, thetype, lfc, qs, cond1, cond2up, cond2dn, clip_p=300) #  typegenes_idx)
plt.show()

In [None]:
n = len(unq_conds)
fig, axs = plt.subplots(1,n,figsize=(4*n,4*1), sharex=True, sharey=True)
for cond_idx in unq_condidx:
    ax = axs.flat[cond_idx]
    thecond = unq_conds[cond_idx]
    show_volcano_v2(cond_idx, thecond, lfc, qs, cond1, cond2up, cond2dn, 
                    # typegenes_idx, 
                    clip_p=300,
                    ax=ax, bbox_to_anchor=(0.5, -0.3), loc='upper center')
    sns.despine(ax=ax)
fig.tight_layout()

# output = os.path.join(outfigdir, "volcano.pdf")
# powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
querygenes = ['Meis2','Foxp1','Cdh13','Cdh12']
querygenes_idx = basicu.get_index_from_array(genes_comm, querygenes) 
gene_annots = genes_comm

n = len(unq_conds)
fig, axs = plt.subplots(2,4,figsize=(4*4,4*2), sharex=True, sharey=True)
for cond_idx in unq_condidx:
    ax = axs.flat[cond_idx]
    thecond = unq_conds[cond_idx]
    show_volcano_v2(cond_idx, thecond, lfc, qs, cond1, cond2up, cond2dn, 
                    clip_p=300,
                    querygenes_idx=querygenes_idx, 
                    gene_annots=genes_comm,
                    ax=ax, bbox_to_anchor=(0.5, -0.3), loc='upper center')
    sns.despine(ax=ax)
fig.tight_layout()
plt.show()

In [None]:
genes_sig = genes_comm[np.any(np.abs(lfc) > 1, axis=1)]
querygenes = np.intersect1d(gene_modules.annots['tf'], genes_sig) #['Meis2','Foxp1','Cdh13','Cdh12']
querygenes_idx = basicu.get_index_from_array(genes_comm, querygenes) 
gene_annots = genes_comm

n = len(unq_conds)
fig, axs = plt.subplots(2,4,figsize=(4*4,4*2), sharex=True, sharey=True)
for cond_idx in unq_condidx:
    ax = axs.flat[cond_idx]
    thecond = unq_conds[cond_idx]
    show_volcano_v2(cond_idx, thecond, lfc, qs, cond1, cond2up, cond2dn, 
                    clip_p=300,
                    querygenes_idx=querygenes_idx, 
                    gene_annots=genes_comm,
                    ax=ax, bbox_to_anchor=(0.5, -0.3), loc='upper center')
    sns.despine(ax=ax)
fig.tight_layout()
plt.show()

# A vs C genes

In [None]:
unq_conds

In [None]:
# genes = adata_rna.var.index.values
degs_a = []
degs_c = []
for i, t in enumerate(unq_conds):
    cond_a = np.logical_and(qs_avc[i]<0.05, l2fc_avc[i] < -1)
    cond_c = np.logical_and(qs_avc[i]<0.05, l2fc_avc[i] >  1)
    deg_a = np.sort(genes[cond_a])
    deg_c = np.sort(genes[cond_c])
    degs_a.append(deg_a)
    degs_c.append(deg_c)
    
    print(f"{t}")
    _, _, gs = gene_modules.check_genes(deg_a)
    print("\t".join(gs))
    print("\033[0;m ---")
    _, _, gs = gene_modules.check_genes(deg_c)
    print("\t".join(gs))
    print("\033[0;m ---")
    

In [None]:
# organize it into a heatmap
degs_a_union = np.unique(np.hstack(degs_a))
degs_c_union = np.unique(np.hstack(degs_c))
a_annos, a_colors, a_styled = gene_modules.check_genes(degs_a_union)
c_annos, c_colors, c_styled = gene_modules.check_genes(degs_c_union)
colordict_a = {g:c for g, c in zip(degs_a_union, a_colors)}
colordict_c = {g:c for g, c in zip(degs_c_union, c_colors)}

degs_a_union.shape, degs_c_union.shape, np.intersect1d(degs_a_union, degs_c_union)

In [None]:
degs_a_union_tf  = degs_a_union[np.array(a_colors)=='blue']
degs_c_union_tf  = degs_c_union[np.array(c_colors)=='blue']

degs_a_union_cam = degs_a_union[np.array(a_colors)=='red']
degs_c_union_cam = degs_c_union[np.array(c_colors)=='red']

In [None]:
degs_a_union_tf, degs_c_union_tf

In [None]:
mat_a = np.zeros((len(degs_a_union),7)).astype(int)
for i, deg_a in enumerate(degs_a):
    indices = basicu.get_index_from_array(degs_a_union, deg_a) 
    mat_a[indices,i] = 1
mat_a = pd.DataFrame(mat_a, index=degs_a_union)
    
mat_c = np.zeros((len(degs_c_union),7)).astype(int)
for i, deg_c in enumerate(degs_c):
    indices = basicu.get_index_from_array(degs_c_union, deg_c) 
    mat_c[indices,i] = 1
mat_c = pd.DataFrame(mat_c, index=degs_c_union)
    

In [None]:
num_times = mat_a.values.sum(axis=1)
fst_times = np.array([np.min(np.nonzero(row)) for row in mat_a.values])
avg_times = np.array([np.mean(np.nonzero(row)) for row in mat_a.values])

df_a = pd.DataFrame() # index=degs_a_union)
df_a['gene'] = degs_a_union
df_a['num_times'] = pd.cut(num_times, bins=[0,3.5,6.5,7.5], labels=[0,1,2])
df_a['fst_times'] = fst_times
df_a['avg_times'] = avg_times
df_a = df_a.sort_values(['num_times', 'fst_times', 'avg_times'])
df_a

In [None]:
num_times = mat_c.values.sum(axis=1)
fst_times = np.array([np.min(np.nonzero(row)) for row in mat_c.values])
avg_times = np.array([np.mean(np.nonzero(row)) for row in mat_c.values])

df_c = pd.DataFrame() # index=degs_c_union)
df_c['gene'] = degs_c_union
df_c['num_times'] = pd.cut(num_times, bins=[0,3.5,6.5,7.5], labels=[0,1,2])
df_c['fst_times'] = fst_times
df_c['avg_times'] = avg_times
df_c = df_c.sort_values(['num_times', 'fst_times', 'avg_times'])
df_c

In [None]:
def recolor_yticks(ax, colordict, fontsize=None):
    """
    """
    # Get the current tick labels
    yticks = ax.get_yticklabels()

    # Apply colors to the y-tick labels
    for tick in yticks:
        tick.set_fontsize(fontsize)
        
        text = tick.get_text()
        if text in colordict.keys():
            color = colordict[text]
            tick.set_color(color)
            
    return ax

In [None]:
# subset to TFs
# subset to CAMs

In [None]:
with sns.axes_style('whitegrid'): # , {'ticks': 'off'}):
    mosaic = np.vstack([['A']*len(mat_a)+['N']*(len(mat_c)-len(mat_a)), ['B']*len(mat_c)]).T
    fig, axdict = plt.subplot_mosaic(mosaic, figsize=(4,10))
    ax = axdict['N']
    ax.axis('off')

    # tweak length such that the same genes hav ethe same width
    ax = axdict['A']
    ax.set_title('A genes', fontsize=10)
    sns.heatmap(mat_a.loc[df_a['gene'].values], 
                cbar=False, #dict(shrink=0.5), 
                xticklabels=False,
                yticklabels=True, 
                cmap='gray_r', ax=ax,)

    # recolor
    recolor_yticks(ax, colordict_a, fontsize=6)
    # grid
    m, n = mat_a.shape
    ax.hlines(np.arange(m), 0, n, linewidth=0.5, color='gray') 
    ax.vlines(np.arange(n), 0, m, linewidth=0.5, color='gray') 


    ax = axdict['B']
    ax.set_title('C genes', fontsize=10)
    sns.heatmap(mat_c.loc[df_c['gene'].values], 
                cbar=False, #dict(shrink=0.5), 
                xticklabels=False,
                yticklabels=True, 
                cmap='gray_r', ax=ax,)

    # recolor
    recolor_yticks(ax, colordict_c, fontsize=6)
    # grid
    m, n = mat_c.shape
    ax.hlines(np.arange(m), 0, n, linewidth=0.5, color='gray') 
    ax.vlines(np.arange(n), 0, m, linewidth=0.5, color='gray') 

    fig.subplots_adjust(wspace=1)
    plt.show()

In [None]:
mat_a_tf = mat_a.loc[[g for g in df_a['gene'].values if g in degs_a_union_tf]]
mat_c_tf = mat_c.loc[[g for g in df_c['gene'].values if g in degs_c_union_tf]]

mat_a_cam = mat_a.loc[[g for g in df_a['gene'].values if g in degs_a_union_cam]]
mat_c_cam = mat_c.loc[[g for g in df_c['gene'].values if g in degs_c_union_cam]]

In [None]:
with sns.axes_style('whitegrid'): # , {'ticks': 'off'}):
    # tweak length such that the same genes hav ethe same width
    mosaic = np.vstack([['A']*len(mat_a_tf)+['N']*(len(mat_c_tf)-len(mat_a_tf)), ['B']*len(mat_c_tf)]).T
    fig, axdict = plt.subplot_mosaic(mosaic, figsize=(4,3))
    ax = axdict['N']
    ax.axis('off')
    
    ax = axdict['A']
    mat = mat_a_tf
    colordict = colordict_a
    ax.set_title('A TFs', fontsize=10)
    sns.heatmap(mat, 
                cbar=False, #dict(shrink=0.5), 
                xticklabels=False,
                yticklabels=True, 
                cmap='gray_r', ax=ax,)

    # recolor
    recolor_yticks(ax, colordict, fontsize=6)
    # grid
    m, n = mat.shape
    ax.hlines(np.arange(m), 0, n, linewidth=0.5, color='gray') 
    ax.vlines(np.arange(n), 0, m, linewidth=0.5, color='gray') 


    ax = axdict['B']
    mat = mat_c_tf
    colordict = colordict_c
    ax.set_title('C TFs', fontsize=10)
    sns.heatmap(mat, 
                cbar=False, #dict(shrink=0.5), 
                xticklabels=False,
                yticklabels=True, 
                cmap='gray_r', ax=ax,)


    # recolor
    recolor_yticks(ax, colordict, fontsize=6)
    # grid
    m, n = mat.shape
    ax.hlines(np.arange(m), 0, n, linewidth=0.5, color='gray') 
    ax.vlines(np.arange(n), 0, m, linewidth=0.5, color='gray') 

    fig.subplots_adjust(wspace=1)
    plt.show()

In [None]:
with sns.axes_style('whitegrid'): # , {'ticks': 'off'}):
    # tweak length such that the same genes hav ethe same width
    mosaic = np.vstack([['A']*len(mat_a_cam)+['N']*(len(mat_c_cam)-len(mat_a_cam)), ['B']*len(mat_c_cam)]).T
    fig, axdict = plt.subplot_mosaic(mosaic, figsize=(4,5))
    ax = axdict['N']
    ax.axis('off')
    
    ax = axdict['A']
    mat = mat_a_cam
    colordict = colordict_a
    ax.set_title('A CAMs', fontsize=10)
    sns.heatmap(mat, 
                cbar=False, #dict(shrink=0.5), 
                xticklabels=False,
                yticklabels=True, 
                cmap='gray_r', ax=ax,)

    # recolor
    recolor_yticks(ax, colordict, fontsize=6)
    # grid
    m, n = mat.shape
    ax.hlines(np.arange(m), 0, n, linewidth=0.5, color='gray') 
    ax.vlines(np.arange(n), 0, m, linewidth=0.5, color='gray') 


    ax = axdict['B']
    mat = mat_c_cam
    colordict = colordict_c
    ax.set_title('C CAMs', fontsize=10)
    sns.heatmap(mat, 
                cbar=False, #dict(shrink=0.5), 
                xticklabels=False,
                yticklabels=True, 
                cmap='gray_r', ax=ax,)


    # recolor
    recolor_yticks(ax, colordict, fontsize=6)
    # grid
    m, n = mat.shape
    ax.hlines(np.arange(m), 0, n, linewidth=0.5, color='gray') 
    ax.vlines(np.arange(n), 0, m, linewidth=0.5, color='gray') 

    fig.subplots_adjust(wspace=1)
    plt.show()