In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections

from sklearn.decomposition import PCA
from umap import UMAP

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu


In [None]:
outdir_fig = "/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/figures"

In [None]:
adata = anndata.read("../../data/cheng21_cell_scrna/reorganized/all_L23.h5ad")
genes = adata.var.index.values
conds = adata.obs['cond'].values
types = adata.obs['Type'].values
samps = adata.obs['sample'].values
adata

In [None]:
rename = {
    "L2/3_A": "L2/3_A",
    "L2/3_B": "L2/3_B",
    "L2/3_C": "L2/3_C",
    
    "L2/3_1": "L2/3_A",
    "L2/3_2": "L2/3_B",
    "L2/3_3": "L2/3_C",
    
    "L2/3_AB": "L2/3_A",
    "L2/3_BC": "L2/3_C",
}
adata.obs['easitype'] = adata.obs['Type'].apply(lambda x: rename[x])

In [None]:
# use those 286 genes
# df = pd.read_csv("../../data/cheng21_cell_scrna/res/candidate_genes_vincent_0503_v2.csv")
df = pd.read_csv("../../data/cheng21_cell_scrna/res/L23-ABC-genes-n288-n286unq-annot.csv")
genes_sel = df['gene'].astype(str).values
genes_grp = df['P17on'].astype(str).values
assert len(genes_sel) == len(np.unique(genes_sel))

gi = basicu.get_index_from_array(genes, genes_sel)
assert np.all(genes_sel != -1)

# CP10k for single cells
cov = np.array(adata.X.sum(axis=1))
counts = np.array(adata.X[:,gi].todense())
norm = counts/cov*1e4
lognorm = np.log10(norm+1)
zlognorm = zscore(lognorm, axis=0) 

In [None]:
# zlognorm = np.nan_to_num(zlognorm, 0)

In [None]:
gene_types, gene_type_counts = np.unique(genes_grp, return_counts=True)
gene_types, gene_type_counts

In [None]:
pca = PCA(n_components=50)
pcs = pca.fit_transform(zlognorm)

ucs = UMAP(n_components=2, n_neighbors=50).fit_transform(pcs)

In [None]:
# # fix pc1 to make sure a < c:
# pc1 = pcs[:,0]
# pc_types, unq_types = basicu.group_mean(pc1.reshape(-1,1), types)
# a = pc_types[0,0]
# c = pc_types[-1,0]
# if a > c:
#     pcs[:,0] = -pcs[:,0]

In [None]:
res = pd.DataFrame(pcs, columns=np.char.add("PC", ((1+np.arange(pcs.shape[1])).astype(str))))
res['cond'] = conds
res['type'] = types
res['samp'] = samps
res['umap1'] = ucs[:,0]
res['umap2'] = ucs[:,1]
# res['type'] = types

In [None]:
# plt.plot(np.cumsum(pca.explained_variance_ratio_), '-o')
fig, ax = plt.subplots(figsize=(6,4))
ax.plot(np.arange(len(pca.explained_variance_ratio_))+1, pca.explained_variance_ratio_, '-o', markersize=5)
ax.axhline(1/lognorm.shape[1], linestyle='--', color='gray')
ax.set_xlabel('PC')
ax.set_ylabel('explained var')

In [None]:
allcolors = sns.color_palette('tab20c', 20)
allcolors

In [None]:
allcolors2 = sns.color_palette('tab10', 20)
allcolors2

In [None]:
palette = collections.OrderedDict({
     "P8NR": allcolors[2],
    "P14NR": allcolors[0],
    
    "P17NR": allcolors[7],
    "P21NR": allcolors[6],
    "P28NR": allcolors[5],
    "P38NR": allcolors[4],
    
    "P28DL": allcolors[8],
    
    "P28DR": allcolors[14],
    "P38DR": allcolors[12],
})

cases = np.array(list(palette.keys()))
cases

In [None]:
palette_types = collections.OrderedDict({
    'L2/3_A': allcolors2[0],
    'L2/3_B': allcolors2[1],
    'L2/3_C': allcolors2[2],
    
    'L2/3_1': allcolors2[0],
    'L2/3_2': allcolors2[1],
    'L2/3_3': allcolors2[2],
    
    'L2/3_AB': allcolors2[0],
    'L2/3_BC': allcolors2[2],
})             

In [None]:
fig, ax = plt.subplots(figsize=(1,4))
for i, (key, item) in enumerate(palette.items()):
    ax.plot(0,     len(palette)-i, 'o', c=item, )
    ax.text(0.02,  len(palette)-i, key, va='center', fontsize=15)
    ax.axis('off')
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8*1,6))
sns.scatterplot(data=res.sample(frac=1, replace=False), 
                x='PC1', y='PC2', 
                hue='cond',
                hue_order=list(palette.keys()),
                palette=palette,
                s=5, edgecolor='none', 
                # legend=False,
                ax=ax,
               )
ax.legend(bbox_to_anchor=(1,1))
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8*1,6))
sns.scatterplot(data=res.sample(frac=1, replace=False), 
                x='PC2', y='PC3', 
                hue='cond',
                hue_order=list(palette.keys()),
                palette=palette,
                s=5, edgecolor='none', 
                # legend=False,
                ax=ax,
               )
ax.legend(bbox_to_anchor=(1,1))
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8*1,6))
sns.scatterplot(data=res.sample(frac=1, replace=False), 
                x='PC3', y='PC4', 
                hue='cond',
                hue_order=list(palette.keys()),
                palette=palette,
                s=5, edgecolor='none', 
                # legend=False,
                ax=ax,
               )
ax.legend(bbox_to_anchor=(1,1))
plt.show()

In [None]:
fig, axs = plt.subplots(3,3,figsize=(4*3,4*3), sharex=True, sharey=True)
for ax, cond in zip(axs.flat, cases):
    ax.set_title(cond)
    sns.scatterplot(data=res, 
                    x='PC1', y='PC2', 
                    c='lightgray',
                    s=1, edgecolor='none', 
                    legend=False,
                    ax=ax,
                   )
    sns.scatterplot(data=res[res['cond']==cond],
                    x='PC1', y='PC2', 
                    hue='cond',
                    hue_order=list(palette.keys()),
                    palette=palette,
                    s=3, edgecolor='none', 
                    legend=False,
                    ax=ax,
                   )
    sns.despine(ax=ax)
plt.show()

In [None]:
fig, axs = plt.subplots(3,3,figsize=(4*3,4*3), sharex=True, sharey=True)
for ax, cond in zip(axs.flat, cases):
    ax.set_title(cond)
    sns.scatterplot(data=res, 
                    x='PC2', y='PC3', 
                    c='lightgray',
                    s=1, edgecolor='none', 
                    legend=False,
                    ax=ax,
                   )
    sns.scatterplot(data=res[res['cond']==cond],
                    x='PC2', y='PC3', 
                    hue='cond',
                    hue_order=list(palette.keys()),
                    palette=palette,
                    s=3, edgecolor='none', 
                    legend=False,
                    ax=ax,
                   )
    sns.despine(ax=ax)
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8*1,6))
sns.scatterplot(data=res.sample(frac=1, replace=False), 
                x='umap1', y='umap2', 
                hue='cond',
                hue_order=list(palette.keys()),
                palette=palette,
                s=5, edgecolor='none', 
                # legend=False,
                ax=ax,
               )
ax.legend(bbox_to_anchor=(1,1))
plt.show()

In [None]:
fig, axs = plt.subplots(3,3,figsize=(4*3,4*3), sharex=True, sharey=True)
for ax, cond in zip(axs.flat, cases):
    ax.set_title(cond)
    sns.scatterplot(data=res, 
                    x='umap1', y='umap2', 
                    c='lightgray',
                    s=1, edgecolor='none', 
                    legend=False,
                    ax=ax,
                   )
    sns.scatterplot(data=res[res['cond']==cond],
                    x='umap1', y='umap2', 
                    hue='cond',
                    hue_order=list(palette.keys()),
                    palette=palette,
                    s=3, edgecolor='none', 
                    legend=False,
                    ax=ax,
                   )
    sns.despine(ax=ax)
plt.show()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8*1,6))
sns.scatterplot(data=res.sample(frac=1, replace=False), 
                x='umap1', y='umap2', 
                hue='type',
                # hue='cond',
                # hue_order=list(palette.keys()),
                palette=palette_types,
                s=5, edgecolor='none', 
                # legend=False,
                ax=ax,
               )
ax.legend(bbox_to_anchor=(1,1))
plt.show()

In [None]:
fig, axs = plt.subplots(1,2,figsize=(6*2,6), sharex=True, sharey=True)
categories = ['NR|DL', 'DR']
for cat, ax in zip(categories, axs.flat):
    ax.set_title(cat)
    sns.scatterplot(data=res, #[res['cond'].str.contains('DR')].sample(frac=1, replace=False), 
                    x='umap1', y='umap2', 
                    color='lightgray',
                    s=1, edgecolor='none', 
                    legend=False,
                    ax=ax,
                   )
    sns.scatterplot(data=res[res['cond'].str.contains(cat)].sample(frac=1, replace=False), 
                    x='umap1', y='umap2', 
                    hue='type',
                    palette=palette_types,
                    s=5, edgecolor='none', 
                    legend=False,
                    ax=ax,
                   )
# ax.legend(bbox_to_anchor=(1,1))
fig.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(3,3,figsize=(4*3,4*3), sharex=True, sharey=True)
for ax, cond in zip(axs.flat, cases):
    ax.set_title(cond)
    sns.scatterplot(data=res, 
                    x='umap1', y='umap2', 
                    c='lightgray',
                    s=1, edgecolor='none', 
                    legend=False,
                    ax=ax,
                   )
    sns.scatterplot(data=res[res['cond']==cond],
                    x='umap1', y='umap2', 
                    hue='type',
                    # hue_order=list(palette.keys()),
                    palette=palette_types,
                    s=3, edgecolor='none', 
                    legend=False,
                    ax=ax,
                   )
    sns.despine(ax=ax)
plt.show()

# heatmap

# give each cell a pseudo-time index

In [None]:
ncell, ngene = zlognorm.shape

my_colors = ('C0', 'C1', 'C2')
cmap_d1 = LinearSegmentedColormap.from_list('', my_colors, len(my_colors))

my_colors = ('C0', 'C1', 'C2', 'C3')
cmap_d2 = LinearSegmentedColormap.from_list('', my_colors, len(my_colors))

In [None]:
# Use P28 ordering
f = '../../results/gene_ptime_P28_L23_Mar27.tsv'
gpt = pd.read_csv(f)
gpt = gpt['gene_ptime'].values
geneidx = np.argsort(gpt)

In [None]:
# Use ptime from each sample analysis
dfall = []
for case in cases:
    f = f'../../results/cell_ptime_{case}_L23_Mar28.tsv'
    df = pd.read_csv(f, index_col=0)
    df['case'] = case
    dfall.append(df)
dfall = pd.concat(dfall)
assert np.all(dfall.index.values == adata.obs.index.values)
dfall

In [None]:
def plot(zlognorm, cellidx, geneidx, types, genes_grp, case, figsize=(12,6), xticklabels=1000, hratio=20, vratio=10):
    """
    """
    mosaic = ("B"+"A"*hratio+"\n")*vratio + "."+"C"*hratio
    # print(mosaic)
    fig, axdict = plt.subplot_mosaic(mosaic, figsize=figsize)
    fig.suptitle(case)
    ax = axdict['A']
    sns.heatmap(zlognorm[cellidx][:,geneidx].T,
                xticklabels=False,
                yticklabels=False,
                cbar_kws=dict(shrink=0.3, label='zscore log10CP10k', aspect=10),
                center=0,
                vmax=3,
                vmin=-3,
                cmap='coolwarm',
                rasterized=True,
                ax=ax,
               )

    ax = axdict['C']
    sns.heatmap(pd.factorize(types[cellidx], sort=True)[0].reshape(-1,1).T, 
                xticklabels=xticklabels,
                yticklabels=False,
                cmap=cmap_d1,
                cbar_kws=dict(ticks=[0,1,2], shrink=2, aspect=5),
                rasterized=True,
                ax=ax, 
               )
    ax.set_xlabel('Cells')
    fig.axes[-1].set_yticklabels(['A','B','C'])
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=10)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=0, fontsize=10)

    ax = axdict['B']
    sns.heatmap(pd.factorize(genes_grp[geneidx], sort=True)[0].reshape(-1,1), 
                xticklabels=False,
                yticklabels=100,
                cmap=cmap_d1,
                cbar=False,
                rasterized=True,
                ax=ax, 
               )
    ax.set_ylabel('Genes')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=10)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=0, fontsize=10)

    fig.subplots_adjust(hspace=0.4)
    return fig, axdict
    # fig.suptitle(sample)
    # powerplots.savefig_autodate(fig, os.path.join(outdir_fig, f'heatmap_{sample}_v3.pdf'))
    # plt.show()

In [None]:
plt.plot(zlognorm.mean(axis=0))
plt.plot(zlognorm.std(axis=0))

In [None]:
allcells = adata.obs.index.values
mean_exps = []
rnge_exps = []
for case in cases:
    
    # select cells
    adata_sub = adata[adata.obs['cond']==case]
    cells_sub = adata_sub.obs.index.values
    cellsidx_sub = basicu.get_index_from_array(allcells, cells_sub)
    
    dfall_sub = dfall.loc[cells_sub]
    zlognorm_sub = zlognorm[cellsidx_sub]
    
    # order cells
    cellidx = np.argsort(dfall_sub['ptime'].values)
    types   = adata_sub.obs['Type'].values
    
    mean_exp = zlognorm_sub.mean(axis=0)
    rnge_exp = np.percentile(zlognorm_sub, 95, axis=0) - np.percentile(zlognorm_sub, 5, axis=0)
    mean_exps.append(mean_exp)
    rnge_exps.append(rnge_exp)
    plot(zlognorm_sub, cellidx, geneidx, types, genes_grp, case)
    
    # break

In [None]:
# a big table of everything
allcells = adata.obs.index.values
bigmat = []
bigtypes = []

ncases = []
for case in cases:
    # select cells
    adata_sub = adata[adata.obs['cond']==case]
    cells_sub = adata_sub.obs.index.values
    cellsidx_sub = basicu.get_index_from_array(allcells, cells_sub)
    
    dfall_sub = dfall.loc[cells_sub]
    zlognorm_sub = zlognorm[cellsidx_sub]
    
    # order cells
    cellidx = np.argsort(dfall_sub['ptime'].values)
    types   = adata_sub.obs['easitype'].values
    
    bigmat.append(zlognorm_sub[cellidx])
    bigtypes.append(types[cellidx])
    ncases.append(len(cellidx))
    
bigmat = np.vstack(bigmat)
bigtypes = np.hstack(bigtypes)
bigmat.shape

In [None]:
fig, axdict = plot(bigmat, np.arange(len(bigmat)), geneidx, bigtypes, genes_grp, "", 
                   figsize=(20,6), xticklabels=5000, hratio=50)
axdict['A'].vlines(np.cumsum(ncases), 0, bigmat.shape[1], color='k', linestyle='--', linewidth=1)
axdict['C'].vlines(np.cumsum(ncases), 0, 1, color='k', linestyle='--', linewidth=1)
for x, case in zip(np.hstack([[0], np.cumsum(ncases)]), cases):
    axdict['A'].text(x, 0, case, fontsize=15)
plt.show()

# Collapse into types

# Specific genes

# Quantify this