In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu


In [None]:
outdir_fig = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures"

In [None]:
# use those 286 genes
# df = pd.read_csv("../../data/cheng21_cell_scrna/res/candidate_genes_vincent_0503_v2.csv")
df = pd.read_csv("../../data/cheng21_cell_scrna/res/L23-ABC-genes-n288-n286unq-annot.csv")
genes_l23 = df['gene'].astype(str).values
genes_grp = df['P17on'].astype(str).values
assert len(genes_l23) == len(np.unique(genes_l23))



In [None]:
# load
adata = anndata.read("../../data/cheng21_cell_scrna/reorganized/all_L23.h5ad")
adata

In [None]:
# select
adata = adata[adata.obs['cond'].str.contains(r"NR$")]
adata.obs['cond'].unique()

# define
genes = adata.var.index.values
conds = adata.obs['cond'].values
types = adata.obs['Type'].values
samps = adata.obs['sample'].values

# organize
rename = {
    "L2/3_A": "L2/3_A",
    "L2/3_B": "L2/3_B",
    "L2/3_C": "L2/3_C",
    
    "L2/3_1": "L2/3_A",
    "L2/3_2": "L2/3_B",
    "L2/3_3": "L2/3_C",
    
    "L2/3_AB": "L2/3_A",
    "L2/3_BC": "L2/3_C",
}
adata.obs['easitype'] = adata.obs['Type'].apply(lambda x: rename[x])
adata

In [None]:
# filter genes
cond = np.ravel((adata.X>0).sum(axis=0)) > 10 # expressed in more than 10 cells
adata_sub = adata[:,cond].copy()

# counts
x = adata_sub.X
cov = adata_sub.obs['n_counts'].values
genes = adata_sub.var.index.values

# CP10k
# xn = x/cov.reshape(x.shape[0], -1)*1e4
xn = (sparse.diags(1/cov).dot(x))*1e4

# log10(CP10k+1)
xln = xn.copy()
xln.data = np.log10(xln.data+1)

adata_sub.layers['norm'] = np.array(xn.todense())
adata_sub.layers['lognorm'] = np.array(xln.todense())
adata_sub.layers['zlognorm'] = zscore(np.array(xln.todense()), axis=0)

In [None]:
# select HVGs with mean and var
nbin = 20
qth = 0.3

# min
gm = np.ravel(xn.mean(axis=0))

# var
tmp = xn.copy()
tmp.data = np.power(tmp.data, 2)
gv = np.ravel(tmp.mean(axis=0))-gm**2

# cut 
lbl = pd.qcut(gm, nbin, labels=np.arange(nbin))
gres = pd.DataFrame()
gres['name'] = genes
gres['lbl'] = lbl
gres['mean'] = gm
gres['var'] = gv
gres['ratio']= gv/gm

# select
gres_sel = gres.groupby('lbl')['ratio'].nlargest(int(qth*(len(gm)/nbin))) #.reset_index()
gsel_idx = np.sort(gres_sel.index.get_level_values(1).values)
assert np.all(gsel_idx != -1)

# # 
# l23_gidx = basicu.get_index_from_array(genes, genes_l23)
# assert np.all(l23_gidx != -1)

In [None]:
fig, axs = plt.subplots(1,2, figsize=(6*2,4), sharex=True, sharey=True)
ax = axs[0]
ax.scatter(gm, gv/gm, s=5, edgecolor='none', color='gray')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('mean expr (CP10k)')
ax.set_ylabel('var/mean')
ax.set_title(f'{len(gm):,} genes')

ax = axs[1]
ax.scatter(gm, gv/gm, s=5, edgecolor='none', color='gray')
ax.scatter(gm[gsel_idx], (gv/gm)[gsel_idx], c=lbl[gsel_idx], s=5, edgecolor='none', cmap='viridis_r', label=f'{len(gsel_idx):,} HVGs')
# ax.scatter(gm[l23_gidx], (gv/gm)[l23_gidx], s=5, facecolors='none', edgecolor='C0', label='L2/3 type genes')
# ax.scatter(gm[it_gidx], (gv/gm)[it_gidx], s=5, facecolors='none', edgecolor='C1', label='IT genes')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('mean expr (CP10k)')
ax.set_ylabel('var/mean')
ax.legend(bbox_to_anchor=(1,1))

plt.show()

In [None]:
adata = adata_sub[:,gsel_idx]
# adata = adata_sub[:,l23_gidx]
genes_sel = adata.var.index.values
lognorm = np.array(adata.layers['lognorm'])
zlognorm = np.array(adata.layers['zlognorm'])
zlognorm.shape

In [None]:
pca = PCA(n_components=10)
pcs = pca.fit_transform(zlognorm)

# ucs = UMAP(n_components=2, n_neighbors=50).fit_transform(pcs)

In [None]:
vt = pca.components_

# topgenes = genes_sel[np.argsort(np.abs(vt[3]))[::-1]]
# np.sort(np.abs(vt[3]))[::-1]

In [None]:
# # fix pc1 to make sure a < c:
# pc1 = pcs[:,0]
# pc_types, unq_types = basicu.group_mean(pc1.reshape(-1,1), types)
# a = pc_types[0,0]
# c = pc_types[-1,0]
# if a > c:
#     pcs[:,0] = -pcs[:,0]

In [None]:
res = pd.DataFrame(pcs, columns=np.char.add("PC", ((1+np.arange(pcs.shape[1])).astype(str))))
res['cond'] = conds
res['type'] = types
res['samp'] = samps


# res2 = pd.DataFrame(zlognorm, columns=genes_sel)
# res = pd.concat([res, res2], axis=1)

In [None]:
fig, ax = plt.subplots(figsize=(6,4))
ax.plot(np.arange(len(pca.explained_variance_ratio_))+1, pca.explained_variance_ratio_, '-o', markersize=5)
ax.axhline(1/lognorm.shape[1], linestyle='--', color='gray')
ax.set_xlabel('PC')
ax.set_ylabel('explained var')

In [None]:
allcolors = sns.color_palette('tab20c', 20)
allcolors

In [None]:
allcolors2 = sns.color_palette('tab10', 20)
allcolors2

In [None]:
palette = collections.OrderedDict({
     "P8NR": allcolors[2],
    "P14NR": allcolors[0],
    
    "P17NR": allcolors[7],
    "P21NR": allcolors[6],
    "P28NR": allcolors[5],
    "P38NR": allcolors[4],
    
})

palette_types = collections.OrderedDict({
    'L2/3_A': allcolors2[0],
    'L2/3_B': allcolors2[1],
    'L2/3_C': allcolors2[2],
    
    'L2/3_1': allcolors2[0],
    'L2/3_2': allcolors2[1],
    'L2/3_3': allcolors2[2],
    
    'L2/3_AB': allcolors2[0],
    'L2/3_BC': allcolors2[2],
})             

cases = np.array(list(palette.keys()))

In [None]:
fig, ax = plt.subplots(figsize=(1,4))
for i, (key, item) in enumerate(palette.items()):
    ax.plot(0,     len(palette)-i, 'o', c=item, )
    ax.text(0.02,  len(palette)-i, key, va='center', fontsize=15)
    ax.axis('off')
plt.show()

In [None]:
from scipy import stats

def plot(x, y):
    fig, axs = plt.subplots(1,6,figsize=(4*6,4*1), sharex=True, sharey=True)
    for ax, cond in zip(axs.flat, cases):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        sns.scatterplot(data=res[res['cond']==cond],
                        x=x, y=y, 
                        hue='type',
                        hue_order=list(palette_types.keys()),
                        palette=palette_types,
                        s=3, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        sns.despine(ax=ax)
        # ax.set_aspect('equal')
    plt.show()
    
def plot2(x, y, hue=None):
    fig, axs = plt.subplots(1,6,figsize=(4*6,4*1), sharex=True, sharey=True)
    for ax, cond in zip(axs.flat, cases):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        show = res[res['cond']==cond]
        if hue:
            ax.scatter(show[x], show[y], c=show[hue], 
                       cmap='coolwarm',
                       vmin=-3, vmax=3,
                       s=3, 
                       edgecolor='none', 
                      )
        else:
            r, p = stats.spearmanr(show[x], show[y])
            ax.scatter(show[x], show[y],  
                       s=3, 
                       edgecolor='none', 
                      )
            ax.set_title(f'{cond}\n r={r:.2f}')
        sns.despine(ax=ax)
        # ax.set_aspect('equal')
    plt.show()

In [None]:
plot('PC1', 'PC2')
plot('PC2', 'PC3')
plot('PC3', 'PC4')
plot('PC4', 'PC5')

# Try
- ICA
- HVGs
- Project to adult
- 3D visuals
- Understand this

In [None]:
# Apply ICA
from sklearn.decomposition import FastICA
ica = FastICA(n_components=5, whiten='unit-variance', max_iter=5000, tol=1e-3)
ics = ica.fit_transform(pcs[:,:5])

In [None]:
res2 = pd.DataFrame(ics, columns=np.char.add("IC", ((1+np.arange(ics.shape[1])).astype(str))))
res = pd.concat([res, res2], axis=1)

In [None]:
plot('PC1', 'IC1')
plot('PC1', 'IC2')
plot('PC1', 'IC3')
plot('PC1', 'IC4')
plot('PC1', 'IC5')

In [None]:
plot('IC1', 'IC2')

In [None]:
plot('IC3', 'IC4')

In [None]:
plot('IC4', 'IC5')

In [None]:
plot('IC1', 'IC4')