In [None]:
import anndata
import numpy as np
import pandas as pd
import seaborn as sns
import os

from scroutines.config_plots import *
from scroutines import basicu
import importlib
importlib.reload(basicu)

# check features

In [None]:
ddir = '../../data/cheng21_cell_scrna/organized/'
files = [
    #  'P8NR.h5ad',
    # 'P14NR.h5ad',
    'P17NR.h5ad',
    
    'P21NR.h5ad',
    'P28NR.h5ad',
    'P38NR.h5ad',
] 

In [None]:
pbulks = []
xclsts = []
xcnsts = []

ncond, nrep, nclst, ngene = len(files), 2, 3, -1, 

for f in files:
    print(f)
    
    path = os.path.join(ddir, f)
    adata = anndata.read(path)
    genes = adata.var.index.values
    
    adata = adata[adata.obs['Type'].isin(['L2/3_A', 'L2/3_B', 'L2/3_C',])]
    mat = adata.X
    # type 
    types = adata.obs['Type'].astype(str).values
    
    # 1 or 2
    sample_codes = adata.obs['sample'].apply(lambda x: x.split('_')[-1][:-1].replace('3', '2')).astype(str).values
    sample_and_type = sample_codes + "_" + types
    unqs, cnts = np.unique(sample_and_type, return_counts=True)
    _xclsts, Xk, Xk_n, Xk_ln = basicu.counts_to_bulk_profiles(mat, sample_and_type) 
    
    # check all types + reps are the same
    if len(xclsts) > 0:
        if not np.all(_xclsts == xclsts):
            raise ValueError(_xclsts.shape, xclsts.shape, _xclsts, xclsts,)
        if not np.all(_xclsts == unqs):
            raise ValueError(_xclsts.shape, unqs.shape, _xclsts, unqs,)
    else:
        xclsts = _xclsts
        print(xclsts)
        
    print(Xk_ln.shape)
    pbulks.append(Xk_ln)
    xcnsts.append(cnts)
    
pbulks = np.array(pbulks)
xcnsts = np.array(xcnsts)
print(pbulks.shape)
pbulks = pbulks.reshape(ncond,nrep,nclst,-1)
xcnsts = xcnsts.reshape(ncond,nrep,nclst)
xclsts = xclsts.reshape(      nrep,nclst)
print(pbulks.shape)

In [None]:
# check the pbulks are good -- log10(CPM+1) for each sample
checkpbulks = np.sum(np.power(10, pbulks)-1, axis=-1)
checkpbulks.shape, np.all(np.abs(checkpbulks-1e6) < 1e-6)

In [None]:
# cell types
xclsts_short = np.array([clst[len("1_"):] for clst in xclsts[0]]) #, '_'
numcells = pd.DataFrame(xcnsts.T.reshape(-1,nrep*ncond), index=xclsts_short)
numcells.min(axis=1).sort_values()

In [None]:
# select cell types
xclsts_sel = xclsts_short[numcells.min(axis=1)>20]
xclsts_selidx = basicu.get_index_from_array(xclsts_short, xclsts_sel)
X = pbulks[:,:,xclsts_selidx,:]
print(xclsts_sel)

# select genes - mean (across 2 rep) expr of CPM=10 in any subclass at any time
expressed_any = np.any(np.mean(pbulks, axis=2) > np.log10(10+1), axis=(0,1)) 
genes_comm = genes[expressed_any]
genes_cidx = np.arange(len(genes))[expressed_any] 
X = X[:,:,:,expressed_any]
print(X.shape)

# reorder
X = np.swapaxes(X,1,2) 
print(X.shape)
nt, nc, nr, ng = X.shape # ntime, nclst, nrep, ngene


# Load annotations

In [None]:
# !ls -alhtr ../../data/annot/*

In [None]:
annot_files = [
    '../../data/annot/Lrr_superfamily.txt',
    '../../data/annot/Igsf_uniprot.txt',
    '../../data/annot/GPCR.txt',
    '../../data/annot/diffgenes_2022RNA-2023Multiome.txt',
    '../../data/annot/CdhSF_interpro.txt',
    '../../data/annot/All_TFs.txt',
]

genes_annots = {os.path.basename(f).split('.')[0]: np.loadtxt(f, dtype=str) for f in annot_files}
for key, val in genes_annots.items():
    print(key, len(val), val[:5])

In [None]:
mdl = np.sort([g for g in genes if g.startswith('Ntn')]) # ntn and ntng
print(mdl)
genes_annots['Netrin'] = mdl

mdl = np.sort([g for g in genes if g.startswith('Unc5')]) # ntn and ntng
print(mdl)
genes_annots['Unc5'] = mdl

mdl = np.sort([g for g in genes if g.startswith('Sema')]) # ntn and ntng
print(mdl)
genes_annots['Semaphorin'] = mdl

mdl = np.sort([g for g in genes if g.startswith('Slit')]) # ntn and ntng
print(mdl)
genes_annots['Slit'] = mdl

mdl = np.sort([g for g in genes if g.startswith('Epha') or g.startswith('Ephb') or g.startswith('Efn')]) # ntn and ntng
print(mdl)
genes_annots['EphEphrin'] = mdl

mdl = np.sort([g for g in genes if g.startswith('Fgf')]) # or g.startswith('Ephb') or g.startswith('Efn')]) # ntn and ntng
print(mdl)
genes_annots['Fgf'] = mdl

mdl = np.sort([g for g in genes if g.startswith('Grm')]) 
print(mdl)
genes_annots['Grm'] = mdl

genes_annots_overlap = {}
for key, val in genes_annots.items():
    overlap = np.sort(np.intersect1d(val, genes_comm))
    genes_annots_overlap[key] = overlap
    print(key, len(val), len(overlap))
    print(val[:5], overlap[:5])

# calc stats 

In [None]:
def get_2way_eta2_allgenes(nums):
    """
    nums: c0, c1, r, g matrix - (cond0, cond1, cond x, ..., reps, genes)
    
    return (eta2, stdv) - vectors one entry for each gene
    """
    nc0, nc1, nr, ng = nums.shape # (num cond0, cond1, num rep, num genes)

    rm   = np.mean(nums, axis=(0,1,2)) # global mean; reduced form
    rm0  = np.mean(nums, axis=(1,2))   # mean per c0 across reps and ignoring c1  
    rm1  = np.mean(nums, axis=(0,2))   # mean per c1 across reps and ignoring c0 
    rm01 = np.mean(nums, axis=(2)) # mean per (c0, c1) across reps  
    
    em   = np.expand_dims(rm  , axis=(0,1,2)) # expanded form
    em0  = np.expand_dims(rm0 , axis=(1,2))   # expanded form
    em1  = np.expand_dims(rm1 , axis=(0,2))   # expanded form
    em01 = np.expand_dims(rm01, axis=(2))     #  

    # # SSt 
    SSt  = np.sum(np.power(nums-em, 2),   axis=(0,1,2))  
    
    # # SSwr (noise)
    SSwr = np.sum(np.power(nums-em01, 2), axis=(0,1,2))  # within (c0,c1) across reps 
    
    # # SSw
    SSw0 = nr*np.sum(np.power(em01-em0, 2),  axis=(0,1,2))  # within c0 across reps and ignoring c1
    SSw1 = nr*np.sum(np.power(em01-em1, 2),  axis=(0,1,2))  # within c1 across reps and ignoring c0 
    
    # SSt = SSwr + SSexp
    # where SSexp = SSw0 + SSexp0 = SSw1 + SSexp1
    SSexp  = SSt   - SSwr
    SSexp0 = SSexp - SSw0
    SSexp1 = SSexp - SSw1
    
    # return SSt, SSwr, SSw0, SSw1
    
    o = 1e-10
    eta2_01 = (SSexp +o)/(SSt+o)
    eta2_0  = (SSexp0+o)/(SSt+o)
    eta2_1  = (SSexp1+o)/(SSt+o)
    
    return eta2_01, eta2_0, eta2_1


In [None]:
eta2_tc, eta2_t, eta2_c = get_2way_eta2_allgenes(X)
eta2_r = 1-eta2_tc

In [None]:
fig, ax = plt.subplots(figsize=(5,6))
sns.boxplot([eta2_t, eta2_c, eta2_r, 
             eta2_t+eta2_c, eta2_tc, eta2_t+eta2_c+eta2_r,  
            ])
ax.set_xticklabels(['time', 'type', 'rep', 
                    'time+\ntype', 
                    'time&\ntype', 
                    'time+\ntype+\nrep'], rotation=0, fontsize=12)
ax.set_ylabel('variance explained by')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
g = ax.scatter(eta2_t, eta2_c, c=1-eta2_tc, s=1, cmap='viridis', vmin=0, vmax=1)
fig.colorbar(g, shrink=0.5, ticks=[0, 0.5, 1], label='var exp replicates')
ax.set_aspect('equal')
ax.set_xlabel('var exp time')
ax.set_ylabel('var exp type')
plt.show()

In [None]:
import plotly.graph_objects as go

# Generate some random data
x = eta2_t
y = eta2_c
z = 1-eta2_tc

# Create a 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=1))])

# Update layout
fig.update_layout(scene=dict(
                    xaxis_title='time',
                    yaxis_title='type',
                    zaxis_title='rep'),
                  title='time type rep', 
                  height=800,
                  width=1000,
                 )

# Display the plot in the Jupyter notebook
fig.show()


In [None]:
n = len(genes_annots_overlap)
fig, axs = plt.subplots(5, 3, figsize=(5*3,6*5))
for i, (key, val) in enumerate(genes_annots_overlap.items()):
    ax = axs.flat[i]
    ax.set_xticks([0, 0.5, 1])
    ax.set_yticks([0, 0.5, 1])
    g = ax.scatter(eta2_t, eta2_c, s=1, c='lightgray')
    
    ax.set_title(f'{key}\nn={len(val)}/{len(genes_annots[key])}')
    val_idx = basicu.get_index_from_array(genes_comm, val)
    g2 = ax.scatter(eta2_t[val_idx], eta2_c[val_idx], s=1, c='C1', zorder=2)
    sns.despine(ax=ax)
    ax.set_xlim([0,1])
    ax.set_ylim([0,1])
    
    # fig.colorbar(g)
    ax.set_aspect('equal')
plt.show()

# check a few genes

In [None]:
def plot_query_gene_landscape(key, val):
    val_idx = basicu.get_index_from_array(genes_comm, val)
    x = eta2_t[val_idx]
    y = eta2_c[val_idx]

    fig, ax = plt.subplots(1, 1, figsize=(5*1,6*1))
    ax.set_title(key)
    # ax.set_title(f'{key}\nn={len(val)}/{len(genes_annots[key])}')
    ax.set_xticks([0, 0.5, 1])
    ax.set_yticks([0, 0.5, 1])

    g = ax.scatter(eta2_t, eta2_c, s=1, c='lightgray')
    g2 = ax.scatter(x, y, s=1, c='C1', zorder=2)
    for xi, yi, vali in zip(x, y, val):
        ax.text(xi, yi, vali, fontsize=10)
    sns.despine(ax=ax)
    ax.set_xlim([0,1])
    ax.set_ylim([0,1])
    ax.set_xlabel('time')
    ax.set_ylabel('type')

    ax.set_aspect('equal')
    plt.show()

In [None]:
def plot_query_genes(query_genes):
    types = xclsts_sel 
    colors = sns.color_palette('tab10', len(types))
    # ts = [8, 14, 17, 21, 28, 38]
    ts = [17, 21, 28, 38]
    
    query_gis   = basicu.get_index_from_array(genes_comm, query_genes)
    pbulks_sub = X[:,:,:,query_gis]
    pbulks_sub = np.swapaxes(pbulks_sub, 0, 3) # nt, nc, nr, ng -> ng, nc, nr, nt
    gnames = genes_comm[query_gis]
    
    n = len(query_gis)
    nx = min(n, 5)
    ny = int((n+nx-1)/nx)

    fig, axs = plt.subplots(ny,nx,figsize=(nx*3,ny*4), sharex=True)
    for j, (pbulks_g, gname, ax) in enumerate(zip(pbulks_sub, gnames, axs.flat)):
        ax.set_title(gname)
        for i in range(nc):
            color = colors[i]
            lbl = types[i]
            ax.plot(ts, np.mean(pbulks_g[i], axis=0), color=color, label=lbl)
            ax.scatter(ts, pbulks_g[i][0], s=5, color=color)
            ax.scatter(ts, pbulks_g[i][1], s=5, color=color)

        sns.despine(ax=ax)
        ax.grid(False) # , axis='x')

        ax.set_xticks(ts)
        ax.tick_params(axis='both', which='major', labelsize=10)

        # if j == 0:
        #     ax.legend()
        if j % 5 == 0:
            ax.set_ylabel('log10(CPM+1)')
        # if j >= 5:
        #     ax.set_xlabel('P')

    # axs.flat[0].legend(bbox_to_anchor=(1,1))
    axs.flat[0].legend()
    fig.tight_layout()
    plt.show()

In [None]:
key = 'Rfx3'
val = ['Rfx3', 'Npas4'] #$ genes_annots_overlap[key]
print(val)

plot_query_gene_landscape(key, val)
plot_query_genes(val)

In [None]:
key = 'Netrin'
val = genes_annots_overlap[key]
print(val)

plot_query_gene_landscape(key, val)
plot_query_genes(val)

In [None]:
key = 'Grm'
val = genes_annots_overlap[key]
print(val)

plot_query_gene_landscape(key, val)
plot_query_genes(val)

In [None]:
key = 'Unc5'
val = genes_annots_overlap[key]
print(val)

plot_query_gene_landscape(key, val)
plot_query_genes(val)

In [None]:
key = 'Semaphorin'
val = genes_annots_overlap[key]
print(val)

plot_query_gene_landscape(key, val)
plot_query_genes(val)

In [None]:
key = 'Slit'
val = genes_annots_overlap[key]
print(val)

plot_query_gene_landscape(key, val)
plot_query_genes(val)

In [None]:
key = 'EphEphrin'
val = genes_annots_overlap[key]
print(val)

plot_query_gene_landscape(key, val)
plot_query_genes(val)

In [None]:
key = 'Fgf'
val = genes_annots_overlap[key]
print(val)

plot_query_gene_landscape(key, val)
plot_query_genes(val)