In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections

from sklearn.decomposition import PCA
from umap import UMAP

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu


In [None]:
outdir_fig = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures"

In [None]:
# load
adata = anndata.read("../../data/cheng21_cell_scrna/reorganized/all_L23.h5ad")
adata

In [None]:
# select
# adata = adata[adata.obs['cond'].str.contains(r"NR$")]
# adata = adata[adata.obs['cond'].isin(['P17NR', 'P21NR', 'P28NR', 'P38NR'])]
adata = adata[adata.obs['cond'].isin(['P17NR', 'P14NR', 'P8NR', ])]
print(adata.obs['cond'].unique())

# define
genes = adata.var.index.values
conds = adata.obs['cond'].values
types = adata.obs['Type'].values
samps = adata.obs['sample'].values

# organize
rename = {
    "L2/3_A": "L2/3_A",
    "L2/3_B": "L2/3_B",
    "L2/3_C": "L2/3_C",
    
    "L2/3_1": "L2/3_A",
    "L2/3_2": "L2/3_B",
    "L2/3_3": "L2/3_C",
    
    "L2/3_AB": "L2/3_A",
    "L2/3_BC": "L2/3_C",
}
adata.obs['easitype'] = adata.obs['Type'].apply(lambda x: rename[x])
adata

In [None]:
# use those 286 genes
# df = pd.read_csv("../../data/cheng21_cell_scrna/res/candidate_genes_vincent_0503_v2.csv")
df = pd.read_csv("../../data/cheng21_cell_scrna/res/L23-ABC-genes-n288-n286unq-annot.csv")
genes_sel = df['gene'].astype(str).values
genes_grp = df['P17on'].astype(str).values
assert len(genes_sel) == len(np.unique(genes_sel))

gi = basicu.get_index_from_array(genes, genes_sel)
assert np.all(genes_sel != -1)

# CP10k for single cells
cov = np.array(adata.X.sum(axis=1))
counts = np.array(adata.X[:,gi].todense())
norm = counts/cov*1e4
lognorm = np.log10(norm+1)
zlognorm = zscore(lognorm, axis=0) 
print(zlognorm.shape)

In [None]:
# zlognorm = np.nan_to_num(zlognorm, 0)

In [None]:
gene_types, gene_type_counts = np.unique(genes_grp, return_counts=True)
gene_types, gene_type_counts

In [None]:
pca = PCA(n_components=10)
pcs = pca.fit_transform(zlognorm)

# ucs = UMAP(n_components=2, n_neighbors=50).fit_transform(pcs)

In [None]:
# # fix pc1 to make sure a < c:
# pc1 = pcs[:,0]
# pc_types, unq_types = basicu.group_mean(pc1.reshape(-1,1), types)
# a = pc_types[0,0]
# c = pc_types[-1,0]
# if a > c:
#     pcs[:,0] = -pcs[:,0]

In [None]:
res = pd.DataFrame(pcs, columns=np.char.add("PC", ((1+np.arange(pcs.shape[1])).astype(str))))
res['cond'] = conds
res['type'] = types
res['samp'] = samps
# res['umap1'] = ucs[:,0]
# res['umap2'] = ucs[:,1]
# res['type'] = types

In [None]:
# plt.plot(np.cumsum(pca.explained_variance_ratio_), '-o')
fig, ax = plt.subplots(figsize=(6,4))
ax.plot(np.arange(len(pca.explained_variance_ratio_))+1, pca.explained_variance_ratio_, '-o', markersize=5)
ax.axhline(1/lognorm.shape[1], linestyle='--', color='gray')
ax.set_xlabel('PC')
ax.set_ylabel('explained var')

In [None]:
allcolors = sns.color_palette('tab20c', 20)
allcolors

In [None]:
allcolors2 = sns.color_palette('tab10', 20)
allcolors2

In [None]:
palette = collections.OrderedDict({
     "P8NR": allcolors[2],
    "P14NR": allcolors[0],
    
    "P17NR": allcolors[7],
    # "P21NR": allcolors[6],
    # "P28NR": allcolors[5],
    # "P38NR": allcolors[4],
    
#     "P28DR": allcolors[14],
#     "P38DR": allcolors[12],
    
#     "P28DL": allcolors[8],
})

cases = np.array(list(palette.keys()))
cases

In [None]:
palette_types = collections.OrderedDict({
    'L2/3_A': allcolors2[0],
    'L2/3_B': allcolors2[1],
    'L2/3_C': allcolors2[2],
    
    'L2/3_1': allcolors2[0],
    'L2/3_2': allcolors2[1],
    'L2/3_3': allcolors2[2],
    
    'L2/3_AB': allcolors2[0],
    'L2/3_BC': allcolors2[2],
})             

In [None]:
fig, ax = plt.subplots(figsize=(1,4))
for i, (key, item) in enumerate(palette.items()):
    ax.plot(0,     len(palette)-i, 'o', c=item, )
    ax.text(0.02,  len(palette)-i, key, va='center', fontsize=15)
    ax.axis('off')
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8*1,6))
sns.scatterplot(data=res.sample(frac=1, replace=False), 
                x='PC1', y='PC2', 
                hue='cond',
                hue_order=list(palette.keys()),
                palette=palette,
                s=5, edgecolor='none', 
                # legend=False,
                ax=ax,
               )
ax.legend(bbox_to_anchor=(1,1))
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8*1,6))
sns.scatterplot(data=res.sample(frac=1, replace=False), 
                x='PC2', y='PC3', 
                hue='cond',
                hue_order=list(palette.keys()),
                palette=palette,
                s=5, edgecolor='none', 
                # legend=False,
                ax=ax,
               )
ax.legend(bbox_to_anchor=(1,1))
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8*1,6))
sns.scatterplot(data=res.sample(frac=1, replace=False), 
                x='PC3', y='PC4', 
                hue='cond',
                hue_order=list(palette.keys()),
                palette=palette,
                s=5, edgecolor='none', 
                # legend=False,
                ax=ax,
               )
ax.legend(bbox_to_anchor=(1,1))
plt.show()

In [None]:

def plot(x, y):
    fig, axs = plt.subplots(1,3,figsize=(4*3,4*1), sharex=True, sharey=True)
    for ax, cond in zip(axs.flat, cases):
        ax.set_title(cond)
        sns.scatterplot(data=res, 
                        x=x, y=y, 
                        c='lightgray',
                        s=1, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        sns.scatterplot(data=res[res['cond']==cond].sample(frac=1, replace=False),
                        x=x, y=y, 
                        hue='type',
                        hue_order=list(palette_types.keys()),
                        palette=palette_types,
                        s=3, edgecolor='none', 
                        legend=False,
                        ax=ax,
                       )
        sns.despine(ax=ax)
        # ax.set_aspect('equal')
    plt.show()

In [None]:
plot('PC1', 'PC2')
plot('PC2', 'PC3')
plot('PC3', 'PC4')
plot('PC4', 'PC5')

In [None]:
plot('PC1', 'PC5')

# Try
- ICA
- HVGs
- Project to adult
- 3D visuals
- Understand this

In [None]:
import numpy as np
import matplotlib.pyplot as plt



In [None]:
# Apply ICA
from sklearn.decomposition import FastICA
ica = FastICA(n_components=5, whiten='unit-variance', max_iter=5000, tol=1e-3)
ics = ica.fit_transform(pcs[:,:5])

In [None]:
res2 = pd.DataFrame(ics, columns=np.char.add("IC", ((1+np.arange(ics.shape[1])).astype(str))))
res[res2.columns] = res2

In [None]:
plot('PC1', 'IC1')
plot('PC1', 'IC2')
plot('PC1', 'IC3')
plot('PC1', 'IC4')
plot('PC1', 'IC5')

In [None]:
plot('IC1', 'IC5')

In [None]:
plot('IC5', 'IC4')