In [None]:
import numpy as np
import pandas as pd
from scipy.stats import zscore
from scipy import sparse 
from scipy.spatial import ConvexHull
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
from py_pcha import PCHA

In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections
from natsort import natsorted
import re

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu
from scroutines.gene_modules import GeneModules  


In [None]:
from scipy import stats
from matplotlib.ticker import MaxNLocator

def plot(x, y, aspect_equal=False, density=False, hue='type'):
    n = len(cases)
    fig, axs = plt.subplots(1,n,figsize=(4*n,4*1), sharex=True, sharey=True)
    for i, (ax, cond) in enumerate(zip(axs.flat, cases)):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                        rasterized=True,
                       )
        if hue == 'type':
            sns.scatterplot(data=res[res['cond']==cond].sample(frac=1, replace=False),
                            x=x, y=y, 
                            hue='type',
                            hue_order=list(palette_types.keys()),
                            palette=palette_types,
                            s=3, edgecolor='none', 
                            legend=False,
                            ax=ax,
                            rasterized=True,
                           )
        else:
            sns.scatterplot(data=res[res['cond']==cond].sample(frac=1, replace=False),
                            x=x, y=y, 
                            hue='rep',
                            # hue_order=list(palette_types.keys()),
                            # palette=palette_types,
                            s=3, edgecolor='none', 
                            legend=False,
                            ax=ax,
                            rasterized=True,
                           )
            
        if density:
            sns.histplot(data=res[res['cond']==cond],
                            x=x, y=y, 
                            legend=False,
                            ax=ax,
                            rasterized=True,
                           )
        sns.despine(ax=ax)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=3))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
        if aspect_equal:
            ax.set_aspect('equal')
        if i > 0:
            ax.set_xlabel('')
            ax.set_ylabel('')
    return fig
    # plt.show()
    
def plot2(x, y, hue=None, aspect_equal=False, s=10, vmin=-2.5, vmax=2.5, vminp=None, vmaxp=None, cmap='coolwarm'):
    n = len(cases)
    fig, axs = plt.subplots(1,n,figsize=(4*n,4*1), sharex=True, sharey=True)
    fig.suptitle(hue, x=0, ha='left')
    for i, (ax, cond) in enumerate(zip(axs.flat, cases)):
        ax.set_title(cond)
        show = res[res['cond']==cond]
        
        if vminp is not None:
            vmin = np.percentile(show[hue], vminp)
        if vmaxp is not None:
            vmax = np.percentile(show[hue], vmaxp)
            
        if hue:
            # order = np.argsort(show[hue].values)
            ax.scatter(
                       # show[x].iloc[order], show[y].iloc[order], c=show[hue].iloc[order], 
                       show[x], show[y], c=show[hue], 
                       cmap=cmap,
                       vmin=vmin, vmax=vmax,
                       s=s, 
                       edgecolor='none', 
                       rasterized=True,
                      )
        else:
            r, p = stats.spearmanr(show[x], show[y])
            ax.scatter(show[x], show[y],  
                       s=s, 
                       edgecolor='none', 
                       rasterized=True,
                      )
            ax.set_title(f'{cond}\n r={r:.2f}')
        sns.despine(ax=ax)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=3))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
        if aspect_equal:
            ax.set_aspect('equal')
        if i == 0: 
            ax.set_xlabel(x)
            ax.set_ylabel(y)
        else:
            ax.set_xlabel('')
            ax.set_ylabel('')
        ax.grid(False)
        ax.set_xticks([])
        ax.set_yticks([])
    fig.tight_layout()
    
    return fig
    # plt.show()

In [None]:
import numpy as np
import pandas as pd
from scipy.stats import zscore
from scipy import sparse 
from scipy.spatial import ConvexHull
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
from py_pcha import PCHA

from scipy.stats import gaussian_kde

In [None]:
def norm(x, depths):
    """
    Arguments: 
        x - cell by gene count matrix
        depths - sequencing depth per cell
        
    Output:
        xn - normalized count matrix

    This function takes raw counts as the input, and does the following steps sequencially.
         1. size normalization (CP10k) 
         2. log1p normalization (base 2 - log2(1+CP10k))
         3. zscore per gene  
    """

    xn = x/depths.reshape(-1,1)*1e4
    xn = np.log2(1+xn)
    xn = zscore(xn, axis=0)

    if np.any(np.isnan(xn)):
        print('Warning: the normalized matrix contains nan values. Check input.')
        xn = np.nan_to_num(xn, 0)

    return xn

In [None]:
def proj(x_norm, ndim, method='PCA'):
    """
    Arguments: 
        x_norm - normalized cell by gene feature matrix
        ndim   - number of dimensions

    Output:
        x_proj - a low-dimensional representation of `x_norm` 

    Here we only implemented PCA - a common projection method widely used, including by
    Adler et al. 2019 and Xie et al. 2024 for the Achetypal Analysis of scRNA-seq data.

    In principle, one can also choose to use other projection methods as needed.
    """

    if method == 'PCA':
        x_proj = PCA(n_components=ndim).fit_transform(x_norm)
    else:
        raise ValueError('methods other than PCA are not implemented...')

    return x_proj


In [None]:
def pcha(X, noc=3, delta=0, **kwargs):
    """
    """
    XC, S, C, SSE, varexpl = PCHA(X, noc=noc, delta=delta, **kwargs)
    XC = np.array(XC)
    XC = XC[:,np.argsort(XC[0])] # assign an order according to x-axis 
    return XC 

In [None]:
def downsamp(x, which='cell', p=0.8, seed=None):
    """
    Arguments:
        x - cell by gene matrix
        which - downsample cells (rows) or genes (columns)
        p - fraction of cells/genes to keep - should be a value between ~ [0,1]
    """
    n0, n1 = x.shape
    
    rng = np.random.default_rng(seed=seed)
    
    if which in [0, 'cell', 'row']:
        return x[rng.random(n0)<p, :]
    elif which in [1, 'gene', 'col', 'column']:
        return x[:, rng.random(n1)<p]
    else:
        raise ValueError('choose from cell or gene')

In [None]:
def shuffle_rows_per_col(x, seed=None):
    """
    Arguments:
       x - cell by gene matrix
       seed - a random seed for reproducibility
    
    shuffles entries across rows (cells) independently for each col (gene)
    """
    rng = np.random.default_rng(seed=seed)
    x_shuff = rng.permuted(x, axis=0)
    return x_shuff

In [None]:
def plot_archetype(ax, aa, fmt='--o', color='k', **kwargs):
    """
    """
    ax.plot(aa[0].tolist()+[aa[0,0]], aa[1].tolist()+[aa[1,0]], fmt, color=color, **kwargs)

In [None]:
def get_t_ratio(xp, aa):
    """
    Arguments:
     xp -- projected matrix (cell by 2)
     aa -- inferred archetypes (2 by noc)
     note that this function only works for 2-dimensional space only
     
    Return: 
     t-ratio - ratio of areas (convex hull vs PCH)
     
    """
    assert xp.shape[1] == aa.shape[0] == 2
    
    ch_area  = ConvexHull(xp).volume
   
    x = aa[0]
    y = aa[1]
    pch_area = 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1)))

    return ch_area/pch_area 

In [None]:
class SingleCellArchetype():
    """
    """
    def __init__(self, x, depths, types):
        """
        Arguments: 
            x - cell by gene count matrix
            depths - sequencing depth per cell
            types  - cell type labels per cell
        
        Initiate the SingleCellArchetype object

        """
        
        # input
        self.x = x
        self.depths = depths
        self.types = types
        
        # cell type label
        types_idx, types_lbl = pd.factorize(types, sort=True)
        
        self.types_idx = types_idx
        self.types_lbl = types_lbl 
        
        # normalize
        self.xn = norm(self.x, self.depths)
        
        # feature matrix 
        self.xf = None 
        return 
        
    def setup_feature_matrix(self, method='data'):
        """
        """
        if method == 'data': 
            self.xf = self.xn
            print('use data')
            return  
        
        elif method == 'gshuff':
            # shuffle gene expression globally across all cells
            self.xf = shuffle_rows_per_col(self.xn)
            print('use shuffled data')
            return
            
        elif method == 'tshuff':
            # shuff each gene across cells independently - internally for each type A,B,C
            xn = self.xn
            xn_tshuff = xn.copy()
            
            types_lbl = self.types_lbl
            types_idx = self.types_idx
            for i in range(len(types_lbl)):
                xn_tshuff[types_idx==i] = shuffle_rows_per_col(xn[types_idx==i])
            self.xf = xn_tshuff
            print('use per-type shuffled data')
            return
        else:
            raise ValueError('choose from (data, gshuff, tshuff)')
    
    def proj_and_pcha(self, ndim, noc, **kwargs):
        """
        """
        xp = proj(self.xf, ndim)
        aa = pcha(xp.T, noc=noc, **kwargs)
        
        self.xp = xp
        self.aa = aa
        return (xp, aa)
        
    def downsamp_proj_pcha(self, ndim, noc, nrepeats=10, which='cell', p=0.8, **kwargs): 
        """
        """
        aa_dsamps = []
        for i in range(nrepeats):
            xn_dsamp = downsamp(self.xf, which=which, p=p)
            xp_dsamp = proj(xn_dsamp, ndim)
            aa_dsamp = pcha(xp_dsamp.T, noc=noc, **kwargs)
            aa_dsamps.append(aa_dsamp)
            
        return aa_dsamps
    
    def t_ratio_test(self, ndim, noc, nrepeats=10, **kwargs): 
        """
        this only work for 2-dimensional space for now
        """
        assert ndim == 2
        
        self.setup_feature_matrix(method='data')
        xp, aa = self.proj_and_pcha(ndim, noc)
        t_ratio = get_t_ratio(xp, aa)
        
        t_ratios_shuff = []
        for i in range(nrepeats):
            self.setup_feature_matrix(method='gshuff')
            xp_shuff, aa_shuff = self.proj_and_pcha(ndim, noc)
            t_ratio_shuff = get_t_ratio(xp_shuff, aa_shuff)
            t_ratios_shuff.append(t_ratio_shuff)
            
        pvalue = (np.sum(t_ratio > t_ratios_shuff)+1)/nrepeats
        
        return t_ratio, t_ratios_shuff, pvalue

In [None]:
outdirfig = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures/250407"
!mkdir -p $outdirfig

# load gene annotation

In [None]:
gene_modules = GeneModules()
anno, color, gene_styled = gene_modules.check_genes('Cdh13')
print("\t".join(anno))
print("\t".join(color))
print("\t".join(gene_styled))

In [None]:
# use those 286 genes
df = pd.read_csv("../../data/cheng21_cell_scrna/res/L23-ABC-genes-n288-n286unq-annot.csv")
genes_l23 = df['gene'].astype(str).values
genes_l23a = df[df['P17on']=='A']['gene'].astype(str).values
genes_l23b = df[df['P17on']=='B']['gene'].astype(str).values
genes_l23c = df[df['P17on']=='C']['gene'].astype(str).values

print(genes_l23a.shape, genes_l23b.shape, genes_l23c.shape)
genes_grp = df['P17on'].astype(str).values
assert len(genes_l23) == len(np.unique(genes_l23))

genes_l23.shape

In [None]:
scores_abc = pd.read_csv("/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/scores_l23abc.csv", index_col=0)
scores_abc

# load data 

In [None]:
# f_anndata_in  = "../../data/v1_multiome/superdupermegaRNA_hasraw_multiome_l23.h5ad"# _multiome_l23.h5ad"#  L23_allmultiome_raw.h5ad"
f_anndata_in  = "../../data/v1_multiome/L23_allmultiome_raw.h5ad"
adata = anndata.read(f_anndata_in)
adata

In [None]:
sample_labels = ["-".join(cell.split(' ')[0].split('-')[2:]).replace('-2023', '') for cell in adata.obs.index]
time_labels = [s[:-1].replace('DR', '') for s in sample_labels]

adata.obs['n_counts'] = adata.obs['nCount_RNA']
adata.obs['sample'] = sample_labels
adata.obs['time']   = time_labels

uniq_samples = natsorted(np.unique(sample_labels))
uniq_times = natsorted(np.unique(time_labels))

nr_samples = [s for s in uniq_samples if "DR" not in s]
dr_samples = [s for s in uniq_samples if "DR" in s]
print(uniq_times)
print(nr_samples)
print(dr_samples)

# adata.obs['sample'] = sample_labels

In [None]:
# select samples
adata.obs['cond'] = adata.obs['sample'].apply(lambda x: x[:-1]) # .unique()

# # remove mitocondria genes
# adata = adata[:,~adata.var['features'].str.contains(r'^mt-')]
# adata = adata[:,~adata.var['features'].str.contains(r'Xist')]

# # select
# adata.obs['sample'].unique(), adata.obs['cond'].unique()

In [None]:
f_in1 = "../../data/v1_multiome/L23_allmultiome_raw_pca_early.npy"
f_in2 = "../../data/v1_multiome/L23_allmultiome_raw_pca_later.npy"
f_in3 = "../../data/v1_multiome/L23_allmultiome_raw_pca_all.npy"

pcs_early = np.load(f_in1)
pcs_later = np.load(f_in2)
pcs_all   = np.load(f_in3)

# set up for plotting

In [None]:
# res0 = pd.DataFrame(adata.layers['zlognorm'][...], columns=genes)
res0 = pd.DataFrame() 
res0['cond'] = adata.obs['cond'].values
res0['type'] = np.char.add('c', adata.obs['Type'].values.astype(str))
res0['samp'] = adata.obs['sample'].values
res0['rep']  = res0['samp'].apply(lambda x: x[-1])

res1 = pd.DataFrame(pcs_early, columns=np.char.add("pcs_early", ((1+np.arange(pcs_early.shape[1])).astype(str))))
res2 = pd.DataFrame(pcs_later, columns=np.char.add("pcs_later", ((1+np.arange(pcs_later.shape[1])).astype(str))))
res3 = pd.DataFrame(pcs_all  , columns=np.char.add("pcs_all",   ((1+np.arange(pcs_all.shape[1])).astype(str))))

res5 = scores_abc.reindex(adata.obs.index.values).reset_index().drop('cond', axis=1)
res = pd.concat([res0, res1, res2, res3, res5], axis=1)

In [None]:
allcolors = sns.color_palette('tab20b', 20)
allcolors

In [None]:
allcolors2 = sns.color_palette('tab10', 20)
allcolors2

In [None]:
palette = collections.OrderedDict({
     "P6": allcolors[1],
     "P8": allcolors[0],
    "P10": allcolors[4+2],
    "P12": allcolors[4+1],
    "P14": allcolors[4+0],
    
    "P17": allcolors[8+2],
    "P21": allcolors[8+0],
    
    "P12DR": allcolors[8+2],
    "P14DR": allcolors[8+0],
    "P17DR": allcolors[8+0],
    "P21DR": allcolors[8+0],
    
})
cases = np.array(list(palette.keys()))

cond_order_dict = {
    'P6':  0,
    'P8':  1,
    'P10': 2,
    'P12': 3,
    'P14': 4,
    'P17': 5,
    'P21': 6,
    
    'P12DR': 7,
    'P14DR': 8,
    'P17DR': 9,
    'P21DR': 10,
}
unq_conds = np.array(list(cond_order_dict.keys()))
adata.obs['cond_order'] = adata.obs['cond'].apply(lambda x: cond_order_dict[x])

palette_types = collections.OrderedDict({
    'L2/3_A': allcolors2[0],
    'L2/3_B': allcolors2[1],
    'L2/3_C': allcolors2[2],
})             

palette_types = {
    'c14': 'C0', 
    'c18': 'C1',
    'c16': 'C2', 
    
    'c13': 'C0', 
    'c15': 'C1', 
    'c17': 'C2',
}
type_order = [key for key, val in palette_types.items()]
type_order

In [None]:
palette_time = sns.cubehelix_palette(n_colors=7, start=.5, rot=-.5)
palette_time

# Plot A vs C genes aligning cells along early vs late PCs

In [None]:
# plot('pcs_all1', 'pcs_all2', aspect_equal=True)
# plt.show()
# plot('pcs_early1', 'pcs_early2', aspect_equal=True)
# plt.show()
# plot('pcs_later1', 'pcs_later2', aspect_equal=True)
# plt.show()

# Archetypes

In [None]:
todo_conds = [
    'P12DR', 'P14DR', 'P17DR', 'P21DR',
    'P6', 'P8', 'P10', 'P12', 'P14', 'P17', 'P21', 
]
todo_samps = [
    'P12DRa', 'P12DRb',
    'P14DRa', 'P14DRb',
    'P17DRa', 'P17DRb',
    'P21DRa', 'P21DRb',
    'P6a', 'P6b', 'P6c', 
    'P8a', 'P8b', 'P8c', 
    'P10a', 'P10b', 
    'P12a', 'P12b', 'P12c', 
    'P14a', 'P14b',
    'P17a', 'P17b', 
    'P21a', 'P21b', 
]

In [None]:
res['scores_c-a'] = res['scores_c'] - res['scores_a']

In [None]:
ndim = 2
noc = 3
delta = 0.2 # 0

aas = []

for cond in todo_conds:
    print(cond)
    xp = res[res['cond']==cond][['pcs_later1', 'pcs_later2']].values # pcs_later[:,:2]
    aa = pcha(xp.T, noc=noc, delta=delta)
    aas.append(aa)
    

In [None]:
# plot
n = len(todo_conds)
fig, axs = plt.subplots(1,n,figsize=(8*n,6), sharex=True, sharey=True)
for i, cond in enumerate(todo_conds):
    ax = axs[i]
    aa = aas[-1]
    xp = res[res['cond']==cond][['pcs_later1', 'pcs_later2']].values # pcs_later[:,:2]
    
    ax.scatter(xp[:,0], xp[:,1], s=2) #  c=types_colorvec, s=2)
    plot_archetype(ax, aa, fmt='-o', color='k', zorder=0)
    ax.set_title(cond)

    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.grid(False)
    
plt.show()

In [None]:
xbins = np.linspace(-6,6,2*12+1)
ybins = np.linspace(-5,5,2*10+1)

xbins, ybins

In [None]:
from matplotlib.colors import LinearSegmentedColormap

colors_a = [(0.0, 'white'), (1.0, 'C0')]      
colors_b = [(0.0, 'white'), (1.0, 'C1')]      
colors_c = [(0.0, 'white'), (1.0, 'C2')]      
colors_n = [(0.0, 'white'), (1.0, 'lightgray')]      
cmap_a = LinearSegmentedColormap.from_list('cmap_a', colors_a)
cmap_b = LinearSegmentedColormap.from_list('cmap_b', colors_b)
cmap_c = LinearSegmentedColormap.from_list('cmap_c', colors_c)
cmap_n = LinearSegmentedColormap.from_list('cmap_n', colors_n)

colors_ac = [
    np.array(cmap_a(1.0)),
    0.5*np.array(cmap_a(1.0))+0.5*np.array(cmap_n(1.0)),
    np.array(cmap_n(1.0)),
    0.5*np.array(cmap_c(1.0))+0.5*np.array(cmap_n(1.0)),
    np.array(cmap_c(1.0)),
]
cmap_ac = LinearSegmentedColormap.from_list('cmap_ac', colors_ac)

colors_b2 = [
    'lightgray',
    cmap_b(0.2),
    cmap_b(1.0),
]
cmap_b2 = LinearSegmentedColormap.from_list('cmap_b2', colors_b2)

colors_nrdr = [
    'black',
    'white',
    'C1',
]
cmap_nrdr = LinearSegmentedColormap.from_list('cmap_nrdr', colors_nrdr)

In [None]:
genes_viz = ['scores_b','scores_c-a']
cmaps = [
    cmap_b2,
    cmap_ac,
]
for gn, cmap in zip(genes_viz, cmaps):
    
    vmin = np.percentile(res[gn],  5)
    vmax = np.percentile(res[gn], 95)
    
    fig = plot2('pcs_later1', 'pcs_later2', hue=gn, aspect_equal=True, s=20, vmin=vmin, vmax=vmax,cmap=cmap) #vminp=0, vmaxp=98)
    axs = fig.get_axes()
    for ax in axs:
        plot_archetype(ax, aa, fmt='--', color='gray', zorder=2)
        # ax.scatter(aa[0,0], aa[1,0], color='C0', zorder=2)
        # ax.scatter(aa[0,1], aa[1,1], color='C1', zorder=2)
        # ax.scatter(aa[0,2], aa[1,2], color='C2', zorder=2)
    
    output = os.path.join(outdirfig, f'pc12_{gn}.pdf')
    powerplots.savefig_autodate(fig, output)
    plt.show()
    
    # break
    

In [None]:
# genes_viz = ['scores_a', 'scores_b', 'scores_c', 'scores_c-a']
# for gn in genes_viz:
fig = plot2('scores_c-a', 'scores_b', aspect_equal=True, s=20) # , vmin=None, vmax=None, vminp=0, vmaxp=98)
plt.show()
    
    # break

# variance along A-C axis vs B-axis

In [None]:
# define A-C (t1) and B (t2) axis

t1 = aa[:,2]-aa[:,0]
t1 = t1/np.sqrt(t1.dot(t1))
t2 = np.array([t1[1], -t1[0]])

print(np.sum(t1**2), t1.dot(t2))

In [None]:
fjg, ax = plt.subplots(figsize=(2,2))
plot_archetype(ax, aa, fmt='--', color='gray', zorder=2)
ax.plot([aa[0,0],aa[0,0]+t1[0]], [aa[1,0], aa[1,0]+t1[1]])
ax.plot([aa[0,0],aa[0,0]+t2[0]], [aa[1,0], aa[1,0]+t2[1]])
ax.set_aspect('equal')

In [None]:
res['pcs_later_t1'] = res[['pcs_later1', 'pcs_later2']].values.dot(t1)
res['pcs_later_t2'] = res[['pcs_later1', 'pcs_later2']].values.dot(t2)


In [None]:
var_mtx = res[['pcs_later_t1', 'pcs_later_t2', 
               'pcs_early1',
               'cond', 'samp']].groupby(['cond']).std(numeric_only=True) 
var_mtx = np.power(var_mtx, 2).reindex(todo_conds)
var_mtx

In [None]:
var_mtx_rep = res[['pcs_later_t1', 'pcs_later_t2', 
                   'pcs_early1',
                   'cond', 'samp']].groupby(['samp']).std(numeric_only=True)
var_mtx_rep = np.power(var_mtx_rep, 2).reindex(todo_samps)
var_mtx_rep

In [None]:
times = np.array([6,8,10,12,14,17,21])
todo_conds_t = np.array([int(re.sub(r'[a-zA-Z]', '', a)) for a in todo_conds])
todo_samps_t = np.array([int(re.sub(r'[a-zA-Z]', '', a)) for a in todo_samps])
print(times)
print(todo_conds_t)
print(todo_samps_t)

In [None]:

scale1 = var_mtx['pcs_later_t1'].max()
scale2 = var_mtx['pcs_later_t2'].max()


fig, axs = plt.subplots(1,2,figsize=(3*2,1*3), sharex=True, sharey=True)
ax = axs[0]
ax.plot(todo_conds_t[4:], var_mtx['pcs_later_t1'].values[4:]/scale1,       color='C1')
ax.plot(todo_conds_t[:4], var_mtx['pcs_later_t1'].values[:4]/scale1,       color='k')

ax.plot(todo_samps_t[8:], var_mtx_rep['pcs_later_t1'].values[8:]/scale1,       'o', markersize=5, fillstyle='none', color='C1')
ax.plot(todo_samps_t[:8], var_mtx_rep['pcs_later_t1'].values[:8]/scale1,       's', markersize=5, fillstyle='none', color='k')

sns.despine(ax=ax)
ax.grid(False)
ax.set_xlabel('Postnatal day (P)')
ax.set_ylabel('norm. variance')
ax.set_title('A-C axis')

ax = axs[1]
ax.plot(todo_conds_t[4:], var_mtx['pcs_later_t2'].values[4:]/scale1,       color='C1', label='NR')
ax.plot(todo_conds_t[:4], var_mtx['pcs_later_t2'].values[:4]/scale1,       color='k',  label='DR')

ax.plot(todo_samps_t[8:], var_mtx_rep['pcs_later_t2'].values[8:]/scale1,       'o', markersize=5, fillstyle='none', color='C1')
ax.plot(todo_samps_t[:8], var_mtx_rep['pcs_later_t2'].values[:8]/scale1,       's', markersize=5, fillstyle='none', color='k')

sns.despine(ax=ax)
ax.grid(False)
ax.set_ylim(bottom=0)
ax.set_xticks(times)
ax.legend(bbox_to_anchor=(1,1))
ax.set_title('B axis')

output = os.path.join(outdirfig, "ACvsB_var_vs_time.pdf")
powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
# fit KDE

kdes = []
kde_values = []
n = len(todo_conds)
for i, cond in enumerate(todo_conds):
    xp = res[res['cond']==cond][['pcs_later1', 'pcs_later2']].values # pcs_later[:,:2]
    
    # Fit KDE
    kde = gaussian_kde(xp.T)
    kdes.append(kde)
    kde_values.append(kde(xp.T))
    

In [None]:
xbins = np.arange(-20, 20, 1)
ybins = np.arange(-20, 20, 1)
xbins, ybins

In [None]:
# fit KDE

hist_values = []
n = len(todo_conds)
for i, cond in enumerate(todo_conds):
    xp = res[res['cond']==cond][['pcs_later1', 'pcs_later2']].values # pcs_later[:,:2]
    
    # Fit KDE
    hist, _, _ = np.histogram2d(xp[:,0], xp[:,1], bins=(xbins, ybins))
    hist = hist/np.sum(hist)
    hist_values.append(hist)
    
    # kde = gaussian_kde(xp.T)
    # kdes.append(kde)
    # kde_values.append(kde(xp.T))
    

In [None]:
# plot
n = len(todo_conds)
fig, axs = plt.subplots(1,n,figsize=(4*n,3), sharex=True, sharey=True)
for i, cond in enumerate(todo_conds):
    ax = axs[i]
    # kde = kdes[i]
    kde_val = kde_values[i]
    aa = aas[-1]
    
    xp = res[res['cond']==cond][['pcs_later1', 'pcs_later2']].values # pcs_later[:,:2]
    
    # calc KDE
    c = kde_val
    
    ax.scatter(xp[:,0], xp[:,1], c=c, s=10, cmap='rocket_r', rasterized=True) #  c=types_colorvec, s=2)
    
    plot_archetype(ax, aa, fmt='--', color='gray', zorder=2)
    ax.scatter(aa[0,0], aa[1,0], color='C0', zorder=2)
    ax.scatter(aa[0,1], aa[1,1], color='C1', zorder=2)
    ax.scatter(aa[0,2], aa[1,2], color='C2', zorder=2)
    
    ax.set_title(cond)

    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.grid(False)
    # break
    
output = os.path.join(outdirfig, f'pc12_density.pdf')
powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
# Define grid range and resolution
x = np.arange(-20, 20, 1)  # 5 points from 0 to 1
y = np.arange(-20, 20, 1)

# Create a 2D mesh grid
X, Y = np.meshgrid(x, y)
X = X.reshape(-1,)
Y = Y.reshape(-1,)
XY = np.vstack([X,Y])
kde_base = kdes[-1](XY)

# plot
n = len(todo_conds[4:-1])
fig, axs = plt.subplots(1,n,figsize=(4*n,3), sharex=True, sharey=True)
for i, cond in enumerate(todo_conds[4:-1]):
    ax = axs[i]
    aa = aas[-1]
    
    
    kde_val = kdes[4+i](XY) - kde_base
    
    # calc KDE
    c = kde_val
    
    vmax = np.percentile(c,99)
    vmin = np.percentile(c, 1)
    vmax = np.max([np.abs(vmax), np.abs(vmin)])
    vmin = -vmax
    
    g = ax.scatter(XY[0], XY[1], c=c, s=16, 
                   marker='s', edgecolor='none',
                   cmap='coolwarm', 
                   vmax=vmax, vmin=vmin, rasterized=True)#  c=types_colorvec, s=2)
    # fig.colorbar(g)
    
    plot_archetype(ax, aa, fmt='--', color='gray', zorder=2)
    ax.scatter(aa[0,0], aa[1,0], color='C0', zorder=2)
    ax.scatter(aa[0,1], aa[1,1], color='C1', zorder=2)
    ax.scatter(aa[0,2], aa[1,2], color='C2', zorder=2)
    
    ax.set_title(f'{cond}-P21')

    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.grid(False)
    # break
    
axs[0].set_xlabel('PC1')
axs[0].set_ylabel('PC2')
    
# output = os.path.join(outdirfig, f'pc12_density.pdf')
# powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
# Define grid range and resolution
x = np.arange(-20, 20, 1)  # 5 points from 0 to 1
y = np.arange(-20, 20, 1)

# Create a 2D mesh grid
X, Y = np.meshgrid(x, y)
X = X.reshape(-1,)
Y = Y.reshape(-1,)
XY = np.vstack([X,Y])
toplot_conds = ['P6', 'P10', 'P14', 'P17', 'P21']
toplot_indices = [4,6,8,9,10] 

# plot
fig, axs = plt.subplots(1,4,figsize=(4*4,3), sharex=True, sharey=True)
for i in range(4):
    ax = axs[i]
    aa = aas[-1]
    
    ind_i, ind_j = toplot_indices[i], toplot_indices[i+1]
    cond_i, cond_j = toplot_conds[i], toplot_conds[i+1]
    
    kde_val = kdes[ind_j](XY) - kdes[ind_i](XY)
    
    # calc KDE
    c = kde_val
    
    vmax = np.percentile(c,99)
    vmin = np.percentile(c, 1)
    vmax = np.max([np.abs(vmax), np.abs(vmin)])
    vmin = -vmax
    
    g = ax.scatter(XY[0], XY[1], c=c, s=16, 
                   marker='s', edgecolor='none',
                   cmap='coolwarm', 
                   vmax=vmax, vmin=vmin, rasterized=True)#  c=types_colorvec, s=2)
    
    plot_archetype(ax, aa, fmt='--', color='gray', zorder=2)
    ax.scatter(aa[0,0], aa[1,0], color='C0', zorder=2)
    ax.scatter(aa[0,1], aa[1,1], color='C1', zorder=2)
    ax.scatter(aa[0,2], aa[1,2], color='C2', zorder=2)
    
    ax.set_title(f'{cond_i}->{cond_j}')

    ax.set_aspect('equal')
    sns.despine(ax=ax)
    ax.grid(False)
    # break
    
axs[0].set_xlabel('PC1')
axs[0].set_ylabel('PC2')
    
# output = os.path.join(outdirfig, f'pc12_density.pdf')
# powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
# Define grid range and resolution
x = np.arange(-20, 20, 1)  # 5 points from 0 to 1
y = np.arange(-20, 20, 1)

# Create a 2D mesh grid
X, Y = np.meshgrid(x, y)
X = X.reshape(-1,)
Y = Y.reshape(-1,)
XY = np.vstack([X,Y])

# plot
fig, ax = plt.subplots(1,1,figsize=(4*1,3), sharex=True, sharey=True)
aa = aas[-1]

ind_i, ind_j = 10, 3 
cond_i, cond_j = 'P21NR', 'P21DR'

kde_val = kdes[ind_j](XY) - kdes[ind_i](XY)

# calc KDE
c = kde_val

vmax = np.percentile(c,98)
vmin = np.percentile(c, 2)
vmax = np.max([np.abs(vmax), np.abs(vmin)])
vmin = -vmax

g = ax.scatter(XY[0], XY[1], c=c, s=16, 
               marker='s', edgecolor='none',
               cmap='coolwarm', 
               vmax=vmax, vmin=vmin, rasterized=True)#  c=types_colorvec, s=2)

plot_archetype(ax, aa, fmt='--', color='gray', zorder=2)
ax.scatter(aa[0,0], aa[1,0], color='gray', zorder=2)
ax.scatter(aa[0,1], aa[1,1], color='gray', zorder=2)
ax.scatter(aa[0,2], aa[1,2], color='gray', zorder=2)

ax.set_title(f'{cond_i}->{cond_j}')

ax.set_aspect('equal')
sns.despine(ax=ax)
ax.grid(False)
# break

ax.set_xlabel('PC1')
ax.set_ylabel('PC2')

# output = os.path.join(outdirfig, f'pc12_density.pdf')
# powerplots.savefig_autodate(fig, output)
plt.show()

In [None]:
# Define grid range and resolution
x = np.arange(-20, 20, 1)  # 5 points from 0 to 1
y = np.arange(-20, 20, 1)

# Create a 2D mesh grid
X, Y = np.meshgrid(x, y)
X = X.reshape(-1,)
Y = Y.reshape(-1,)
XY = np.vstack([X,Y])

# plot
fig, ax = plt.subplots(1,1,figsize=(4*1,3), sharex=True, sharey=True)
aa = aas[-1]

ind_i, ind_j = 10, 3 
cond_i, cond_j = 'P21NR', 'P21DR'
# in_i, ind_j = 3, 9 
# cond_i, cond_j = 'P21DR', 'P17NR'

# kde_val = kdes[ind_j](XY) - kdes[ind_i](XY)

kde_val = hist_values[ind_j] - hist_values[ind_i] 

# calc KDE
c = kde_val

vmax = np.percentile(c,99)
vmin = np.percentile(c, 1)
vmax = np.max([np.abs(vmax), np.abs(vmin)])
vmin = -vmax

sns.heatmap(kde_val.T, 
            cmap='coolwarm', 
               vmax=vmax, vmin=vmin, rasterized=True
           )
ax.invert_yaxis()

# g = ax.scatter(XY[0], XY[1], c=c, s=16, 
#                marker='s', edgecolor='none',
#                cmap='coolwarm', 
#                vmax=vmax, vmin=vmin, rasterized=True)#  c=types_colorvec, s=2)

# plot_archetype(ax, aa, fmt='--', color='gray', zorder=2)
# ax.scatter(aa[0,0], aa[1,0], color='C0', zorder=2)
# ax.scatter(aa[0,1], aa[1,1], color='C1', zorder=2)
# ax.scatter(aa[0,2], aa[1,2], color='C2', zorder=2)

ax.set_title(f'{cond_i}->{cond_j}')

ax.set_aspect('equal')
sns.despine(ax=ax)
ax.grid(False)
# break

ax.set_xlabel('PC1')
ax.set_ylabel('PC2')

# output = os.path.join(outdirfig, f'pc12_density.pdf')
# powerplots.savefig_autodate(fig, output)
plt.show()