In [None]:
%load_ext autoreload
%autoreload 2
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import os
import anndata as ad
mpl.rcParams['figure.dpi'] = 150
plt.rcParams['pdf.fonttype'] = 42

import seaborn as sns
import sys
from spatial_analysis import *
from plotting import *
from utils import *
sns.set_style('white')
import pandas as pd

In [None]:
def unbinarize_strings(A):
    A.var_names = [i.decode('ascii') for i in A.var_names]
    A.obs.index = [i.decode('ascii') for i in A.obs.index]
    for i in A.obs.columns:
        if A.obs[i].dtype != np.dtype('bool') and \
            A.obs[i].dtype != np.dtype('int64') and \
            A.obs[i].dtype != np.dtype('int32') and \
            A.obs[i].dtype != np.dtype('object_') and \
            A.obs[i].dtype != np.dtype('float64') and A.obs[i].dtype != np.dtype('float32'):
            if A.obs[i].dtype.is_dtype('category'):
                try:
                    A.obs[i] = [i.decode('ascii') for i in A.obs[i]]
                except Exception as e:
                    pass
    return A

In [None]:
adata = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_lps_ctrl_allages.h5ad")


In [None]:
adata = unbinarize_strings(adata)

In [None]:
adata.obs["smoothed_spatial_clust_annot_values"] = np.nan
for i in adata.obs.data_batch.unique():
    print(i)
    curr_adata = adata[adata.obs.data_batch==i]
    for j in curr_adata.obs.slice.unique():
        A_section = curr_adata[np.logical_and(curr_adata.obs.slice==j, curr_adata.obs.spatial_clust_annots.isin(['L2/3','L5','L6']))]
        A_section = cleanup_section(A_section,50)
        adata.obs.loc[A_section.obs.index, "smoothed_spatial_clust_annot_values"] = np.array(A_section.obs["smoothed_spatial_clust_annot_values"])
#adata.obs['spatial_clust_annots_value'] = adata.obs.smoothed_spatial_clust_annot_values
temp = adata.obs.smoothed_spatial_clust_annot_values
temp[np.isnan(temp)] = adata.obs.spatial_clust_annots_value[np.isnan(temp)]
#adata.obs.spatial_clust_annots = 
adata.obs.spatial_clust_annots_value = list(temp.copy())

In [None]:
adata_ctl_raw = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_adata_combined_harmony.h5ad")


In [None]:
adata_ctl_raw_merfish = adata_ctl_raw[adata_ctl_raw.obs.dtype=='merfish']

In [None]:
idx = [i+'-0' for i in adata_ctl_raw_merfish.obs.index]

In [None]:
adata.obs.loc[idx,"spatial_clust_annots_value"] = list(adata_ctl_raw_merfish.obs.spatial_clust_annots_value)
adata.obs.loc[idx,"spatial_clust_annots"] = list(adata_ctl_raw_merfish.obs.spatial_clust_annots )

In [None]:
adata.obs.loc[idx,"spatial_clust_annots_value"]

In [None]:
adata_ctl_raw_merfish.obs.spatial_clust_annots.shape

In [None]:
#adata_combined.obs.spatial_clust_annots = [spatial_clust_annots_values[i] if i in spatial_clust_annots_values else '' for i in adata_combined.obs.smoothed_spatial_clust_annot_values]

In [None]:
adata_lps = adata[adata.obs.cond=="lps"]
adata_ctl = adata[adata.obs.cond=="ctrl"]

In [None]:
adata_lps_raw = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/112921_merged_lps_merfish_with_doublet_umap_allages.h5ad")
adata_lps_raw = adata_lps_raw.raw.to_adata()

In [None]:
lps_idx = [i.split('-')[0] for i in adata_lps.obs.index]

In [None]:
adata_lps_raw = adata_lps_raw[lps_idx]
adata_lps_raw.obs = adata_lps.obs.copy()

In [None]:
sc.pp.scale(adata_lps_raw,max_value=10)

In [None]:
sc.pl.pca(adata_lps_raw,color='clust_annot_preds')

In [None]:
import bbknn
sc.tl.pca(adata_lps_raw,n_comps=30)


In [None]:
bbknn.bbknn(adata_lps_raw, 'data_batch')


In [None]:
sc.tl.umap(adata_lps_raw)

In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_lps_raw, color=['clust_annot_preds'],palette=clust_pals,ax=ax,legend_loc='bottom')
ax.axis('off')
ax.set_title('')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_celltype_umap.png",bbox_inches='tight',dpi=300)

In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_lps_raw, color=['age'], palette=age_pal,ax=ax,legend_loc='bottom',)
ax.axis('off')
ax.set_title('')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_age_umap.png",bbox_inches='tight',dpi=300)

In [None]:
celltype_colors, celltype_pals, label_colors, clust_pals = generate_palettes(adata, cell_type_key='cell_type_preds',clust_key='clust_annot_preds')

In [None]:
cond_palette = sns.color_palette("Set1", n_colors=2)
cond_palette.reverse()


In [None]:
age_colors = ['cornflowerblue','thistle','lightcoral']

f,ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata, color=['age'],palette=sns.color_palette(age_colors),size=0.1,legend_loc='bottom',ax=ax)
ax.set_title('')
ax.axis('off')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_lps_integration.png", bbox_inches='tight', dpi=300)

In [None]:
f,ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata, color=['cond'],palette=sns.color_palette(['g','m']),size=0.1,legend_loc='bottom',ax=ax)
ax.set_title('')
ax.axis('off')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_lps_integration.png", bbox_inches='tight', dpi=300)

In [None]:
#f,ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_ctl, color=['C4b','Il33','C3','Gfap','age','Ifit3'],palette=age_pal,size=0.5,legend_loc='bottom')


In [None]:
sc.pl.umap(adata_lps[adata_lps.obs.age=='4wk'], color=['C4b','Il33','C3','Gfap','age','Ifit3'],palette=age_pal,size=0.5,legend_loc='bottom')


In [None]:
sc.pl.umap(adata_lps[adata_lps.obs.age=='90wk'], color=['C4b','Il33','C3','Gfap','age','Ifit3'],palette=age_pal,size=0.5,legend_loc='bottom')


In [None]:
f,ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_ctl, color=['clust_annot'],palette=clust_pals,size=5,legend_loc='bottom',ax=ax)


In [None]:
f,ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata[adata.obs.cond=='ctrl'], color=['clust_annot_preds'],palette=clust_pals,size=5,legend_loc='bottom',ax=ax)
ax.set_title('')
ax.axis('off')

f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_ctrl_iclusts.png", bbox_inches='tight', dpi=200)

In [None]:
f,ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata[adata.obs.cond=='lps'], color=['clust_annot_preds'],palette=clust_pals,size=5,legend_loc='bottom',ax=ax)
ax.set_title('')
ax.axis('off')

f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_lps_iclusts.png", bbox_inches='tight', dpi=200)

In [None]:
# Train classifier on cell type labels for integrated data and create confusion matrix
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
all_preds = []
all_y_test = []
for i in range(5):
    print(i)
    X_train, X_test, y_train, y_test = train_test_split(adata_ctl.obsm['X_pca'], np.array(adata_ctl.obs.clust_annot), test_size=0.2, random_state=42)
    mdl = KNeighborsClassifier(n_jobs=-1).fit(X_train, y_train)
    preds = mdl.predict(X_test)
    all_preds.append(preds)
    all_y_test.append(y_test)

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
clusts = sorted(adata_ctl.obs.clust_annot.unique())
cmat = np.zeros((len(clusts), len(clusts)))
for i in range(3):
    cmat += confusion_matrix(all_preds[i], all_y_test[i],labels=clusts).astype(np.float64)
for i in range(cmat.shape[0]):
    cmat[i,:] = cmat[i,:]/cmat[i,:].sum()

In [None]:
label_colors = {i:label_colors[i] for i in clusts}

In [None]:
f, axes = plt.subplots(figsize=(10,10), nrows=2, ncols=2, gridspec_kw={'width_ratios':[1,10],'height_ratios':[10,1], 'wspace':0.01, 'hspace':0.01})
ax = axes[1,0]
ax.axis('off')
ax = axes[0,1]
ax.imshow(np.flipud(cmat),vmin=0,vmax=1,cmap=plt.cm.viridis,aspect='auto',interpolation='none',rasterized=True)
ax.axis('off')
ax = axes[0,0]
ax.set_xticks([])
ax.set_yticks(np.arange(len(label_colors)));
ax.set_yticklabels(clusts[::-1]);
sns.despine(ax=ax,bottom=True,left=True)

ax.imshow(np.expand_dims(np.arange(len(label_colors))[::-1],1),cmap=mpl.colors.ListedColormap(label_colors.values()))
#ax.axis('off')
ax = axes[1,1]
ax.imshow(np.expand_dims(np.arange(len(label_colors)),1).T,cmap=mpl.colors.ListedColormap(label_colors.values()))
#ax.axis('off')
ax.set_xticks(np.arange(len(label_colors)));
ax.set_xticklabels(clusts, rotation=90);
ax.set_yticks([])
sns.despine(ax=ax,bottom=True,left=True)
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS12_clust_ident_confusion.pdf",bbox_inches='tight',dpi=300)

In [None]:
sorted(clusts)

In [None]:
spatial_domains = ['Pia','L2/3', 'L5', 'L6','LatSept', 'CC', 'Striatum','Ventricle']
clust_order = [
 'ExN-L2/3-1',
 'ExN-L2/3-2',
 'ExN-L5-1',
 'ExN-L5-2',
 'ExN-L5-3',
 'ExN-L6-1',
 'ExN-L6-2',
 'ExN-L6-3',
 'ExN-Olf',
 'InN-Olf-1',
 'InN-Olf-2',

 'InN-Vip',

 'InN-Lamp5',

 'InN-Pvalb-1',
 'InN-Pvalb-2',
 'InN-Pvalb-3',
 'InN-Sst-1',
 'InN-Sst-2',
 'InN-Calb2-1',
 'InN-Calb2-2',
 'InN-Chat',
 'InN-Lhx6',

'MSN-D1-1',
 'MSN-D1-2',
 'MSN-D2',
 'OPC',
 'Olig-1',
 'Olig-2',
 'Olig-3',

'Astro-1',
 'Astro-2',
 'Vlmc',
 'Peri-1',
 'Peri-2',
 'Endo-1',
 'Endo-2',
 'Endo-3',
 'Epen',

 'Micro-1',
 'Micro-2',
 'Micro-3',
 'Macro',
 'T cell',
]

seg_cmap = mpl.colors.ListedColormap([ 'gold','tan', 'peru', 'maroon', 'steelblue','gray',  'purple', 'darkkhaki'])

In [None]:
# quantify celltypes as a function of region
young_clusts, young_counts = plot_clust_spatial_enrichment(adata_lps[adata_lps.obs.age=='4wk'],vmax=1,clust_key='clust_annot_preds',uniq_clusts=clust_order,seg_cmap=seg_cmap,label_colors=label_colors)
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS10_lps_4wk_celltypecomp.pdf",bbox_inches='tight')

In [None]:
# quantify celltypes as a function of region
young_clusts, young_counts = plot_clust_spatial_enrichment(adata_lps[adata_lps.obs.age=='24wk'],vmax=1,clust_key='clust_annot_preds',uniq_clusts=clust_order,seg_cmap=seg_cmap,label_colors=label_colors)
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS10_lps_24wk_celltypecomp.pdf",bbox_inches='tight')

In [None]:
# quantify celltypes as a function of region
old_clusts_lps, _ = plot_clust_spatial_enrichment(adata_lps[adata_lps.obs.age=='90wk'],vmax=1,clust_key='clust_annot_preds',uniq_clusts=clust_order,seg_cmap=seg_cmap, label_colors=label_colors)
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS10_lps_90wk_celltypecomp.pdf",bbox_inches='tight')

In [None]:
# quantify celltypes as a function of region
young_clusts, young_counts = plot_clust_spatial_enrichment(adata_ctl[adata_ctl.obs.age=='4wk'],vmax=1,clust_key='clust_annot_preds',uniq_clusts=clust_order,seg_cmap=seg_cmap,label_colors=label_colors)


In [None]:
old_clusts_ctl, _ = plot_clust_spatial_enrichment(adata_ctl[adata_ctl.obs.age=='90wk'],vmax=1,clust_key='clust_annot_preds',uniq_clusts=clust_order,seg_cmap=seg_cmap, label_colors=label_colors)


In [None]:
# plot cell type fractsion
def count_celltypes(A, age, key='cell_type_preds'):
    counts = {
    "Inhibitory":0,
    "Excitatory":0,
    "MSN":0,
    "Non-neuronal":0
    }
    cell_types = A[A.obs.age==age].obs[key]
    for i in cell_types:
        if "ExN" in i:
            counts["Excitatory"] += 1
        elif "InN" in i:
            counts["Inhibitory"] += 1
        elif "MSN" in i:
            counts["MSN"] += 1
        else:
            counts["Non-neuronal"] += 1
    return pd.DataFrame({ 'counts':list(counts.values())},index=list(counts.keys()))

def simplify_celltypes(A, age, key='cell_type_preds'):
    celltypes = []
    for i in A[A.obs.age==age].obs[key]:
        if "ExN" in i:
            celltypes.append("ExN")
        elif "InN" in i:
            celltypes.append("InN")
        elif "MSN" in i:
            celltypes.append("MSN")
        #else:
        #    celltypes.append("Non-neuronal")

    return pd.DataFrame({'cell_type':celltypes, 'count':np.ones(len(celltypes)), 'age':age})

def simplify_clusts(A, age, key='cell_type_preds'):
    celltypes = list(A[A.obs.age==age].obs[key])
    return pd.DataFrame({'cell_type':celltypes, 'count':np.ones(len(celltypes)), 'age':age})

In [None]:
young_ct = simplify_celltypes(adata_lps, '4wk')
med_ct = simplify_celltypes(adata_lps, '24wk')
old_ct = simplify_celltypes(adata_lps, '90wk')
combined_ct = pd.concat([young_ct, med_ct, old_ct])

In [None]:
## Pie charts
import pandas as pd
cell_types_young = adata_lps[adata_lps.obs.age=='4wk'].obs.cell_type_preds
cell_types_med = adata_lps[adata_lps.obs.age=='24wk'].obs.cell_type_preds
cell_types_old = adata_lps[adata_lps.obs.age=='90wk'].obs.cell_type_preds


In [None]:
f, axes = plt.subplots(figsize=(8,24), nrows=1, ncols=3, gridspec_kw={'wspace':0.1})
young_ct_agg = young_ct.value_counts().reset_index()
axes[0].pie(young_ct_agg[0],colors=[celltype_colors[i] for i in young_ct_agg.cell_type], labels=young_ct_agg.cell_type);

med_ct_agg = med_ct.value_counts().reset_index()
axes[1].pie(med_ct_agg[0],colors=[celltype_colors[i] for i in med_ct_agg.cell_type], labels=med_ct_agg.cell_type);

old_ct_agg = old_ct.value_counts().reset_index()
axes[2].pie(old_ct_agg[0],colors=[celltype_colors[i] for i in old_ct_agg.cell_type], labels=old_ct_agg.cell_type);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS7_celltype_comp_neuronal.pdf", bbox_inches='tight')

In [None]:
adata_nonneuronal = adata_lps[~adata_lps.obs.cell_type_preds.isin(['ExN',"InN","MSN"])]
young_ct = simplify_clusts(adata_nonneuronal, '4wk')
med_ct = simplify_clusts(adata_nonneuronal, '24wk')
old_ct = simplify_clusts(adata_nonneuronal, '90wk')
combined_ct = pd.concat([young_ct, med_ct, old_ct])

In [None]:
f, axes = plt.subplots(figsize=(8,24), nrows=1, ncols=3, gridspec_kw={'wspace':0.1})
young_ct_agg = young_ct.value_counts().reset_index()
axes[0].pie(young_ct_agg[0],colors=[celltype_colors[i] for i in young_ct_agg.cell_type], labels=young_ct_agg.cell_type);

med_ct_agg = med_ct.value_counts().reset_index()
axes[1].pie(med_ct_agg[0],colors=[celltype_colors[i] for i in med_ct_agg.cell_type], labels=med_ct_agg.cell_type);

old_ct_agg = old_ct.value_counts().reset_index()
axes[2].pie(old_ct_agg[0],colors=[celltype_colors[i] for i in old_ct_agg.cell_type], labels=old_ct_agg.cell_type);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS7_celltype_comp_nonneuronal.pdf", bbox_inches='tight')

In [None]:
def compute_frac_per_clust(A,clust_order):
    n_bins = 100
    frac_per_age = np.zeros((len(clust_order), n_bins))
    #frac4 = total_4wk/(total_90wk+total_24wk+total_4wk)
    #frac24 = total_24wk/(total_90wk+total_24wk+total_4wk)
    #frac90 = total_90wk/(total_90wk+total_24wk+total_4wk)

    total_90wk = np.sum(A.obs.age=='90wk')
    total_24wk = np.sum(A.obs.age=='24wk')
    total_4wk = np.sum(A.obs.age=='4wk')

    for n,c in enumerate(clust_order):
        curr_clust = A[A.obs.clust_annot_preds==c]
        # count fraction of total cells that are in this area for each age
        curr4 = np.sum(curr_clust.obs.age == "4wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "4wk")
        curr24 = np.sum(curr_clust.obs.age == "24wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "24wk")
        curr90 = np.sum(curr_clust.obs.age == "90wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "90wk")

        # scale based on the relative number of cells in each age in the total experiment
        denom = total_4wk + total_24wk + total_90wk
        curr4 /= total_4wk
        curr24 /= total_24wk
        curr90 /= total_90wk
        denom = curr4+curr24+curr90
        curr4 /= denom
        curr24 /= denom
        curr90 /= denom
        nbins90 = int(round(n_bins*curr90))
        nbins24 = int(round(n_bins*curr24))
        #print(n, c, curr4, curr24, curr90)
        frac_per_age[n,:] = np.hstack([2*np.ones(nbins90),
                                       np.ones(nbins24), 
                                       np.zeros(n_bins-nbins90-nbins24)])
    return frac_per_age

In [None]:
# make SI plot of fraction of cells per age
# compute fraction of each cluster per age and per brain area
frac_per_age_lps = compute_frac_per_clust(adata_lps, clust_order)
frac_per_age_ctrl = compute_frac_per_clust(adata_ctl, clust_order)

In [None]:
age_colors = ['cornflowerblue','thistle','lightcoral']
age_cmap = mpl.colors.ListedColormap(age_colors)
age_pal = sns.color_palette(age_colors)

In [None]:
curr_cols = mpl.colors.ListedColormap([label_colors[c] for c in clust_order])
f,axes=plt.subplots(nrows=3, ncols=1, gridspec_kw={'hspace':0.5, 'height_ratios':[1,10,10]},figsize=(10,5))
ax = axes[0]
ax.imshow(np.expand_dims(np.arange(len(label_colors.keys())),1).T, cmap=curr_cols,aspect='auto',interpolation='none',rasterized=True)
ax.set_yticks([])
ax.set_xticks(np.arange(len(clust_order)))
ax.set_xticklabels(clust_order,rotation=90)
ax.xaxis.set_tick_params(labeltop=True)
ax.xaxis.set_tick_params(labelbottom=False)
sns.despine(ax=ax, bottom=True, left=True)
ax = axes[1]
ax.imshow(frac_per_age_ctrl.T, vmin=0,vmax=2,aspect='auto',interpolation='none', cmap=age_cmap,rasterized=True)
ax.set_yticklabels([])
ax.set_xticks([])
#ax.set_xticklabels(lbl_order,rotation=90); #[str(np.sum(adata.obs.clust_label==i)) + " " + i for i in lbl_order])
ax.axhline(33,color='w',linestyle='--')
ax.axhline(66,color='w',linestyle='--')
sns.despine(ax=ax, left=True)
ax.set_xticks(np.arange(len(clust_order)))
ax.set_xticklabels([np.sum(adata_ctl.obs.clust_annot_preds==i) for i in clust_order],rotation=90, fontsize=7)
ax.set_ylabel('- LPS')

ax = axes[2]
ax.imshow(frac_per_age_lps.T, vmin=0,vmax=2,aspect='auto',interpolation='none', cmap=age_cmap,rasterized=True)
ax.set_yticklabels([])
ax.set_xticks([])
#ax.set_xticklabels(lbl_order,rotation=90); #[str(np.sum(adata.obs.clust_label==i)) + " " + i for i in lbl_order])
ax.axhline(33,color='w',linestyle='--')
ax.axhline(66,color='w',linestyle='--')
sns.despine(ax=ax, left=True)
ax.set_ylabel('+ LPS')
ax.set_xticks(np.arange(len(clust_order)))
ax.set_xticklabels([np.sum(adata_lps.obs.clust_annot_preds==i) for i in clust_order],rotation=90, fontsize=7);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS9_ctl_lps_cellcomp.pdf",bbox_inches='tight',dpi=200)

In [None]:
# show per-batch number of cells for certain cell types
import pandas as pd
adata_obs = adata_lps.obs.copy()
#adata_obs = adata_obs
clust_names = sorted(adata_obs.clust_annot_preds.unique())
ct_counts = []
ages = []
clusts = []
for k in clust_names:
    for i in ['4wk','24wk','90wk']:
        curr_obs = adata_obs[adata_obs.age==i]
        for j in curr_obs.data_batch.unique():
            temp = curr_obs[curr_obs.data_batch == j]
            ct_counts.append(100*np.sum(temp.clust_annot_preds==k)/temp.shape[0])
            ages.append(i)
            clusts.append(k)
            
counts = pd.DataFrame({'count':ct_counts, 'age': ages, 'clust': clusts})
f = plt.figure(figsize=(20,20))
gs = plt.GridSpec(nrows=7, ncols=7, wspace=0.5,hspace=0.5)
for n,i in enumerate(clust_names):
    curr_counts = counts[counts.clust==i]
    ax = plt.subplot(gs[n])
    sns.barplot(x='age',y='count',data=curr_counts,ax=ax, palette=sns.color_palette(age_colors),linewidth=0, errwidth=0,zorder=0)

    sns.scatterplot(x='age',y='count',data=curr_counts,ax=ax,color='k',zorder=1,linewidth=1)

    sns.despine(ax=ax)
    ax.set_ylabel('')
    ax.set_title(i)
f.savefig('/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS11_lps_perbatch_counts.pdf',bbox_inches='tight', dpi=200)

In [None]:
# get normalized number of cells per area

In [None]:
# Differential expression between ctl/LPS in same age for each cell type

In [None]:
adata_raw = adata.raw.to_adata()

In [None]:
adata_lps = adata[adata_raw.obs.cond=="lps"]
adata_ctl = adata[adata_raw.obs.cond=="ctrl"]

In [None]:
from scipy.stats import mannwhitneyu


In [None]:
#for i in range(lps.shape[1]):
#    lps[:,i] = zscore(lps[:,i])
    
#for j in range(ctl.shape[1]):
#    ctl[:,i] = zscore(ctl[:,i])

In [None]:
def compute_expr_diff(a,b,ct,age):
    pvals = []
    fcs = []
    frac1 = []
    frac2 = []
    avg1 = []
    avg2 = []
    for i in range(lps.shape[1]):
        pvals.append(mannwhitneyu(lps[:,i], ctl[:,i])[1])
        fcs.append(np.log2(a[:,i].mean(0)/b[:,i].mean(0)))
        frac1.append(np.sum(a[:,i]>0.1)/a.shape[0])
        frac2.append(np.sum(a[:,i]>0.1)/b.shape[0])
        avg1.append(a[:,i].mean(0))
        avg2.append(b[:,i].mean(0))
    pvals = np.array(pvals)
    fcs = np.array(fcs)
    #pvals = multipletests(pvals, method='fdr_bh')[1]
    fcs[np.isnan(fcs)] = 0
    res = pd.DataFrame({'gene_names':adata.var_names, 'pval':pvals, 'fc':fcs, 'frac1':frac1, 'frac2':frac2, 'avg1':avg1, 'avg2':avg2, 'cell_type':ct,'age':age})
    return res

In [None]:
cell_types = [
  'ExN',
  'InN',
   'MSN',
 'OPC',
 'Olig',

'Astro',
 'Vlmc',
 'Peri',
 'Endo',
 'Epen',

 'Micro',
 'Macro',
]

all_res = []
for ct in cell_types:
    for age in ['4wk','24wk','90wk']:
        try:
            print(ct,age)
            lps = adata_lps[np.logical_and(adata_lps.obs.cell_type_preds==ct,adata_lps.obs.age==age)].X.copy()
            ctl = adata_ctl[np.logical_and(adata_ctl.obs.cell_type_preds==ct,adata_ctl.obs.age==age)].X.copy()
            all_res.append(compute_expr_diff(lps, ctl,ct,age))
        except Exception as e:
            pass
            #all_res.append()
all_res_unfilt = pd.concat(all_res)
all_res_unfilt['qval'] = multipletests(all_res_unfilt['pval'],method='fdr_bh')[1]
all_res = all_res_unfilt[np.logical_and(all_res_unfilt.qval<0.01, np.abs(all_res_unfilt.fc)>1)]
all_res = all_res[np.logical_or(all_res.frac1>0.1, all_res.frac2>0.1)]

In [None]:
all_res_lps = []
for ct in cell_types:
    try:
        print(ct,age)
        lps = adata_lps[np.logical_and(adata_lps.obs.cell_type_preds= =ct,adata_lps.obs.age=='90wk')].X.copy()
        ctl = adata_lps[np.logical_and(adata_lps.obs.cell_type_preds==ct,adata_lps.obs.age=='4wk')].X.copy()
        all_res_lps.append(compute_expr_diff(lps, ctl,ct,'90wk vs 4wk'))
    except Exception as e:
        pass
        #all_res.append()
all_res_lps_unfilt = pd.concat(all_res_lps)
all_res_lps_unfilt['qval'] = multipletests(all_res_lps_unfilt['pval'],method='fdr_bh')[1]
all_res_lps = all_res_lps_unfilt[np.logical_and(all_res_lps_unfilt.qval<0.01, np.abs(all_res_lps_unfilt.fc)>1)]
all_res_lps = all_res_lps[np.logical_or(all_res_lps.frac1>0.1, all_res_lps.frac2>0.1)]

In [None]:
all_res_ctl = []
for ct in cell_types:
    try:
        print(ct,age)
        lps = adata_ctl[np.logical_and(adata_ctl.obs.cell_type_preds==ct,adata_ctl.obs.age=='90wk')].X.copy()
        ctl = adata_ctl[np.logical_and(adata_ctl.obs.cell_type_preds==ct,adata_ctl.obs.age=='4wk')].X.copy()
        all_res_ctl.append(compute_expr_diff(lps, ctl,ct,'90wk vs 4wk'))
    except Exception as e:
        pass
        #all_res.append()
all_res_ctl_unfilt = pd.concat(all_res_ctl)
all_res_ctl_unfilt['qval'] = multipletests(all_res_ctl_unfilt['pval'],method='fdr_bh')[1]
all_res_ctl = all_res_ctl_unfilt[np.logical_and(all_res_ctl_unfilt.qval<0.01, np.abs(all_res_ctl_unfilt.fc)>1)]
all_res_ctl = all_res_ctl[np.logical_or(all_res_ctl.frac1>0.1, all_res_ctl.frac2>0.1)]

In [None]:
# create summary tables of numbers of differentially expressed genes
upreg_counts = []
downreg_counts = []
cond = []
cell_type = []
for i in all_res_ctl.cell_type.unique():
    upreg = all_res_ctl[all_res_ctl.fc>=0]
    downreg = all_res_ctl[all_res_ctl.fc<0]
    upreg_counts.append(np.sum(upreg.cell_type==i))
    downreg_counts.append(-np.sum(downreg.cell_type==i))
    cell_type.append(i)
    cond.append("ctl")
    
for i in all_res.cell_type.unique():
    upreg = all_res_lps[all_res_lps.fc>=0]
    downreg = all_res_lps[all_res_lps.fc<0]
    upreg_counts.append(np.sum(upreg.cell_type==i))
    downreg_counts.append(-np.sum(downreg.cell_type==i))
    cell_type.append(i)
    cond.append("lps")


In [None]:
lps_age_comparison = pd.DataFrame({'upreg':upreg_counts, 'downreg':downreg_counts, 'cond':cond, 'cell_type':cell_type})

In [None]:
f, ax= plt.subplots()
sns.barplot(x='cell_type', y='upreg', data=lps_age_comparison, hue='cond', order=cell_types,ax=ax)
sns.barplot(x='cell_type', y='downreg', data=lps_age_comparison, hue='cond', order=cell_types,ax=ax)

sns.despine()
plt.ylim([-60,60])

In [None]:
# create summary tables of numbers of differentially expressed genes
upreg_counts = []
downreg_counts = []
age = []
cell_type = []
for i in all_res_ctl.cell_type.unique():
    for j in ['4wk','24wk','90wk']:
        curr = all_res[all_res.age==j]
        upreg = curr[curr.fc>=0]
        downreg = curr[curr.fc<0]
        upreg_counts.append(np.sum(upreg.cell_type==i))
        downreg_counts.append(-np.sum(downreg.cell_type==i))
        cell_type.append(i)
        age.append(j)
    


In [None]:
lps_ctl_comparison = pd.DataFrame({'upreg':upreg_counts, 'downreg':downreg_counts, 'age':age, 'cell_type':cell_type})

In [None]:
f,ax=plt.subplots()
sns.barplot(x='cell_type', y='upreg', data=lps_ctl_comparison, hue='age', order=cell_types, palette=age_pal,ax=ax)
sns.barplot(x='cell_type', y='downreg', data=lps_ctl_comparison, hue='age', order=cell_types, palette=age_pal,ax=ax)

sns.despine()
plt.ylim([-80,80])

In [None]:
# get all genes that are differentially expressed
all_de_genes = all_res.gene_names.unique()
fc_mat = np.zeros((len(all_res_unfilt.cell_type.unique()), len(genes_to_plot)))
for n,i in enumerate(genes_to_plot):
    for k,j in enumerate(cell_types):
        curr = all_res_unfilt[all_res_unfilt.age=='4wk']
        curr = curr[np.logical_and(curr.cell_type==j, curr.gene_names==i)]
        fc_mat[k,n] = curr.fc.values[0]
        pval_mat[k,n] = curr.qval.values[0]


In [None]:
genes_to_plot = ['Sparc',"Sst","Vip", "Cux2","Cdkn1a","Cxcl10","Il18","Tnf","Il1b","Il6","Ifng","C4b","C3","Gfap","Il33",'Nfib','Serpina3n','Ifit3','Xdh','Nfkbia','Ifitm3','Hif3a']

In [None]:
# plot fold change in expression between young/old within LPS
fc_mat = np.zeros((len(all_res_unfilt.cell_type.unique()), len(genes_to_plot)))
pval_mat = np.zeros((len(all_res_unfilt.cell_type.unique()), len(genes_to_plot)))
for n,i in enumerate(genes_to_plot):
    for k,j in enumerate(cell_types):
        curr = all_res_unfilt[all_res_unfilt.age=='4wk']
        curr = curr[np.logical_and(curr.cell_type==j, curr.gene_names==i)]
        fc_mat[k,n] = curr.fc.values[0]
        pval_mat[k,n] = curr.qval.values[0]
pval_mat[np.isnan(pval_mat)] = 0
pval_mat[pval_mat<1e-20] = 1e-20

In [None]:
f,ax = plt.subplots()
for i in range(fc_mat.shape[0]):
    for j in range(fc_mat.shape[1]):
        if pval_mat[i,j] < 1e-10:
            ax.scatter(i,j,s=5*-np.log10(pval_mat[i,j]),c=fc_mat[i,j],vmin=-10,vmax=10,cmap=plt.cm.seismic,edgecolor='k',linewidths=0.5)
        else:
            ax.scatter(i,j,s=5*-np.log10(pval_mat[i,j]),c=fc_mat[i,j],vmin=-10,vmax=10,cmap=plt.cm.seismic)
ax.set_xticks(np.arange(fc_mat.shape[0]));
ax.set_yticks(np.arange(fc_mat.shape[1]));
ax.set_xticklabels(cell_types);
ax.set_yticklabels(genes_to_plot);

In [None]:
# plot fold change in expression between young/old within LPS
fc_mat_lps_age = np.zeros((len(all_res_lps.cell_type.unique()), len(genes_to_plot)))
pval_mat_lps_age = np.zeros((len(all_res_lps.cell_type.unique()), len(genes_to_plot)))
for n,i in enumerate(genes_to_plot):
    for k,j in enumerate(cell_types):
        curr = all_res_lps_unfilt[np.logical_and(all_res_lps_unfilt.cell_type==j, all_res_lps_unfilt.gene_names==i)]
        fc_mat_lps_age[k,n] = curr.fc.values[0]
        pval_mat_lps_age[k,n] = curr.qval.values[0]

In [None]:
pval_mat_lps_age[np.isnan(pval_mat_lps_age)] = 0
pval_mat_lps_age[pval_mat_lps_age<1e-20] = 1e-20

In [None]:
f,ax = plt.subplots()
for i in range(fc_mat_lps_age.shape[0]):
    for j in range(fc_mat_lps_age.shape[1]):
        if pval_mat_lps_age[i,j] < 0.01:
            ax.scatter(i,j,s=5*-np.log10(pval_mat_lps_age[i,j]),c=fc_mat_lps_age[i,j],vmin=-10,vmax=10,cmap=plt.cm.seismic,edgecolor='k',linewidths=0.5)
        else:
            ax.scatter(i,j,s=5*-np.log10(pval_mat_lps_age[i,j]),c=fc_mat_lps_age[i,j],vmin=-10,vmax=10,cmap=plt.cm.seismic,linewidths=0.5)
ax.set_xticks(np.arange(fc_mat_lps_age.shape[0]));
ax.set_yticks(np.arange(fc_mat_lps_age.shape[1]));
ax.set_xticklabels(cell_types);
ax.set_yticklabels(genes_to_plot);

# Build a statistical model for each cell type that separates age and LPS effects


In [None]:
from de import *

In [None]:
adata.obs['log_umi'] = np.log(adata.obs.total_counts)

In [None]:
curr_adata = adata[adata.obs.cond=='ctrl']
res_age_ctrl = run_glm_de_age_merfish(curr_adata, grouping='cell_type_preds', family='ols')


In [None]:
res_age_ctrl = pd.concat([i for i in res_age_ctrl.values()])

In [None]:
curr_adata = adata[adata.obs.cond=='lps']
res_age_lps = run_glm_de_age_merfish(curr_adata, grouping='cell_type_preds', family='ols')


In [None]:
res_age_lps = pd.concat([i for i in res_age_lps.values()])

In [None]:
res_age_lps.to_csv("res_age_lps_major_nologumi_V3.csv")

In [None]:
res_age_ctrl.to_csv("res_age_ctrl_major_nologumi_V3.csv")

In [None]:
# run model for just LPS in 4 wk condition
curr_adata = adata[adata.obs.age=='4wk']
res_lps = run_glm_de_age_merfish(curr_adata, grouping='cell_type_preds',obs_name='cond', comp_name="T.lps", family='ols')
all_de_lps = pd.concat([i for i in res_lps.values()])

In [None]:
# run model for just age in control condition
curr_adata = adata[np.logical_and(adata.obs.cond=="ctrl", adata.obs.age.isin(['4wk','90wk']))]
res_age = run_glm_de_age_merfish(curr_adata, grouping='cell_type_preds', family='ols')
all_de_age = pd.concat([i for i in res_age.values()])

In [None]:
res_age_ctrl = pd.read_csv("res_age_ctrl_major_nologumi_V3.csv")
res_age_lps = pd.read_csv("res_age_lps_major_nologumi_V3.csv")

In [None]:
all_de_age.to_csv("all_de_ageonly_minorcelltypes_nologumi_V3.csv")
all_de_lps.to_csv("all_de_lpsonly_minorcelltypes_nologumi_V3.csv")


In [None]:
all_de_age = pd.read_csv("all_de_ageonly_minorcelltypes_nologumi_V3.csv")
all_de_lps = pd.read_csv("all_de_lpsonly_minorcelltypes_nologumi_V3.csv")


In [None]:
de_thresh = np.log(2) # originally 2
qval_thresh = 1e-4
all_de_age['coef_age'] = all_de_age.coef#np.log2(np.exp(all_de_age.coef))
all_de_lps['coef_lps'] = all_de_lps.coef#np.log2(np.exp(all_de_lps.coef))

de_genes_age = list(all_de_age[np.logical_and(np.abs(all_de_age.coef_age) > de_thresh, all_de_age.qval<qval_thresh)].gene.unique())
de_genes_lps = list(all_de_lps[np.logical_and(np.abs(all_de_lps.coef_lps) > de_thresh, all_de_lps.qval<qval_thresh)].gene.unique())
combined_genes = list(set(de_genes_age + de_genes_lps))

In [None]:
len(de_genes_lps)

In [None]:
coef_age = np.zeros((len(all_de_age.cell_type.unique()), len(combined_genes)))
pval_age = np.zeros_like(coef_age)
coef_age_lps = np.zeros_like(coef_age)
coef_age_ctrl = np.zeros_like(coef_age)
age_count = np.zeros(coef_age.shape[0])
lps_count = np.zeros(coef_age.shape[0])
both_count = np.zeros(coef_age.shape[0])
for i,ct in enumerate(all_de_age.cell_type.unique()):
    curr_de_age = all_de_age[all_de_age.cell_type==ct]
    curr_de_lps = all_de_lps[all_de_lps.cell_type==ct]
    curr_age_signif = curr_de_age[np.logical_and(np.abs(curr_de_age.coef_age) > de_thresh, curr_de_age.qval<qval_thresh)]
    curr_lps_signif = curr_de_lps[np.logical_and(np.abs(curr_de_lps.coef_lps) > de_thresh, curr_de_lps.qval<qval_thresh)]
    print(ct, list(curr_lps_signif.gene), list(curr_age_signif.gene))
    both_count[i] = len(np.intersect1d(curr_age_signif.gene, curr_lps_signif.gene))
    age_count[i] = curr_age_signif.shape[0] - both_count[i]
    lps_count[i] = curr_lps_signif.shape[0] - both_count[i]
    
for i,ct in enumerate(all_de_age.cell_type.unique()):
    for j,g in enumerate(combined_genes):
        curr_de = all_de_age[np.logical_and(all_de_age.cell_type==ct, all_de_age.gene==g)]
        coef_age[i,j] = curr_de.coef_age.values[0]
        pval_age[i,j] = curr_de.qval.values[0]
        
for i,ct in enumerate(all_de_age.cell_type.unique()):
    for j,g in enumerate(combined_genes):
        curr_de = res_age_ctrl[np.logical_and(res_age_ctrl.cell_type==ct, res_age_ctrl.gene==g)]
        coef_age_ctrl[i,j] = curr_de.coef.values[0]
        curr_de = res_age_lps[np.logical_and(res_age_lps.cell_type==ct, res_age_lps.gene==g)]
        coef_age_lps[i,j] = curr_de.coef.values[0]

coef_lps = np.zeros((len(all_de_age.cell_type.unique()), len(combined_genes)))
pval_lps = np.zeros_like(coef_lps)
for i,ct in enumerate(all_de_age.cell_type.unique()):
    for j,g in enumerate(combined_genes):
        curr_de = all_de_lps[np.logical_and(all_de_lps.cell_type==ct, all_de_lps.gene==g)]
        coef_lps[i,j] = curr_de.coef_lps.values[0]
        pval_lps[i,j] = curr_de.qval.values[0]
        
#row_idx, dn_row = order_values(coef_lps, return_linkage=True, metric='cosine')
#col_idx = order_values(coef_lps.T)
col_idx, dn_col = order_values(np.hstack((coef_lps.T,coef_age.T)), return_linkage=True, metric='cosine')

coef_lps = coef_lps[:,col_idx]#[row_idx,:][:,col_idx]
coef_age = coef_age[:,col_idx]#[row_idx,:][:,col_idx]



In [None]:
ct_uniq = all_de_age.cell_type.unique()

In [None]:
row_order = ["ExN", "InN", "MSN", "Olig",  "OPC", "Astro", "Epen",  "Vlmc", "Endo", "Peri", "Micro",  "Macro", "T cell"]
row_idx = np.array([np.argwhere(ct_uniq==i)[0] for i in row_order]).flatten()

In [None]:
row_idx

In [None]:
plt.scatter(coef_age_lps.flatten(), coef_age_ctrl.flatten())

In [None]:
plt.imshow(coef_age_lps[:,col_idx][row_idx,:]-coef_age_ctrl[:,col_idx][row_idx,:],vmin=-2,vmax=2,cmap=plt.cm.seismic)

In [None]:
hc.dendrogram(dn_col);

In [None]:
hc.dendrogram(dn_row);

In [None]:
plt.hist(coef_lps.flatten(),100);


In [None]:
combined_coef_good = np.hstack((all_de_age[np.logical_and(all_de_age.coef_age>0,all_de_age.qval<qval_thresh)].coef_age,
                               all_de_lps[np.logical_and(all_de_lps.coef_lps>0, all_de_lps.qval<qval_thresh)].coef_lps))
combined_coef_bad = np.hstack((all_de_age[np.logical_and(all_de_age.coef_age>0,all_de_age.qval>=qval_thresh)].coef_age,
                               all_de_lps[np.logical_and(all_de_lps.coef_lps>0, all_de_lps.qval>=qval_thresh)].coef_lps))
thresh = np.mean(combined_coef_good)-np.mean(combined_coef_bad)

In [None]:
thresh = np.log(1.25)

In [None]:
diff_coef = np.abs(coef_lps) - np.abs(coef_age)
coef_signif = np.zeros_like(coef_lps)
diff_coef = zscore(diff_coef.flatten()).reshape(coef_signif.shape)
#thresh = 0.1 #np.log(1.2)
diff_thresh = 1
for i in range(diff_coef.shape[0]):
    for j in range(diff_coef.shape[1]): 
        #if np.abs(diff_coef[i,j]) > diff_thresh and np.abs(coef_lps[i,j]) > thresh:
        #    coef_signif[i,j] = 1
        #elif np.abs(diff_coef[i,j]) > diff_thresh and np.abs(coef_age[i,j]) > thresh:
        #    coef_signif[i,j] = 2
        #elif (np.abs(diff_coef[i,j]) < diff_thresh) and (np.abs(coef_lps[i,j]) > thresh and np.abs(coef_age[i,j]) > thresh):
        #    coef_signif[i,j] = 3
        if  coef_lps[i,j] > thresh and coef_age[i,j] > thresh:
            #if pval_lps[i,j] < 0.01 or pval_age[i,j] < 0.01:
            coef_signif[i,j] = 3
        elif coef_age[i,j] > thresh:
            #if pval_age[i,j] < 0.01:
            coef_signif[i,j] = 2
        elif coef_lps[i,j] > thresh:
            #if pval_lps[i,j] < 0.01:
            coef_signif[i,j] = 1
        
#row_idx = np.argsort(np.sum(coef_signif==1,1))

plt.imshow(coef_signif[row_idx,:], cmap=mpl.colors.ListedColormap(['w','m','g','k']),vmin=0,vmax=3)

In [None]:
fraction_signif = np.vstack((np.sum(coef_signif==1,1), np.sum(coef_signif==3,1), np.sum(coef_signif==2,1) )).T.astype(np.float64)

In [None]:
for i in range(fraction_signif.shape[0]):
    fraction_signif[i,:] = fraction_signif[i,:]/fraction_signif[i,:].sum()
    #curr_sum = float(fraction_signif[i,:].sum())
    #for j in range(fraction_signif.shape[1]):
        #fraction_signif[i,j] /= curr_sum

In [None]:
plt.imshow(fraction_signif[row_idx,:],vmin=0,vmax=1,cmap=plt.cm.viridis)

In [None]:
coef_signif = coef_signif[row_idx,:]

In [None]:
coef_lps = coef_lps[row_idx,:]
coef_age = coef_age[row_idx,:]

In [None]:
f,axes = plt.subplots(ncols=2, figsize=(4,2),gridspec_kw={'width_ratios':[1,20],'wspace':0.1})
ax = axes[0]
sorted_celltypes = np.array(all_de_age.cell_type.unique())[row_idx]
celltype_cmap = mpl.colors.ListedColormap([celltype_colors[i] for i in sorted_celltypes])
ax.imshow(np.expand_dims(celltype_annots,1), cmap=celltype_cmap, aspect='auto',interpolation='none',rasterized=True)
sns.despine(ax=ax,bottom=True,left=True)
ax.set_xticks([])
ax.set_yticks([])
ax = axes[1]
ax.imshow(coef_age,vmin=-3,vmax=3,cmap=plt.cm.seismic,aspect='auto',interpolation='none',rasterized=True)
ax.axis('off')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_ageonly.pdf", bbox_inches='tight', dpi=200)

In [None]:
plt.imshow(coef_lps-coef_age,vmin=-3,vmax=3,cmap=plt.cm.seismic)

In [None]:
combined_coef = np.zeros((coef_lps.shape[0]*2, coef_lps.shape[1]))
combined_coef[1::2,:] = coef_age
combined_coef[::2,:] = coef_lps

In [None]:
combined_pval = np.zeros((pval_lps.shape[0]*2, pval_lps.shape[1]))
combined_pval[1::2,:] = pval_age
combined_pval[::2,:] = pval_lps

In [None]:
combined_pval[combined_pval<1e-4] = 1e-4

In [None]:
row_idx_map = np.zeros(combined_pval.shape[0])
row_idx_map[::2] = np.arange(coef_lps.shape[0])
row_idx_map[1::2] = np.arange(coef_lps.shape[0])
row_idx_map = row_idx_map.astype(np.int)

In [None]:
# heatmap params
vmin = -2 
vmax = 2
s = 25

# plot
aspect_ratio = combined_coef.shape[1]/combined_coef.shape[0]
f, axes = plt.subplots(figsize=(aspect_ratio*4,4),nrows=2,ncols=4,gridspec_kw={'height_ratios':[2,20], 'hspace':0.1,'width_ratios':[0.25, 0.25,20,  1], 'wspace':0.02})

# celltype annots
ax = axes[0,0]
ax.axis('off')

ax = axes[0,1]
ax.axis('off')

ax = axes[0,2]
hc.dendrogram(dn_col,orientation='bottom',ax=ax,above_threshold_color='k',color_threshold=0)
ax.invert_yaxis()

ax.axis('off')

ax = axes[0,3]
ax.axis('off')

ax = axes[1,0]
celltype_annots  = np.zeros(combined_coef.shape[0])
k = 0
for i in range(0,len(celltype_annots),2):
    celltype_annots[i] = k
    celltype_annots[i+1] = k
    k += 1
sorted_celltypes = np.array(all_de_age.cell_type.unique())[row_idx]
celltype_cmap = mpl.colors.ListedColormap([celltype_colors[i] for i in sorted_celltypes])
ax.imshow(np.expand_dims(celltype_annots,1), cmap=celltype_cmap, aspect='auto',interpolation='none')
sns.despine(ax=ax,bottom=True,left=True)
ax.set_xticks([])
ax.set_yticks(np.arange(2*len(sorted_celltypes)))
ylabel = []
k = 0
for i in range(2*len(sorted_celltypes)):
    if i%2 == 0:
        ylabel.append(sorted_celltypes[k]) #+ f" {int(lps_count[row_idx][k]), int(both_count[row_idx[k]]), int(age_count[row_idx][k])}")
        k += 1
    else:
        ylabel.append("")
ax.set_yticklabels(ylabel)
# main heatmap
ax = axes[1,1]
conds = np.zeros(combined_coef.shape[0])
conds[1::2] = 1
ax.imshow(np.expand_dims(conds,1),aspect='auto',interpolation='none',cmap=mpl.colors.ListedColormap(['m','g']))
ax.set_yticks([])
ax.axis('off')
#
ax = axes[1,2]
for i in range(combined_coef.shape[0]):
    for j in range(combined_coef.shape[1]):
        r = row_idx_map[i]
        if coef_signif[r,j] == 1:
            circle_color = 'm'
        elif coef_signif[r,j] == 2:
            circle_color = 'g'
        elif coef_signif[r,j] == 3:
            circle_color = 'k'
        else:
            circle_color = None
        if circle_color is not None:
            # if 'm' or 'g', only color the appropriate row
            plot_circle = True
            if circle_color == 'm':
                if i%2 == 1:
                    plot_circle = False
            elif circle_color == 'g':
                if i%2 == 0:
                    plot_circle = False
            if plot_circle:
                ax.scatter(j,i, c=combined_coef[i,j], s=s, vmin=vmin,vmax=vmax, cmap=plt.cm.seismic, edgecolors=circle_color, linewidths=0.5,facecolors='none',)
            else:
                ax.scatter(j,i, c=combined_coef[i,j], s=15, vmin=-2,vmax=2, cmap=plt.cm.seismic, edgecolors=circle_color, linewidths=0.5,facecolors='none',)


        #else:
            #ax.scatter(j,i, c=combined_coef[i,j], s=5*-np.log10(combined_pval[i,j]), vmin=-2,vmax=2, cmap=plt.cm.seismic)

ax.imshow(combined_coef,vmin=-3,vmax=3,cmap=plt.cm.seismic,aspect='auto',interpolation='none',rasterized=True)
ax.set_xticks(np.arange(combined_coef.shape[1]));
ax.set_yticks(np.arange(combined_coef.shape[0]));
ax.set_xticklabels(np.array(combined_genes)[col_idx],rotation=90,size=6);
#ax.set_yticklabels(all_de.cell_type.unique()[row_idx],size=6);
for i in range(0,combined_coef.shape[0],2):
    ax.axhline(i-0.5,color='k',linestyle='--',lw=0.5)
#for i in range(1,len(coef_idx)):
#    if coef_idx[i-1] != coef_idx[i]:
#        ax.axvline(i-0.5, color='k',linestyle='--',lw=0.5)
sns.despine(ax=ax,bottom=True,left=True) 
#for i in ax.get_xticklabels(): 
#    if i.get_text() in lps_only:
#        i.set_color('m')
#    elif i.get_text() in age_only:
#        i.set_color('g')
#    else:
#        i.set_color('k')
ax.set_yticks([])
# dendrogram
ax = axes[1,3]
#hc.dendrogram(dn_row,orientation='right',ax=ax,above_threshold_color='k',color_threshold=0)
ax.invert_yaxis()
ax.axis('off')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_lps_age_combined_nologumi.pdf", bbox_inches='tight', dpi=200)

In [None]:
good_cols = np.any(coef_signif>0,0)
coef_signif = coef_signif[:,good_cols]

In [None]:
aspect_ratio = combined_coef.shape[1]/combined_coef.shape[0]
f, axes = plt.subplots(figsize=(aspect_ratio*3,3),ncols=3,gridspec_kw={'width_ratios':[0.25,20, 2], 'wspace':0.02})
ax = axes[0]
celltype_annots  = np.zeros(coef_signif.shape[0])
k = 0
for i in range(0,len(coef_signif)):
    celltype_annots[i] = k
    k += 1
sorted_celltypes = np.array(all_de_age.cell_type.unique())[row_idx]
celltype_cmap = mpl.colors.ListedColormap([celltype_colors[i] for i in sorted_celltypes])
ax.imshow(np.expand_dims(celltype_annots,1), cmap=celltype_cmap, aspect='auto',interpolation='none')
sns.despine(ax=ax,bottom=True,left=True)
ax.set_xticks([])
ax.set_yticks(np.arange(len(sorted_celltypes)))
ylabel = []
k = 0
for i in range(len(sorted_celltypes)):
    ylabel.append(sorted_celltypes[i]) #+ f" {int(lps_count[row_idx][k]), int(both_count[row_idx[k]]), int(age_count[row_idx][k])}")
ax.set_yticklabels(ylabel)

ax = axes[1]
for i in range(coef_signif.shape[0]):
    for j in range(coef_signif.shape[1]):
        if coef_signif[i,j] == 1:
            circle_color = 'm'
        elif coef_signif[i,j] == 2: 
            circle_color = 'g' 
        elif coef_signif[i,j] == 3:
            circle_color = 'k'
        else:
            circle_color = None
        if circle_color is not None:
            ax.scatter(j,coef_signif.shape[0] - i -1, s=15, vmin=-2,vmax=2, cmap=plt.cm.seismic, edgecolors=circle_color, linewidths=0.5, facecolors=circle_color)
        #else:
            #ax.scatter(j,i, c=combined_coef[i,j], s=5*-np.log10(combined_pval[i,j]), vmin=-2,vmax=2, cmap=plt.cm.seismic)

ax.set_xticks(np.arange(coef_signif.shape[1]));
ax.set_yticks(np.arange(coef_signif.shape[0]));
ax.set_xticklabels(np.array(combined_genes)[col_idx][good_cols],rotation=90,size=6);
#ax.set_yticklabels(all_de.cell_type.unique()[row_idx],size=6);
for i in range(0,coef_signif.shape[0]):
    ax.axhline(i-0.5,color='k',linestyle='--',lw=0.5)
#for i in range(1,len(coef_idx)):
#    if coef_idx[i-1] != coef_idx[i]:
#        ax.axvline(i-0.5, color='k',linestyle='--',lw=0.5)
sns.despine(ax=ax,bottom=True,left=True) 
#for i in ax.get_xticklabels(): 
#    if i.get_text() in lps_only:
#        i.set_color('m')
#    elif i.get_text() in age_only:
#        i.set_color('g')
#    else:
#        i.set_color('k')
ax.set_yticks([])

ax = axes[2]
#hc.dendrogram(dn_row,orientation='right',ax=ax,above_threshold_color='k',color_threshold=0)
ax.invert_yaxis()
ax.axis('off')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_lps_age_de_thresh_nologumi.pdf", bbox_inches='tight', dpi=200)

In [None]:
from scipy.stats import pearsonr
cross_corr = np.zeros((14,14))
for i in range(cross_corr.shape[0]):
    for j in range(cross_corr.shape[1]):
        cross_corr[i,j] = pearsonr(coef_lps[i,:], coef_age[j,:])[0]
        #cross_corr[i,j] = 0.5*(np.corrcoef(coef_lps[i,:], coef_age[j,:])[0,1] + np.corrcoef(coef_lps[j,:], coef_age[i,:])[0,1])

In [None]:
f, ax = plt.subplots(figsize=(5,4))
plt.imshow(cross_corr,vmin=-1,vmax=1,cmap=plt.cm.seismic,aspect='auto',interpolation='none')
ax.set_xticks(np.arange(14));
ax.set_yticks(np.arange(14));
ax.set_xticklabels(sorted_celltypes,rotation=90);
ax.set_yticklabels(sorted_celltypes);
plt.colorbar()
ax.set_xlabel('LPS coef')
ax.set_ylabel('Age coef')

In [None]:
# score genes
sc.tl.score_genes(adata, gene_list=['B2m','Trem2', 'Ccl2', 'Apoe', 'Spp1', 'Cst7', 'Axl', 'Itgax', 'Csf1', 'Cd9','C1qa','C1qc','Lyz2','Ctss'], score_name='activate_micro', use_raw=False)
activate_endo = ["B2m", "Nfkbia", "Serinc3","Xdh", "Gfap", "Tap1"]
sc.tl.score_genes(adata, gene_list=activate_endo, score_name='activate_endo',use_raw=False)
sc.tl.score_genes(adata, gene_list=["C4b", "Il33", "Il18"], score_name="activate_olig",use_raw=False)

In [None]:
def get_genes_from_de(A, cell_type):
    all_de = A[A.cell_type == cell_type]
    de_genes_age = list(all_de[np.logical_and(np.abs(all_de.coef_age > 1), all_de.qval<0.01)].gene.unique())
    de_genes_lps = list(all_de[np.logical_and(np.abs(all_de.coef_lps > 1), all_de.qval<0.01)].gene.unique())
    combined_genes = list(set(de_genes_age+de_genes_lps))
    return combined_genes, de_genes_age, de_genes_lps

def get_genes_from_separate_de(A_lps, A_age, cell_type):
    all_de = A_age[A_age.cell_type == cell_type]

    de_genes_age = list(all_de[np.logical_and(np.abs(all_de.coef) > 1, all_de.qval<0.01)].gene.unique())
    all_de = A_lps[A_lps.cell_type == cell_type]

    de_genes_lps = list(all_de[np.logical_and(np.abs(all_de.coef) > 1, all_de.qval<0.01)].gene.unique())
    combined_genes = list(set(de_genes_age+de_genes_lps))
    return combined_genes, de_genes_age, de_genes_lps


In [None]:
cell_types_to_score = ['Vlmc','OPC','Peri','Epen','Astro','Micro','Olig','Endo','Macro']
for i in cell_types_to_score:
    curr_genes, de_genes_age, de_genes_lps = get_genes_from_separate_de(all_de_age, all_de_lps, i)
    print(i, curr_genes)
    sc.tl.score_genes(adata, gene_list=curr_genes, score_name=i+"_score")
    sc.tl.score_genes(adata, gene_list=de_genes_age, score_name=i+"_score_age")
    sc.tl.score_genes(adata, gene_list=de_genes_lps, score_name=i+"_score_lps")

In [None]:
def compute_celltype_neighborhood_regression(A, celltype_key, source, celltypes=None,min_radiu=0, obs_keys=None):
    if obs_keys is None:
        expr = A.X
    else:
        expr = np.array(A.obs.loc[:,obs_keys].values)
    if celltypes is None:
        celltypes = list(sorted(A.obs[celltype_key].unique()))
    pos = A.obsm['spatial']
    labels = A.obs[celltype_key]
    tstats = np.zeros((len(celltypes), expr.shape[1]))
    pvals = np.zeros((len(celltypes), expr.shape[1]))
    # get all the cells of a certain type
    curr_X = pos[labels==source]
    curr_expr = expr[labels==source]
    interactions = {}
    for i, c1 in enumerate(celltypes):
        # find all the cells of the neighboring type
        curr_Y = pos[labels==c1]

        # identify neighbors of target cell type X to cells in cell type Y
        dists, idx = get_nearest_neighbor_dists(curr_Y, curr_X)
       #print(c1, curr_X.shape[0], curr_Y.shape[0], len(idx))

        interactions[c1] = (dists, curr_expr[idx])
    return interactions

def get_nearest_neighbor_dists(X,Y):
    kdtree = KDTree(Y)
    dist, idx = kdtree.query(X, k=2)
    return dist[:,1], idx[:,1]


def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

def average_within_bins(dists, score, min_val, max_val, bin_width, smooth_size=5):
    scores_bin = []
    scores_var = []
    for i in range(min_val, max_val):
        idx = np.argwhere(dists[np.logical_and(dists>i, dists<=(i+bin_width))]).flatten()
        scores_bin.append(np.mean(score[idx]))
        scores_var.append(np.std(score[idx]))
    scores_bin = moving_average(np.array(scores_bin), smooth_size)
    scores_var = np.array(scores_var)
    return scores_bin, scores_var



In [None]:
# plot individual genes that change with age, LPS, or both

In [None]:
plot_info = {
    '4wk_ctrl' : {
        'batch' : 7,
        'slice' : 0,
        'rot' : -183,
        'xlim' : [200, 2300],
        'ylim' : [200, 3800]
    },
    '24wk_ctrl' : {
        'batch' : 11,
        'slice' : 0,
        'rot' : -12,
        'xlim' : [1950, 1950+2100],
        'ylim' : [200, 3700]
    },
    '90wk_ctrl' : {
        'batch' : 8,
        'slice' : 1,
        'rot' : 35,
        'xlim' : [200, 2300],
        'ylim' : [400, 3900]
    },
    '4wk_lps' : {
        'batch' : 17,
        'slice' : 0,
        'xlim' : [2150,2150+2100],
        'ylim' : [600, 4600],
        'rot' : 160
    },
    '24wk_lps' : {
        'batch' : 16,
        'slice' : 1,
        'rot' : 20,
        'xlim' : [1800,1800+2100],
        'ylim' : [200, 4300],
    },
    '90wk_lps' : {
        #'batch' : 13,
        #'slice' : 0,
        #'xlim' : [450,100+2100],
        #'ylim' : [150,3700],
        #'rot' : 190
        'batch' : 19,
        'slice' : 1,
        'rot' : 0,
        'xlim' : [1800,1800+2100],
        'ylim' : [150,3900]
    }
}


In [None]:
clust_encoding = {k:i for i,k in enumerate(label_colors.keys())}
curr_cmap = mpl.colors.ListedColormap([label_colors[i] for i in label_colors.keys()])
adata.obs['clust_encoding'] = [clust_encoding[i] for i in adata.obs.clust_annot_preds]

In [None]:
def find_clust_names_with_str(clust_names, to_find):
    valid_clusts = []
    for i in clust_names:
        for j in to_find:
            if j in i:
                valid_clusts.append(i)
    return list(set(valid_clusts))

In [None]:
for curr_samp in ['4wk_lps','24wk_lps','90wk_lps']:

    curr_adata = adata[np.logical_and(adata.obs.data_batch==str(plot_info[curr_samp]['batch']), 
                                      adata.obs.slice==plot_info[curr_samp]['slice'])]
    print(curr_adata.obs.age.unique())
    curr_rot = plot_info[curr_samp]['rot']
    curr_size = 1
    aspect_ratio, nx, ny = calculate_aspect_ratio(curr_adata, rot=curr_rot)
    print(aspect_ratio, nx, ny)
    xlim = plot_info[curr_samp]['xlim']
    ylim = plot_info[curr_samp]['ylim']
    aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])
    f, ax = plt.subplots(figsize=(5*7*aspect_ratio,5))
    ax = plt.subplot(1,7,1)
    plot_seg(curr_adata, seg_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)
    cell_types = find_clust_names_with_str(adata.obs.clust_annot_preds, ["ExN"])
    ax = plt.subplot(1,7,2)
    plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax, xlim=xlim, ylim=ylim,clust_key='clust_annot_preds')

    ax = plt.subplot(1,7,3)
    cell_types = find_clust_names_with_str(adata.obs.clust_annot_preds, ['InN', 'MSN'])
    plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim,clust_key='clust_annot_preds')

    ax = plt.subplot(1,7,4)
    cell_types = find_clust_names_with_str(adata.obs.clust_annot_preds, ['Olig', 'OPC'])
    plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim,clust_key='clust_annot_preds')

    ax = plt.subplot(1,7,5)
    cell_types = find_clust_names_with_str(adata.obs.clust_annot_preds, ['Astro'])
    plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim,clust_key='clust_annot_preds')

    ax = plt.subplot(1,7,6)
    cell_types = find_clust_names_with_str(adata.obs.clust_annot_preds, ['Epen', 'Endo', 'Vlmc', 'Peri'])
    plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim,clust_key='clust_annot_preds')

    ax = plt.subplot(1,7,7)
    cell_types = find_clust_names_with_str(adata.obs.clust_annot_preds, ['Micro','Macro','T cell','B cell'])
    plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim,clust_key='clust_annot_preds')
    f.savefig(f"/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_spatial_{curr_samp}.pdf", bbox_inches='tight', dpi=200)

In [None]:
# plot area specific genes
area_genes = [
 'Otof',
'Cux2',
'Rorb',
  'Rspo1',

 'Scube1',
  'Fezf2',
 'Ptpru',
 'Syt6',
 'Nxph4',
 'Drd1',
 'Drd2',
]
curr_samp = '4wk_lps'
celltypes = adata.obs.clust_annot.unique()
curr_adata = adata[np.logical_and(adata.obs.data_batch==str(plot_info[curr_samp]['batch']), 
                                  adata.obs.slice==plot_info[curr_samp]['slice'])]

#curr_adata = adata_combined_merfish[np.logical_and(adata_combined_merfish.obs.batch==9, adata_combined_merfish.obs.slice==1)]
curr_rot = plot_info[curr_samp]['rot']
xlim = plot_info[curr_samp]['xlim']
ylim = plot_info[curr_samp]['ylim']
aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])
f, axes = plt.subplots(nrows=1, ncols=len(area_genes)+1, figsize=(5*aspect_ratio*(len(area_genes)+1),5*1))
plot_seg(curr_adata, seg_cmap, rot=curr_rot,s=curr_size, ax=axes[0],xlim=xlim, ylim=ylim)

k = 1
for c in area_genes: 
    ax = axes[k]
    vmin = np.quantile(curr_adata[:,c].X, 0.05)
    vmax = np.quantile(curr_adata[:,c].X, 0.95)
    plot_gene_expr(curr_adata, celltypes, c, plt.cm.Reds, vmin=vmin,vmax=vmax, rot=curr_rot, ax=ax, xlim=xlim, ylim=ylim)
    k += 1
    ax.set_title(c)

In [None]:
curr_adata = adata[np.logical_and(adata.obs.data_batch=='16', adata.obs.slice==1)]
print(curr_adata.obs.age.unique(), curr_adata.obs.cond.unique())
plot_gene_expr(curr_adata, adata.obs.clust_annot_preds.unique(), "Cxcl10",rot=0,s=0.1,vmin=vmin,vmax=5,use_raw=False,key='clust_annot_preds',cmap=plt.cm.Reds)
plot_gene_expr(curr_adata, adata.obs.clust_annot_preds.unique(), "Ifit3",rot=0,s=0.1,vmin=vmin,vmax=10,use_raw=False,key='clust_annot_preds',cmap=plt.cm.Reds)
plot_gene_expr(curr_adata, adata.obs.clust_annot_preds.unique(), "Irf7",rot=0,s=0.1,vmin=vmin,vmax=10,use_raw=False,key='clust_annot_preds',cmap=plt.cm.Reds)


In [None]:
curr_adata = adata[np.logical_and(adata.obs.data_batch=='12', adata.obs.slice==0)]
print(curr_adata.obs.age.unique(), curr_adata.obs.cond.unique())
plot_gene_expr(curr_adata, adata.obs.clust_annot_preds.unique(), "Cxcl10",rot=0,s=0.1,vmin=vmin,vmax=10,use_raw=False,key='clust_annot_preds',cmap=plt.cm.Reds)
plot_gene_expr(curr_adata, adata.obs.clust_annot_preds.unique(), "Ifit3",rot=0,s=0.1,vmin=vmin,vmax=10,use_raw=False,key='clust_annot_preds',cmap=plt.cm.Reds)
plot_gene_expr(curr_adata, adata.obs.clust_annot_preds.unique(), "Irf7",rot=0,s=0.1,vmin=vmin,vmax=10,use_raw=False,key='clust_annot_preds',cmap=plt.cm.Reds)


In [None]:
def get_plot_info(age, cond):
    info = plot_info[age+"_"+cond]
    return info['batch'], info['slice'], info['xlim'], info['ylim'], info['rot']

def plot_gene_by_conditions(A, gene_name, vmin=0,vmax=3):
    f,ax = plt.subplots(figsize=(4,10), nrows=3, ncols=2, gridspec_kw={'wspace':0.05, 'hspace':0.01})
    for i, cond in enumerate(['ctrl','lps']):
        for j, age in enumerate(['4wk', '24wk', '90wk']):
            batch, dslice, xlim, ylim, rot = get_plot_info(age, cond)
            curr_ax = ax[j][i]
            curr_adata = A[np.logical_and(A.obs.data_batch==str(batch), A.obs.slice==dslice)]
            plot_gene_expr(curr_adata, A.obs.clust_annot_preds.unique(), gene_name,rot=rot,s=0.1,vmin=vmin,vmax=vmax,use_raw=False,key='clust_annot_preds',cmap=plt.cm.Reds,ax=curr_ax)
            curr_ax.set_xlim(xlim)
            curr_ax.set_ylim(ylim)
    return f

def plot_obs_by_conditions(A, obs_name, vmin=0,vmax=3,cmap=plt.cm.Reds, cell_types=None,key='clust_annot_preds',s=0.1):
    if cell_types is None:
        cell_types = A.obs.clust_annot_preds.unique()
    f,ax = plt.subplots(figsize=(4,10), nrows=3, ncols=2, gridspec_kw={'wspace':0.05, 'hspace':0.01})
    for i, cond in enumerate(['ctrl','lps']):
        for j, age in enumerate(['4wk', '24wk', '90wk']):
            batch, dslice, xlim, ylim, rot = get_plot_info(age, cond)
            curr_ax = ax[j][i]
            curr_adata = A[np.logical_and(A.obs.data_batch==str(batch), A.obs.slice==dslice)]
            plot_obs(curr_adata, cell_types, obs_name,rot=rot,s=s,vmin=vmin,vmax=vmax,key=key,cmap=cmap,ax=curr_ax)
            curr_ax.set_xlim(xlim)
            curr_ax.set_ylim(ylim)
    return f


In [None]:
sc.tl.score_genes(adata, gene_list=['B2m','Trem2', 'Ccl2', 'Apoe',  'Axl', 'Itgax', 'Cd9','C1qa','C1qc','Lyz2','Ctss'], score_name='activate_micro', use_raw=False)
sc.tl.score_genes(adata, gene_list=['C4b', 'C3', 'Serpina3n', 'Cxcl10', 'Gfap', 'Vim', 'Il18','Hif3a'], score_name='activate_astro', use_raw=False)

adata_micro = adata[adata.obs.cell_type_preds=="Micro"]
adata.obs.loc[adata.obs.cell_type_preds=="Micro","activate_micro"] = adata_micro.obs.activate_micro - np.mean(adata_micro[np.logical_and(adata_micro.obs.cond=='ctrl',
                                                                                                                                   
                                                                                                                                   adata_micro.obs.age=='4wk')].obs.activate_micro)
adata_astro = adata[adata.obs.cell_type_preds=="Astro"]
adata.obs.loc[adata.obs.cell_type_preds=="Astro","activate_astro"] = adata_astro.obs.activate_astro - np.mean(adata_astro[np.logical_and(adata_astro.obs.age=='4wk',
                                                                                                                                   adata_astro.obs.cond=='ctrl')].obs.activate_astro)

In [None]:
vmax = np.quantile(adata[adata.obs.cell_type_preds=="Micro"].obs.activate_micro,0.95)#np.quantile(adata.obs.activate_micro,0.99999)
vmin = np.quantile(adata[adata.obs.cell_type_preds=="Micro"].obs.activate_micro,0.05) #np.quantile(adata.obs.activate_micro,0.00001)

f = plot_obs_by_conditions(adata, "activate_micro",s=2.5,vmax=vmax, vmin=vmin,cmap=plt.cm.rainbow,cell_types="Micro",key='cell_type_preds');
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_activate_micro.pdf", bbox_inches='tight', dpi=200)

In [None]:
vmax = np.quantile(adata[adata.obs.cell_type_preds=="Astro"].obs.activate_astro,0.95)
vmin = np.quantile(adata[adata.obs.cell_type_preds=="Astro"].obs.activate_astro,0.05)

f = plot_obs_by_conditions(adata, "activate_astro",s=2.5,vmax=vmax, vmin=vmin,cmap=plt.cm.rainbow,cell_types="Astro",key='cell_type_preds');
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_activate_astro.pdf", bbox_inches='tight', dpi=200)

In [None]:
vmax = np.quantile(adata[adata.obs.cell_type_preds=="Endo"].obs.activate_endo,0.95)#np.quantile(adata.obs.activate_endo,0.999)
vmin = np.quantile(adata[adata.obs.cell_type_preds=="Endo"].obs.activate_endo,0.05)

f = plot_obs_by_conditions(adata, "activate_endo",s=1, vmax=vmax, vmin=vmin,cmap=plt.cm.rainbow,cell_types="Endo",key='cell_type_preds');
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_activate_endo.pdf", bbox_inches='tight', dpi=200)

In [None]:
vmax = np.quantile(adata[adata.obs.cell_type_preds=="Olig"].obs.activate_olig,0.95)#np.quantile(adata.obs.activate_endo,0.999)
vmin = np.quantile(adata[adata.obs.cell_type_preds=="Olig"].obs.activate_olig,0.05)

f = plot_obs_by_conditions(adata, "activate_olig",s=1, vmax=vmax, vmin=vmin,cmap=plt.cm.rainbow,cell_types="Olig",key='cell_type_preds');
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_activate_oligo.pdf", bbox_inches='tight', dpi=200)

In [None]:
f = plot_gene_by_conditions(adata, "Rsrp1",vmax=3);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_rsrp1_example.pdf", bbox_inches='tight', dpi=200)

In [None]:
plot_gene_by_conditions(adata, "Cdkn1a",vmax=7);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_cdkn1a_example.pdf", bbox_inches='tight', dpi=200)

In [None]:
f = plot_gene_by_conditions(adata, "Cdkn1a",vmax=5);
#f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_il33_example.pdf", bbox_inches='tight', dpi=200)

In [None]:
f = plot_gene_by_conditions(adata, "Il33",vmax=5);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_il33_example.pdf", bbox_inches='tight', dpi=200)

In [None]:
f = plot_gene_by_conditions(adata, "C4b",vmax=5);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_c4b_example.pdf", bbox_inches='tight', dpi=200)

In [None]:
adata_olig = adata[adata.obs.cell_type_preds=="Olig"]
adata.obs.loc[adata.obs.cell_type_preds=="Olig","activate_olig"] = adata_olig.obs.activate_olig - np.mean(adata_olig[np.logical_and(adata_olig.obs.age=='4wk',
                                                                                                                                   adata_olig.obs.cond=='ctrl')].obs.activate_olig)

In [None]:
# plot age/obs comparisons
from plotting import *
ylim = [-1, 3.5]
spatial_order = ['Pia','L2/3','L5','L6','CC','LatSept','Striatum','Ventricle']
f = plot_cond_obs_comparison(adata[adata.obs.age=='4wk'], "spatial_clust_annots", "activate_astro", "Astro", order=spatial_order,clust_key='cell_type_preds',ylim=ylim);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_astro_spatial_lps.pdf",bbox_inches='tight',dpi=300)

f = plot_cond_obs_comparison(adata[adata.obs.age=='90wk'], "spatial_clust_annots", "activate_astro", "Astro", order=spatial_order,clust_key='cell_type_preds',ylim=ylim);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_astro_spatial_lps_old.pdf",bbox_inches='tight',dpi=300)

In [None]:
# plot age/obs comparisons
# plot age/obs comparisons
from plotting import *
spatial_order = ['Pia','L2/3','L5','L6','CC','LatSept','Striatum','Ventricle']
ylim = [-1.5, 4]
f = plot_cond_obs_comparison(adata[adata.obs.age=='4wk'], "spatial_clust_annots", "activate_olig", "Olig", order=spatial_order,clust_key='cell_type_preds', ylim=ylim);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_olig_spatial_lps.pdf",bbox_inches='tight',dpi=300)

from plotting import *
f = plot_cond_obs_comparison(adata[adata.obs.age=='90wk'], "spatial_clust_annots", "activate_olig", "Olig", order=spatial_order,clust_key='cell_type_preds',ylim=ylim);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_olig_spatial_lps_old.pdf",bbox_inches='tight',dpi=300)

In [None]:
ylim = [-1, 4.5]
f = plot_cond_obs_comparison(adata[adata.obs.age=='4wk'], "spatial_clust_annots", "activate_endo", "Endo", order=spatial_order,clust_key='cell_type_preds',ylim=ylim);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_endo_spatial_lps.pdf",bbox_inches='tight',dpi=300)

f = plot_cond_obs_comparison(adata[adata.obs.age=='90wk'], "spatial_clust_annots", "activate_endo", "Endo", order=spatial_order,clust_key='cell_type_preds',ylim=ylim);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_endo_spatial_lps_old.pdf",bbox_inches='tight',dpi=300)

In [None]:
ylim = [-1.5, 2.0]
f = plot_cond_obs_comparison(adata[adata.obs.age=='4wk'], "spatial_clust_annots", "activate_micro", "Micro", order=spatial_order,clust_key='cell_type_preds',ylim=ylim);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_micro_spatial_lps.pdf",bbox_inches='tight',dpi=300)

f = plot_cond_obs_comparison(adata[adata.obs.age=='90wk'], "spatial_clust_annots", "activate_micro", "Micro", order=spatial_order,clust_key='cell_type_preds',ylim=ylim);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_micro_spatial_lps_old.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plot_age_obs_comparison(adata[np.logical_and(adata.obs.cond=="lps",adata.obs.age.isin(['4wk','90wk']))], "spatial_clust_annots", "activate_micro", "Micro", order=spatial_order,clust_key='cell_type_preds',age_pal=sns.color_palette(['cornflowerblue','lightcoral']));
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_micro_spatial_lps_agecomp.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plot_age_obs_comparison(adata[np.logical_and(adata.obs.cond=="lps",adata.obs.age.isin(['4wk','90wk']))], "spatial_clust_annots", "activate_astro", "Astro", order=spatial_order,clust_key='cell_type_preds',age_pal=sns.color_palette(['cornflowerblue','lightcoral']));
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_astro_spatial_lps_agecomp.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plot_age_obs_comparison(adata[np.logical_and(adata.obs.cond=="lps",adata.obs.age.isin(['4wk','90wk']))], "spatial_clust_annots", "activate_olig", "Olig", order=spatial_order,clust_key='cell_type_preds',age_pal=sns.color_palette(['cornflowerblue','lightcoral']));
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_olig_spatial_lps_agecomp.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plot_age_obs_comparison(adata[np.logical_and(adata.obs.cond=="lps",adata.obs.age.isin(['4wk','90wk']))], "spatial_clust_annots", "activate_endo", "Endo", order=spatial_order,clust_key='cell_type_preds',age_pal=sns.color_palette(['cornflowerblue','lightcoral']));
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_endo_spatial_lps_agecomp.pdf",bbox_inches='tight',dpi=300)

In [None]:
def identify_nearest_neighbors_with_idx(X,Y,dist_thresh, min_dist_thresh=15):
    if X.shape[0] > 0 and Y.shape[0] > 0:
        kdtree = KDTree(Y)
        ind, dists = kdtree.query_radius(X, r=dist_thresh, count_only=False,return_distance=True)
        ind_X = np.hstack([[i]*len(ind[i]) for i in np.arange(len(ind)) if len(ind[i])>0])
        
        ind = np.hstack(ind)
        dists = np.hstack(dists)
        if len(ind) > 0:
            ind = ind[dists>min_dist_thresh]      
            ind_X = ind_X[dists>min_dist_thresh]
        return ind.astype(np.int), ind_X.astype(np.int)
    else:
        return np.array([])

def identify_all_nearest_neighbors_with_dist(X,Y,dist_thresh=80, min_dist_thresh=15):
    if X.shape[0] > 0 and Y.shape[0] > 0:
        kdtree = KDTree(Y)
        ind, dists = kdtree.query_radius(X, r=dist_thresh, count_only=False,return_distance=True)
        ind = np.hstack(ind)
        dists = np.hstack(dists)

        #ind_X = np.hstack([[i]*len(ind[i]) for i in np.arange(len(ind)) if len(ind[i])>0])
        return dists, ind
    else:
        return np.array([])

def identify_nearest_neighbors_with_dist(X,Y):
    if X.shape[0] > 0 and Y.shape[0] > 0:
        kdtree = KDTree(Y)
        dists, ind = kdtree.query(X, k=1,return_distance=True)
        #ind_X = np.hstack([[i]*len(ind[i]) for i in np.arange(len(ind)) if len(ind[i])>0])
        return dists, ind
    else:
        return np.array([])

def compute_celltype_obs_distance_correlation(A,cell_type_X, cell_type_Y, obs_key_X, celltype_key='cell_type'):
    X = A[A.obs[celltype_key] == cell_type_X]
    Y = A[A.obs[celltype_key] == cell_type_Y]
    obs_X = X.obs[obs_key_X]
    curr_X = X.obsm['spatial']
    curr_Y = Y.obsm['spatial']
    dists_Y, ind_Y = identify_nearest_neighbors_with_dist(curr_X,curr_Y)#identify_nearest_neighbors_with_dist(curr_X, curr_Y)
    return obs_X.values, dists_Y

def compute_celltype_obs_correlation(A,cell_type_X, cell_type_Y, obs_key_X, obs_key_Y, celltype_key='cell_type', radius=40, min_dist_thresh=15):
    X = A[A.obs[celltype_key] == cell_type_X]
    Y = A[A.obs[celltype_key] == cell_type_Y]
    obs_X = X.obs[obs_key_X]
    obs_Y = Y.obs[obs_key_Y]
    curr_X = X.obsm['spatial']
    curr_Y = Y.obsm['spatial']
    neighbors_X, ind_X = identify_nearest_neighbors_with_idx(curr_X, curr_Y, dist_thresh=radius, min_dist_thresh=min_dist_thresh)
    curr_expr = obs_Y[neighbors_X]
    return obs_X.values[ind_X], curr_expr.values

def compute_binned_values(dists, scores, min_d=0, max_d=100, bin_size=30):
    binned_mean = np.zeros(max_d-min_d-bin_size)
    binned_std = np.zeros(max_d-min_d-bin_size)
    for i in np.arange(min_d, max_d-bin_size):
        # find distances in this bin range
        idx = np.argwhere(np.logical_and(dists>i, dists<=(i+bin_size)))
        curr_scores = scores[idx]
        binned_mean[i] = np.mean(curr_scores)#/len(idx)
        binned_std[i] = np.std(curr_scores)/np.sqrt(len(curr_scores))#/len(idx)
    binned_mean -= binned_mean.mean()
    binned_std -= binned_mean.mean()
    return binned_mean, binned_std

In [None]:
plt.figure(figsize=(3,3))
celltypes = ["Peri","Endo","Vlmc", "Olig", "Epen"]

for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata[np.logical_and(adata.obs.age=='4wk',adata.obs.cond=='lps')], "Astro", i, "activate_astro",celltype_key='cell_type_preds')
    binned_mean, binned_std = compute_binned_values(dists, scores)
    x = np.arange(len(binned_mean))+30
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])

plt.legend( celltypes)
plt.ylim([-0.05, 0.3])
sns.despine()
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_distance_astro_activation_lps.pdf",bbox_inches='tight')

In [None]:
plt.figure(figsize=(3,3))
celltypes = ["Peri","Endo","Vlmc", "Olig", "Epen"]

for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata[np.logical_and(adata.obs.age=='4wk',adata.obs.cond=='lps')], "Astro", i, "activate_astro",celltype_key='cell_type_preds')
    binned_mean, binned_std = compute_binned_values(dists, scores)
    x = np.arange(len(binned_mean))+30
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])

plt.legend( celltypes)
plt.ylim([-0.05, 0.3])
sns.despine()
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_distance_astro_activation_lps.pdf",bbox_inches='tight')

In [None]:
plt.figure(figsize=(3,3))
celltypes = ["Peri","Endo","Vlmc", "Olig", "Epen"]

for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata[np.logical_and(adata.obs.age=='4wk',adata.obs.cond=='ctrl')], "Astro", i, "activate_astro",celltype_key='cell_type_preds')
    binned_mean, binned_std = compute_binned_values(dists, scores)
    x = np.arange(len(binned_mean))+30
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])

plt.legend( celltypes)
plt.ylim([-0.05, 0.3])
sns.despine()
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_distance_astro_activation_ctrl.pdf",bbox_inches='tight')

In [None]:
plt.figure(figsize=(3,3))
celltypes = ["Peri","Endo","Vlmc", "Olig", "Epen"]

for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata[np.logical_and(adata.obs.age=='4wk',adata.obs.cond=='lps')], "Micro", i, "activate_micro",celltype_key='cell_type_preds')
    binned_mean, binned_std = compute_binned_values(dists, scores)
    x = np.arange(len(binned_mean))+30
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])

plt.legend( celltypes)
plt.ylim([-0.05, 0.1])
sns.despine()
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_distance_micro_activation_lps.pdf",bbox_inches='tight')

In [None]:
plt.figure(figsize=(3,3))
celltypes = ["Peri","Endo","Vlmc", "Olig", "Epen"]

for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata[np.logical_and(adata.obs.age=='4wk',adata.obs.cond=='ctrl')], "Micro", i, "activate_micro",celltype_key='cell_type_preds')
    binned_mean, binned_std = compute_binned_values(dists, scores)
    x = np.arange(len(binned_mean))+30
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])

plt.legend( celltypes)
plt.ylim([-0.05, 0.1])
sns.despine()
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_distance_micro_activation_ctrl.pdf",bbox_inches='tight')

In [None]:
plt.figure(figsize=(3,3))
celltypes = ["Peri","Endo","Vlmc", "Olig", "Epen"]

for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata[np.logical_and(adata.obs.age=='90wk',adata.obs.cond=='ctrl')], "Astro", i, "activate_astro",celltype_key='cell_type_preds')
    binned_mean, binned_std = compute_binned_values(dists, scores)
    x = np.arange(len(binned_mean))+30
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])

plt.legend( celltypes)
plt.ylim([-0.05, 0.3])
sns.despine()
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_distance_astro_activation_ctrl_old.pdf",bbox_inches='tight')

In [None]:
plt.figure(figsize=(3,3))
celltypes = ["Peri","Endo","Vlmc", "Olig", "Epen"]

for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata[np.logical_and(adata.obs.age=='90wk',adata.obs.cond=='lps')], "Astro", i, "activate_astro",celltype_key='cell_type_preds')
    binned_mean, binned_std = compute_binned_values(dists, scores)
    x = np.arange(len(binned_mean))+30
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])

plt.legend( celltypes)
plt.ylim([-0.05, 0.3])
sns.despine()
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_distance_astro_activation_lps_old.pdf",bbox_inches='tight')

In [None]:
plt.figure(figsize=(3,3))
celltypes = ["Peri","Endo","Vlmc", "Olig", "Epen"]

for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata[np.logical_and(adata.obs.age=='90wk',adata.obs.cond=='ctrl')], "Micro", i, "activate_micro",celltype_key='cell_type_preds')
    binned_mean, binned_std = compute_binned_values(dists, scores)
    x = np.arange(len(binned_mean))+30
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])

plt.legend( celltypes)
plt.ylim([-0.05, 0.1])
sns.despine()
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_distance_micro_activation_ctrl_old.pdf",bbox_inches='tight')

In [None]:
plt.figure(figsize=(3,3))
celltypes = ["Peri","Endo","Vlmc", "Olig", "Epen"]

for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata[np.logical_and(adata.obs.age=='90wk',adata.obs.cond=='lps')], "Micro", i, "activate_micro",celltype_key='cell_type_preds')
    binned_mean, binned_std = compute_binned_values(dists, scores)
    x = np.arange(len(binned_mean))+30
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])

plt.legend( celltypes)
plt.ylim([-0.05, 0.1])
sns.despine()
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_distance_micro_activation_lps_old.pdf",bbox_inches='tight')

In [None]:
cc_lps = np.zeros((4,4))
for n1,i in enumerate(['Astro','Micro','Olig', 'Endo']):
    for n2,j in enumerate(['Astro','Micro','Olig', 'Endo']):
        x,y = compute_celltype_obs_correlation(adata[adata.obs.cond=='lps'],  i,j, f"activate_{i.lower()}",f"activate_{j.lower()}",  celltype_key='cell_type_preds', radius=40)
        cc_lps[n1,n2] = np.corrcoef(x,y)[0,1]
        print(i,j,cc_lps[n1,n2])

In [None]:
cc_age = np.zeros((4,4))
for n1,i in enumerate(['Astro','Micro','Olig', 'Endo']):
    for n2,j in enumerate(['Astro','Micro','Olig', 'Endo']):
        x,y = compute_celltype_obs_correlation(adata[adata.obs.cond=='ctrl'],  i,j, f"activate_{i.lower()}",f"activate_{j.lower()}",  celltype_key='cell_type_preds', radius=40)
        cc_age[n1,n2] = np.corrcoef(x,y)[0,1]
        print(i,j,cc_age[n1,n2])

In [None]:
plt.imshow(cc_age,cmap=plt.cm.viridis,vmin=0,vmax=1)

In [None]:
x,y = compute_celltype_obs_correlation(adata[adata.obs.age=='4wk'],  "Endo","Micro", f"activate_endo",f"activate_micro",  celltype_key='cell_type_preds', radius=40)
plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Endo -> Micro (R={np.corrcoef(x,y)[0,1]})")
#plt.hist2d(x,y,cmap=plt.cm.viridis,bins=25);
sns.kdeplot(x=x,y=y,fill=True)
sns.despine()
#plt.xlim([0,5])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_endo_micro_corr_age.pdf",bbox_inches='tight',dpi=300)

In [None]:
x,y = compute_celltype_obs_correlation(adata[np.logical_and(adata.obs.cond=='lps',adata.obs.age=='4wk')],  "Olig","Micro", f"activate_olig",f"activate_micro",  celltype_key='cell_type_preds', radius=40)
plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Olig -> Micro (R={np.corrcoef(x,y)[0,1]})")
#plt.hist2d(x,y,cmap=plt.cm.viridis,bins=25);
sns.kdeplot(x=x,y=y,fill=True)
sns.despine()
#plt.xlim([0,5])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_olig_micro_corr_young.pdf",bbox_inches='tight',dpi=300)

In [None]:
x,y = compute_celltype_obs_correlation(adata[np.logical_and(adata.obs.cond=='lps',adata.obs.age=='4wk')],  "Olig","Astro", f"activate_olig",f"activate_astro",  celltype_key='cell_type_preds', radius=40)
plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Olig -> Astro (R={np.corrcoef(x,y)[0,1]})")
#plt.hist2d(x,y,cmap=plt.cm.viridis,bins=25);
sns.kdeplot(x=x,y=y,fill=True)
sns.despine()
#plt.xlim([0,5])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_olig_astro_corr_young.pdf",bbox_inches='tight',dpi=300)

In [None]:
x,y = compute_celltype_obs_correlation(adata[np.logical_and(adata.obs.cond=='lps',adata.obs.age=='4wk')],  "Micro","Astro", f"activate_micro",f"activate_astro",  celltype_key='cell_type_preds', radius=50)
plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Micro -> Astro (R={np.corrcoef(x,y)[0,1]})")
#plt.hist2d(x,y,cmap=plt.cm.viridis,bins=25);
sns.kdeplot(x=x,y=y,fill=True)
sns.despine()
#plt.xlim([0,5])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_micro_astro_corr_young_lps.pdf",bbox_inches='tight',dpi=300)

In [None]:
x,y = compute_celltype_obs_correlation(adata[np.logical_and(adata.obs.cond=='ctrl',adata.obs.age=='4wk')],  "Micro","Astro", f"activate_micro",f"activate_astro",  celltype_key='cell_type_preds', radius=50)
plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Micro -> Astro (R={np.corrcoef(x,y)[0,1]})")
#plt.hist2d(x,y,cmap=plt.cm.viridis,bins=25);
sns.kdeplot(x=x,y=y,fill=True)
sns.despine()
#plt.xlim([0,5])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_micro_astro_corr_young_ctrl.pdf",bbox_inches='tight',dpi=300)

In [None]:
x,y = compute_celltype_obs_correlation(adata[np.logical_and(adata.obs.cond=='ctrl',adata.obs.age=='90wk')],  "Endo","Astro", f"activate_endo",f"activate_astro",  celltype_key='cell_type_preds', radius=40)
plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Endo -> Micro (R={np.corrcoef(x,y)[0,1]})")
#plt.hist2d(x,y,cmap=plt.cm.viridis,bins=25);
sns.kdeplot(x=x,y=y,fill=True)
sns.despine()
#plt.xlim([0,5])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_endo_astro_corr_young.pdf",bbox_inches='tight',dpi=300)

In [None]:
x,y = compute_celltype_obs_correlation(adata[np.logical_and(adata.obs.cond=='lps',adata.obs.age=='90wk')],  "Endo","Astro", f"activate_endo",f"activate_astro",  celltype_key='cell_type_preds', radius=40)
plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Endo -> Micro (R={np.corrcoef(x,y)[0,1]})")
#plt.hist2d(x,y,cmap=plt.cm.viridis,bins=25);
sns.kdeplot(x=x,y=y,fill=True)
sns.despine()
#plt.xlim([0,5])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_endo_astro_corr_young.pdf",bbox_inches='tight',dpi=300)

In [None]:
# look at correlation between Il33 and Activated Micro/Astro
x,y = compute_celltype_obs_correlation(adata[np.logical_and(adata.obs.cond=='lps',adata.obs.age=='4wk')],  "Endo","Micro", f"activate_endo",f"activate_micro",  celltype_key='cell_type_preds', radius=40)
plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Endo -> Micro (R={np.corrcoef(x,y)[0,1]})")
#plt.hist2d(x,y,cmap=plt.cm.viridis,bins=25);
sns.kdeplot(x=x,y=y,fill=True)
sns.despine()
#plt.xlim([0,5])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_endo_micro_corr_young.pdf",bbox_inches='tight',dpi=300)

In [None]:
# look at correlation between Il33 and Activated Micro/Astro
x,y = compute_celltype_obs_correlation(adata_lps[adata_lps.obs.age=='4wk'],  "Endo","Micro", f"activate_endo",f"activate_micro",  celltype_key='cell_type_preds', radius=40)
plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Endo -> Micro (R={np.corrcoef(x,y)[0,1]})")
#plt.hist2d(x,y,cmap=plt.cm.viridis,bins=25);
sns.kdeplot(x=x,y=y,fill=True)
sns.despine()
#plt.xlim([0,5])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_endo_micro_corr_young.pdf",bbox_inches='tight',dpi=300)

# Map spatial interactions

In [None]:
celltypes = [
    'InN',
 'ExN',
 'MSN',
 'Astro',
 'OPC',
 'Olig',
 'Endo',
 'Vlmc',
 'Peri',
 'Macro',
 'Micro',
]

niter = 1000
perturb_max = 100
dist_thresh = 20

In [None]:
ctrl_interactions_clust, ctrl_pvals_clust, ctrl_qvals_clust = compute_celltype_interactions(adata_ctl[adata_ctl.obs.age=='4wk'], 
                                                                'cell_type_preds', celltypes,niter=niter,dist_thresh=dist_thresh,perturb_max=perturb_max)

lps_interactions_clust, lps_pvals_clust, lps_qvals_clust = compute_celltype_interactions(adata_lps[adata_lps.obs.age=='4wk'], 
                                                                'cell_type_preds', celltypes,niter=niter,dist_thresh=dist_thresh,perturb_max=perturb_max)


In [None]:
from statsmodels.stats.multitest import multipletests
def fdr_correct(X):
    new_X = np.zeros_like(X)
    for i in range(X.shape[0]):
        pvals = multipletests(X[i,:],method='fdr_bh')[1]
        new_X[i,:] = multipletests(X[i,:],method='fdr_bh')[1]
        new_X[:,i] = new_X[i,:]
    #X = multipletests(X.flatten(), method='fdr_bh')[1]
    return new_X#X.reshape(X_shape)

In [None]:
ctrl_qvals_clust = fdr_correct(ctrl_pvals_clust.copy())
lps_qvals_clust = fdr_correct(lps_pvals_clust.copy())


In [None]:
ctrl_qvals_clust[np.isnan(ctrl_qvals_clust)] = 0
lps_qvals_clust[np.isnan(lps_qvals_clust)] = 0

In [None]:
ctrl_interactions_clust[np.isinf(ctrl_interactions_clust)] = 5

In [None]:
f = plot_interactions(ctrl_qvals_clust, ctrl_interactions_clust, celltypes,celltype_colors,cmap=plt.cm.seismic,vmax=5, vmin=-5)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_ctrl_cell_contact_diff.pdf",bbox_inches='tight', dpi=200)

In [None]:
f = plot_interactions(lps_qvals_clust, lps_interactions_clust, celltypes,celltype_colors,cmap=plt.cm.seismic,vmax=5, vmin=-5)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_lps_cell_contact_diff.pdf",bbox_inches='tight', dpi=200)

In [None]:
diff = lps_interactions_clust-ctrl_interactions_clust
diff[np.isnan(diff)] = 0
for i in range(diff.shape[0]):
    for j in range(diff.shape[1]):
        if ctrl_qvals_clust[i,j] < 0.1 or lps_pvals_clust[i,j] < 0.1:
            pass
        else:
            diff[i,j] = 0
diff_qvals = np.zeros_like(lps_qvals_clust)
for i in range(lps_qvals_clust.shape[0]):
    for j in range(lps_qvals_clust.shape[0]):
        if ctrl_qvals_clust[i,j] < 0.1 or lps_qvals_clust[i,j] < 0.1:
            if diff[i,j] > np.log2(1.2) and (lps_interactions_clust[i,j] > 0 or ctrl_interactions_clust[i,j] > 0):
                diff_qvals[i,j] = 0
            else:
                diff_qvals[i,j] = 1
        else:
            diff_qvals[i,j] = 1

In [None]:
f = plot_interactions(diff_qvals, diff, celltypes,celltype_colors,cmap=plt.cm.Reds,vmax=1.5, vmin=0)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_lps_ctrl_cell_contact_diff.pdf",bbox_inches='tight', dpi=200)

In [None]:
# do same comparison but with 90 wk
ctrl_interactions_clust, ctrl_pvals_clust, ctrl_qvals_clust = compute_celltype_interactions(adata_ctl[adata_ctl.obs.age=='90wk'], 
                                                                'cell_type_preds', celltypes,niter=niter,dist_thresh=dist_thresh,perturb_max=perturb_max)

lps_interactions_clust, lps_pvals_clust, lps_qvals_clust = compute_celltype_interactions(adata_lps[adata_lps.obs.age=='90wk'], 
                                                                'cell_type_preds', celltypes,niter=niter,dist_thresh=dist_thresh,perturb_max=perturb_max)
diff = lps_interactions_clust-ctrl_interactions_clust
diff[np.isnan(diff)] = 0
for i in range(diff.shape[0]):
    for j in range(diff.shape[1]):
        if ctrl_qvals_clust[i,j] < 0.1 or lps_pvals_clust[i,j] < 0.1:
            pass
        else:
            diff[i,j] = 0
diff_qvals = np.zeros_like(lps_qvals_clust)
for i in range(lps_qvals_clust.shape[0]):
    for j in range(lps_qvals_clust.shape[0]):
        if ctrl_qvals_clust[i,j] < 0.1 or lps_qvals_clust[i,j] < 0.1:
            if diff[i,j] > np.log2(1.2) and (lps_interactions_clust[i,j] > 0 or ctrl_interactions_clust[i,j] > 0):
                diff_qvals[i,j] = 0
            else:
                diff_qvals[i,j] = 1
        else:
            diff_qvals[i,j] = 1

In [None]:
f = plot_interactions(diff_qvals, diff, celltypes,celltype_colors,cmap=plt.cm.Reds,vmax=1.5, vmin=0)


In [None]:
# do same comparison but with 90 wk
ctrl_interactions_clust, ctrl_pvals_clust, ctrl_qvals_clust = compute_celltype_interactions(adata_lps[adata_lps.obs.age=='4wk'], 
                                                                'cell_type_preds', celltypes,niter=niter,dist_thresh=dist_thresh,perturb_max=perturb_max)

lps_interactions_clust, lps_pvals_clust, lps_qvals_clust = compute_celltype_interactions(adata_lps[adata_lps.obs.age=='90wk'], 
                                                                'cell_type_preds', celltypes,niter=niter,dist_thresh=dist_thresh,perturb_max=perturb_max)
diff = lps_interactions_clust-ctrl_interactions_clust
diff[np.isnan(diff)] = 0
for i in range(diff.shape[0]):
    for j in range(diff.shape[1]):
        if ctrl_qvals_clust[i,j] < 0.1 or lps_pvals_clust[i,j] < 0.1:
            pass
        else:
            diff[i,j] = 0
diff_qvals = np.zeros_like(lps_qvals_clust)
for i in range(lps_qvals_clust.shape[0]):
    for j in range(lps_qvals_clust.shape[0]):
        if ctrl_qvals_clust[i,j] < 0.1 or lps_qvals_clust[i,j] < 0.1:
            if diff[i,j] > np.log2(1.2) and (lps_interactions_clust[i,j] > 0 or ctrl_interactions_clust[i,j] > 0):
                diff_qvals[i,j] = 0
            else:
                diff_qvals[i,j] = 1
        else:
            diff_qvals[i,j] = 1
f = plot_interactions(diff_qvals, diff, celltypes,celltype_colors,cmap=plt.cm.Reds,vmax=1.5, vmin=0)


In [None]:
def identify_nearest_neighbors_with_idx(X,Y,dist_thresh, min_dist_thresh=15):
    if X.shape[0] > 0 and Y.shape[0] > 0:
        kdtree = KDTree(Y)
        ind, dists = kdtree.query_radius(X, r=dist_thresh, count_only=False,return_distance=True)
        ind_X = np.hstack([[i]*len(ind[i]) for i in np.arange(len(ind)) if len(ind[i])>0])
        
        ind = np.hstack(ind)
        dists = np.hstack(dists)
        if len(ind) > 0:
            ind = ind[dists>min_dist_thresh]      
            ind_X = ind_X[dists>min_dist_thresh]
        return ind.astype(np.int), ind_X.astype(np.int)
    else:
        return np.array([])


def compute_celltype_obs_correlation(A,cell_type_X, cell_type_Y, obs_key_X, obs_key_Y, celltype_key='cell_type', radius=40, min_dist_thresh=15):
    X = A[A.obs[celltype_key] == cell_type_X]
    Y = A[A.obs[celltype_key] == cell_type_Y]
    obs_X = X.obs[obs_key_X]
    obs_Y = Y.obs[obs_key_Y]
    curr_X = X.obsm['spatial']
    curr_Y = Y.obsm['spatial']
    neighbors_X, ind_X = identify_nearest_neighbors_with_idx(curr_X, curr_Y, dist_thresh=radius, min_dist_thresh=min_dist_thresh)
    curr_expr = obs_Y[neighbors_X]
    return obs_X.values[ind_X], curr_expr.values

In [None]:
import pandas as pd
plt.figure(figsize=(5,5))
x,y = compute_celltype_obs_correlation(adata_lps[adata_lps.obs.age=='90wk'], 'Astro', 'Micro', 'activate_astro', 'activate_micro', radius=40,celltype_key='cell_type_preds')
plt.scatter(x,y,s=1,c='k')
print(np.corrcoef(x,y))
#plt.xlim([-0.5, 2])
#plt.ylim([-0.5, 1.5])
sns.despine()