In [1]:
import itertools
import numpy as np
import pandas as pd
from scipy.stats import ttest_ind
import anndata

import matplotlib.pyplot as plt
from statsmodels.stats.multitest import multipletests
import seaborn as sns

from dredFISH.Utils import basicu
from dredFISH.Utils import powerplots

In [2]:
def get_normed_bulks(mat, genes, types, genes_sel_idx=None):
    """
    Assume cell by gene
    - select expressed genes
    - sparse to dense
    - merge cells to clusters
    - log10(CPM+1) norm bulk samples
    """
    ncell, ngene = mat.shape
    cellcov = np.asarray(mat.sum(axis=1)).reshape(-1,)
    genecov = np.asarray(mat.sum(axis=0)).reshape(-1,)
    
    if genes_sel_idx is None:
        cond = genecov>ncell*0.01 # expressed in at least 1% of cells
        matsub = np.asarray(mat[:,cond].todense())
        genes_sel = genes[cond]
    else:
        matsub = np.asarray(mat[:,genes_sel_idx].todense())
        genes_sel = genes[genes_sel_idx]
    
    # pseudo-bulk samples -- pull counts from cells
    Xk, xclsts = basicu.group_sum(matsub, types)
    ck, xclsts = basicu.group_sum(cellcov.reshape(-1,1), types)
    Xk = np.log10(1+(np.array(Xk)/np.array(ck))*1e6) # log10(1+CPM)
    df = pd.DataFrame(Xk, index=xclsts, columns=genes_sel)
    return df.T # gene by types 

def get_normed_bulks_for_adata_by_types(adata, genes_cndd=None):
    """
    """
    if genes_cndd is None:
        genes_sel_idx = None
    else:
        genes_sel_idx = basicu.get_index_from_array(adata.var.index.values, genes_cndd)
        if np.sum(genes_sel_idx == -1) > 0:
            print("some genes are not there")
            genes_sel_idx = genes_sel_idx[genes_sel_idx!=-1]
            
    return get_normed_bulks(adata.X.copy(), adata.var.index.values, adata.obs['Type'], 
                            genes_sel_idx=genes_sel_idx)

In [3]:
sns.set_context('talk')

In [4]:
f = "/greendata/GeneralStorage/fangming/projects/visctx/data_dump/MERFISH_gene_panel_ver_3-3_Feb2023.csv"
df_merfish = pd.read_csv(f)

# # df_merfish = df_merfish[~df_merfish[0].isnull()]
cnddts = df_merfish['gene_name_data'].values
unq, cnts = np.unique(cnddts, return_counts=True)
print(len(cnddts), unq.shape, unq[cnts>1])

df_merfish

685 (683,) ['Hsd11b1' 'Whrn']


Unnamed: 0,gene_name_vizgen,gene_name_data,why included,source,Annot1,Annot2,Annot3,Unnamed: 7
0,Matn2,Matn2,L2/3 subtypes,Cheng22_Cell,A>C>B,screened,,
1,Egfem1,Egfem1,L2/3 subtypes,Cheng22_Cell,A>C>B,screened,,
2,Grb14,Grb14,L2/3 subtypes,Cheng22_Cell,A>C>B,*,,
3,Adamts17,Adamts17,L2/3 subtypes,Cheng22_Cell,A>C>B,*,,
4,Ldb2,Ldb2,L2/3 subtypes,Cheng22_Cell,A>C>B,*,,
...,...,...,...,...,...,...,...,...
680,Tox2,Tox2,L5 IT subtypes,Tasic18_Nature,,,,
681,Batf3,Batf3,L5 IT subtypes,Tasic18_Nature,,,,
682,Col6a1,Col6a1,L5 IT subtypes,Tasic18_Nature,,,,
683,Fezf2,Fezf2,L5 IT subtypes,Tasic18_Nature,,,,


In [5]:
df_merfish.groupby(['why included', 'source']).size().to_frame('number')

Unnamed: 0_level_0,Unnamed: 1_level_0,number
why included,source,Unnamed: 2_level_1
All cell types in V1,PROPOSE,50
Astrocyte NRvsDR DEGs,Cheng22_Cell,7
Astrocyte related,Bayraktar20_NatNeuro,44
Cortical excitatory types,Chen22_biorxiv_Zador,51
Cortical types,Cheng22_Cell,69
DEG_NRvsDR,our analysis,88
Early on marker,Cheng22_Cell,4
IEG,Hrvatin17_NatNeuro,22
L2/3 subtypes,Cheng22_Cell,122
L2/3 subtypes in SSp,Condylis22_Science,10


In [6]:
f = '../data_dump/counts/P28_nr_allcells_Oct24.h5ad'
adata_nr = anndata.read(f)
f = '../data_dump/counts/P28_dr_allcells_Oct25.h5ad'
adata_dr = anndata.read(f)

genes_data = adata_nr.var.index.values

adata_nr, adata_dr, genes_data



(AnnData object with n_obs × n_vars = 23930 × 53547
     obs: 'cov', 'covfactor', 'batch', 'n_genes', 'percent_mito', 'n_counts', 'leiden', 'Doublet', 'Doublet Score', 'Class_broad', 'sample', 'Type', 'Subclass'
     var: 'id',
 AnnData object with n_obs × n_vars = 24816 × 53547
     obs: 'cov', 'covfactor', 'n_genes', 'percent_mito', 'n_counts', 'Doublet', 'Doublet Score', 'batch', 'leiden', 'sample', 'Type', 'Subclass'
     var: 'id',
 array(['4933401J01Rik', 'Gm26206', 'Xkr4', ..., 'CAAA01064564.1',
        'Vmn2r122', 'CAAA01147332.1'], dtype=object))

In [7]:
gunsel = [g for g in cnddts if g not in genes_data]
len(gunsel), gunsel

(2, ['Magi1', 'Bcl11b'])

In [8]:
genes_cndd = np.array([g for g in cnddts if g in genes_data])
len(genes_cndd)


KeyboardInterrupt



In [None]:
nr_types = np.sort(adata_nr.obs['Type'].unique().astype(str))
dr_types = np.sort(adata_dr.obs['Type'].unique().astype(str))
nr_types, dr_types

In [None]:
adata_nr.obs['biosample'] = adata_nr.obs['sample'].apply(lambda x: x[:-1])
adata_dr.obs['biosample'] = adata_dr.obs['sample'].apply(lambda x: x[:-1])
samples_nr = np.unique(adata_nr.obs['biosample'])
samples_dr = np.unique(adata_dr.obs['biosample'])
samples_nr, samples_dr

In [None]:
%%time
dfs_nr = []
for samp in samples_nr:
    adatasub = adata_nr[adata_nr.obs['biosample']==samp]
    df = get_normed_bulks_for_adata_by_types(adatasub, genes_cndd=genes_cndd)
    print(df.shape)
    print(df.columns)
    dfs_nr.append(df)

dfs_dr = []
for samp in samples_dr:
    adatasub = adata_dr[adata_dr.obs['biosample']==samp]
    df = get_normed_bulks_for_adata_by_types(adatasub, genes_cndd=genes_cndd)
    print(df.shape)
    print(df.columns)
    dfs_dr.append(df)

In [None]:
genes_comm = adata_nr.var.index.values
for df in dfs_nr:
    genes_comm = np.intersect1d(genes_comm, df.index.values)
    print(genes_comm.shape)
for df in dfs_dr:
    genes_comm = np.intersect1d(genes_comm, df.index.values)
    print(genes_comm.shape)
    
for i in range(2):
    dfs_nr[i] = dfs_nr[i].loc[genes_comm]
    dfs_dr[i] = dfs_dr[i].loc[genes_comm]

In [None]:
tensor_nr = np.stack(dfs_nr, axis=2)
tensor_dr = np.stack(dfs_dr, axis=2)
tensor_nr.shape, tensor_dr.shape

# visualize

In [None]:
gsel = genes_comm
gidx = basicu.get_index_from_array(genes_comm, gsel)
gidx.shape

In [None]:
types_order = [
    'L2/3_A', 'L2/3_B', 'L2/3_C', 
    'L4_A', 'L4_B', 'L4_C', 
    'L5IT', 
    'L6IT_A', 'L6IT_B', 
    'L5PT_A', 'L5PT_B', 
    'L5NP', 
    'L6CT_A', 'L6CT_B', 'L6CT_C', 
    'L6b',
    
    'Pvalb_A', 'Pvalb_B', 'Pvalb_C', 'Pvalb_D', 
    'Sst_A', 'Sst_B', 'Sst_C', 'Sst_D', 'Sst_E', 
    'Vip_A', 'Vip_B', 'Vip_C',
    'Lamp5', 
    
    'Stac', 
    'Frem1', 
    
    'Astro_A', 'Astro_B', 
    'OD_A', 'OD_B', 'OD_C', 
    'OPC_A', 'OPC_B',
    'Micro', 
    'Endo', 'VLMC_A', 'VLMC_B', 
]

In [None]:
types_idx = basicu.get_index_from_array(nr_types, types_order)
types_idx

In [None]:
bigmat = np.hstack([
    tensor_nr[:,:,0][:,types_idx],
    tensor_nr[:,:,1][:,types_idx],
    tensor_dr[:,:,0][:,types_idx],
    tensor_dr[:,:,1][:,types_idx],
])

ncols = int(bigmat.shape[1]/4)
reidx = np.array([i+j*ncols for (i,j) in itertools.product(np.arange(ncols), np.arange(4))])
print(reidx)
bigmat = bigmat[:,reidx]

In [None]:
bigmatz = (bigmat-np.mean(bigmat, axis=1).reshape(-1,1))/np.std(bigmat, axis=1).reshape(-1,1)

In [None]:
toplot = pd.DataFrame(bigmatz[gidx], index=gsel, columns=np.repeat(types_order,4))# .iloc[:50]
nrow, ncol = toplot.shape

with sns.axes_style('ticks'):
    fig, ax = plt.subplots(figsize=(0.2*ncol,0.2*nrow))
    sns.heatmap(toplot, 
                cmap='coolwarm', 
                yticklabels=True, 
                xticklabels=4, 
                cbar_kws=dict(pad=0.01, location='top', orientation='horizontal', shrink=0.3, label='zscored mean log10(CPM+1)'),
                center=0,
                ax=ax
               )
    ax.vlines(np.arange(0,ncol,2), 0, len(gidx)+1, linewidth=0.5, color='gray')
    ax.vlines(np.arange(0,ncol,4), 0, len(gidx)+1, linewidth=1, color='gray')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=10)

    for num in np.arange(0, nrow, 10):
        ax.text(ncol, num, num, fontsize=10)
    for i, samp in enumerate(['NR1', 'NR2', 'DR1', 'DR2']):
        ax.text(0.5+i, -0.5, samp, rotation=90, ha='center', va='bottom', fontsize=10)

    powerplots.savefig_autodate(fig, "../results/merfish_v3_zscore_bigmatz.pdf")
    plt.show()

In [None]:
toplot2 = pd.DataFrame(bigmat[gidx], index=gsel, columns=np.repeat(types_order,4))# .iloc[:50]
nrow, ncol = toplot.shape

with sns.axes_style('ticks'):
    fig, ax = plt.subplots(figsize=(0.2*ncol,0.2*nrow))
    sns.heatmap(toplot2, 
                cmap='rocket', # _r', #coolwarm', 
                yticklabels=True, 
                xticklabels=4, 
                cbar_kws=dict(pad=0.01, location='top', orientation='horizontal', shrink=0.3, label='mean log10(CPM+1)'),
                ax=ax
               )
    ax.vlines(np.arange(0,ncol,2), 0, len(gidx)+1, linewidth=0.5, color='gray')
    ax.vlines(np.arange(0,ncol,4), 0, len(gidx)+1, linewidth=1, color='gray')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=10)

    for num in np.arange(0, nrow, 10):
        ax.text(ncol, num, num, fontsize=10)
    for i, samp in enumerate(['NR1', 'NR2', 'DR1', 'DR2']):
        ax.text(0.5+i, -0.5, samp, rotation=90, ha='center', va='bottom', fontsize=10)

    powerplots.savefig_autodate(fig, "../results/merfish_v3_bigmat.pdf")
    plt.show()

In [None]:
# save toplot
toplot

In [None]:
toplot.to_csv('../results/merfish_v3_zscore_bigmatz.csv')

# subsets - exc

In [None]:
types_order_sub = [
    'L2/3_A', 'L2/3_B', 'L2/3_C', 
    'L4_A', 'L4_B', 'L4_C', 
    'L5IT', 
    'L6IT_A', 'L6IT_B', 
    'L5PT_A', 'L5PT_B', 
    'L5NP', 
    'L6CT_A', 'L6CT_B', 'L6CT_C', 
    'L6b',
    
#     'Pvalb_A', 'Pvalb_B', 'Pvalb_C', 'Pvalb_D', 
#     'Sst_A', 'Sst_B', 'Sst_C', 'Sst_D', 'Sst_E', 
#     'Vip_A', 'Vip_B', 'Vip_C',
#     'Lamp5', 
    
#     'Stac', 
#     'Frem1', 
    
#     'Astro_A', 'Astro_B', 
#     'OD_A', 'OD_B', 'OD_C', 
#     'OPC_A', 'OPC_B',
#     'Micro', 
#     'Endo', 'VLMC_A', 'VLMC_B', 
]

In [None]:
types_idx = basicu.get_index_from_array(nr_types, types_order_sub)
types_idx

In [None]:
bigmat = np.hstack([
    tensor_nr[:,:,0][:,types_idx],
    tensor_nr[:,:,1][:,types_idx],
    tensor_dr[:,:,0][:,types_idx],
    tensor_dr[:,:,1][:,types_idx],
])

ncols = int(bigmat.shape[1]/4)
reidx = np.array([i+j*ncols for (i,j) in itertools.product(np.arange(ncols), np.arange(4))])
print(reidx)
bigmat = bigmat[:,reidx]

In [None]:
bigmatz = (bigmat-np.mean(bigmat, axis=1).reshape(-1,1))/np.std(bigmat, axis=1).reshape(-1,1)

In [None]:
toplot = pd.DataFrame(bigmatz[gidx], index=gsel, columns=np.repeat(types_order_sub,4)).iloc[:50]
nrow, ncol = toplot.shape

with sns.axes_style('ticks'):
    fig, ax = plt.subplots(figsize=(0.2*ncol,0.2*nrow))
    sns.heatmap(toplot, 
                cmap='coolwarm', 
                yticklabels=True, 
                xticklabels=4, 
                cbar_kws=dict(pad=0.01, location='top', orientation='horizontal', shrink=0.3, label='zscored mean log10(CPM+1)'),
                center=0,
                ax=ax
               )
    ax.vlines(np.arange(0,ncol,2), 0, len(gidx)+1, linewidth=0.5, color='gray')
    ax.vlines(np.arange(0,ncol,4), 0, len(gidx)+1, linewidth=1, color='gray')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=10)

    for num in np.arange(0, nrow, 10):
        ax.text(ncol, num, num, fontsize=10)
    for i, samp in enumerate(['NR1', 'NR2', 'DR1', 'DR2']):
        ax.text(0.5+i, -0.5, samp, rotation=90, ha='center', va='bottom', fontsize=10)

    powerplots.savefig_autodate(fig, "../results/merfish_v3_zscore_bigmatz_exc_demo.pdf")
    plt.show()

In [None]:
toplot = pd.DataFrame(bigmatz[gidx], index=gsel, columns=np.repeat(types_order_sub,4))# .iloc[:50]
nrow, ncol = toplot.shape

with sns.axes_style('ticks'):
    fig, ax = plt.subplots(figsize=(0.2*ncol,0.2*nrow))
    sns.heatmap(toplot, 
                cmap='coolwarm', 
                yticklabels=True, 
                xticklabels=4, 
                cbar_kws=dict(pad=0.01, location='top', orientation='horizontal', shrink=0.3, label='zscored mean log10(CPM+1)'),
                center=0,
                ax=ax
               )
    ax.vlines(np.arange(0,ncol,2), 0, len(gidx)+1, linewidth=0.5, color='gray')
    ax.vlines(np.arange(0,ncol,4), 0, len(gidx)+1, linewidth=1, color='gray')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=10)

    for num in np.arange(0, nrow, 10):
        ax.text(ncol, num, num, fontsize=10)
    for i, samp in enumerate(['NR1', 'NR2', 'DR1', 'DR2']):
        ax.text(0.5+i, -0.5, samp, rotation=90, ha='center', va='bottom', fontsize=10)

    powerplots.savefig_autodate(fig, "../results/merfish_v3_zscore_bigmatz_exc.pdf")
    plt.show()

In [None]:
toplot2 = pd.DataFrame(bigmat[gidx], index=gsel, columns=np.repeat(types_order_sub,4))# .iloc[:50]
nrow, ncol = toplot.shape

with sns.axes_style('ticks'):
    fig, ax = plt.subplots(figsize=(0.2*ncol,0.2*nrow))
    sns.heatmap(toplot2, 
                cmap='rocket', # _r', #coolwarm', 
                yticklabels=True, 
                xticklabels=4, 
                cbar_kws=dict(pad=0.01, location='top', orientation='horizontal', shrink=0.3, label='mean log10(CPM+1)'),
                ax=ax
               )
    ax.vlines(np.arange(0,ncol,2), 0, len(gidx)+1, linewidth=0.5, color='gray')
    ax.vlines(np.arange(0,ncol,4), 0, len(gidx)+1, linewidth=1, color='gray')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=10)

    for num in np.arange(0, nrow, 10):
        ax.text(ncol, num, num, fontsize=10)
    for i, samp in enumerate(['NR1', 'NR2', 'DR1', 'DR2']):
        ax.text(0.5+i, -0.5, samp, rotation=90, ha='center', va='bottom', fontsize=10)

    powerplots.savefig_autodate(fig, "../results/merfish_v3_bigmat_exc.pdf")
    plt.show()