In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu


In [None]:
outdir_fig = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures"

In [None]:
# use those 286 genes
# df = pd.read_csv("../../data/cheng21_cell_scrna/res/candidate_genes_vincent_0503_v2.csv")
df = pd.read_csv("../../data/cheng21_cell_scrna/res/L23-ABC-genes-n288-n286unq-annot.csv")
genes_l23 = df['gene'].astype(str).values
genes_grp = df['P17on'].astype(str).values
assert len(genes_l23) == len(np.unique(genes_l23))

genes_l23.shape

In [None]:
for g in ['Epha6', 'Pcdh9', 'Pcdh7']:
    print(f'{g} in {g in genes_l23}')

In [None]:
[g for g in genes_l23 if g.startswith('Pcdh')]

In [None]:
[g for g in genes_l23 if g.startswith('Cdh')]

In [None]:
gaba = [
    "Gabra1",
    "Gabra2",
    "Gabra3",
    "Gabra4",
    "Gabra5",
    "Gabra6",
    
    "Gabrb1",
    "Gabrb2",
    "Gabrb3",
    
    "Gabrg1",
    "Gabrg2",
    "Gabrg3",
    
    "Gabrd",
    "Gabre",
    "Gabrp",
    "Gabrq",
    
    "Gabrr1",
    "Gabrr2",
    "Gabrr3", 
]

In [None]:
# load
adata = anndata.read("../../data/cheng21_cell_scrna/reorganized/all_L23.h5ad")
adata

In [None]:
# select
adata = adata[adata.obs['cond'].str.contains(r"NR$")]
adata.obs['cond'].unique()

# define
genes = adata.var.index.values
conds = adata.obs['cond'].values
types = adata.obs['Type'].values
samps = adata.obs['sample'].values

# organize
rename = {
    "L2/3_A": "L2/3_A",
    "L2/3_B": "L2/3_B",
    "L2/3_C": "L2/3_C",
    
    "L2/3_1": "L2/3_A",
    "L2/3_2": "L2/3_B",
    "L2/3_3": "L2/3_C",
    
    "L2/3_AB": "L2/3_A",
    "L2/3_BC": "L2/3_C",
}
adata.obs['easitype'] = adata.obs['Type'].apply(lambda x: rename[x])
adata

In [None]:
# filter genes
cond = np.ravel((adata.X>0).sum(axis=0)) > 10 # expressed in more than 10 cells
adata_sub = adata[:,cond].copy()

# counts
x = adata_sub.X
cov = adata_sub.obs['n_counts'].values
genes = adata_sub.var.index.values

# CP10k
# xn = x/cov.reshape(x.shape[0], -1)*1e4
xn = (sparse.diags(1/cov).dot(x))*1e4

# log10(CP10k+1)
xln = xn.copy()
xln.data = np.log10(xln.data+1)

adata_sub.layers['norm'] = np.array(xn.todense())
adata_sub.layers['lognorm'] = np.array(xln.todense())
adata_sub.layers['zlognorm'] = zscore(np.array(xln.todense()), axis=0)

In [None]:
gaba_idx = basicu.get_index_from_array(genes, gaba)
gaba_idx = gaba_idx[gaba_idx!=-1]
gaba_sel = genes[gaba_idx]

gaba_idx, gaba_sel

In [None]:
# select HVGs with mean and var
nbin = 20
qth = 0.3

# min
gm = np.ravel(xn.mean(axis=0))

# var
tmp = xn.copy()
tmp.data = np.power(tmp.data, 2)
gv = np.ravel(tmp.mean(axis=0))-gm**2

# cut 
lbl = pd.qcut(gm, nbin, labels=np.arange(nbin))
gres = pd.DataFrame()
gres['name'] = genes
gres['lbl'] = lbl
gres['mean'] = gm
gres['var'] = gv
gres['ratio']= gv/gm

# select
gres_sel = gres.groupby('lbl')['ratio'].nlargest(int(qth*(len(gm)/nbin))) #.reset_index()
gsel_idx = np.sort(gres_sel.index.get_level_values(1).values)
assert np.all(gsel_idx != -1)

gsel_idx = np.union1d(gsel_idx, gaba_idx)

# # 
# l23_gidx = basicu.get_index_from_array(genes, genes_l23)
# assert np.all(l23_gidx != -1)

In [None]:
fig, axs = plt.subplots(1,2, figsize=(6*2,4), sharex=True, sharey=True)
ax = axs[0]
ax.scatter(gm, gv/gm, s=5, edgecolor='none', color='gray')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('mean expr (CP10k)')
ax.set_ylabel('var/mean')
ax.set_title(f'{len(gm):,} genes')

ax = axs[1]
ax.scatter(gm, gv/gm, s=5, edgecolor='none', color='gray')
ax.scatter(gm[gsel_idx], (gv/gm)[gsel_idx], c=lbl[gsel_idx], s=5, edgecolor='none', cmap='viridis_r', label=f'{len(gsel_idx):,} HVGs')
# ax.scatter(gm[l23_gidx], (gv/gm)[l23_gidx], s=5, facecolors='none', edgecolor='C0', label='L2/3 type genes')
# ax.scatter(gm[it_gidx], (gv/gm)[it_gidx], s=5, facecolors='none', edgecolor='C1', label='IT genes')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('mean expr (CP10k)')
ax.set_ylabel('var/mean')
ax.legend(bbox_to_anchor=(1,1))

plt.show()

In [None]:
adata = adata_sub[:,gsel_idx]
# adata = adata_sub[:,l23_gidx]
genes_sel = adata.var.index.values
lognorm = np.array(adata.layers['lognorm'])
zlognorm = np.array(adata.layers['zlognorm'])
zlognorm.shape

In [None]:
adata_p8 = adata[adata.obs['cond'].isin(['P8NR'])]
lognorm_p8 = np.array(adata_p8.layers['lognorm'])
zlognorm_p8 = np.array(adata_p8.layers['zlognorm'])

pca = PCA(n_components=10)
pca.fit(zlognorm_p8)
pcs_p8 = pca.transform(zlognorm)

pcs_p8[:,0] = -pcs_p8[:,0] # flip PC1

adata_p17on = adata[adata.obs['cond'].isin(['P17NR', 'P21NR', 'P28NR'])]
lognorm_p17on = np.array(adata_p17on.layers['lognorm'])
zlognorm_p17on = np.array(adata_p17on.layers['zlognorm'])

pca2 = PCA(n_components=10)
pca2.fit(zlognorm_p17on)
pcs_p17on = pca2.transform(zlognorm)

pcs_p17on[:,1] = -pcs_p17on[:,1] # flip PC2

# ucs = UMAP(n_components=2, n_neighbors=50).fit_transform(pcs)

In [None]:
vt = pca.components_
topgenes_p8_pc1 = genes_sel[np.argsort(np.abs(vt[0]))[::-1]]
print(topgenes_p8_pc1[:10])

topgenes_p8_pc2 = genes_sel[np.argsort(np.abs(vt[1]))[::-1]]
print(topgenes_p8_pc2[:10])



# np.sort(np.abs(vt[3]))[::-1]

In [None]:
topgenes_p8_pc2[:50]

In [None]:
# # fix pc1 to make sure a < c:
# pc1 = pcs[:,0]
# pc_types, unq_types = basicu.group_mean(pc1.reshape(-1,1), types)
# a = pc_types[0,0]
# c = pc_types[-1,0]
# if a > c:
#     pcs[:,0] = -pcs[:,0]

In [None]:
res1 = pd.DataFrame(pcs_p8, columns=np.char.add("p8PC", ((1+np.arange(pcs_p8.shape[1])).astype(str))))
res1['cond'] = conds
res1['type'] = types
res1['samp'] = samps
res1['rep']  = [samp.split('_')[1][0] for samp in samps]


res0 = pd.DataFrame(zlognorm, columns=genes_sel)
res2 = pd.DataFrame(pcs_p17on, columns=np.char.add("p17onPC", ((1+np.arange(pcs_p17on.shape[1])).astype(str))))
res = pd.concat([res0, res1, res2], axis=1)

In [None]:
res1['rep']

In [None]:
fig, ax = plt.subplots(figsize=(6,4))
ax.plot(np.arange(len(pca.explained_variance_ratio_))+1, pca.explained_variance_ratio_, '-o', markersize=5)
ax.plot(np.arange(len(pca2.explained_variance_ratio_))+1, pca2.explained_variance_ratio_, '-o', markersize=5)
ax.axhline(1/lognorm.shape[1], linestyle='--', color='gray')
ax.set_xlabel('PC')
ax.set_ylabel('explained var')

In [None]:
allcolors = sns.color_palette('tab20c', 20)
allcolors

In [None]:
allcolors2 = sns.color_palette('tab10', 20)
allcolors2

In [None]:
palette = collections.OrderedDict({
     "P8NR": allcolors[2],
    "P14NR": allcolors[0],
    
    "P17NR": allcolors[7],
    "P21NR": allcolors[6],
    "P28NR": allcolors[5],
    "P38NR": allcolors[4],
    
})

palette_types = collections.OrderedDict({
    'L2/3_A': allcolors2[0],
    'L2/3_B': allcolors2[1],
    'L2/3_C': allcolors2[2],
    
    'L2/3_1': allcolors2[0],
    'L2/3_2': allcolors2[1],
    'L2/3_3': allcolors2[2],
    
    'L2/3_AB': allcolors2[0],
    'L2/3_BC': allcolors2[2],
})             

cases = np.array(list(palette.keys()))

In [None]:
fig, ax = plt.subplots(figsize=(1,4))
for i, (key, item) in enumerate(palette.items()):
    ax.plot(0,     len(palette)-i, 'o', c=item, )
    ax.text(0.02,  len(palette)-i, key, va='center', fontsize=15)
    ax.axis('off')
plt.show()

In [None]:
from scipy import stats
from matplotlib.ticker import MaxNLocator


def plot(x, y, aspect_equal=False, density=False, hue='type'):
    fig, axs = plt.subplots(1,6,figsize=(4*6,4*1), sharex=True, sharey=True)
    for i, (ax, cond) in enumerate(zip(axs.flat, cases)):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        if hue == 'type':
            sns.scatterplot(data=res[res['cond']==cond].sample(frac=1, replace=False),
                            x=x, y=y, 
                            hue='type',
                            hue_order=list(palette_types.keys()),
                            palette=palette_types,
                            s=3, edgecolor='none', 
                            legend=False,
                            ax=ax,
                           )
        else:
            sns.scatterplot(data=res[res['cond']==cond].sample(frac=1, replace=False),
                            x=x, y=y, 
                            hue='rep',
                            # hue_order=list(palette_types.keys()),
                            # palette=palette_types,
                            s=3, edgecolor='none', 
                            legend=False,
                            ax=ax,
                           )
            
        if density:
            sns.histplot(data=res[res['cond']==cond],
                            x=x, y=y, 
                            legend=False,
                            ax=ax,
                           )
        sns.despine(ax=ax)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=3))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
        if aspect_equal:
            ax.set_aspect('equal')
        if i > 0:
            ax.set_xlabel('')
            ax.set_ylabel('')
    plt.show()
    
def plot2(x, y, hue=None, aspect_equal=False):
    fig, axs = plt.subplots(1,6,figsize=(4*6,4*1), sharex=True, sharey=True)
    fig.suptitle(hue, x=0.1, ha='left')
    for i, (ax, cond) in enumerate(zip(axs.flat, cases)):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        alpha=0.3,
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        show = res[res['cond']==cond]
        if hue:
            ax.scatter(show[x], show[y], c=show[hue], 
                       cmap='coolwarm',
                       vmin=-3, vmax=3,
                       s=5, 
                       edgecolor='none', 
                      )
        else:
            r, p = stats.spearmanr(show[x], show[y])
            ax.scatter(show[x], show[y],  
                       s=5, 
                       edgecolor='none', 
                      )
            ax.set_title(f'{cond}\n r={r:.2f}')
        sns.despine(ax=ax)
        ax.xaxis.set_major_locator(MaxNLocator(nbins=3))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=3))
        if aspect_equal:
            ax.set_aspect('equal')
        if i > 0:
            ax.set_xlabel('')
            ax.set_ylabel('')
    plt.show()

In [None]:
plot('p8PC1', 'p8PC2', aspect_equal=True, hue='rep')

In [None]:
plot('p17onPC1', 'p17onPC2', aspect_equal=True, hue='rep')

In [None]:
plot('p8PC1', 'p8PC2', aspect_equal=True)
plot('p17onPC1', 'p17onPC2', aspect_equal=True)

# plot('PC2', 'PC3')
# plot('PC3', 'PC4')
# plot('PC4', 'PC5')

In [None]:
plot('p8PC1', 'p8PC2', aspect_equal=True, density=True)
plot('p17onPC1', 'p17onPC2', aspect_equal=True, density=True)

# what are the genes

In [None]:
plot2('p17onPC1', 'p17onPC2', hue='Cdh13', aspect_equal=True)
plot2('p17onPC1', 'p17onPC2', hue='Sorcs3', aspect_equal=True)
plot2('p17onPC1', 'p17onPC2', hue='Chrm2', aspect_equal=True)

In [None]:
plot2('p8PC1', 'p8PC2', hue='Cdh13', aspect_equal=True)
plot2('p8PC1', 'p8PC2', hue='Sorcs3', aspect_equal=True)
plot2('p8PC1', 'p8PC2', hue='Chrm2', aspect_equal=True)

# explore

In [None]:

for g in gaba_sel:
    plot2('p8PC1', 'p8PC2', hue=g, aspect_equal=True)
    plot2('p17onPC1', 'p17onPC2', hue=g, aspect_equal=True)

In [None]:
for g in ['Rfx3', 'Egr1', 'Meis2', 'Pou3f2', 'Sox5']:
    plot2('p8PC1', 'p8PC2', hue=g, aspect_equal=True)
    plot2('p17onPC1', 'p17onPC2', hue=g, aspect_equal=True)

In [None]:
for g in ['Egr1', 'Meis2', 'Pou3f2', 'Sox5']:
    plot2(g, 'Rfx3', hue='p8PC1', aspect_equal=True)
    # plot2('p17onPC1', 'p17onPC2', hue=g, aspect_equal=True)