# analysis V4 - new basis code - Dec 6, 2023
- cleaned up the data QC and organizations a bit
- having two anndata and their spin offs - both FISH and proj data 
- flip an axis (y-axis) when plotting - not in data
- focus on L2/3 cells

TODO: 
- separate and organize plotting functions 
- organize the plots and generate more insights

In [None]:
import numpy as np
import pandas as pd
import os 
import matplotlib.pyplot as plt
import umap
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import matplotlib

import anndata 

from scroutines import config_plots
from scroutines import basicu
from scroutines import miscu

In [None]:
import sys
sys.setrecursionlimit(10000)
from scipy.cluster import hierarchy as sch

In [None]:
def norm_data(adata):
    """This procedure is independnet of each gene
    norm by size (cell volume)
    norm by log2(1+)
    norm by zscore
    """
    # size
    med_size = adata.obs['area'].median()
    max_size = adata.obs['area'].max()
    min_size = adata.obs['area'].min()

    print(f"Min cell size {min_size:.1f} um^3\t  {np.power(min_size,1/3):.1f} um")
    print(f"Med cell size {med_size:.1f} um^3\t  {np.power(med_size,1/3):.1f} um")
    print(f"Max cell size {max_size:.1f} um^3\t  {np.power(max_size,1/3):.1f} um")

    size_factor = (adata.obs['area']/med_size).values


    # norm by size; by log2+1; by zscore
    mat_raw = np.array(adata.X)
    mat_nrm = mat_raw/size_factor.reshape(-1,1) # .divide(size_factor, axis=0)
    mat_log = np.log2(1+mat_nrm)
    mat_zsc = (mat_log - np.mean(mat_log, axis=0))/np.std(mat_log, axis=0)

    adata.layers['nrm'] = mat_nrm
    adata.layers['log'] = mat_log
    adata.layers['zsc'] = mat_zsc
    
    return # mat_raw, mat_nrm, mat_log, mat_zsc

In [None]:
pth_dat = '/u/home/f/f7xiesnm/project-zipursky/easifish/lt172/proc/r12345v3/'
!ls $pth_dat

In [None]:
# features
var_names = {
    'r1v3_c0': 'Sorcs3',
    'r1v3_c1': 'Kcnq5',
    'r1v3_c2': 'Chrm2',
    'r1v3_c4': 'Adamts2',
    
    'r2v3_c0': 'Kcnip3',
    'r2v3_c1': 'Rorb',
    'r2v3_c2': 'Cdh13',
    'r2v3_c4': 'Cntn5',
    
    'r3v3_c0': 'Cdh12',
    'r3v3_c1': 'Gria3',
    'r3v3_c2': 'Cntnap2',
    'r3v3_c4': 'Gabrg3',
    
    'r4v3_c0': 'Kcnh5',
    'r4v3_c1': 'RL Cre',
    'r4v3_c2': 'Slc17a7',
    'r4v3_c4': 'Grm8',
    
    'r5v3_c0': 'Ncam2',
    'r5v3_c1': 'Rfx3',
    'r5v3_c2': 'Epha10',
    'r5v3_c4': 'Baz1a',
}
proj_idx = np.array(['r4v3_c1'])

# cells
f_meta = os.path.join(pth_dat, 'roi.csv')

# spots
f_spot = os.path.join(pth_dat, 'spotcount.csv')

In [None]:
# f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/cheng21_cell_scrna/res/L23-ABC-genes-n288-n286unq-annot_v2.csv'
f = '/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/results/gene_ptime_P28_L23_Mar27.tsv'
df_annot = pd.read_csv(f).sort_values('gene_ptime')

In [None]:
var = pd.Series(var_names).to_frame('name')
var['proj'] = False
var.loc[proj_idx, 'proj'] = True
var['ptime'] = df_annot.set_index('gene').reindex(var['name'])['gene_ptime'].values
# var['ptime_order'] = var['ptime'].rank()  

var_idx = var.index.values.astype(str)
var_i2n = var['name'] 
var_n2i = var.reset_index().set_index('name')['index']

gene_idx = np.array([idx for idx in var_idx if idx not in proj_idx])

# ftrs_order  = gene_names.index.values 
# genes_order = gene_names.values
# var_oidx = np.argsort(var['ptime'].values)
# var_order = var_i2n[var_idx[var_oidx]]

In [None]:
var_order = var.sort_values('ptime')['name']

var_order_manual = pd.Series({
    'r2v3_c2': 'Cdh13',
    'r1v3_c4': 'Adamts2',
    'r5v3_c1': 'Rfx3',
    'r2v3_c4': 'Cntn5',
    
    'r3v3_c4': 'Gabrg3',
    'r4v3_c4': 'Grm8',
    
    'r2v3_c0': 'Kcnip3',
    'r5v3_c4': 'Baz1a',
    'r1v3_c0': 'Sorcs3',
    'r1v3_c1': 'Kcnq5',
    'r3v3_c2': 'Cntnap2',
    
    'r3v3_c0': 'Cdh12',
    'r5v3_c2': 'Epha10',
    'r4v3_c0': 'Kcnh5',
    
    'r5v3_c0': 'Ncam2',
    'r1v3_c2': 'Chrm2',
    'r2v3_c1': 'Rorb',
    
    'r3v3_c1': 'Gria3',
    'r4v3_c2': 'Slc17a7',
    'r4v3_c1': 'RL Cre',
    
})


print(var_idx)
print(var_order)
print(var_order_manual)

In [None]:
# cols
raw_var_idx = np.char.add(var_idx, '_raw')
nrm_var_idx = np.char.add(var_idx, '_nrm')
log_var_idx = np.char.add(var_idx, '_log')
zsc_var_idx = np.char.add(var_idx, '_zsc')

raw_gene_idx = np.char.add(gene_idx, '_raw')
nrm_gene_idx = np.char.add(gene_idx, '_nrm')
log_gene_idx = np.char.add(gene_idx, '_log')
zsc_gene_idx = np.char.add(gene_idx, '_zsc')

In [None]:
meta = pd.read_csv(f_meta, index_col=0)
spot = pd.read_csv(f_spot, index_col=0)
assert np.all(meta.index.values == spot.index.values)
print(meta.shape, spot.shape) # 

max_x, max_y, max_z = meta[['x', 'y', 'z']].describe().loc['max']
min_x, min_y, min_z = meta[['x', 'y', 'z']].describe().loc['min']
print(f'x: {min_x:.1f}\t{max_x:.1f}')
print(f'y: {min_y:.1f}\t{max_y:.1f}')
print(f'z: {min_z:.1f}\t{max_z:.1f}')

meta['to_edge'] = np.minimum(
    np.minimum(meta['x']-min_x, max_x-meta['x']), 
    np.minimum(meta['y']-min_y, max_y-meta['y']), 
    np.minimum(meta['z']-min_z, max_z-meta['z']),
)
meta['cov'] = spot.sum(axis=1)

# bin data 
bins_8p = np.linspace(0,400,8+1).astype(int)
bins_4p = np.linspace(0,400,4+1).astype(int)
print(bins_8p, bins_4p)

meta['xb_8p'] = pd.cut(meta['x'], bins=bins_8p)
meta['yb_8p'] = pd.cut(meta['y'], bins=bins_8p)
meta['zb_8p'] = pd.cut(meta['z'], bins=bins_8p)

meta['xb_4p'] = pd.cut(meta['x'], bins=bins_4p)
meta['yb_4p'] = pd.cut(meta['y'], bins=bins_4p)
meta['zb_4p'] = pd.cut(meta['z'], bins=bins_4p)

In [None]:
adata = anndata.AnnData(X=spot.values, obs=meta, var=var)
adata

In [None]:
# remove outliers adata -> adata2
df = adata.obs
conds = [
    df['area'] > 500,
    df['x'] > min_x + 20,
    df['x'] < max_x - 20,
    
    df['y'] > min_y + 20,
    df['y'] < max_y - 20,
    
    df['z'] > min_z + 20,
    df['z'] < max_z - 20,
]
cond_all = np.ones(len(df)) > 0
for cond in conds:
    cond_all = np.logical_and(cond_all, cond)
    print(cond_all.sum())
print(f"Num cells before and after: {len(df)} -> {cond_all.sum()}")

# remove outliers
adata2 = adata[cond_all].copy()
adata2

In [None]:
# separate the proj data adata2 -> adata3
adata3 = adata2[:,gene_idx].copy()
adata3.obs[proj_idx] = np.array(adata2[:,proj_idx].X)
adata3.obs['cov_gene'] = np.array(adata3.X.sum(axis=1))

In [None]:
# normalize the data and record
norm_data(adata2)
df_p2 = adata2.obs.copy()
df_p2[nrm_var_idx] = np.array(adata2.layers['nrm'])

norm_data(adata3)
df_p3 = adata3.obs.copy()
df_p3[nrm_gene_idx] = np.array(adata3.layers['nrm'])

# report 

In [None]:
for idx in var_idx:
    val = adata[:,idx].X[:,0]
    print(f'{idx}\t{var_i2n.loc[idx]}\t{100*np.sum(val>0)/len(val):.2f}%\t{np.min(val):.1f}\t{np.median(val):.1f}\t{np.percentile(val, 99):.1f}\t{np.max(val):.1f}')

In [None]:
cols = ['x', 'y', 'z', 'area']
with sns.plotting_context('paper'):
    fig, axs = plt.subplots(4, 1, figsize=(1*6, 2*4))
    for ax, col in zip(axs, cols):
        sns.histplot(adata.obs[col], ax=ax)
        ax.set_xlabel(col)
    fig.subplots_adjust(hspace=0.5)
    plt.show()


In [None]:
sns.scatterplot(data=adata3.obs, x='area', y='cov', s=2, edgecolor='none')
sns.scatterplot(data=adata3.obs, x='area', y='cov_gene', s=2, edgecolor='none')
plt.show()

plt.scatter(np.log2(adata3.obs['area']), 
            np.log2(adata3.obs['cov']),
            s=2, edgecolor='none',
           )
plt.scatter(np.log2(adata3.obs['area']), 
            np.log2(adata3.obs['cov_gene']),
            s=2, edgecolor='none',
           )

In [None]:
with sns.plotting_context('paper'):
    fig, axs = plt.subplots(2, 1, figsize=(10*1,4*2))
    ax = axs[0]
    sns.boxplot(data=adata.X, ax=ax)
    ax.set_xticklabels(adata.var.index.values, rotation=90) 
    ax.set_ylabel('counts')
    ax.set_xlabel('Genes')
    sns.despine(ax=ax)

    ax = axs[1]
    sns.boxplot(data=adata.X, ax=ax)
    ax.set_xticklabels(adata.var['name'].values, rotation=90)
    sns.despine(ax=ax)
    ax.set_ylim([0,50])
    ax.set_ylabel('counts')
    ax.set_xlabel('Genes')

# z-sectioning visuals
- bin into zbin
- plot for each zbin

In [None]:
sys.path.insert(0, '../')
import plotting_easifish

import importlib
importlib.reload(plotting_easifish)

from plotting_easifish import view_z_sections
from plotting_easifish import view_z_sections_4panels
from plotting_easifish import view_z_sections_labels
from plotting_easifish import gen_discrete_colors


In [None]:
for col in var_idx:
    title = var_i2n[col]
    sp_x, sp_y, sp_z = 'x', 'y', 'zb_4p'
    view_z_sections_4panels(df_p2, sp_x, sp_y, col=col+'_nrm', sp_z=sp_z, cmap='Greys', title=title, flip_y=True)

In [None]:
for col in var_idx:
    title = var_i2n[col]
    sp_x, sp_y, sp_z = 'x', 'z', 'yb_4p'
    view_z_sections_4panels(df_p2, sp_x, sp_y, col=col+'_nrm', sp_z=sp_z, cmap='Greys', title=title, flip_y=True)
    
for col in var_idx:
    title = var_i2n[col]
    sp_x, sp_y, sp_z = 'z', 'y', 'xb_4p'
    view_z_sections_4panels(df_p2, sp_x, sp_y, col=col+'_nrm', sp_z=sp_z, cmap='Greys', title=title, flip_y=True)

In [None]:
sp_x, sp_y = 'x', 'y'
fig, axs = plt.subplots(5,4,figsize=(4*4,4*5), sharey=True, sharex=True)
cbar_ax = fig.add_axes([0.92, 0.5, 0.01, 0.2])
axs.flat[0].invert_yaxis()
for i, (col, ax) in enumerate(zip(var_idx, axs.flat)):
    x = df_p2[sp_x].values
    y = df_p2[sp_y].values
    c = df_p2[col+'_nrm'].values
    vmax=np.percentile(c, 95)
    
    ax.axis('off')
    ax.grid(False)
    g = ax.scatter(x, y, c=c/vmax, s=5, edgecolor='none', cmap='Greys', vmax=1, vmin=-0.1) #vmax, vmin=-0.1*vmax)
    sns.despine(ax=ax)
    ax.set_title(var_i2n[col])
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    
fig.colorbar(g, cax=cbar_ax, label='Normed counts\n(0-95 perctl.)', aspect=5, shrink=0.3, ticks=[0, 1])
fig.subplots_adjust(hspace=0.1, wspace=0.02)

In [None]:
sp_x, sp_y = 'z', 'y'
fig, axs = plt.subplots(5,4,figsize=(4*4,4*5), sharey=True, sharex=True)
cbar_ax = fig.add_axes([0.92, 0.5, 0.01, 0.2])
axs.flat[0].invert_yaxis()
for i, (col, ax) in enumerate(zip(var_idx, axs.flat)):
    x = df_p2[sp_x].values
    y = df_p2[sp_y].values
    c = df_p2[col+'_nrm'].values
    vmax=np.percentile(c, 95)
    
    ax.axis('off')
    ax.grid(False)
    g = ax.scatter(x, y, c=c/vmax, s=5, edgecolor='none', cmap='Greys', vmax=1, vmin=-0.1) #vmax, vmin=-0.1*vmax)
    sns.despine(ax=ax)
    ax.set_title(var_i2n[col])
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    
fig.colorbar(g, cax=cbar_ax, label='Normed counts\n(0-95 perctl.)', aspect=5, shrink=0.3, ticks=[0, 1])
fig.subplots_adjust(hspace=0.1, wspace=0.02)

In [None]:
sp_x, sp_y = 'x', 'z'
fig, axs = plt.subplots(5,4,figsize=(4*4,4*5), sharey=True, sharex=True)
cbar_ax = fig.add_axes([0.92, 0.5, 0.01, 0.2])
axs.flat[0].invert_yaxis()
for i, (col, ax) in enumerate(zip(var_idx, axs.flat)):
    x = df_p2[sp_x].values
    y = df_p2[sp_y].values
    c = df_p2[col+'_nrm'].values
    vmax=np.percentile(c, 95)
    
    ax.axis('off')
    ax.grid(False)
    g = ax.scatter(x, y, c=c/vmax, s=5, edgecolor='none', cmap='Greys', vmax=1, vmin=-0.1) #vmax, vmin=-0.1*vmax)
    sns.despine(ax=ax)
    ax.set_title(var_i2n[col])
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    
fig.colorbar(g, cax=cbar_ax, label='Normed counts\n(0-95 perctl.)', aspect=5, shrink=0.3, ticks=[0, 1])
fig.subplots_adjust(hspace=0.1, wspace=0.02)

# reorder as the gradient

In [None]:
sp_x, sp_y = 'x', 'y'
fig, axs = plt.subplots(5,4,figsize=(4*4,4*5), sharey=True, sharex=True)
cbar_ax = fig.add_axes([0.92, 0.5, 0.01, 0.2])
axs.flat[0].invert_yaxis()
for i, (col, ax) in enumerate(zip(var_order_manual.index.values, axs.flat)):
    x = df_p2[sp_x].values
    y = df_p2[sp_y].values
    c = df_p2[col+'_nrm'].values
    vmax=np.percentile(c, 95)
    
    ax.axis('off')
    ax.grid(False)
    g = ax.scatter(x, y, c=c/vmax, s=5, edgecolor='none', cmap='Greys', vmax=1, vmin=-0.1) #vmax, vmin=-0.1*vmax)
    sns.despine(ax=ax)
    ax.set_title(var_i2n[col])
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    
fig.colorbar(g, cax=cbar_ax, label='Normed counts\n(0-95 perctl.)', aspect=5, shrink=0.3, ticks=[0, 1])
fig.subplots_adjust(hspace=0.1, wspace=0.02)

In [None]:
sp_x, sp_y = 'z', 'y'
fig, axs = plt.subplots(5,4,figsize=(4*4,4*5), sharey=True, sharex=True)
cbar_ax = fig.add_axes([0.92, 0.5, 0.01, 0.2])
axs.flat[0].invert_yaxis()
for i, (col, ax) in enumerate(zip(var_order_manual.index.values, axs.flat)):
    x = df_p2[sp_x].values
    y = df_p2[sp_y].values
    c = df_p2[col+'_nrm'].values
    vmax=np.percentile(c, 95)
    
    ax.axis('off')
    ax.grid(False)
    g = ax.scatter(x, y, c=c/vmax, s=5, edgecolor='none', cmap='Greys', vmax=1, vmin=-0.1) #vmax, vmin=-0.1*vmax)
    sns.despine(ax=ax)
    ax.set_title(var_i2n[col])
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    
fig.colorbar(g, cax=cbar_ax, label='Normed counts\n(0-95 perctl.)', aspect=5, shrink=0.3, ticks=[0, 1])
fig.subplots_adjust(hspace=0.1, wspace=0.02)

In [None]:
sp_x, sp_y = 'x', 'z'
fig, axs = plt.subplots(5,4,figsize=(4*4,4*5), sharey=True, sharex=True)
cbar_ax = fig.add_axes([0.92, 0.5, 0.01, 0.2])
axs.flat[0].invert_yaxis()
for i, (col, ax) in enumerate(zip(var_order_manual.index.values, axs.flat)):
    x = df_p2[sp_x].values
    y = df_p2[sp_y].values
    c = df_p2[col+'_nrm'].values
    vmax=np.percentile(c, 95)
    
    ax.axis('off')
    ax.grid(False)
    g = ax.scatter(x, y, c=c/vmax, s=5, edgecolor='none', cmap='Greys', vmax=1, vmin=-0.1) #vmax, vmin=-0.1*vmax)
    sns.despine(ax=ax)
    ax.set_title(var_i2n[col])
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    
fig.colorbar(g, cax=cbar_ax, label='Normed counts\n(0-95 perctl.)', aspect=5, shrink=0.3, ticks=[0, 1])
fig.subplots_adjust(hspace=0.1, wspace=0.02)

# hierarchical clustering

In [None]:
mat_ftrs = pd.DataFrame(adata3.layers['zsc'], index=adata3.obs.index, columns=adata3.var.index)
mat_prjs = adata3.obs[proj_idx] 

In [None]:
mat_ftrs

In [None]:
from umap import UMAP 
ucs = UMAP(n_components=2).fit_transform(mat_ftrs)
ucs.shape

In [None]:
g = mat_ftrs[var_n2i['Slc17a7']]
# g = mat_ftrs[var_n2i['Cdh13']]
# g = mat_ftrs[var_n2i['Sorcs3']]
vmax = np.percentile(g, 95)
vmin = np.percentile(g,  5)
p = plt.scatter(ucs[:,0], ucs[:,1], s=1, 
                c=g, vmin=vmin, vmax=vmax,
                cmap='rocket_r',
               )
plt.colorbar(p)

In [None]:


# sch
col_link = sch.linkage(mat_ftrs.T, method='ward', metric='euclidean')
col_order = sch.dendrogram(col_link, no_plot=True)['leaves']

row_link = sch.linkage(mat_ftrs, method='ward', metric='euclidean')
row_order = sch.dendrogram(row_link, p=30, truncate_mode='none', no_plot=True)['leaves']

# get clusters
clsts1 = sch.fcluster(row_link, 100, criterion='distance')
clsts2 = sch.fcluster(row_link,  50, criterion='distance')
clsts3 = sch.fcluster(row_link,  30, criterion='distance')

print(clsts1.shape, np.unique(clsts1).shape)
print(clsts2.shape, np.unique(clsts2).shape)
print(clsts3.shape, np.unique(clsts3).shape)


# register
adata3.obs['clst_l1'] = clsts1
adata3.obs['clst_l2'] = clsts2
adata3.obs['clst_l3'] = clsts3

adata2.obs['clst_l1'] = clsts1
adata2.obs['clst_l2'] = clsts2
adata2.obs['clst_l3'] = clsts3


df_p3['clst_l1'] = clsts1
df_p3['clst_l2'] = clsts2
df_p3['clst_l3'] = clsts3

df_p2['clst_l1'] = clsts1
df_p2['clst_l2'] = clsts2
df_p2['clst_l3'] = clsts3

# colors
palette1, dismap1 = gen_discrete_colors(len(np.unique(clsts1)))
palette2, dismap2 = gen_discrete_colors(len(np.unique(clsts2)))
palette3, dismap3 = gen_discrete_colors(len(np.unique(clsts3)))

In [None]:
fig = plt.figure(figsize=(10,15))
axd = fig.subplot_mosaic("A"*12+"BCDEF")
ax = axd["A"]
col_idx = mat_ftrs.columns[col_order].values.astype(str)
col_names = var_i2n[col_idx].values.astype(str)
xticklabels = np.char.add(np.char.add(col_names, ' '), col_idx)

sns.heatmap(
    mat_ftrs.iloc[row_order,col_order],
    yticklabels=False,
    xticklabels=xticklabels,
    cmap='coolwarm',
    cbar_kws=dict(shrink=0.5, location='bottom', label='zscore log2(1+norm expr.)'),
    vmax=3, vmin=-3,
    ax=ax,
)
ax.set_ylabel('Cells')
# ax.set_xticklabels()

for i, (ax, col) in enumerate(zip(
    (axd["B"], axd["C"], axd["D"]),
    ('clst_l1', 'clst_l2', 'clst_l3'),
    )):
    vec = adata3.obs[col].values
    sns.heatmap(
        adata3.obs[[col]].rename(columns={col: col.replace('clst_l', 'L')}).iloc[row_order],
        yticklabels=False,
        cmap='jet',
        cbar_kws=dict(location='bottom', aspect=5, label='', ticks=[]),
        ax=ax,
    )
    ax.set_ylabel('')
    positions = np.cumsum(np.unique(vec, return_counts=True)[1])
    ax.hlines(positions, xmin=-0.2, xmax=1, linewidth=1, linestyle='-', color='k', clip_on=False)
    for i, pos in enumerate(positions):
        ax.text(0, pos, i+1, color='white', fontsize=15)

ax = axd["E"]
col_idx = mat_prjs.columns.values.astype(str)
col_names = var_i2n[col_idx].values.astype(str)
xticklabels = np.char.add(np.char.add(col_names, ' '), col_idx)
g = sns.heatmap(
    mat_prjs.iloc[row_order],
    yticklabels=False,
    xticklabels=xticklabels,
    cmap='Greys',
    cbar_kws=dict(location='bottom', aspect=5, ),
    ax=ax,
)
ax.tick_params(rotation=90)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points

ax = axd["F"]
g = sns.heatmap(
    adata3.obs[['y']].iloc[row_order], 
    yticklabels=False,
    cmap='Greys',
    cbar_kws=dict(location='bottom', aspect=5, label='um'),
    # vmin=0,
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10, rotation=90)  # Set fontsize to 12 points


In [None]:
sp_x, sp_y = 'x', 'y'
fig, axs = plt.subplots(1,3,figsize=(4*3,5*1), sharey=True, sharex=True)
axs.flat[0].invert_yaxis()
# cbar_ax = fig.add_axes([0.92, 0.5, 0.01, 0.2])
for i, (col, palette, dismap, ax) in enumerate(zip(['clst_l1', 'clst_l2', 'clst_l3'], 
                                                   [ palette1,  palette2,  palette3],
                                                   [ dismap1,   dismap2,   dismap3],
                                                   axs.flat)):
    x = adata3.obs[sp_x].values
    y = adata3.obs[sp_y].values
    c = adata3.obs[col].values
    
    ax.grid(False)
    g = ax.scatter(x, y, c=pd.Series(palette)[c], s=5, edgecolor='none',)
    sns.despine(ax=ax)
    ax.set_title(col)
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    ax.axis('off')
    fig.colorbar(dismap, ax=ax, orientation='horizontal', shrink=0.5, aspect=10)
    
fig.subplots_adjust(hspace=0.1, wspace=0.02)

In [None]:
sp_x, sp_y = 'x', 'y'
col = 'clst_l2'
palette = palette2
dismap = dismap2
n = len(adata3.obs[col].unique())
nx = min(4,n)
ny = int((n+nx-1)/nx)
suptitle = f'L2 ({n} clusters.)'

clst_order_manual = None
clst_order_manual = [6,8,9,7, 3,1,2,5, 4,10,11]
if clst_order_manual is not None:
    clst_order = clst_order_manual
else:
    clst_order = np.arange(n)+1

fig, axs = plt.subplots(ny,nx,figsize=(3*nx,3*ny), sharey=True, sharex=True)
axs.flat[0].invert_yaxis()
for i, ax in zip(clst_order, axs.flat):
    x = adata3.obs[sp_x].values
    y = adata3.obs[sp_y].values
    c = adata3.obs[col].values
    
    ax.grid(False)
    g = ax.scatter(x, y, c=[palette[_c] if i==_c else 'lightgray' for _c in c], s=5, edgecolor='none',)
    
    ax.set_title(f"C{i}")
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    ax.axis('off')

for ax in axs.flat[n-1:]:
    ax.axis('off')
    
fig.suptitle(suptitle)
fig.subplots_adjust(hspace=0.1, wspace=0.02)

In [None]:
group = 'clst_l2'
clst_order_manual = None
clst_order_manual = [6,8,9,7, 3,1,2,5, 4,10,11]

dfmean = df_p2.groupby(group)[np.char.add(var_order_manual.index.values.astype(str), '_nrm')].mean()
dfmean = basicu.zscore(dfmean, axis=0)
if clst_order_manual is not None:
    dfmean = dfmean.reindex(clst_order_manual)
yticklabels = var_order_manual.values


fig, ax = plt.subplots(figsize=(6,6))
sns.heatmap(dfmean.T, ax=ax, 
            yticklabels=yticklabels, 
            cmap='coolwarm', cbar_kws=dict(label='zscore mean exp.', shrink=0.5, aspect=10))

# remove outlier cells and redo

In [None]:
remove_these = [4,10,11]
adata2v2 = adata2[~adata2.obs['clst_l2'].isin(remove_these)].copy()
adata3v2 = adata3[~adata3.obs['clst_l2'].isin(remove_these)].copy()
print(adata2v2.shape, adata3v2.shape)

In [None]:
df_p2v2 = adata2v2.obs.copy()
df_p2v2[nrm_var_idx] = np.array(adata2v2.layers['nrm'])
df_p3v2 = adata3v2.obs.copy()
df_p3v2[nrm_gene_idx] = np.array(adata3v2.layers['nrm'])

In [None]:
# # normalize the data and record
# norm_data(adata2v2)
# norm_data(adata3v2)

In [None]:
mat_ftrs = pd.DataFrame(adata3v2.layers['zsc'], index=adata3v2.obs.index, columns=adata3v2.var.index)
mat_prjs = adata3v2.obs[proj_idx] 

# sch
col_link = sch.linkage(mat_ftrs.T, method='ward', metric='euclidean')
col_order = sch.dendrogram(col_link, no_plot=True)['leaves']

row_link = sch.linkage(mat_ftrs, method='ward', metric='euclidean')
row_order = sch.dendrogram(row_link, p=30, truncate_mode='none', no_plot=True)['leaves']

# get clusters
clsts1 = sch.fcluster(row_link, 100, criterion='distance')
clsts2 = sch.fcluster(row_link,  50, criterion='distance')
clsts3 = sch.fcluster(row_link,  30, criterion='distance')

print(clsts1.shape, np.unique(clsts1).shape)
print(clsts2.shape, np.unique(clsts2).shape)
print(clsts3.shape, np.unique(clsts3).shape)


# register
adata3v2.obs['clst_l1'] = clsts1
adata3v2.obs['clst_l2'] = clsts2
adata3v2.obs['clst_l3'] = clsts3

adata2v2.obs['clst_l1'] = clsts1
adata2v2.obs['clst_l2'] = clsts2
adata2v2.obs['clst_l3'] = clsts3

df_p3v2['clst_l1'] = clsts1
df_p3v2['clst_l2'] = clsts2
df_p3v2['clst_l3'] = clsts3

df_p2v2['clst_l1'] = clsts1
df_p2v2['clst_l2'] = clsts2
df_p2v2['clst_l3'] = clsts3

# colors
palette1, dismap1 = gen_discrete_colors(len(np.unique(clsts1)))
palette2, dismap2 = gen_discrete_colors(len(np.unique(clsts2)))
palette3, dismap3 = gen_discrete_colors(len(np.unique(clsts3)))

In [None]:
fig = plt.figure(figsize=(10,15))
axd = fig.subplot_mosaic("A"*12+"BCDEF")
ax = axd["A"]
col_idx = mat_ftrs.columns[col_order].values.astype(str)
col_names = var_i2n[col_idx].values.astype(str)
xticklabels = np.char.add(np.char.add(col_names, ' '), col_idx)

sns.heatmap(
    mat_ftrs.iloc[row_order,col_order],
    yticklabels=False,
    xticklabels=xticklabels,
    cmap='coolwarm',
    cbar_kws=dict(shrink=0.5, location='bottom', label='zscore log2(1+norm expr.)'),
    vmax=3, vmin=-3,
    ax=ax,
)
ax.set_ylabel('Cells')
# ax.set_xticklabels()

for i, (ax, col) in enumerate(zip(
    (axd["B"], axd["C"], axd["D"]),
    ('clst_l1', 'clst_l2', 'clst_l3'),
    )):
    vec = adata3v2.obs[col].values
    sns.heatmap(
        adata3v2.obs[[col]].rename(columns={col: col.replace('clst_l', 'L')}).iloc[row_order],
        yticklabels=False,
        cmap='jet',
        cbar_kws=dict(location='bottom', aspect=5, label='', ticks=[]),
        ax=ax,
    )
    ax.set_ylabel('')
    positions = np.cumsum(np.unique(vec, return_counts=True)[1])
    ax.hlines(positions, xmin=-0.2, xmax=1, linewidth=1, linestyle='-', color='k', clip_on=False)
    for i, pos in enumerate(positions):
        ax.text(0, pos, i+1, color='white', fontsize=15)

ax = axd["E"]
g = sns.heatmap(
    mat_prjs.iloc[row_order],
    yticklabels=False,
    cmap='Greys',
    cbar_kws=dict(location='bottom', aspect=5, ),
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points

ax = axd["F"]
g = sns.heatmap(
    adata3v2.obs[['y']].iloc[row_order], 
    yticklabels=False,
    cmap='Greys',
    cbar_kws=dict(location='bottom', aspect=5, label='um'),
    # vmin=0,
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10, rotation=90)  # Set fontsize to 12 points


In [None]:
sp_x, sp_y = 'x', 'y'
fig, axs = plt.subplots(1,3,figsize=(4*3,5*1), sharey=True, sharex=True)
axs.flat[0].invert_yaxis()
# cbar_ax = fig.add_axes([0.92, 0.5, 0.01, 0.2])
for i, (col, palette, dismap, ax) in enumerate(zip(['clst_l1', 'clst_l2', 'clst_l3'], 
                                                   [ palette1,  palette2,  palette3],
                                                   [ dismap1,   dismap2,   dismap3],
                                                   axs.flat)):
    x = adata3v2.obs[sp_x].values
    y = adata3v2.obs[sp_y].values
    c = adata3v2.obs[col].values
    
    ax.grid(False)
    g = ax.scatter(x, y, c=pd.Series(palette)[c], s=5, edgecolor='none',)
    sns.despine(ax=ax)
    ax.set_title(col)
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    ax.axis('off')
    fig.colorbar(dismap, ax=ax, orientation='horizontal', shrink=0.5, aspect=10)
    
fig.subplots_adjust(hspace=0.1, wspace=0.02)

In [None]:
sp_x, sp_y = 'x', 'y'
col = 'clst_l2'
palette = palette2
dismap = dismap2
n = len(df_p3v2[col].unique())
nx = min(4,n)
ny = int((n+nx-1)/nx)
suptitle = f'L2 ({n} clusters.)'

clst_order_manual = None
clst_order_manual = [3,1,4,2,5,7,8,6]
if clst_order_manual is not None:
    clst_order = clst_order_manual
else:
    clst_order = np.arange(n)+1

fig, axs = plt.subplots(ny,nx,figsize=(3*nx,3*ny), sharey=True, sharex=True)
axs.flat[0].invert_yaxis()
for i, ax in zip(clst_order, axs.flat):
    x = adata3v2.obs[sp_x].values
    y = adata3v2.obs[sp_y].values
    c = adata3v2.obs[col].values
    
    ax.grid(False)
    g = ax.scatter(x, y, c=[palette[_c] if i==_c else 'lightgray' for _c in c], s=5, edgecolor='none',)
    
    ax.set_title(f"C{i}")
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    ax.axis('off')

for ax in axs.flat[n-1:]:
    ax.axis('off')
    
fig.suptitle(suptitle)
fig.subplots_adjust(hspace=0.1, wspace=0.02)

In [None]:
group = 'clst_l2'
# clst_order_manual = None
# clst_order_manual = [6,8,9,7, 3,1,2,5, 4,10,11]

dfmean = df_p2v2.groupby(group)[np.char.add(var_order_manual.index.values.astype(str), '_nrm')].mean()
dfmean = basicu.zscore(dfmean, axis=0)
if clst_order_manual is not None:
    dfmean = dfmean.reindex(clst_order_manual)
yticklabels = var_order_manual.values


fig, ax = plt.subplots(figsize=(6,6))
sns.heatmap(dfmean.T, ax=ax, 
            yticklabels=yticklabels, 
            cmap='coolwarm', cbar_kws=dict(label='zscore mean exp.', shrink=0.5, aspect=10))

# order cells by y-axis

In [None]:
row_order = np.argsort(adata3v2.obs['y']).values # .sort_values().index
row_order

In [None]:
fig = plt.figure(figsize=(10,15))
axd = fig.subplot_mosaic("A"*12+"BCDEF")
ax = axd["A"]
col_idx = mat_ftrs.columns[col_order].values.astype(str)
col_names = var_i2n[col_idx].values.astype(str)
xticklabels = np.char.add(np.char.add(col_names, ' '), col_idx)

sns.heatmap(
    mat_ftrs.iloc[row_order,col_order],
    yticklabels=False,
    xticklabels=xticklabels,
    cmap='coolwarm',
    cbar_kws=dict(shrink=0.5, location='bottom', label='zscore log2(1+norm expr.)'),
    vmax=3, vmin=-3,
    ax=ax,
)
ax.set_ylabel('Cells')
# ax.set_xticklabels()

for i, (ax, col) in enumerate(zip(
    (axd["B"], axd["C"], axd["D"]),
    ('clst_l1', 'clst_l2', 'clst_l3'),
    )):
    # vec = adata3v2.obs[col][row_order].values
    sns.heatmap(
        adata3v2.obs[[col]].rename(columns={col: col.replace('clst_l', 'L')}).iloc[row_order],
        yticklabels=False,
        cmap='jet',
        cbar_kws=dict(location='bottom', aspect=5, label='', ticks=[]),
        ax=ax,
    )
    ax.set_ylabel('')
    # positions = np.cumsum(np.unique(vec, return_counts=True)[1])
    # ax.hlines(positions, xmin=-0.2, xmax=1, linewidth=1, linestyle='-', color='k', clip_on=False)
    # for i, pos in enumerate(positions):
    #     ax.text(0, pos, i+1, color='white', fontsize=15)

ax = axd["E"]
g = sns.heatmap(
    mat_prjs.iloc[row_order],
    yticklabels=False,
    cmap='Greys',
    cbar_kws=dict(location='bottom', aspect=5, ),
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points

ax = axd["F"]
g = sns.heatmap(
    adata3v2.obs[['y']].iloc[row_order], 
    yticklabels=False,
    cmap='Greys',
    cbar_kws=dict(location='bottom', aspect=5, label='um'),
    # vmin=0,
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10, rotation=90)  # Set fontsize to 12 points

In [None]:
sp_x, sp_y = 'x', 'y'
col = 'clst_l2'
palette = palette2
view_z_sections_labels(df_p3v2, sp_x, sp_y, col, palette=palette, title=None, flip_y=True)

# remove and redo again - L2/3 only

In [None]:
remove_these = [] #[8,6]
adata2v3 = adata2v2[~adata2v2.obs['clst_l2'].isin(remove_these)].copy()
adata3v3 = adata3v2[~adata3v2.obs['clst_l2'].isin(remove_these)].copy()
print(adata2v3.shape, adata3v3.shape)

In [None]:
# normalize the data and record
norm_data(adata2v3)
norm_data(adata3v3)

In [None]:
df_p2v3 = adata2v3.obs.copy()
df_p2v3[nrm_var_idx]  = np.array(adata2v3.layers['nrm'])

df_p3v3 = adata3v3.obs.copy()
df_p3v3[nrm_gene_idx] = np.array(adata3v3.layers['nrm'])

In [None]:
mat_ftrs = pd.DataFrame(adata3v3.layers['zsc'], index=adata3v3.obs.index, columns=adata3v3.var.index)
mat_prjs = adata3v3.obs[proj_idx] 

# sch
col_link = sch.linkage(mat_ftrs.T, method='ward', metric='euclidean')
col_order = sch.dendrogram(col_link, no_plot=True)['leaves']

row_link = sch.linkage(mat_ftrs, method='ward', metric='euclidean')
row_order = sch.dendrogram(row_link, p=30, truncate_mode='none', no_plot=True)['leaves']

# get clusters
clsts1 = sch.fcluster(row_link, 100, criterion='distance')
clsts2 = sch.fcluster(row_link,  50, criterion='distance')
clsts3 = sch.fcluster(row_link,  30, criterion='distance')

print(clsts1.shape, np.unique(clsts1).shape)
print(clsts2.shape, np.unique(clsts2).shape)
print(clsts3.shape, np.unique(clsts3).shape)


# register
adata3v3.obs['clst_l1'] = clsts1
adata3v3.obs['clst_l2'] = clsts2
adata3v3.obs['clst_l3'] = clsts3

adata2v3.obs['clst_l1'] = clsts1
adata2v3.obs['clst_l2'] = clsts2
adata2v3.obs['clst_l3'] = clsts3

df_p3v3['clst_l1'] = clsts1
df_p3v3['clst_l2'] = clsts2
df_p3v3['clst_l3'] = clsts3

df_p2v3['clst_l1'] = clsts1
df_p2v3['clst_l2'] = clsts2
df_p2v3['clst_l3'] = clsts3

# colors
palette1, dismap1 = gen_discrete_colors(len(np.unique(clsts1)))
palette2, dismap2 = gen_discrete_colors(len(np.unique(clsts2)))
palette3, dismap3 = gen_discrete_colors(len(np.unique(clsts3)))

In [None]:
fig = plt.figure(figsize=(10,15))
axd = fig.subplot_mosaic("A"*12+"BCDEF")
ax = axd["A"]
col_idx = mat_ftrs.columns[col_order].values.astype(str)
col_names = var_i2n[col_idx].values.astype(str)
xticklabels = np.char.add(np.char.add(col_names, ' '), col_idx)

sns.heatmap(
    mat_ftrs.iloc[row_order,col_order],
    yticklabels=False,
    xticklabels=xticklabels,
    cmap='coolwarm',
    cbar_kws=dict(shrink=0.5, location='bottom', label='zscore log2(1+norm expr.)'),
    vmax=3, vmin=-3,
    ax=ax,
)
ax.set_ylabel('Cells')
# ax.set_xticklabels()

for i, (ax, col) in enumerate(zip(
    (axd["B"], axd["C"], axd["D"]),
    ('clst_l1', 'clst_l2', 'clst_l3'),
    )):
    vec = adata3v3.obs[col].values
    sns.heatmap(
        adata3v3.obs[[col]].rename(columns={col: col.replace('clst_l', 'L')}).iloc[row_order],
        yticklabels=False,
        cmap='jet',
        cbar_kws=dict(location='bottom', aspect=5, label='', ticks=[]),
        ax=ax,
    )
    ax.set_ylabel('')
    positions = np.cumsum(np.unique(vec, return_counts=True)[1])
    ax.hlines(positions, xmin=-0.2, xmax=1, linewidth=1, linestyle='-', color='k', clip_on=False)
    for i, pos in enumerate(positions):
        ax.text(0, pos, i+1, color='white', fontsize=15)

ax = axd["E"]
g = sns.heatmap(
    mat_prjs.iloc[row_order],
    yticklabels=False,
    cmap='Greys',
    cbar_kws=dict(location='bottom', aspect=5, ),
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points

ax = axd["F"]
g = sns.heatmap(
    adata3v3.obs[['y']].iloc[row_order], 
    yticklabels=False,
    cmap='Greys',
    cbar_kws=dict(location='bottom', aspect=5, label='um'),
    # vmin=0,
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10, rotation=90)  # Set fontsize to 12 points


In [None]:
sp_x, sp_y = 'x', 'y'
fig, axs = plt.subplots(1,3,figsize=(4*3,5*1), sharey=True, sharex=True)
axs.flat[0].invert_yaxis()
# cbar_ax = fig.add_axes([0.92, 0.5, 0.01, 0.2])
for i, (col, palette, dismap, ax) in enumerate(zip(['clst_l1', 'clst_l2', 'clst_l3'], 
                                                   [ palette1,  palette2,  palette3],
                                                   [ dismap1,   dismap2,   dismap3],
                                                   axs.flat)):
    x = adata3v3.obs[sp_x].values
    y = adata3v3.obs[sp_y].values
    c = adata3v3.obs[col].values
    
    ax.grid(False)
    g = ax.scatter(x, y, c=pd.Series(palette)[c], s=5, edgecolor='none',)
    sns.despine(ax=ax)
    ax.set_title(col)
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    ax.axis('off')
    fig.colorbar(dismap, ax=ax, orientation='horizontal', shrink=0.5, aspect=10)
    
fig.subplots_adjust(hspace=0.1, wspace=0.02)

In [None]:
sp_x, sp_y = 'x', 'y'
col = 'clst_l2'
palette = palette2
dismap = dismap2
n = len(df_p3v3[col].unique())
nx = min(4,n)
ny = int((n+nx-1)/nx)
suptitle = f'L2 ({n} clusters.)'

clst_order_manual = None
clst_order_manual = [7,6,5,4,3,1,2]
if clst_order_manual is not None:
    clst_order = clst_order_manual
else:
    clst_order = np.arange(n)+1

fig, axs = plt.subplots(ny,nx,figsize=(3*nx,3*ny), sharey=True, sharex=True)
axs.flat[0].invert_yaxis()
for i, ax in zip(clst_order, axs.flat):
    x = adata3v3.obs[sp_x].values
    y = adata3v3.obs[sp_y].values
    c = adata3v3.obs[col].values
    
    ax.grid(False)
    g = ax.scatter(x, y, c=[palette[_c] if i==_c else 'lightgray' for _c in c], s=5, edgecolor='none',)
    
    ax.set_title(f"C{i}")
    ax.set_xlabel(sp_x)
    ax.set_ylabel(sp_y)
    ax.set_aspect('equal')
    ax.axis('off')

for ax in axs.flat[n-1:]:
    ax.axis('off')
    
fig.suptitle(suptitle)
fig.subplots_adjust(hspace=0.1, wspace=0.02)

In [None]:
group = 'clst_l2'
# clst_order_manual = None
# clst_order_manual = [6,8,9,7, 3,1,2,5, 4,10,11]

dfmean = df_p2v3.groupby(group)[np.char.add(var_order_manual.index.values.astype(str), '_nrm')].mean()
dfmean = basicu.zscore(dfmean, axis=0)
if clst_order_manual is not None:
    dfmean = dfmean.reindex(clst_order_manual)
yticklabels = var_order_manual.values


fig, ax = plt.subplots(figsize=(6,6))
sns.heatmap(dfmean.T, ax=ax, 
            yticklabels=yticklabels, 
            cmap='coolwarm', cbar_kws=dict(label='zscore mean exp.', shrink=0.5, aspect=10))

# order by y-axis

In [None]:
row_order = np.argsort(adata3v3.obs['y']).values # .sort_values().index
col_order = basicu.get_index_from_array(mat_ftrs.columns, var_order_manual.index.values[:-1])
col_order

In [None]:
fig = plt.figure(figsize=(20,8))
axd = fig.subplot_mosaic("A\n"*12+"B\nC\nD\nE\nF")
ax = axd["A"]
col_idx = mat_ftrs.columns[col_order].values.astype(str)
col_names = var_i2n[col_idx].values.astype(str)
yticklabels = np.char.add(np.char.add(col_names, ' '), col_idx)

sns.heatmap(
    mat_ftrs.iloc[row_order,col_order].T,
    xticklabels=False,
    yticklabels=yticklabels,
    cmap='coolwarm',
    cbar_kws=dict(shrink=0.5, label='zscore log2(1+norm expr.)'),
    vmax=3, vmin=-3,
    ax=ax,
)
ax.set_ylabel('Cells')
# ax.set_xticklabels()

for i, (ax, col) in enumerate(zip(
    (axd["B"], axd["C"], axd["D"]),
    ('clst_l1', 'clst_l2', 'clst_l3'),
    )):
    # vec = adata3v2.obs[col][row_order].values
    sns.heatmap(
        adata3v3.obs[[col]].rename(columns={col: col.replace('clst_l', 'L')}).iloc[row_order].T,
        xticklabels=False,
        cmap='jet',
        cbar_kws=dict(aspect=5, label='', ticks=[]),
        ax=ax,
    )
    ax.set_ylabel('')
    # positions = np.cumsum(np.unique(vec, return_counts=True)[1])
    # ax.hlines(positions, xmin=-0.2, xmax=1, linewidth=1, linestyle='-', color='k', clip_on=False)
    # for i, pos in enumerate(positions):
    #     ax.text(0, pos, i+1, color='white', fontsize=15)

ax = axd["E"]
g = sns.heatmap(
    mat_prjs.iloc[row_order].T,
    xticklabels=False,
    cmap='Greys',
    cbar_kws=dict(aspect=5, ),
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points

ax = axd["F"]
g = sns.heatmap(
    adata3v3.obs[['y']].iloc[row_order].T, 
    xticklabels=False,
    cmap='Greys',
    cbar_kws=dict(aspect=5, label='um'),
    # vmin=0,
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points

In [None]:
mat_ftrs = pd.DataFrame(adata3v3.layers['zsc'], index=adata3v3.obs.index, columns=adata3v3.var.index)
mat_prjs = adata3v3.obs[proj_idx] 

# sch
col_link = sch.linkage(mat_ftrs.T, method='ward', metric='euclidean')
col_order = sch.dendrogram(col_link, no_plot=True)['leaves']

row_link = sch.linkage(mat_ftrs, method='ward', metric='euclidean')
row_order = sch.dendrogram(row_link, p=30, truncate_mode='none', no_plot=True)['leaves']

# get clusters
clsts1 = sch.fcluster(row_link, 100, criterion='distance')
clsts2 = sch.fcluster(row_link,  50, criterion='distance')
clsts3 = sch.fcluster(row_link,  30, criterion='distance')

print(clsts1.shape, np.unique(clsts1).shape)
print(clsts2.shape, np.unique(clsts2).shape)
print(clsts3.shape, np.unique(clsts3).shape)


# register
adata3v3.obs['clst_l1'] = clsts1
adata3v3.obs['clst_l2'] = clsts2
adata3v3.obs['clst_l3'] = clsts3

adata2v3.obs['clst_l1'] = clsts1
adata2v3.obs['clst_l2'] = clsts2
adata2v3.obs['clst_l3'] = clsts3

df_p3v3['clst_l1'] = clsts1
df_p3v3['clst_l2'] = clsts2
df_p3v3['clst_l3'] = clsts3

df_p2v3['clst_l1'] = clsts1
df_p2v3['clst_l2'] = clsts2
df_p2v3['clst_l3'] = clsts3

# colors
palette1, dismap1 = gen_discrete_colors(len(np.unique(clsts1)))
palette2, dismap2 = gen_discrete_colors(len(np.unique(clsts2)))
palette3, dismap3 = gen_discrete_colors(len(np.unique(clsts3)))

In [None]:
fig = plt.figure(figsize=(20,8))
axd = fig.subplot_mosaic("A\n"*12+"B\nC\nD\nE\nF")
ax = axd["A"]
col_idx = mat_ftrs.columns[col_order].values.astype(str)
col_names = var_i2n[col_idx].values.astype(str)
yticklabels = np.char.add(np.char.add(col_names, ' '), col_idx)

sns.heatmap(
    mat_ftrs.iloc[row_order,col_order].T,
    xticklabels=False,
    yticklabels=yticklabels,
    cmap='coolwarm',
    cbar_kws=dict(shrink=0.5, label='zscore log2(1+norm expr.)'),
    vmax=3, vmin=-3,
    ax=ax,
)
ax.set_ylabel('Cells')
# ax.set_xticklabels()

for i, (ax, col) in enumerate(zip(
    (axd["B"], axd["C"], axd["D"]),
    ('clst_l1', 'clst_l2', 'clst_l3'),
    )):
    # vec = adata3v2.obs[col][row_order].values
    sns.heatmap(
        adata3v3.obs[[col]].rename(columns={col: col.replace('clst_l', 'L')}).iloc[row_order].T,
        xticklabels=False,
        cmap='jet',
        cbar_kws=dict(aspect=5, label='', ticks=[]),
        ax=ax,
    )
    ax.set_ylabel('')
    # positions = np.cumsum(np.unique(vec, return_counts=True)[1])
    # ax.hlines(positions, xmin=-0.2, xmax=1, linewidth=1, linestyle='-', color='k', clip_on=False)
    # for i, pos in enumerate(positions):
    #     ax.text(0, pos, i+1, color='white', fontsize=15)

ax = axd["E"]
g = sns.heatmap(
    mat_prjs.iloc[row_order].T,
    xticklabels=False,
    cmap='Greys',
    cbar_kws=dict(aspect=5, ),
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points

ax = axd["F"]
g = sns.heatmap(
    adata3v3.obs[['y']].iloc[row_order].T, 
    xticklabels=False,
    cmap='Greys',
    cbar_kws=dict(aspect=5, label='um'),
    # vmin=0,
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points

In [None]:
col_order = basicu.get_index_from_array(mat_ftrs.columns, var_order_manual.index.values[:-1])

In [None]:

fig = plt.figure(figsize=(20,8))
axd = fig.subplot_mosaic("A\n"*12+"B\nC\nD\nE\nF")
ax = axd["A"]
col_idx = mat_ftrs.columns[col_order].values.astype(str)
col_names = var_i2n[col_idx].values.astype(str)
yticklabels = np.char.add(np.char.add(col_names, ' '), col_idx)

sns.heatmap(
    mat_ftrs.iloc[row_order,col_order].T,
    xticklabels=False,
    yticklabels=yticklabels,
    cmap='coolwarm',
    cbar_kws=dict(shrink=0.5, label='zscore log2(1+norm expr.)'),
    vmax=3, vmin=-3,
    ax=ax,
)
ax.set_ylabel('Cells')
# ax.set_xticklabels()

for i, (ax, col) in enumerate(zip(
    (axd["B"], axd["C"], axd["D"]),
    ('clst_l1', 'clst_l2', 'clst_l3'),
    )):
    # vec = adata3v2.obs[col][row_order].values
    sns.heatmap(
        adata3v3.obs[[col]].rename(columns={col: col.replace('clst_l', 'L')}).iloc[row_order].T,
        xticklabels=False,
        cmap='jet',
        cbar_kws=dict(aspect=5, label='', ticks=[]),
        ax=ax,
    )
    ax.set_ylabel('')
    # positions = np.cumsum(np.unique(vec, return_counts=True)[1])
    # ax.hlines(positions, xmin=-0.2, xmax=1, linewidth=1, linestyle='-', color='k', clip_on=False)
    # for i, pos in enumerate(positions):
    #     ax.text(0, pos, i+1, color='white', fontsize=15)

ax = axd["E"]
g = sns.heatmap(
    mat_prjs.iloc[row_order].T,
    xticklabels=False,
    cmap='Greys',
    cbar_kws=dict(aspect=5, ),
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points

ax = axd["F"]
g = sns.heatmap(
    adata3v3.obs[['y']].iloc[row_order].T, 
    xticklabels=False,
    cmap='Greys',
    cbar_kws=dict(aspect=5, label='um'),
    # vmin=0,
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points

# pseudotime ordering

In [None]:
import scanpy as sc
from scipy.stats import spearmanr

In [None]:
pca = PCA(n_components=10)
pcs = pca.fit_transform(mat_ftrs.values)
pcs.shape

In [None]:
plt.plot(pca.explained_variance_ratio_, '-o')

In [None]:
df_p2v3['pc1'] = pcs[:,0]
df_p2v3['pc2'] = pcs[:,1]
df_p2v3['pc3'] = pcs[:,2]
df_p2v3['pc4'] = pcs[:,3]

In [None]:
adata3v3.obsm['X_pca'] = pcs 
sc.pp.neighbors(adata3v3, n_neighbors=30, use_rep='X_pca')
sc.tl.diffmap(adata3v3)
adata3v3.uns['iroot'] = np.argmin(-pcs[:,0]-pcs[:,1])
sc.tl.dpt(adata3v3)

df_p2v3['ptime'] = adata3v3.obs['dpt_pseudotime'].values

In [None]:
fig, ax = plt.subplots()
sns.scatterplot(data=df_p2v3, x='pc1', y='pc2', hue='clst_l2', palette=palette2, s=5, edgecolor='none', ax=ax)
ax.legend(bbox_to_anchor=(1,1))
ax.grid(False)
plt.show()

fig, ax = plt.subplots()
sns.scatterplot(data=df_p2v3, x='pc3', y='pc4', hue='clst_l2', palette=palette2, s=5, edgecolor='none', ax=ax)
ax.legend(bbox_to_anchor=(1,1))
ax.grid(False)
plt.show()

In [None]:
fig, ax = plt.subplots()
sns.scatterplot(data=df_p2v3, x='pc1', y='pc2', hue='ptime', s=5, edgecolor='none', ax=ax)
ax.legend(bbox_to_anchor=(1,1))
ax.grid(False)
plt.show()

In [None]:
fig, ax = plt.subplots()
sns.scatterplot(data=df_p2v3, x='pc1', y='pc2', hue='y', s=5, edgecolor='none', ax=ax)
ax.legend(bbox_to_anchor=(1,1))
ax.grid(False)
plt.show()

In [None]:
x = 'y'
y = 'ptime'

fig, ax = plt.subplots()
a = df_p2v3[x] 
b = df_p2v3[y] 
r, p = spearmanr(a, b)
print(r, p)
sns.scatterplot(data=df_p2v3, x=x, y=y, hue='clst_l2', palette=palette2, s=5, edgecolor='none', ax=ax)
ax.set_title(f'Spearman r={r:.2f}, p={p:.1e}')
ax.legend(bbox_to_anchor=(1,1))
ax.grid(False)
plt.show()

In [None]:
for x in ['x', 'z']:
    y = 'ptime'

    fig, ax = plt.subplots()
    a = df_p2v3[x] 
    b = df_p2v3[y] 
    r, p = spearmanr(a, b)
    print(r, p)
    sns.scatterplot(data=df_p2v3, x=x, y=y, hue='clst_l2', palette=palette2, s=5, edgecolor='none', ax=ax)
    ax.set_title(f'Spearman r={r:.2f}, p={p:.1e}')
    ax.legend(bbox_to_anchor=(1,1))
    ax.grid(False)
    plt.show()

In [None]:
x = 'y'
y = 'ptime'
hue = 'r4v3_c1_nrm'

fig, ax = plt.subplots()
a = df_p2v3[x] 
b = df_p2v3[y] 
r, p = spearmanr(a, b)
print(r, p)
g = sns.scatterplot(data=df_p2v3, x=x, y=y, c=df_p2v3[hue], cmap='gray_r', # hue=hue, 
                vmin=-10, vmax=100,
                s=5, edgecolor='none', ax=ax)
# ax.set_title(f'Spearman r={r:.2f}, p={p:.1e}')
ax.legend(bbox_to_anchor=(1,1))
ax.grid(False)
plt.show()

In [None]:
x = 'y'
y = 'ptime'
fig, axs = plt.subplots(5,4,figsize=(4*4,5*4), sharex=True, sharey=True)
for col, ax in zip(var_idx, axs.flat):
    hue = col+'_nrm'
    c = df_p2v3[hue].values
    vmax = np.percentile(c,95) 
    vmin = -0.1*vmax # np.percentile(c,0) 
    g = sns.scatterplot(data=df_p2v3, x=x, y=y, c=c, cmap='gray_r', # hue=hue, 
                    vmin=vmin, vmax=vmax, s=5, edgecolor='none', ax=ax)
    ax.set_title(col+'_'+var_i2n[col])
    ax.grid(False)
    sns.despine(ax=ax)
plt.show()

In [None]:
x = 'y'
y = 'ptime'
fig, axs = plt.subplots(5,4,figsize=(4*4,5*4), sharex=True, sharey=True)
for col, ax in zip(var_order_manual.index, axs.flat):
    hue = col+'_nrm'
    c = df_p2v3[hue].values
    vmax = np.percentile(c,95) 
    vmin = -0.1*vmax # np.percentile(c,0) 
    g = sns.scatterplot(data=df_p2v3, x=x, y=y, c=c, cmap='gray_r', # hue=hue, 
                    vmin=vmin, vmax=vmax, s=5, edgecolor='none', ax=ax)
    ax.set_title(col+'_'+var_i2n[col])
    ax.grid(False)
    sns.despine(ax=ax)
plt.show()

In [None]:
row_order = np.argsort(df_p2v3['ptime'].values)

In [None]:
col_order = basicu.get_index_from_array(mat_ftrs.columns, var_order_manual.index.values[:-1])

In [None]:

fig = plt.figure(figsize=(20,8))
axd = fig.subplot_mosaic("A\n"*12+"B\nC\nD\nE\nF")
ax = axd["A"]
col_idx = mat_ftrs.columns[col_order].values.astype(str)
col_names = var_i2n[col_idx].values.astype(str)
yticklabels = np.char.add(np.char.add(col_names, ' '), col_idx)

sns.heatmap(
    mat_ftrs.iloc[row_order,col_order].T,
    xticklabels=False,
    yticklabels=yticklabels,
    cmap='coolwarm',
    cbar_kws=dict(shrink=0.5, label='zscore log2(1+norm expr.)'),
    vmax=3, vmin=-3,
    ax=ax,
)
ax.set_ylabel('Cells')
# ax.set_xticklabels()

for i, (ax, col) in enumerate(zip(
    (axd["B"], axd["C"], axd["D"]),
    ('clst_l1', 'clst_l2', 'clst_l3'),
    )):
    # vec = adata3v2.obs[col][row_order].values
    sns.heatmap(
        adata3v3.obs[[col]].rename(columns={col: col.replace('clst_l', 'L')}).iloc[row_order].T,
        xticklabels=False,
        cmap='jet',
        cbar_kws=dict(aspect=5, label='', ticks=[]),
        ax=ax,
    )
    ax.set_ylabel('')
    # positions = np.cumsum(np.unique(vec, return_counts=True)[1])
    # ax.hlines(positions, xmin=-0.2, xmax=1, linewidth=1, linestyle='-', color='k', clip_on=False)
    # for i, pos in enumerate(positions):
    #     ax.text(0, pos, i+1, color='white', fontsize=15)

ax = axd["E"]
g = sns.heatmap(
    mat_prjs.iloc[row_order].T,
    xticklabels=False,
    cmap='Greys',
    cbar_kws=dict(aspect=5, ),
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points

ax = axd["F"]
g = sns.heatmap(
    adata3v3.obs[['y']].iloc[row_order].T, 
    xticklabels=False,
    cmap='Greys',
    cbar_kws=dict(aspect=5, label='um'),
    # vmin=0,
    ax=ax,
)
ax.set_ylabel('')
colorbar = g.collections[0].colorbar
colorbar.ax.tick_params(labelsize=10)  # Set fontsize to 12 points