In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections
from natsort import natsorted

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu
from scroutines.gene_modules import GeneModules  


In [None]:
outdir_fig = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures"

# load gene annotation and data

In [None]:
gene_modules = GeneModules()
g, gs, ms = gene_modules.check_genes('Cdh13')
print("\t".join(g))
print("\t".join(gs))
print("\t".join(ms))

In [None]:
# AC genes
f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/Saumya_P6-21_AC_genes.csv'
df = pd.read_csv(f)

df_a = df.iloc[:25]
df_c = df.iloc[25:]

alltime_a = np.unique(df_a.values)
alltime_c = np.unique(df_c.values)
alltime_ac = np.hstack([alltime_a, alltime_c])

ac_overlap = np.intersect1d(alltime_a, alltime_c)

print(df_a.shape, df_c.shape, alltime_a.shape, alltime_c.shape, alltime_ac.shape, ac_overlap.shape)
df.head()

In [None]:
%%time
adata = anndata.read("../../data/v1_multiome/L23_allmultiome_proc_P6toP21.h5ad")
adata

In [None]:
f_rna1 = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/rna_qs_avc_p6to21.txt'
f_rna2 = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/rna_l2fc_avc_p6to21.txt'

rna_qs_avc = np.loadtxt(f_rna1)
rna_l2fc_avc = np.loadtxt(f_rna2)

rna_qs_avc.shape, rna_l2fc_avc.shape

In [None]:
# # define
genes = adata.var.index.values
conds = adata.obs['cond'].values
types = adata.obs['Type'].values
samps = adata.obs['sample'].values

# switch back to float64
adata.layers[    'norm'] = adata.layers['norm'][...].astype(np.float64)
adata.layers[ 'lognorm'] = np.log10(1+adata.layers['norm'][...]) # np.array(xln.todense())
adata.layers['zlognorm'] = zscore(adata.layers['lognorm'][...], axis=0)

In [None]:
pcs_p8 = adata.obsm['pca_p8']
pcs_p17on = adata.obsm['pca_p17on']

res0 = pd.DataFrame(adata.layers['zlognorm'][...], columns=genes)
res0['cond'] = conds
res0['type'] = types
res0['samp'] = samps
res0['rep']  = [samp[-1] for samp in samps]
res0['type'] = np.char.add('c', res0['type'].values.astype(str))

res1 = pd.DataFrame(pcs_p8, columns=np.char.add("p8PC", ((1+np.arange(pcs_p8.shape[1])).astype(str))))
res2 = pd.DataFrame(pcs_p17on, columns=np.char.add("p17onPC", ((1+np.arange(pcs_p17on.shape[1])).astype(str))))
res = pd.concat([res0, res1, res2], axis=1)

In [None]:
allcolors = sns.color_palette('tab20c', 20)
allcolors

In [None]:
allcolors2 = sns.color_palette('tab10', 20)
allcolors2

In [None]:
palette = collections.OrderedDict({
     "P6": allcolors[2],
     "P8": allcolors[1],
    "P10": allcolors[0],
    "P12": allcolors[4+2],
    "P14": allcolors[4+0],
    
    "P17": allcolors[8+2],
    "P21": allcolors[8+0],
    
})
cases = np.array(list(palette.keys()))

cond_order_dict = {
    'P6':  0,
    'P8':  1,
    'P10': 2,
    'P12': 3,
    'P14': 4,
    'P17': 5,
    'P21': 6,
}
unq_conds = np.array(list(cond_order_dict.keys()))
adata.obs['cond_order'] = adata.obs['cond'].apply(lambda x: cond_order_dict[x])

palette_types = collections.OrderedDict({
    'L2/3_A': allcolors2[0],
    'L2/3_B': allcolors2[1],
    'L2/3_C': allcolors2[2],
})             

palette_types = {
    'c14': 'C0', 
    'c18': 'C1',
    'c16': 'C2', 
    
    'c13': 'C0', 
    'c15': 'C1', 
    'c17': 'C2',
}
type_order = [key for key, val in palette_types.items()]
type_order

In [None]:
from scipy import stats
from matplotlib.ticker import MaxNLocator

def plot(x, y, aspect_equal=False, density=False, hue='type'):
    n = 7
    fig, axs = plt.subplots(1,n,figsize=(4*n,4*1), sharex=True, sharey=True)
    for i, (ax, cond) in enumerate(zip(axs.flat, cases)):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        if hue == 'type':
            sns.scatterplot(data=res[res['cond']==cond].sample(frac=1, replace=False),
                            x=x, y=y, 
                            hue='type',
                            hue_order=list(palette_types.keys()),
                            palette=palette_types,
                            s=3, edgecolor='none', 
                            legend=False,
                            ax=ax,
                           )
        else:
            sns.scatterplot(data=res[res['cond']==cond].sample(frac=1, replace=False),
                            x=x, y=y, 
                            hue='rep',
                            # hue_order=list(palette_types.keys()),
                            # palette=palette_types,
                            s=3, edgecolor='none', 
                            legend=False,
                            ax=ax,
                           )
            
        if density:
            sns.histplot(data=res[res['cond']==cond],
                            x=x, y=y, 
                            legend=False,
                            ax=ax,
                           )
        sns.despine(ax=ax)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=3))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
        if aspect_equal:
            ax.set_aspect('equal')
        if i > 0:
            ax.set_xlabel('')
            ax.set_ylabel('')
    plt.show()
    
def plot2(x, y, hue=None, aspect_equal=False):
    n = 7
    fig, axs = plt.subplots(1,n,figsize=(4*n,4*1), sharex=True, sharey=True)
    fig.suptitle(hue, x=0, ha='left')
    for i, (ax, cond) in enumerate(zip(axs.flat, cases)):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        alpha=0.3,
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        show = res[res['cond']==cond]
        if hue:
            ax.scatter(show[x], show[y], c=show[hue], 
                       cmap='coolwarm',
                       vmin=-3, vmax=3,
                       s=5, 
                       edgecolor='none', 
                      )
        else:
            r, p = stats.spearmanr(show[x], show[y])
            ax.scatter(show[x], show[y],  
                       s=5, 
                       edgecolor='none', 
                      )
            ax.set_title(f'{cond}\n r={r:.2f}')
        sns.despine(ax=ax)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=3))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
        if aspect_equal:
            ax.set_aspect('equal')
        if i > 0:
            ax.set_xlabel('')
            ax.set_ylabel('')
        ax.grid(False)
    fig.tight_layout()
    plt.show()

# Plot A vs C genes aligning cells along early vs late PCs

In [None]:
plot('p17onPC1', 'p17onPC2', aspect_equal=True)
plot('p8PC1', 'p8PC2', aspect_equal=True)

In [None]:
# get avsc genes

# genes = adata_rna.var.index.values
degs_a = []
degs_c = []
for i, t in enumerate(unq_conds):
    cond_a = np.logical_and(rna_qs_avc[i]<0.05, rna_l2fc_avc[i] < -1)
    cond_c = np.logical_and(rna_qs_avc[i]<0.05, rna_l2fc_avc[i] >  1)
    deg_a = np.sort(genes[cond_a])
    deg_c = np.sort(genes[cond_c])
    degs_a.append(deg_a)
    degs_c.append(deg_c)
    
    print(f"{t}")
    _, _, gs = gene_modules.check_genes(deg_a)
    print("\t".join(gs))
    print("\033[0;m ---")
    _, _, gs = gene_modules.check_genes(deg_c)
    print("\t".join(gs))
    print("\033[0;m ---")
    

# big tensor

In [None]:
# by condition
nt, nc, nr, ng = len(unq_conds), 5, 2, len(adata.var)  # time/condition, cluster, rep, gene 
tensor = np.zeros((nt, nc, nr, ng))

for i, (cond_order, obssub) in enumerate(adata.obs.groupby('cond_order')):
    adatasub = adata[obssub.index]
    
    # by type -- cut into nc=5 equal bins
    x = adatasub.obsm['pca_p17on'][...,0]
    type_lbls = pd.qcut(x, nc, labels=np.arange(nc)).astype(int)
    unq_types = np.unique(type_lbls)
    
    # by replicates -- first two
    sample_lbls = adatasub.obs['sample'].values
    unq_samples = np.unique(sample_lbls)[:nr] # first nr samples if n > nr2
    
    for j, qtype in enumerate(unq_types):
        for k, samp in enumerate(unq_samples):
            selector = ((sample_lbls==samp) & (type_lbls==qtype))
            tensor[i,j,k] = np.mean(adatasub[selector].layers['lognorm'][...], axis=0)

In [None]:
tensor.shape

In [None]:
tensor[0,0,0]

# gene triangle analysis 

In [None]:
def get_2way_eta2_allgenes(nums):
    """
    nums: c0, c1, r, g matrix - (cond0, cond1, reps, genes)
    
    return (eta2, stdv) - vectors one entry for each gene
    """
    nc0, nc1, nr, ng = nums.shape # (num cond0, cond1, num rep, num genes)

    rm   = np.mean(nums, axis=(0,1,2)) # global mean; reduced form
    rm0  = np.mean(nums, axis=(1,2))   # mean per c0 across reps and ignoring c1  
    rm1  = np.mean(nums, axis=(0,2))   # mean per c1 across reps and ignoring c0 
    rm01 = np.mean(nums, axis=(2)) # mean per (c0, c1) across reps  
    
    em   = np.expand_dims(rm  , axis=(0,1,2)) # expanded form
    em0  = np.expand_dims(rm0 , axis=(1,2))   # expanded form
    em1  = np.expand_dims(rm1 , axis=(0,2))   # expanded form
    em01 = np.expand_dims(rm01, axis=(2))     #  

    # # SSt 
    SSt  = np.sum(np.power(nums-em, 2),   axis=(0,1,2))  
    
    # # SSwr (noise)
    SSwr = np.sum(np.power(nums-em01, 2), axis=(0,1,2))  # within (c0,c1) across reps 
    
    # # SSw
    SSw0 = nr*np.sum(np.power(em01-em0, 2),  axis=(0,1,2))  # within c0 across reps and ignoring c1
    SSw1 = nr*np.sum(np.power(em01-em1, 2),  axis=(0,1,2))  # within c1 across reps and ignoring c0 
    
    # SSt = SSwr + SSexp
    # where SSexp = SSw0 + SSexp0 = SSw1 + SSexp1
    SSexp  = SSt   - SSwr
    SSexp0 = SSexp - SSw0
    SSexp1 = SSexp - SSw1
    
    # return SSt, SSwr, SSw0, SSw1
    
    o = 1e-10
    eta2_01 = (SSexp +o)/(SSt+o)
    eta2_0  = (SSexp0+o)/(SSt+o)
    eta2_1  = (SSexp1+o)/(SSt+o)
    
    return eta2_01, eta2_0, eta2_1


In [None]:
eta2_tc, eta2_t, eta2_c = get_2way_eta2_allgenes(tensor)
eta2_r = 1-eta2_tc
eta2_tic = eta2_tc-(eta2_t+eta2_c)

In [None]:
fig, ax = plt.subplots(figsize=(5,6))
sns.boxplot([eta2_t, eta2_c, 
             eta2_r, 
             # eta2_t+eta2_r, 
             # eta2_t+eta2_c+eta2_r,
             eta2_t+eta2_c, 
             eta2_tic,  
             eta2_tc,  
            ])
ax.set_xticklabels(['time', 'type', 
                    'rep', 
                    # 'time+\nrep', 
                    # 'time+\ntype+\nrep',
                    'time+\ntype', 
                    'time int\ntype', 
                    'time&\ntype', 
                   ], rotation=0, fontsize=12)
ax.set_ylabel('variance explained by')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
g = ax.scatter(eta2_t, eta2_c, c=1-eta2_tc, s=1, cmap='viridis', vmin=0, vmax=1)
fig.colorbar(g, shrink=0.5, ticks=[0, 0.5, 1], label='var exp replicates')
ax.set_aspect('equal')
ax.set_xlabel('var exp time')
ax.set_ylabel('var exp type')
plt.show()

In [None]:
import plotly.graph_objects as go

# Generate some random data
x = eta2_t
y = eta2_c
z = 1-eta2_tc

# Create a 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=1))])

# Update layout
fig.update_layout(scene=dict(
                    xaxis_title='time',
                    yaxis_title='type',
                    zaxis_title='rep'),
                  title='time type rep', 
                  height=800,
                  width=1000,
                 )

# Display the plot in the Jupyter notebook
fig.show()

# check a few genes

In [None]:
def plot_query_gene_landscape(key, val):
    val_idx = basicu.get_index_from_array(genes_comm, val)
    x = eta2_t[val_idx]
    y = eta2_c[val_idx]

    fig, ax = plt.subplots(1, 1, figsize=(5*1,6*1))
    ax.set_title(key)
    # ax.set_title(f'{key}\nn={len(val)}/{len(genes_annots[key])}')
    ax.set_xticks([0, 0.5, 1])
    ax.set_yticks([0, 0.5, 1])

    g = ax.scatter(eta2_t, eta2_c, s=1, c='lightgray')
    g2 = ax.scatter(x, y, s=1, c='C1', zorder=2)
    for xi, yi, vali in zip(x, y, val):
        ax.text(xi, yi, vali, fontsize=10)
    sns.despine(ax=ax)
    ax.set_xlim([0,1])
    ax.set_ylim([0,1])
    ax.set_xlabel('time')
    ax.set_ylabel('type')

    ax.set_aspect('equal')
    plt.show()

In [None]:
def plot_query_genes(query_genes, colors=None, nxset=5):
    types = xclsts_sel 
    if colors is None:
        colors = sns.color_palette('coolwarm', len(types))
    
    query_gis   = basicu.get_index_from_array(genes_comm, query_genes)
    pbulks_sub = X[:,:,:,query_gis]
    pbulks_sub = np.swapaxes(pbulks_sub, 0, 3) # nt, nc, nr, ng -> ng, nc, nr, nt
    gnames = genes_comm[query_gis]
    
    n = len(query_gis)
    nx = min(n, nxset)
    ny = int((n+nx-1)/nx)

    fig, axs = plt.subplots(ny,nx,figsize=(nx*3,ny*4), sharex=True)
    for j, (pbulks_g, gname, ax) in enumerate(zip(pbulks_sub, gnames, axs.flat)):
        ax.set_title(gname)
        for i in range(nc):
            color = colors[i]
            lbl = types[i]
            ax.plot(ts, np.mean(pbulks_g[i], axis=0), color=color, label=lbl)
            ax.scatter(ts, pbulks_g[i][0], s=5, color=color)
            ax.scatter(ts, pbulks_g[i][1], s=5, color=color)

        sns.despine(ax=ax)
        ax.grid(False) # , axis='x')

        ax.set_xticks(ts)
        ax.tick_params(axis='both', which='major', labelsize=10)

        # if j == 0:
        #     ax.legend()
        if j % 5 == 0:
            ax.set_ylabel('log10(CP10k+1)')
        # if j >= 5:
        #     ax.set_xlabel('P')

    # axs[-1].legend(bbox_to_anchor=(1,1))
    fig.tight_layout()
    plt.show()

In [None]:
# set up
genes_comm = adata.var.index.values
xclsts_sel = ['A', 'AB', 'B', 'BC', 'C']
ts = [6, 8, 10, 12, 14, 17, 21,]
X = tensor

In [None]:
from matplotlib.colors import LinearSegmentedColormap

colors_a = [(0.0, 'black'), (1.0, 'C0')]      
colors_b = [(0.0, 'black'), (1.0, 'C1')]      
colors_c = [(0.0, 'black'), (1.0, 'C2')]      

# Create a custom colormap using LinearSegmentedColormap
cmap_a = LinearSegmentedColormap.from_list('cmap_a', colors_a)
cmap_b = LinearSegmentedColormap.from_list('cmap_b', colors_b)
cmap_c = LinearSegmentedColormap.from_list('cmap_c', colors_c)

colors = [
    np.array(cmap_a(1.0)),
    0.5*np.array(cmap_a(1.0))+0.5*np.array(cmap_b(1.0)),
    np.array(cmap_b(1.0)),
    0.5*np.array(cmap_b(1.0))+0.5*np.array(cmap_c(1.0)),
    np.array(cmap_c(1.0)),
]

In [None]:

key = 'A TFs'
val = [
    'Trps1', 'Sox5', 
    'Rfx3', 'Meis2', 'Nfib', 
      ] #$ genes_annots_overlap[key]
print(val)
plot_query_gene_landscape(key, val)
plot_query_genes(val) #, colors=colors)

In [None]:
key = 'C TFs'
val = [
    'Cux1', 'Satb1', 'Thrb', 'Tshz3', 'Rora', 'Lcorl', 'Mkx', 'Tox', 'Rorb', 'Foxp1',
      ] #$ genes_annots_overlap[key]
print(val)
plot_query_gene_landscape(key, val)
plot_query_genes(val) #, colors=colors)

In [None]:
key = 'BMP / TGFbeta'
val = [
    'Vwc2l', 'Fst', 
    'Igfbp5', 'Brinp1', 'Brinp3',  'Lrp1b',
] #$ genes_annots_overlap[key]
print(val)
plot_query_gene_landscape(key, val)
plot_query_genes(val, nxset=6) #, colors=colors)

# check gene groups

In [None]:
genes_annots_overlap = {}
for key, val in gene_modules.annots.items():
    overlap = np.sort(np.intersect1d(val, genes_comm))
    genes_annots_overlap[key] = overlap
    print(key, len(val), len(overlap))
    print(val[:5], overlap[:5])

In [None]:
n = len(genes_annots_overlap)
print(n)
fig, axs = plt.subplots(3, 5, figsize=(5*5,6*3))
for i, (key, val) in enumerate(genes_annots_overlap.items()):
    ax = axs.flat[i]
    ax.set_xticks([0, 0.5, 1])
    ax.set_yticks([0, 0.5, 1])
    g = ax.scatter(eta2_t, eta2_c, s=1, c='lightgray')
    
    ax.set_title(f'{key}\nn={len(val)}')
    val_idx = basicu.get_index_from_array(genes_comm, val)
    g2 = ax.scatter(eta2_t[val_idx], eta2_c[val_idx], s=5, c='C1', zorder=2)
    sns.despine(ax=ax)
    ax.set_xlim([0,1])
    ax.set_ylim([0,1])
    
    # fig.colorbar(g)
    ax.set_aspect('equal')
plt.show()

# check top genes

In [None]:
key = 'top time genes'
val = genes_comm[np.argsort(eta2_t)[::-1][:10]]

print(val)
plot_query_gene_landscape(key, val)
plot_query_genes(val)

In [None]:
key = 'top type genes'
val = genes_comm[np.argsort(eta2_c)[::-1][:20]]

print(val)
plot_query_gene_landscape(key, val)
plot_query_genes(val)

In [None]:
key = 'top rep genes'
val = genes_comm[np.argsort(eta2_r)[::-1][:10]]

print(val)
plot_query_gene_landscape(key, val)
plot_query_genes(val)

In [None]:
key = 'top int genes'
val = genes_comm[np.argsort(eta2_tic/eta2_r)[::-1][:10]]

print(val)
plot_query_gene_landscape(key, val)
plot_query_genes(val)

In [None]:
val_anno, val_style, val_anno_style = gene_modules.check_genes(val)
# print("\t".join(val))
# print("\t".join(val_anno))
# print("\t".join(val_style))
print("\t".join(val_anno_style))

# check A vs C genes
- import it somewhere else

# top type; top time; top interaction (tic/rep)
# transform this triangle

# for any genes with signficant type component - separate them further by time (early mid late) and run GO - emulate Saumya