In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections
from natsort import natsorted

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu


In [None]:
outdir_fig = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures"

# load gene annotation and data

In [None]:
# AC genes
f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/Saumya_P6-21_AC_genes.csv'
df = pd.read_csv(f)

df_a = df.iloc[:25]
df_c = df.iloc[25:]

alltime_a = np.unique(df_a.values)
alltime_c = np.unique(df_c.values)
alltime_ac = np.hstack([alltime_a, alltime_c])

ac_overlap = np.intersect1d(alltime_a, alltime_c)

print(df_a.shape, df_c.shape, alltime_a.shape, alltime_c.shape, alltime_ac.shape, ac_overlap.shape)
df.head()

In [None]:
# use those 286 genes
# df = pd.read_csv("../../data/cheng21_cell_scrna/res/candidate_genes_vincent_0503_v2.csv")
df = pd.read_csv("../../data/cheng21_cell_scrna/res/L23-ABC-genes-n288-n286unq-annot.csv")
genes_l23 = df['gene'].astype(str).values
genes_l23a = df[df['P17on']=='A']['gene'].astype(str).values
genes_l23b = df[df['P17on']=='B']['gene'].astype(str).values
genes_l23c = df[df['P17on']=='C']['gene'].astype(str).values

print(genes_l23a.shape, genes_l23b.shape, genes_l23c.shape)
genes_grp = df['P17on'].astype(str).values
assert len(genes_l23) == len(np.unique(genes_l23))

genes_l23.shape

In [None]:
adata = anndata.read("../../data/v1_multiome/L23_allmultiome_raw.h5ad")
adata

In [None]:
sample_labels = ["-".join(cell.split(' ')[0].split('-')[2:]).replace('-2023', '') for cell in adata.obs.index]
time_labels = [s[:-1].replace('DR', '') for s in sample_labels]

adata.obs['n_counts'] = adata.obs['nCount_RNA']
adata.obs['sample'] = sample_labels
adata.obs['time']   = time_labels

uniq_samples = natsorted(np.unique(sample_labels))
uniq_times = natsorted(np.unique(time_labels))

nr_samples = [s for s in uniq_samples if "DR" not in s]
dr_samples = [s for s in uniq_samples if "DR" in s]
print(uniq_times)
print(nr_samples)
print(dr_samples)

# adata.obs['sample'] = sample_labels

In [None]:
# select samples
adata = adata[adata.obs['sample'].isin(nr_samples)]
adata = adata[adata.obs['time'].isin(['P6', 'P8', 'P10', 'P12', 'P14', 'P17', 'P21'])]
adata.obs['cond'] = adata.obs['time']

# remove mitocondria genes
adata = adata[:,~adata.var['features'].str.contains(r'^mt-')]

# select
# adata = adata[adata.obs['cond'].str.contains(r"NR$")]
adata.obs['sample'].unique(), adata.obs['cond'].unique()

In [None]:
# define
# genes = adata.var.index.values
conds = adata.obs['cond'].values
types = adata.obs['Type'].values
samps = adata.obs['sample'].values

# organize
rename = {
    "L2/3_A": "L2/3_A",
    "L2/3_B": "L2/3_B",
    "L2/3_C": "L2/3_C",
    
    "L2/3_1": "L2/3_A",
    "L2/3_2": "L2/3_B",
    "L2/3_3": "L2/3_C",
    
    "L2/3_AB": "L2/3_A",
    "L2/3_BC": "L2/3_C",
}
adata.obs['easitype'] = adata.obs['Type'] #.apply(lambda x: rename[x])
adata

In [None]:
adata.obs['Type'].unique()

In [None]:
# filter genes
cond = np.ravel((adata.X>0).sum(axis=0)) > 10 # expressed in more than 10 cells
adata = adata[:,cond].copy()
genes = adata.var.index.values

# counts
x = adata.X
cov = adata.obs['n_counts'].values
genes = adata.var.index.values

# CP10k
# xn = x/cov.reshape(x.shape[0], -1)*1e4
xn = (sparse.diags(1/cov).dot(x))*1e4

# log10(CP10k+1)
xln = xn.copy()
xln.data = np.log10(xln.data+1)

adata.layers[    'norm'] = np.array(xn.todense())
adata.layers[ 'lognorm'] = np.array(xln.todense())
adata.layers['zlognorm'] = zscore(np.array(xln.todense()), axis=0)

In [None]:
# select HVGs with mean and var
nbin = 20
qth = 0.3

# min
gm = np.ravel(xn.mean(axis=0))

# var
tmp = xn.copy()
tmp.data = np.power(tmp.data, 2)
gv = np.ravel(tmp.mean(axis=0))-gm**2

# cut 
lbl = pd.qcut(gm, nbin, labels=np.arange(nbin))
gres = pd.DataFrame()
gres['name'] = genes
gres['lbl'] = lbl
gres['mean'] = gm
gres['var'] = gv
gres['ratio']= gv/gm

# select
gres_sel = gres.groupby('lbl')['ratio'].nlargest(int(qth*(len(gm)/nbin))) #.reset_index()
gsel_idx = np.sort(gres_sel.index.get_level_values(1).values)
assert np.all(gsel_idx != -1)


In [None]:
# gsel_idx = np.union1d(gsel_idx, gaba_idx)

# # 
l23_gidx = basicu.get_index_from_array(genes, genes_l23)
l23_gidx = l23_gidx[l23_gidx!=-1]
assert np.all(l23_gidx != -1)

In [None]:
fig, axs = plt.subplots(1,2, figsize=(6*2,4), sharex=True, sharey=True)
ax = axs[0]
ax.scatter(gm, gv/gm, s=5, edgecolor='none', color='gray')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('mean expr (CP10k)')
ax.set_ylabel('var/mean')
ax.set_title(f'{len(gm):,} genes')

ax = axs[1]
ax.scatter(gm, gv/gm, s=5, edgecolor='none', color='gray')
ax.scatter(gm[gsel_idx], (gv/gm)[gsel_idx], c=lbl[gsel_idx], s=5, edgecolor='none', cmap='viridis_r', label=f'{len(gsel_idx):,} HVGs')
# ax.scatter(gm[l23_gidx], (gv/gm)[l23_gidx], s=5, facecolors='none', edgecolor='C0', label='L2/3 type genes')
# ax.scatter(gm[it_gidx], (gv/gm)[it_gidx], s=5, facecolors='none', edgecolor='C1', label='IT genes')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('mean expr (CP10k)')
ax.set_ylabel('var/mean')
ax.legend(bbox_to_anchor=(1,1))

plt.show()

In [None]:
lognorm_full = np.array(adata.layers['lognorm'])
zlognorm_full = np.array(adata.layers['zlognorm'])
zlognorm_full.shape

In [None]:
adata_sub = adata[:,l23_gidx]
genes_sel = adata_sub.var.index.values

lognorm = np.array(adata_sub.layers['lognorm'])
zlognorm = np.array(adata_sub.layers['zlognorm'])
zlognorm.shape

In [None]:
adata_p8 = adata_sub[adata_sub.obs['cond'].isin(['P8'])]
lognorm_p8 = np.array(adata_p8.layers['lognorm'])
zlognorm_p8 = np.array(adata_p8.layers['zlognorm'])

pca = PCA(n_components=10)
pca.fit(zlognorm_p8)
pcs_p8 = pca.transform(zlognorm)

pcs_p8[:,0] = -pcs_p8[:,0] # flip PC1
pcs_p8[:,2] = -pcs_p8[:,2] # flip PC3

adata_p17on = adata_sub[adata_sub.obs['cond'].isin(['P17', 'P21', 'P28'])]
lognorm_p17on = np.array(adata_p17on.layers['lognorm'])
zlognorm_p17on = np.array(adata_p17on.layers['zlognorm'])

pca2 = PCA(n_components=10)
pca2.fit(zlognorm_p17on)
pcs_p17on = pca2.transform(zlognorm)

pcs_p17on[:,1] = -pcs_p17on[:,1] # flip PC2

# ucs = UMAP(n_components=2, n_neighbors=50).fit_transform(pcs)

In [None]:
vt = pca.components_
topgenes_p8_pc1 = genes_sel[np.argsort(np.abs(vt[0]))[::-1]]
print(topgenes_p8_pc1[:10])

topgenes_p8_pc2 = genes_sel[np.argsort(np.abs(vt[1]))[::-1]]
print(topgenes_p8_pc2[:10])

topgenes_p8_pc3 = genes_sel[np.argsort(np.abs(vt[2]))[::-1]]
print(topgenes_p8_pc3[:10])

# np.sort(np.abs(vt[3]))[::-1]

In [None]:
genes_sel[np.argsort(np.abs(vt[1]))[::-1]]
np.unique(genes_sel).shape

In [None]:
# # fix pc1 to make sure a < c:
# pc1 = pcs[:,0]
# pc_types, unq_types = basicu.group_mean(pc1.reshape(-1,1), types)
# a = pc_types[0,0]
# c = pc_types[-1,0]
# if a > c:
#     pcs[:,0] = -pcs[:,0]

In [None]:
res1 = pd.DataFrame(pcs_p8, columns=np.char.add("p8PC", ((1+np.arange(pcs_p8.shape[1])).astype(str))))
res1['cond'] = conds
res1['type'] = types
res1['samp'] = samps
res1['rep']  = [samp[-1] for samp in samps]


# res00 = pd.DataFrame(zlognorm_full, columns=np.char.add(adata.var.index.values.astype(str), '_full'))
res0 = pd.DataFrame(zlognorm_full, columns=genes)#  genes_sel)
res2 = pd.DataFrame(pcs_p17on, columns=np.char.add("p17onPC", ((1+np.arange(pcs_p17on.shape[1])).astype(str))))
res = pd.concat([res0, res1, res2], axis=1)
res['type'] = np.char.add('c', res['type'].values.astype(str))

In [None]:
fig, ax = plt.subplots(figsize=(6,4))
ax.plot(np.arange(len(pca.explained_variance_ratio_))+1, pca.explained_variance_ratio_, '-o', markersize=5)
ax.plot(np.arange(len(pca2.explained_variance_ratio_))+1, pca2.explained_variance_ratio_, '-o', markersize=5)
ax.axhline(1/lognorm.shape[1], linestyle='--', color='gray')
ax.set_xlabel('PC')
ax.set_ylabel('explained var')

In [None]:
allcolors = sns.color_palette('tab20c', 20)
allcolors

In [None]:
allcolors2 = sns.color_palette('tab10', 20)
allcolors2

In [None]:
palette = collections.OrderedDict({
     "P6": allcolors[2],
     "P8": allcolors[1],
    "P10": allcolors[0],
    "P12": allcolors[4+2],
    "P14": allcolors[4+0],
    
    "P17": allcolors[8+2],
    "P21": allcolors[8+0],
    # "P28NR": allcolors[5],
    # "P38NR": allcolors[4],
    
})

palette_types = collections.OrderedDict({
    'L2/3_A': allcolors2[0],
    'L2/3_B': allcolors2[1],
    'L2/3_C': allcolors2[2],
    
    'L2/3_1': allcolors2[0],
    'L2/3_2': allcolors2[1],
    'L2/3_3': allcolors2[2],
    
    'L2/3_AB': allcolors2[0],
    'L2/3_BC': allcolors2[2],
})             

palette_types = {
    'c14': 'C0', 
    'c18': 'C1',
    'c16': 'C2', 
    
    'c13': 'C0', 
    'c15': 'C1', 
    'c17': 'C2',
}
type_order = [key for key, val in palette_types.items()]
type_order

cases = np.array(list(palette.keys()))

In [None]:
fig, ax = plt.subplots(figsize=(1,4))
for i, (key, item) in enumerate(palette.items()):
    ax.plot(0,     len(palette)-i, 'o', c=item, )
    ax.text(0.02,  len(palette)-i, key, va='center', fontsize=15)
    ax.axis('off')
plt.show()

In [None]:
from scipy import stats
from matplotlib.ticker import MaxNLocator


def plot(x, y, aspect_equal=False, density=False, hue='type'):
    n = 7
    fig, axs = plt.subplots(1,n,figsize=(4*n,4*1), sharex=True, sharey=True)
    for i, (ax, cond) in enumerate(zip(axs.flat, cases)):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        if hue == 'type':
            sns.scatterplot(data=res[res['cond']==cond].sample(frac=1, replace=False),
                            x=x, y=y, 
                            hue='type',
                            hue_order=list(palette_types.keys()),
                            palette=palette_types,
                            s=3, edgecolor='none', 
                            legend=False,
                            ax=ax,
                           )
        else:
            sns.scatterplot(data=res[res['cond']==cond].sample(frac=1, replace=False),
                            x=x, y=y, 
                            hue='rep',
                            # hue_order=list(palette_types.keys()),
                            # palette=palette_types,
                            s=3, edgecolor='none', 
                            legend=False,
                            ax=ax,
                           )
            
        if density:
            sns.histplot(data=res[res['cond']==cond],
                            x=x, y=y, 
                            legend=False,
                            ax=ax,
                           )
        sns.despine(ax=ax)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=3))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
        if aspect_equal:
            ax.set_aspect('equal')
        if i > 0:
            ax.set_xlabel('')
            ax.set_ylabel('')
    plt.show()
    
def plot2(x, y, hue=None, aspect_equal=False):
    n = 7
    fig, axs = plt.subplots(1,n,figsize=(4*n,4*1), sharex=True, sharey=True)
    fig.suptitle(hue, x=0, ha='left')
    for i, (ax, cond) in enumerate(zip(axs.flat, cases)):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        alpha=0.3,
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        show = res[res['cond']==cond]
        if hue:
            ax.scatter(show[x], show[y], c=show[hue], 
                       cmap='coolwarm',
                       vmin=-3, vmax=3,
                       s=5, 
                       edgecolor='none', 
                      )
        else:
            r, p = stats.spearmanr(show[x], show[y])
            ax.scatter(show[x], show[y],  
                       s=5, 
                       edgecolor='none', 
                      )
            ax.set_title(f'{cond}\n r={r:.2f}')
        sns.despine(ax=ax)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=3))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
        if aspect_equal:
            ax.set_aspect('equal')
        if i > 0:
            ax.set_xlabel('')
            ax.set_ylabel('')
        ax.grid(False)
    fig.tight_layout()
    plt.show()

# Plot A vs C genes aligning cells along early vs late PCs

In [None]:
plot('p17onPC1', 'p17onPC2', aspect_equal=True)
plot('p8PC1', 'p8PC2', aspect_equal=True)
plot('p8PC1', 'p8PC3', aspect_equal=True)

In [None]:
# genes_viz = ['Pcdh19', 'Pcdh9', 'Pcdh15']

genes_viz = ['Rfx3', 'Cux1', 
             'Meis2', 'Nfib', 
             'Foxp1', 'Tox', 
             'Rora',]
for g in genes_viz:
    plot2('p17onPC1', 'p17onPC2', hue=g, aspect_equal=True)
    
# for g in genes_viz:
#     plot2('p8PC1', 'p8PC2', hue=g, aspect_equal=True)

In [None]:
for g in genes_viz:
    plot2('p8PC1', 'p8PC2', hue=g, aspect_equal=True)

# Plot A vs C genes for all cells (heatmap) ordered along PC17on

In [None]:
sample_order_dict = {
     'P6': 1,
     'P8': 2,
    'P10': 3,
    'P12': 4,
    'P14': 5,
    'P17': 6,
    'P21': 7,
} 

res['sample_order'] = res['cond'].apply(lambda x: sample_order_dict[x])
cols = np.hstack(['cond', 'sample_order', 'p17onPC1', alltime_a, alltime_c])
ressub = res[cols].sort_values(['sample_order', 'p17onPC1'])
ressub

In [None]:
num_cells_cumsum = np.cumsum(ressub['cond'].value_counts(sort=False))
num_cells_cumsum_shift = pd.Series(np.hstack([0, num_cells_cumsum.values[:-1]]), index=num_cells_cumsum.index)
num_cells_cumsum_shift

# rank genes by peak expression

In [None]:
from sklearn.cluster import KMeans
def assign_ordered_gene_cluster(mat, k, direction=1):
    # assign gene clusters, and get centroids
    np.random.seed(0)
    kmeans = KMeans(n_clusters=k, n_init=10)
    labels = kmeans.fit_predict(mat)
    cntrds = kmeans.cluster_centers_

    # give centroids a meaningful order / name (PCA)
    cntrds_order_pca = np.argsort(direction*PCA(n_components=1, random_state=0).fit_transform(cntrds)[:,0])
    old2new_labels = np.zeros(k).astype(int)
    old2new_labels[cntrds_order_pca] = np.arange(k)
    
    newlabels = old2new_labels[labels]
    newcntrds = cntrds[cntrds_order_pca]
    
    return newlabels, newcntrds

def get_shifted_cumsum(labels):
    return np.hstack([0, np.cumsum(np.unique(labels, return_counts=True)[1])[:-1]])

In [None]:
mat_a = ressub[np.hstack([alltime_a, []])].T
labels_a, cntrds_a = assign_ordered_gene_cluster(mat_a, 3, direction=-1)
cumsum_a = get_shifted_cumsum(labels_a)

mat_c = ressub[np.hstack([alltime_c, []])].T
labels_c, cntrds_c = assign_ordered_gene_cluster(mat_c, 3, direction=1)
cumsum_c = get_shifted_cumsum(labels_c)

# check this
sns.heatmap(cntrds_a, 
            cmap='coolwarm', vmax=1, vmin=-1)
plt.show()

sns.heatmap(cntrds_c, 
            cmap='coolwarm', vmax=1, vmin=-1)
plt.show()

In [None]:
fig, axs = plt.subplots(2,1,figsize=(25,12))
ax = axs[0]
matshow = mat_a.iloc[np.argsort(labels_a)]
gene_cumsum = cumsum_a
sns.heatmap(matshow, 
            vmax=3, vmin=-3, cmap='coolwarm', 
            xticklabels=False, 
            yticklabels=True, 
            cbar_kws=dict(shrink=0.5),
            ax=ax,
           )
ax.set_yticklabels(ax.get_yticklabels(), fontsize=5)
for cond, num_cell in num_cells_cumsum_shift.items():
    ax.text(num_cell, 0, cond) 
    ax.axvline(num_cell, color='k', linewidth=1, linestyle='--')
for num in gene_cumsum:
    ax.axhline(num, color='k', linewidth=1, linestyle='--')
ax.set_ylabel('A genes')
    
ax = axs[1]
matshow = mat_c.iloc[np.argsort(labels_c)]
gene_cumsum = cumsum_c
sns.heatmap(matshow, 
            vmax=3, vmin=-3, cmap='coolwarm', 
            xticklabels=False, 
            yticklabels=True, 
            cbar_kws=dict(shrink=0.5),
            ax=ax,
           )
ax.set_yticklabels(ax.get_yticklabels(), fontsize=5)
for cond, num_cell in num_cells_cumsum_shift.items():
    ax.axvline(num_cell, color='k', linewidth=1, linestyle='--')
for num in gene_cumsum:
    ax.axhline(num, color='k', linewidth=1, linestyle='--')
ax.set_ylabel('C genes')

fig.subplots_adjust(hspace=0.01)
plt.show()

# Annot the matrix


In [None]:
annots = {}

f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/annot/Mus_musculus_TF.txt'
annot = pd.read_csv(f, sep='\t')
annots['tf'] = annot['Symbol'].values

f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/annot/genes_cadherins.txt'
annot = pd.read_csv(f, sep='\t')
annots['cad'] = annot['Approved symbol'].apply(lambda x: x[0]+x[1:].lower()).values

f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/annot/genes_igsf.txt'
annot = pd.read_csv(f, sep='\t')
annots['igsf'] = annot['Approved symbol'].apply(lambda x: x[0]+x[1:].lower()).values

f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/annot/genes_ephephrins.txt'
annot = pd.read_csv(f, sep='\t')
annots['eph'] = annot['Approved symbol'].apply(lambda x: x[0]+x[1:].lower()).values

f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/annot/genes_gpcr.txt'
annot = pd.read_csv(f, sep='\t')
annots['gpcr'] = annot['Approved symbol'].apply(lambda x: x[0]+x[1:].lower()).values

f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/annot/genes_ion_channels.txt'
annot = pd.read_csv(f, sep='\t')
annots['channel'] = annot['Approved symbol'].apply(lambda x: x[0]+x[1:].lower()).values
annots


In [None]:
# # add annot
# for i, g in enumerate(matshow.index):
#     if g in annots['tf']:
#         ycolors[i] = 'blue'
#         # ax.text(0,i,g, ha='right', color='blue')


def recolor_yticks(ax, candidates, color, fontsize=None):
    """
    """
    # Get the current tick labels
    yticks = ax.get_yticklabels()

    # Apply colors to the y-tick labels
    for tick in yticks:
        text = tick.get_text()
        if text in candidates:
            tick.set_color(color)
            tick.set_fontsize(fontsize)
            
    return ax


In [None]:
fig, axs = plt.subplots(2,1,figsize=(25,12))
ax = axs[0]
matshow = mat_a.iloc[np.argsort(labels_a)]
gene_cumsum = cumsum_a
sns.heatmap(matshow, 
            vmax=3, vmin=-3, cmap='coolwarm', 
            xticklabels=False, 
            yticklabels=True, 
            cbar_kws=dict(shrink=0.5),
            ax=ax,
           )
for cond, num_cell in num_cells_cumsum_shift.items():
    ax.text(num_cell, 0, cond) 
    ax.axvline(num_cell, color='k', linewidth=1, linestyle='--')
for num in gene_cumsum:
    ax.axhline(num, color='k', linewidth=1, linestyle='--')
ax.set_ylabel('A genes')
ax.tick_params(axis='y', which='both', labelsize=5, rotation=0)

# Apply colors to the y-tick labels
recolor_yticks(ax, annots['tf'], 'blue', fontsize=5)
recolor_yticks(ax, annots['cad'], 'red', fontsize=5)
recolor_yticks(ax, annots['igsf'], 'red', fontsize=5)
recolor_yticks(ax, annots['eph'], 'red', fontsize=5)
recolor_yticks(ax, annots['gpcr'], 'orange', fontsize=5)
recolor_yticks(ax, annots['channel'], 'orange', fontsize=5)

ax = axs[1]
matshow = mat_c.iloc[np.argsort(labels_c)]
gene_cumsum = cumsum_c
sns.heatmap(matshow, 
            vmax=3, vmin=-3, cmap='coolwarm', 
            xticklabels=False, 
            yticklabels=True, 
            cbar_kws=dict(shrink=0.5),
            ax=ax,
           )
for cond, num_cell in num_cells_cumsum_shift.items():
    ax.axvline(num_cell, color='k', linewidth=1, linestyle='--')
for num in gene_cumsum:
    ax.axhline(num, color='k', linewidth=1, linestyle='--')
ax.set_ylabel('C genes')
ax.tick_params(axis='y', which='both', labelsize=5, rotation=0)
# Apply colors to the y-tick labels
recolor_yticks(ax, annots['tf'], 'blue', fontsize=5)
recolor_yticks(ax, annots['cad'], 'red', fontsize=5)
recolor_yticks(ax, annots['igsf'], 'red', fontsize=5)
recolor_yticks(ax, annots['eph'], 'red', fontsize=5)
recolor_yticks(ax, annots['gpcr'], 'orange', fontsize=5)
recolor_yticks(ax, annots['channel'], 'orange', fontsize=5)

fig.subplots_adjust(hspace=0.01)
plt.show()

# Subset of genes 

In [None]:
matshow.loc[[g for g in matshow.index if g in annots['tf']]]

In [None]:
fig, axs = plt.subplots(2,1,figsize=(25,5))
ax = axs[0]
matshow = mat_a.iloc[np.argsort(labels_a)]
matshow = matshow.loc[[g for g in matshow.index if g in annots['tf']]]
# gene_cumsum = cumsum_a
sns.heatmap(matshow, 
            vmax=3, vmin=-3, cmap='coolwarm', 
            xticklabels=False, 
            yticklabels=True, 
            cbar_kws=dict(shrink=0.5),
            ax=ax,
           )
for cond, num_cell in num_cells_cumsum_shift.items():
    ax.text(num_cell, 0, cond) 
    ax.axvline(num_cell, color='k', linewidth=1, linestyle='--')
# for num in gene_cumsum:
#     ax.axhline(num, color='k', linewidth=1, linestyle='--')
ax.set_ylabel('A genes')
ax.tick_params(axis='y', which='both', labelsize=10, rotation=0)


ax = axs[1]
matshow = mat_c.iloc[np.argsort(labels_c)]
matshow = matshow.loc[[g for g in matshow.index if g in annots['tf']]]
# gene_cumsum = cumsum_c
sns.heatmap(matshow, 
            vmax=3, vmin=-3, cmap='coolwarm', 
            xticklabels=False, 
            yticklabels=True, 
            cbar_kws=dict(shrink=0.5),
            ax=ax,
           )
for cond, num_cell in num_cells_cumsum_shift.items():
    ax.axvline(num_cell, color='k', linewidth=1, linestyle='--')
# for num in gene_cumsum:
#     ax.axhline(num, color='k', linewidth=1, linestyle='--')
ax.set_ylabel('C genes')
ax.tick_params(axis='y', which='both', labelsize=10, rotation=0)

fig.subplots_adjust(hspace=0.01)
plt.show()

# GO analysis

In [None]:
# labels_a
# ressub
mat_a

In [None]:
# import gseapy as gp

In [None]:
# gp.get_library_name()

In [None]:
for i in range(3):
    print('--')
    genesout = np.hstack([mat_a[labels_a==i].index.values,
                          mat_c[labels_c==i].index.values,
                         ])
    np.savetxt(f'/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/results/ac_overtime_{i}.txt', genesout, fmt='%s')
