In [None]:
import anndata
import numpy as np
import pandas as pd
import seaborn as sns
import os

from scroutines.config_plots import *
from scroutines import basicu
import importlib
importlib.reload(basicu)

# check features

In [None]:
# !ls -alhtr ../../data/annot/*

In [None]:
annot_files = [
    '../../data/annot/Lrr_superfamily.txt',
    '../../data/annot/Igsf_uniprot.txt',
    '../../data/annot/GPCR.txt',
    '../../data/annot/diffgenes_2022RNA-2023Multiome.txt',
    '../../data/annot/CdhSF_interpro.txt',
    '../../data/annot/All_TFs.txt',
]


genes_annots = {os.path.basename(f).split('.')[0]: np.loadtxt(f, dtype=str) for f in annot_files}
for key, val in genes_annots.items():
    print(key, len(val), val[:5])

In [None]:
ddir = '../../data/cheng21_cell_scrna/organized/'
files = [
     'P8NR.h5ad',
    'P14NR.h5ad',
    'P17NR.h5ad',
    
    'P21NR.h5ad',
    'P28NR.h5ad',
    'P38NR.h5ad',
] 

In [None]:
pbulks = []
xclsts = []
xcnsts = []

ncond, nrep, nclst, ngene = len(files), 2, 4, -1, 

for f in files:
    print(f)
    
    path = os.path.join(ddir, f)
    adata = anndata.read(path)
    genes = adata.var.index.values
    
    adata = adata[adata.obs['Subclass'].isin(['L2/3', 'L4', 'L5IT', 'L6IT'])]
    mat = adata.X
    # type 
    types = adata.obs['Subclass'].astype(str).values
    
    # 1 or 2
    sample_codes = adata.obs['sample'].apply(lambda x: x.split('_')[-1][:-1].replace('3', '2')).astype(str).values
    sample_and_type = sample_codes + "_" + types
    unqs, cnts = np.unique(sample_and_type, return_counts=True)
    _xclsts, Xk, Xk_n, Xk_ln = basicu.counts_to_bulk_profiles(mat, sample_and_type) 
    
    # check all types + reps are the same
    if len(xclsts) > 0:
        if not np.all(_xclsts == xclsts):
            raise ValueError(_xclsts.shape, xclsts.shape, _xclsts, xclsts,)
        if not np.all(_xclsts == unqs):
            raise ValueError(_xclsts.shape, unqs.shape, _xclsts, unqs,)
    else:
        xclsts = _xclsts
        print(xclsts)
        
    print(Xk_ln.shape)
    pbulks.append(Xk_ln)
    xcnsts.append(cnts)
    
pbulks = np.array(pbulks)
xcnsts = np.array(xcnsts)
print(pbulks.shape)
pbulks = pbulks.reshape(ncond,nrep,nclst,-1)
xcnsts = xcnsts.reshape(ncond,nrep,nclst)
xclsts = xclsts.reshape(      nrep,nclst)
print(pbulks.shape)

In [None]:
# check the pbulks are good -- log10(CPM+1) for each sample
checkpbulks = np.sum(np.power(10, pbulks)-1, axis=-1)
checkpbulks.shape, np.all(np.abs(checkpbulks-1e6) < 1e-6)

In [None]:
# cell types
xclsts_short = np.array([clst[len("1_"):] for clst in xclsts[0]]) #, '_'
numcells = pd.DataFrame(xcnsts.T.reshape(-1,nrep*ncond), index=xclsts_short)
numcells.min(axis=1).sort_values()

In [None]:
# select cell types
xclsts_sel = xclsts_short[numcells.min(axis=1)>20]
xclsts_selidx = basicu.get_index_from_array(xclsts_short, xclsts_sel)
X = pbulks[:,:,xclsts_selidx,:]
print(xclsts_sel)

# select genes - mean (across 2 rep) expr of CPM=10 in any subclass at any time
expressed_any = np.any(np.mean(pbulks, axis=2) > np.log10(10+1), axis=(0,1)) 
genes_comm = genes[expressed_any]
genes_cidx = np.arange(len(genes))[expressed_any] 
X = X[:,:,:,expressed_any]
print(X.shape)

# reorder
X = np.swapaxes(X,1,2) 
print(X.shape)
nt, nc, nr, ng = X.shape # ntime, nclst, nrep, ngene


In [None]:
genes_annots_overlap = {}
for key, val in genes_annots.items():
    overlap = np.intersect1d(val, genes_comm)
    genes_annots_overlap[key] = overlap
    print(key, len(val), len(overlap))
    print(val[:5], overlap[:5])

In [None]:
def get_2way_eta2_allgenes(nums):
    """
    nums: c0, c1, r, g matrix - (cond0, cond1, cond x, ..., reps, genes)
    
    return (eta2, stdv) - vectors one entry for each gene
    """
    nc0, nc1, nr, ng = nums.shape # (num cond0, cond1, num rep, num genes)

    rm   = np.mean(nums, axis=(0,1,2)) # global mean; reduced form
    rm0  = np.mean(nums, axis=(1,2))   # mean per c0 across reps and ignoring c1  
    rm1  = np.mean(nums, axis=(0,2))   # mean per c1 across reps and ignoring c0 
    rm01 = np.mean(nums, axis=(2)) # mean per (c0, c1) across reps  
    
    em   = np.expand_dims(rm  , axis=(0,1,2)) # expanded form
    em0  = np.expand_dims(rm0 , axis=(1,2))   # expanded form
    em1  = np.expand_dims(rm1 , axis=(0,2))   # expanded form
    em01 = np.expand_dims(rm01, axis=(2))     #  

    # # SSt 
    SSt  = np.sum(np.power(nums-em, 2),   axis=(0,1,2))  
    
    # # SSwr (noise)
    SSwr = np.sum(np.power(nums-em01, 2), axis=(0,1,2))  # within (c0,c1) across reps 
    
    # # SSw
    SSw0 = nr*np.sum(np.power(em01-em0, 2),  axis=(0,1,2))  # within c0 across reps and ignoring c1
    SSw1 = nr*np.sum(np.power(em01-em1, 2),  axis=(0,1,2))  # within c1 across reps and ignoring c0 
    
    # SSt = SSwr + SSexp
    # where SSexp = SSw0 + SSexp0 = SSw1 + SSexp1
    SSexp  = SSt   - SSwr
    SSexp0 = SSexp - SSw0
    SSexp1 = SSexp - SSw1
    
    # return SSt, SSwr, SSw0, SSw1
    
    o = 1e-10
    eta2_01 = (SSexp +o)/(SSt+o)
    eta2_0  = (SSexp0+o)/(SSt+o)
    eta2_1  = (SSexp1+o)/(SSt+o)
    
    return eta2_01, eta2_0, eta2_1

# sst, sswr, sswt, sswc = get_2way_eta2_allgenes(X)
# ssexp = sst - sswr
# fig, ax = plt.subplots(figsize=(8,6))
# ax.scatter(sst, sswr, s=1, color='gray')
# ax.plot([0,10], [0,10])
# ax.set_aspect('equal')
# plt.show()

# fig, ax = plt.subplots(figsize=(8,6))
# ax.scatter(ssexp, sswt, s=1, color='gray')
# ax.plot([0,10], [0,10])
# ax.set_aspect('equal')
# plt.show()

# fig, ax = plt.subplots(figsize=(8,6))
# ax.scatter(ssexp, sswc, s=1, color='gray')
# ax.plot([0,10], [0,10])
# ax.set_aspect('equal')
# plt.show()

# fig, ax = plt.subplots(figsize=(8,6))
# ax.scatter(ssexp, sswc+sswt, s=1, color='gray')
# ax.set_aspect('equal')
# ax.plot([0,10], [0,10])
# plt.show()

# fig, ax = plt.subplots(figsize=(8,6))
# g = ax.scatter(1-sswc/ssexp, 1-sswt/ssexp, c=1-sswr/sst, s=1,)
# fig.colorbar(g)
# ax.set_aspect('equal')
# plt.show()

In [None]:
eta2_tc, eta2_t, eta2_c = get_2way_eta2_allgenes(X)
eta2_r = 1-eta2_tc

In [None]:
genes_comm.shape, eta2_t.shape, eta2_c.shape, eta2_r.shape
table = pd.DataFrame(np.vstack([genes_comm, eta2_t, eta2_c, eta2_r]), index=['gene', 'time', 'type', 'rep']).T
output = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/results/gene_scores_itsubclasses_240308.csv' 
table.to_csv(output)
!head $output

In [None]:
fig, ax = plt.subplots(figsize=(5,6))
sns.boxplot([eta2_t, eta2_c, eta2_r, 
             eta2_t+eta2_c, eta2_tc, eta2_t+eta2_c+eta2_r,  
            ])
ax.set_xticklabels(['time', 'type', 'rep', 
                    'time+\ntype', 
                    'time&\ntype', 
                    'time+\ntype+\nrep'], rotation=0, fontsize=12)
ax.set_ylabel('variance explained by')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
g = ax.scatter(eta2_t, eta2_c, c=1-eta2_tc, s=1, cmap='viridis', vmin=0, vmax=1)
fig.colorbar(g, shrink=0.5, ticks=[0, 0.5, 1], label='var exp replicates')
ax.set_aspect('equal')
ax.set_xlabel('var exp time')
ax.set_ylabel('var exp type')
plt.show()

In [None]:
from mpl_toolkits.mplot3d import Axes3D

# fig, ax = plt.subplots(figsize=(8,6))
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
g = ax.scatter(eta2_t, eta2_c, 1-eta2_tc, s=1) #  cmap='viridis', vmin=0, vmax=1)
ax.set_xlim([0,1])
ax.set_ylim([0,1])
ax.set_zlim([0,1])

# fig.colorbar(g, shrink=0.5, ticks=[0, 0.5, 1], label='var exp rep')
ax.set_aspect('equal')
ax.set_xlabel('var exp time')
ax.set_ylabel('var exp type')
plt.show()

In [None]:
import plotly.graph_objects as go

# Generate some random data
x = eta2_t
y = eta2_c
z = 1-eta2_tc

# Create a 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=1))])

# Update layout
fig.update_layout(scene=dict(
                    xaxis_title='time',
                    yaxis_title='type',
                    zaxis_title='rep'),
                  title='time type rep', 
                  height=800,
                  width=1000,
                 )

# Display the plot in the Jupyter notebook
fig.show()


In [None]:
n = len(genes_annots_overlap)
fig, axs = plt.subplots(2, 3, figsize=(5*3,6*2))
for i, (key, val) in enumerate(genes_annots_overlap.items()):
    ax = axs.flat[i]
    ax.set_xticks([0, 0.5, 1])
    ax.set_yticks([0, 0.5, 1])
    g = ax.scatter(eta2_t, eta2_c, s=1, c='lightgray')
    
    ax.set_title(f'{key}\nn={len(val)}/{len(genes_annots[key])}')
    val_idx = basicu.get_index_from_array(genes_comm, val)
    g2 = ax.scatter(eta2_t[val_idx], eta2_c[val_idx], s=1, c='C1', zorder=2)
    sns.despine(ax=ax)
    ax.set_xlim([0,1])
    ax.set_ylim([0,1])
    
    # fig.colorbar(g)
    ax.set_aspect('equal')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
# g = ax.scatter(eta2_tc, (eta2_t+eta2_c)/eta2_tc, s=1,)
g = ax.scatter((eta2_t+eta2_c), 1-eta2_tc, s=1,)
ax.set_aspect('equal')
ax.set_xlim([0,1])
ax.set_ylim([0,1])
plt.show()

In [None]:

n = len(genes_annots_overlap)
fig, axs = plt.subplots(2, 3, figsize=(5*3,6*2))
for i, (key, val) in enumerate(genes_annots_overlap.items()):
    ax = axs.flat[i]
    ax.set_xticks([0, 0.5, 1])
    ax.set_yticks([0, 0.5, 1])
    x = (eta2_t+eta2_c)
    y = 1-eta2_tc 
    val_idx = basicu.get_index_from_array(genes_comm, val)
    
    g = ax.scatter(x, y, s=1, c='lightgray')
    g2 = ax.scatter(x[val_idx], y[val_idx], s=1, c='C1')
    
    ax.set_title(f'{key}\nn={len(val)}/{len(genes_annots[key])}')
    ax.set_xlim([0,1])
    ax.set_ylim([0,1])
    sns.despine(ax=ax)
    
    # fig.colorbar(g)
    ax.set_aspect('equal')
plt.show()

# strong time component - separate by early mid late.

# top 10 type specific genes

In [None]:
types = xclsts_sel 
colors = sns.color_palette('tab10', len(types))
ts = [8, 14, 17, 21, 28, 38]
types, colors

In [None]:
eta2_nl = eta2_tc-(eta2_t+eta2_c)

In [None]:
gi_c = np.argsort(eta2_nl)[::-1][:15]
gi_c

In [None]:
pbulks_sub = X[:,:,:,gi_c]
print(pbulks_sub.shape)
pbulks_sub = np.swapaxes(pbulks_sub, 0, 3) # nt, nc, nr, ng -> ng, nc, nr, nt
print(pbulks_sub.shape)
gnames = genes_comm[gi_c]

In [None]:
fig, axs = plt.subplots(3,5,figsize=(5*3,3*4), sharex=True)
for j, (pbulks_g, gname, ax) in enumerate(zip(pbulks_sub, gnames, axs.flat)):
    ax.set_title(gname)
    for i in range(nc):
        color = colors[i]
        lbl = types[i]
        ax.plot(ts, np.mean(pbulks_g[i], axis=0), color=color, label=lbl)
        ax.scatter(ts, pbulks_g[i][0], s=5, color=color)
        ax.scatter(ts, pbulks_g[i][1], s=5, color=color)
    ax.set_xticks(ts)
    sns.despine(ax=ax)
    if j == 0:
        ax.legend()
    if j % 5 == 0:
        ax.set_ylabel('log10(CPM+1)')
    if j >= 5:
        ax.set_xlabel('P')
    
fig.tight_layout()
plt.show()

# top 10 time specific genes

In [None]:
gi_t = np.argsort(eta2_t)[::-1][:15]
gi_t

In [None]:
pbulks_sub = pbulks[:,:,:,gi_t]
pbulks_sub = np.swapaxes(pbulks_sub, 0, 3) # nt, nr, nc, ng -> ng, nc, nr, nt
pbulks_sub = np.swapaxes(pbulks_sub, 1, 2) # nt, nr, nc, ng -> ng, nc, nr, nt
gnames = genes[gi_t]
print(pbulks_sub.shape)

In [None]:
fig, axs = plt.subplots(3,5,figsize=(5*3,3*4), sharex=True)
for j, (pbulks_g, gname, ax) in enumerate(zip(pbulks_sub, gnames, axs.flat)):
    ax.set_title(gname)
    for i in range(nc):
        color = colors[i]
        lbl = types[i]
        ax.plot(ts, np.mean(pbulks_g[i], axis=0), color=color, label=lbl)
        ax.scatter(ts, pbulks_g[i][0], s=5, color=color)
        ax.scatter(ts, pbulks_g[i][1], s=5, color=color)
    ax.set_xticks(ts)
    sns.despine(ax=ax)
    if j == 0:
        ax.legend()
    if j % 5 == 0:
        ax.set_ylabel('log10(CPM+1)')
    if j >= 5:
        ax.set_xlabel('P')
    
fig.tight_layout()
plt.show()

In [None]:
cond = np.logical_and(eta2_t+eta2_c>0.8, np.abs(eta2_t-0.5)<0.2)
cond = np.logical_and(cond, np.abs(eta2_c-0.5)<0.2)
gi_ct = np.arange(len(genes))[cond]
# gi_ct = np.array([g for g in gi_ct if (g not in gi_c and g not in gi_t)])
gi_ct.shape

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
ax.scatter(eta2_t, eta2_c, s=1, color='gray')
ax.scatter(eta2_t[gi_t], eta2_c[gi_t], s=5, color='C0', label='time spec')
ax.scatter(eta2_t[gi_c], eta2_c[gi_c], s=5, color='C1', label='type spec')
ax.scatter(eta2_t[gi_ct], eta2_c[gi_ct], s=5, color='C2', label='type+time spec')
ax.set_aspect('equal')
ax.set_xlabel('var by time')
ax.set_ylabel('var by type')
ax.legend()
ax.plot([0,1], [1,0], color='k', linestyle='--')
plt.show()

In [None]:
pbulks_sub = pbulks[:,:,:,gi_ct]
pbulks_sub = np.swapaxes(pbulks_sub, 0, 3) # nt, nr, nc, ng -> ng, nc, nr, nt
pbulks_sub = np.swapaxes(pbulks_sub, 1, 2) # nt, nr, nc, ng -> ng, nc, nr, nt
gnames = genes[gi_ct]
print(pbulks_g.shape)

In [None]:
fig, axs = plt.subplots(4,5,figsize=(5*3,4*4), sharex=True)
for j, (pbulks_g, gname, ax) in enumerate(zip(pbulks_sub, gnames, axs.flat)):
    ax.set_title(gname)
    for i in range(nc):
        color = colors[i]
        lbl = types[i]
        ax.plot(ts, np.mean(pbulks_g[i], axis=0), color=color, label=lbl)
        ax.scatter(ts, pbulks_g[i][0], s=5, color=color)
        ax.scatter(ts, pbulks_g[i][1], s=5, color=color)
    ax.set_xticks(ts)
    sns.despine(ax=ax)
    if j == 0:
        ax.legend()
    if j % 5 == 0:
        ax.set_ylabel('log10(CPM+1)')
    if j >= 5:
        ax.set_xlabel('hour')
    
fig.tight_layout()
plt.show()

In [None]:
gi1 = basicu.get_index_from_array(genes, ['Slc17a7'])
gi2 = basicu.get_index_from_array(genes, ['Cux2'])
gi3 = basicu.get_index_from_array(genes, ['Cdh13'])

gi1, gi2, gi3

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
ax.scatter(eta2_t, eta2_c, s=1, color='lightgray')
ax.scatter(eta2_t[gi_t], eta2_c[gi_t], s=5, color='C0', label='time spec')
ax.scatter(eta2_t[gi_c], eta2_c[gi_c], s=5, color='C1', label='type spec')
ax.scatter(eta2_t[gi_ct], eta2_c[gi_ct], s=5, color='C2', label='type+time spec')
ax.scatter(eta2_t[gi1], eta2_c[gi1], s=20, color='C3', label='Slc17a7')
ax.scatter(eta2_t[gi2], eta2_c[gi2], s=20, color='C4', label='Cux2')
ax.scatter(eta2_t[gi3], eta2_c[gi3], s=20, color='C5', label='Cdh13')
ax.set_aspect('equal')
ax.set_xlabel('var by time')
ax.set_ylabel('var by type')
ax.legend()
ax.plot([0,1], [1,0], color='k', linestyle='--')
plt.show()

In [None]:
for gi in [gi1, gi2, gi3]:
    pbulks_sub = pbulks[:,:,:,gi] #, axis=3)
    pbulks_sub = np.swapaxes(pbulks_sub, 0, 3) # nt, nr, nc, ng -> ng, nc, nr, nt
    pbulks_sub = np.swapaxes(pbulks_sub, 1, 2) # nt, nr, nc, ng -> ng, nc, nr, nt
    gnames = genes[gi]
    print(pbulks_sub.shape)

    fig, axs = plt.subplots(1,1,figsize=(1*5,1*4), sharex=True)
    for j, (pbulks_g, gname, ax) in enumerate(zip(pbulks_sub, gnames, [axs])):
        ax.set_title(gname)
        for i in range(nc):
            color = colors[i]
            lbl = types[i]
            ax.plot(ts, np.mean(pbulks_g[i], axis=0), color=color, label=lbl)
            ax.scatter(ts, pbulks_g[i][0], s=5, color=color)
            ax.scatter(ts, pbulks_g[i][1], s=5, color=color)
        ax.set_xticks(ts)
        sns.despine(ax=ax)
        if j == 0:
            ax.legend(bbox_to_anchor=(1,1))
        if j % 5 == 0:
            ax.set_ylabel('log10(CPM+1)')
        if j >= 5:
            ax.set_xlabel('hour')

    fig.tight_layout()
    plt.show()

In [None]:
cond = np.logical_and(eta2_t+eta2_c<0.2, eta2_r<0.01)
gi_nct = np.arange(len(genes))[cond]
gi_nct.shape

In [None]:
eta2_r[gi_nct]

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
ax.scatter(eta2_t, eta2_c, s=1, color='gray')
ax.scatter(eta2_t[gi_nct], eta2_c[gi_nct], s=5, color='C2', label='non type+time spec')
ax.set_aspect('equal')
ax.set_xlabel('var by time')
ax.set_ylabel('var by type')
ax.legend()
ax.plot([0,1], [1,0], color='k', linestyle='--')
plt.show()

In [None]:
pbulks_sub = pbulks[:,:,:,gi_nct]
pbulks_sub = np.swapaxes(pbulks_sub, 0, 3) # nt, nr, nc, ng -> ng, nc, nr, nt
pbulks_sub = np.swapaxes(pbulks_sub, 1, 2) # nt, nr, nc, ng -> ng, nc, nr, nt
gnames = genes[gi_nct]
print(pbulks_sub.shape)

In [None]:
fig, axs = plt.subplots(4,5,figsize=(5*3,4*4), sharex=True)
for j, (pbulks_g, gname, ax) in enumerate(zip(pbulks_sub, gnames, axs.flat)):
    ax.set_title(gname)
    for i in range(nc):
        color = colors[i]
        lbl = types[i]
        ax.plot(ts, np.mean(pbulks_g[i], axis=0), color=color, label=lbl)
        ax.scatter(ts, pbulks_g[i][0], s=5, color=color)
        ax.scatter(ts, pbulks_g[i][1], s=5, color=color)
    ax.set_xticks(ts)
    sns.despine(ax=ax)
    if j == 0:
        ax.legend()
    if j % 5 == 0:
        ax.set_ylabel('log10(CPM+1)')
    if j >= 5:
        ax.set_xlabel('hour')
    
fig.tight_layout()
plt.show()

# Compare with Saumya's results

In [None]:
f = '../../data/annot/specificity-dynamicity_scores.txt'

saumya = pd.read_csv(f, sep='\t', index_col='gene')
saumya

In [None]:
genes_smy_overlap = np.intersect1d(saumya.index.values, genes_comm)
genes_smy_overlap_idx = basicu.get_index_from_array(genes_comm, genes_smy_overlap)
saumya_overlap = saumya.loc[genes_smy_overlap]
print(len(genes_comm), len(saumya), len(genes_smy_overlap))

In [None]:
genes_smy_overlap_idx

In [None]:
fig, ax = plt.subplots()
ax.scatter(saumya_overlap['temporal-dynamicity'], 
           saumya_overlap['subclass-specificity'],
           s=1,
          )
ax.set_aspect('equal')
plt.show()

fig, ax = plt.subplots()
ax.scatter(saumya_overlap['temporal-dynamicity'], 
           eta2_t[genes_smy_overlap_idx],
           s=1,
          )
# ax.set_aspect('equal')
plt.show()

fig, ax = plt.subplots()
ax.scatter(saumya_overlap['subclass-specificity'], 
           eta2_c[genes_smy_overlap_idx],
           s=1,
          )
# ax.set_aspect('equal')
plt.show()