In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections
from natsort import natsorted

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu
from scroutines.gene_modules import GeneModules  

from atac_utils import merge_peaks

In [None]:
outdir_fig = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures"

# big tensor

In [None]:
f = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/results_atac/all_AvsC_peaks_unique.bed"
peaks = pd.read_csv(f, sep='\t', header=None)
peaks['peak'] = merge_peaks(peaks, 0,1,2)
peaks

In [None]:
f = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/results_atac/all_AvsC_peak_tensor.npy"
tensor = np.load(f)
print(tensor.shape)
n_time, n_type, n_rep, n_peak = tensor.shape # 

In [None]:
sns.histplot(np.log2(1+tensor.reshape(-1)))

In [None]:
# normalize it as CPM; log2(CPM+1)
# tensor = (tensor/np.sum(tensor, axis=-1, keepdims=True))*1e6
tensor = np.log2(1+tensor) 

# gene triangle analysis 

In [None]:
def get_2way_eta2_allgenes(nums):
    """
    nums: c0, c1, r, g matrix - (cond0, cond1, reps, genes)
    
    return (eta2, stdv) - vectors one entry for each gene
    """
    nc0, nc1, nr, ng = nums.shape # (num cond0, cond1, num rep, num genes)

    rm   = np.mean(nums, axis=(0,1,2)) # global mean; reduced form
    rm0  = np.mean(nums, axis=(1,2))   # mean per c0 across reps and ignoring c1  
    rm1  = np.mean(nums, axis=(0,2))   # mean per c1 across reps and ignoring c0 
    rm01 = np.mean(nums, axis=(2)) # mean per (c0, c1) across reps  
    
    em   = np.expand_dims(rm  , axis=(0,1,2)) # expanded form
    em0  = np.expand_dims(rm0 , axis=(1,2))   # expanded form
    em1  = np.expand_dims(rm1 , axis=(0,2))   # expanded form
    em01 = np.expand_dims(rm01, axis=(2))     #  

    # # SSt 
    SSt  = np.sum(np.power(nums-em, 2),   axis=(0,1,2))  
    
    # # SSwr (noise)
    SSwr = np.sum(np.power(nums-em01, 2), axis=(0,1,2))  # within (c0,c1) across reps 
    
    # # SSw
    SSw0 = nr*np.sum(np.power(em01-em0, 2),  axis=(0,1,2))  # within c0 across reps and ignoring c1
    SSw1 = nr*np.sum(np.power(em01-em1, 2),  axis=(0,1,2))  # within c1 across reps and ignoring c0 
    
    # SSt = SSwr + SSexp
    # where SSexp = SSw0 + SSexp0 = SSw1 + SSexp1
    SSexp  = SSt   - SSwr
    SSexp0 = SSexp - SSw0
    SSexp1 = SSexp - SSw1
    
    # return SSt, SSwr, SSw0, SSw1
    
    o = 1e-10
    eta2_01 = (SSexp +o)/(SSt+o)
    eta2_0  = (SSexp0+o)/(SSt+o)
    eta2_1  = (SSexp1+o)/(SSt+o)
    
    return eta2_01, eta2_0, eta2_1


In [None]:
eta2_tc, eta2_t, eta2_c = get_2way_eta2_allgenes(tensor)
eta2_r = 1-eta2_tc
eta2_tic = eta2_tc-(eta2_t+eta2_c)

In [None]:
fig, ax = plt.subplots(figsize=(5,6))
sns.boxplot([eta2_t, eta2_c, 
             eta2_r, 
             # eta2_t+eta2_r, 
             # eta2_t+eta2_c+eta2_r,
             eta2_t+eta2_c, 
             eta2_tic,  
             eta2_tc,  
            ])
ax.set_xticklabels(['time', 'type', 
                    'rep', 
                    # 'time+\nrep', 
                    # 'time+\ntype+\nrep',
                    'time+\ntype', 
                    'time int\ntype', 
                    'time&\ntype', 
                   ], rotation=0, fontsize=12)
ax.set_ylabel('variance explained by')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
g = ax.scatter(eta2_t, eta2_c, c=1-eta2_tc, s=1, cmap='viridis', vmin=0, vmax=1)
fig.colorbar(g, shrink=0.5, ticks=[0, 0.5, 1], label='var exp replicates')
ax.set_aspect('equal')
ax.set_xlabel('var exp time')
ax.set_ylabel('var exp type')
plt.show()

In [None]:
# import plotly.graph_objects as go

# # Generate some random data
# x = eta2_t
# y = eta2_c
# z = 1-eta2_tc

# # Create a 3D scatter plot
# fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=1))])

# # Update layout
# fig.update_layout(scene=dict(
#                     xaxis_title='time',
#                     yaxis_title='type',
#                     zaxis_title='rep'),
#                   title='time type rep', 
#                   height=800,
#                   width=1000,
#                  )

# # Display the plot in the Jupyter notebook
# fig.show()

# check a few genes

In [None]:
def plot_query_gene_landscape(key, val):
    val_idx = basicu.get_index_from_array(genes_comm, val)
    x = eta2_t[val_idx]
    y = eta2_c[val_idx]

    fig, ax = plt.subplots(1, 1, figsize=(5*1,6*1))
    ax.set_title(key)
    # ax.set_title(f'{key}\nn={len(val)}/{len(genes_annots[key])}')
    ax.set_xticks([0, 0.5, 1])
    ax.set_yticks([0, 0.5, 1])

    g = ax.scatter(eta2_t, eta2_c, s=1, c='lightgray')
    g2 = ax.scatter(x, y, s=1, c='C1', zorder=2)
    for xi, yi, vali in zip(x, y, val):
        ax.text(xi, yi, vali, fontsize=10)
    sns.despine(ax=ax)
    ax.set_xlim([0,1])
    ax.set_ylim([0,1])
    ax.set_xlabel('time')
    ax.set_ylabel('type')

    ax.set_aspect('equal')
    plt.show()

In [None]:
def plot_query_genes(query_genes, X, ts, colors=None, nxset=5, X2=None, ts2=None):
    """
    X (nt, nc, nr, ng) (n_time, n_type, n_rep, n_gene)
    """
    nt, nc, nr, ng = X.shape
    
    if colors is None:
        colors = sns.color_palette('coolwarm', len(types))
    
    query_gis   = basicu.get_index_from_array(genes_comm, query_genes)
    gnames = genes_comm[query_gis]
    
    pbulks_sub = X[:,:,:,query_gis]
    pbulks_sub = np.swapaxes(pbulks_sub, 0, 3) # nt, nc, nr, ng -> ng, nc, nr, nt
    
    if X2 is not None:
        pbulks_sub2 = X2[:,:,:,query_gis]
        pbulks_sub2 = np.swapaxes(pbulks_sub2, 0, 3) # nt, nc, nr, ng -> ng, nc, nr, nt
    
    n = len(query_gis)
    nx = min(n, nxset)
    ny = int((n+nx-1)/nx)
    
    s=3

    fig, axs = plt.subplots(ny,nx,figsize=(nx*3,ny*4), sharex=True)
    for j, (pbulks_g, gname, ax) in enumerate(zip(pbulks_sub, gnames, axs.flat)):
        ax.set_title(gname, fontsize=10)
        for i in range(nc):
            color = colors[i]
            lbl = types[i]
            ax.plot(ts, np.mean(pbulks_g[i], axis=0), color=color, label=lbl, linewidth=3)
            ax.scatter(ts, pbulks_g[i][0], s=s, color=color)
            ax.scatter(ts, pbulks_g[i][1], s=s, color=color)

        sns.despine(ax=ax)
        ax.grid(False) # , axis='x')

        ax.set_xticks(ts)
        ax.tick_params(axis='both', which='major', labelsize=10)

        if j % 5 == 0:
            ax.set_ylabel('log2(atac+1)')
            
    if X2 is not None:
        for j, (pbulks_g, gname, ax) in enumerate(zip(pbulks_sub2, gnames, axs.flat)):
            ax.set_title(gname, fontsize=10)
            for i in range(nc):
                color = colors[i]
                lbl = types[i]
                ax.plot(ts2, np.mean(pbulks_g[i], axis=0), color=color  , label=lbl, linewidth=3)
                ax.plot(ts2, np.mean(pbulks_g[i], axis=0), color='k', label=lbl, linewidth=2, linestyle='--')
                ax.scatter(ts2, pbulks_g[i][0], s=s, color=color, marker='s',)
                ax.scatter(ts2, pbulks_g[i][1], s=s, color=color, marker='s',)


    fig.tight_layout()
    plt.show()

In [None]:
def plot_query_genes_2ends(query_genes, X, ts, colors=None, nxset=5, X2=None, ts2=None):
    if colors is None:
        colors = sns.color_palette('coolwarm', len(types))
    
    query_gis   = basicu.get_index_from_array(genes_comm, query_genes)
    gnames = genes_comm[query_gis]
    
    pbulks_sub = X[:,:,:,query_gis]
    pbulks_sub = np.swapaxes(pbulks_sub, 0, 3) # nt, nc, nr, ng -> ng, nc, nr, nt
    
    if X2 is not None:
        pbulks_sub2 = X2[:,:,:,query_gis]
        pbulks_sub2 = np.swapaxes(pbulks_sub2, 0, 3) # nt, nc, nr, ng -> ng, nc, nr, nt
    
    n = len(query_gis)
    nx = min(n, nxset)
    ny = int((n+nx-1)/nx)
    
    s=3

    fig, axs = plt.subplots(ny,nx,figsize=(nx*3,ny*4), sharex=True)
    for j, (pbulks_g, gname, ax) in enumerate(zip(pbulks_sub, gnames, axs.flat)):
        ax.set_title(gname)
        for i in [0, int(nc/2), nc-1]:
            color = colors[i]
            lbl = types[i]
            ax.plot(ts, np.mean(pbulks_g[i], axis=0), color=color, label=lbl, linewidth=3)
            ax.scatter(ts, pbulks_g[i][0], s=s, color=color)
            ax.scatter(ts, pbulks_g[i][1], s=s, color=color)

        sns.despine(ax=ax)
        ax.grid(False) # , axis='x')

        ax.set_xticks(ts)
        ax.tick_params(axis='both', which='major', labelsize=10)

        if j % 5 == 0:
            ax.set_ylabel('log2(CPM+1)')
            
    if X2 is not None:
        for j, (pbulks_g, gname, ax) in enumerate(zip(pbulks_sub2, gnames, axs.flat)):
            ax.set_title(gname)
            # for i in [0, nc-1]:
            for i in [0, int(nc/2), nc-1]:
                color = colors[i]
                lbl = types[i]
                ax.plot(ts2, np.mean(pbulks_g[i], axis=0), color=color, label=lbl, linewidth=3, alpha=0.5)
                ax.scatter(ts2, pbulks_g[i][0], s=s, color=color, marker='s', linewidth=1,)
                ax.scatter(ts2, pbulks_g[i][1], s=s, color=color, marker='s', linewidth=1,)


    fig.tight_layout()
    plt.show()

In [None]:
tensor.shape

In [None]:
# set up
# genes_comm = adata.var.index.values
genes_comm = peaks['peak'].values 
types = ['A', 'AB', 'B', 'BC', 'C']

X = tensor # [:7]
ts = [6, 8, 10, 12, 14, 17, 21]

# X2 = tensor[7:]
# ts2 = [12, 14, 17, 21]

In [None]:
# X.shape, X2.shape, len(ts), len(ts2)

In [None]:
from matplotlib.colors import LinearSegmentedColormap

colors_a = [(0.0, 'black'), (1.0, 'C0')]      
colors_b = [(0.0, 'black'), (1.0, 'C1')]      
colors_c = [(0.0, 'black'), (1.0, 'C2')]      

# lighter than lightgray: #DCDCDC, #E8E8E8, #F0F0F0
colors_g = [(0.0, 'black'), (1.0, '#DCDCDC')]  

# Create a custom colormap using LinearSegmentedColormap
cmap_a = LinearSegmentedColormap.from_list('cmap_a', colors_a)
cmap_b = LinearSegmentedColormap.from_list('cmap_b', colors_b)
cmap_c = LinearSegmentedColormap.from_list('cmap_c', colors_c)
cmap_g = LinearSegmentedColormap.from_list('cmap_g', colors_g)

colors_abc = [
    np.array(cmap_a(1.0)),
    0.5*np.array(cmap_a(1.0))+0.5*np.array(cmap_b(1.0)),
    np.array(cmap_b(1.0)),
    0.5*np.array(cmap_b(1.0))+0.5*np.array(cmap_c(1.0)),
    np.array(cmap_c(1.0)),
]


colors_ac = [
    np.array(cmap_a(1.0)),
    0.5*np.array(cmap_g(1.0))+0.5*np.array(cmap_a(1.0)),
    np.array(cmap_g(1.0)),
    0.5*np.array(cmap_g(1.0))+0.5*np.array(cmap_c(1.0)),
    np.array(cmap_c(1.0)),
]


In [None]:

key = ''
val = [
    'chr3:141618499-141619148',
    'chr6:47272174-47273739',
    'chr9:34703319-34704364',
    'chr17:64689527-64690121',
] #$ genes_annots_overlap[key]
print(val)
plot_query_gene_landscape(key, val)
plot_query_genes(val, X, ts, colors=colors_ac, nxset=5, X2=None, ts2=None)

# time gene; type gene; time*type gene

In [None]:
key = 'top time peaks'
val = genes_comm[np.argsort(eta2_t)[::-1][:10]]

print(val)
plot_query_gene_landscape(key, val)
plot_query_genes(val, X, ts, colors=colors_ac, nxset=5, X2=None, ts2=None)

In [None]:
key = 'top type peaks'
val = genes_comm[np.argsort(eta2_c)[::-1][:10]]

print(val)
plot_query_gene_landscape(key, val)
plot_query_genes(val, X, ts, colors=colors_ac, nxset=5, X2=None, ts2=None)

In [None]:
key = 'top int genes'
val = genes_comm[np.argsort(eta2_tic/eta2_r)[::-1][:10]]

print(val)
plot_query_gene_landscape(key, val)
plot_query_genes(val, X, ts, colors=colors_ac, nxset=5, X2=None, ts2=None)