In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import os
import anndata as ad
mpl.rcParams['figure.dpi'] = 150
plt.rcParams['pdf.fonttype'] = 42

import sys
from spatial_analysis import *
from plotting import *
sns.set_style('white')

In [None]:
def unbinarize_strings(A):
    A.var_names = [i.decode('ascii') for i in A.var_names]
    A.obs.index = [i.decode('ascii') for i in A.obs.index]
    for i in A.obs.columns:
        if A.obs[i].dtype != np.dtype('bool') and \
            A.obs[i].dtype != np.dtype('int64') and \
            A.obs[i].dtype != np.dtype('int32') and \
            A.obs[i].dtype != np.dtype('object_') and \
            A.obs[i].dtype != np.dtype('float64') and A.obs[i].dtype != np.dtype('float32'):
            if A.obs[i].dtype.is_dtype('category'):
                try:
                    A.obs[i] = [i.decode('ascii') for i in A.obs[i]]
                except Exception as e:
                    pass
    return A


In [None]:
sns.set_style('white')

In [None]:
def plot_gene_by_cells(A, gene_name, s=0.1, vmin=0,vmax=5, cmap=plt.cm.Reds):
    gene_expr = A.X[:,np.argwhere(A.var.index==gene_name)[0][0]]
    plt.scatter(A.obs.center_x, A.obs.center_y, s=s, c=gene_expr, vmin=vmin,vmax=vmax, cmap=cmap)

In [None]:
# load all merged data
adata = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/102821_merged_combined_merfish_allages.h5ad")

In [None]:
print("Starting with", adata.shape[0], "cells")

In [None]:
# genes to exclude that included bit 40 in their code, which had very low signal
bad_genes = ['Prom1',
 'Parp8',
 'Rbpj',
 'Skap2',
 'Ago3',
 'Cntnap3',
 'Meis2',
 'Arnt2',
 'Hivep2',
 'Foxn3',
 'Parp2',
 'Zfp608',
 'Fbxl7',
 'Htr2c',
 'Klf7',
 'Timp2',
 'Zbtb16',
 'Egflam',
 'Ikzf2',
 'Cdh13',
 'Cd63',
 'Marcks',
 'Parp11',
 'Herc6',
 'Cdh9',
 'Tsc22d1',
 'Lef1',
 'Shisa6',
 'St8sia6',
 'Trp53',
 'Plch1',
 'Cp',
 '9630014M24Rik',
 'Elf2',
 'Tafa1',
 'Ntn1',
 'Rarb',
 'Zfp462',
 'Sirt5',
 'Mamdc2',
 'Bach2']


In [None]:
adata = adata[:, ~adata.var_names.isin(bad_genes)]

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=[], percent_top=None, log1p=False, inplace=True)


In [None]:
plt.scatter(adata.obs.total_counts, adata.obs.n_genes_by_counts,s=0.1,alpha=0.1,c=np.array([int(i[:-2]) for i in adata.obs.age]))
plt.xlabel('Counts')
plt.ylabel('Genes')
plt.axvline(20,color='k')
plt.axhline(5,color='k')

In [None]:
plt.scatter(adata.obs.total_counts, adata.obs.n_genes_by_counts,s=0.1,alpha=0.1,c=adata.obs.batch)
plt.xlabel('Counts')
plt.ylabel('Genes')
plt.axvline(20,color='k')
plt.axhline(5,color='k')

In [None]:
# use scrublet
import scrublet as scr

all_doublet_scores = []
for i in adata.obs.batch.unique():
    print("Doubleting", i)
    curr_adata = adata[adata.obs.batch==i]
    scrub = scr.Scrublet(curr_adata.X)
    doublet_scores, predicted_doublets = scrub.scrub_doublets(log_transform=True)
    all_doublet_scores.append(doublet_scores)
    #scrub.plot_histogram()

In [None]:
adata.obs["doublet_scores"] = np.hstack(all_doublet_scores)

In [None]:
plt.hist(np.hstack(all_doublet_scores),100);


In [None]:
adata.write_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_combined_merfish_with_doublet_allages.h5ad")

In [None]:
#adata = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_combined_merfish_with_doublet_allages.h5ad")

In [None]:
adata = adata[adata.obs.doublet_scores<0.2]

In [None]:
# remove cells < 100 um in volume or > 3 x median of all cells
median_vol = np.median(adata.obs.volume)
adata = adata[np.logical_and(adata.obs.volume >= 100, adata.obs.volume < 3*median_vol)]

In [None]:
plt.hist(adata.obs.volume,100);
plt.axvline(100)

In [None]:
sc.pp.filter_cells(adata, min_genes=5)
sc.pp.filter_cells(adata, min_counts=20)


In [None]:
# normalize counts by volume of cell
for i in range(adata.shape[0]):
    adata.X[i,:] /= adata.obs.volume[i]

    # We removed the cells that had total RNA counts lower than 2% quantile or higher than 98% quantile
norm_rna_counts = adata.X.sum(1)
quantile2 = np.quantile(norm_rna_counts, 0.02)
quantile98 = np.quantile(norm_rna_counts, 0.98)
adata = adata[np.logical_and(norm_rna_counts>=quantile2, norm_rna_counts<=quantile98)]
# then by sum
sc.pp.normalize_total(adata, target_sum=250)


In [None]:
print(adata.shape[0])

In [None]:
sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts'],
             jitter=0.4, multi_panel=True)


In [None]:
for i in adata.obs.batch.unique():
    curr_adata = adata[adata.obs.batch==i]
    print(i, curr_adata.shape[0], np.mean(curr_adata.obs.n_genes_by_counts), np.mean(curr_adata.obs.total_counts))

In [None]:
sns.violinplot(x='age',y='total_counts', data=adata.obs)

In [None]:
sns.violinplot(x='age',y='n_genes_by_counts', data=adata.obs)

In [None]:
sc.pl.scatter(adata, x='total_counts', y='n_genes_by_counts')


In [None]:
sc.pp.log1p(adata)


In [None]:
adata.raw = adata
#sc.pp.regress_out(adata, ['total_counts', 'volume'])

sc.pp.scale(adata, max_value=10)

sc.tl.pca(adata, svd_solver='arpack')


In [None]:
sc.pl.pca(adata, color=['total_counts','Vtn','Csf1r','Adora2a','Slc17a7','Slc32a1','Mbp','Cx3cr1', 'age', 'batch'])

In [None]:
sc.pl.pca_variance_ratio(adata, log=True)


In [None]:
#sc.pp.neighbors(adata, n_neighbors=10, n_pcs=30)
import bbknn
bbknn.bbknn(adata,batch_key='batch')

#sc.external.pp.bbknn(adata, batch_key='batch',n_pcs=30)

In [None]:
sc.tl.umap(adata)


In [None]:
sc.pl.umap(adata, color=['age','batch','total_counts', 'volume'])

In [None]:
#adata.write_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_combined_merfish_with_doublet_umap_allages.h5ad")

# Run integration

In [None]:
# start with log transformed counts
#adata = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/102821_merged_combined_merfish_with_doublet_umap_allages.h5ad")
adata = adata.raw.to_adata()
adata.raw = adata
# increment batch for MERFISH
adata.obs.batch = adata.obs.batch+1

In [None]:
celltype_markers = adata.var_names#[adata.var.library == "cell_type"]

In [None]:
# start with raw 10X data
adata10x = sc.read_h5ad("/faststorage/brain_aging/rna_analysis/adata_finalclusts_annot.h5ad")
adata10x = adata10x.raw.to_adata()

In [None]:
adata10x_subset = adata10x[adata10x.obs.area=="PFC"]

In [None]:
import anndata as ad

def integrate_10x_merfish(adata_10x, adata_merfish):
    shared_fields = ["fov", "volumne", "center_x", "center_y", "min_x", "max_x", "min_y", "max_y", "cell_type", "clust_label", "age", "dtype", 'batch', 'total_counts', 'n_genes']
    # deal with 10X
    for i in ["fov", "volumne", "center_x", "center_y", "min_x", "max_x", "min_y", "max_y"]:
        adata_10x.obs[i] = [""]*adata_10x.obs.shape[0]
    adata_10x.obs['batch'] = 0
    adata_10x.obs["dtype"] = "scrnaseq"
    
    adata10x_reduced = adata_10x[:, adata_merfish.var_names]
    adata10x_reduced.obs = adata10x_reduced.obs[shared_fields]

    # deal with merfish
    for i in ["cell_type", "clust_label"]:
        adata_merfish.obs[i] = ["Unlabeled"]*adata_merfish.obs.shape[0]
    adata_merfish.obs["dtype"] = "merfish"
    sc.pp.scale(adata_merfish, max_value=10)
    sc.pp.scale(adata10x_reduced, max_value=10)
    # combine datasets
    print("Concatenating")
    adata_combined = ad.concat({
        '10x': ad.AnnData(
            adata10x_reduced.X,
            obs=adata10x_reduced.obs,
            var=adata10x_reduced.var
        ),
        'merfish': ad.AnnData(
            adata_merfish.X,
            obs=adata_merfish.obs,
            var=adata_merfish.var
        )
    },)
    adata_combined.uns['raw_scrnaseq_X'] = adata_10x.X.copy()
    adata_combined.uns['raw_merfish_X'] = adata_merfish.X.copy()
    adata_combined.raw = adata_combined
    
    print("Scaling")
    #sc.pp.scale(adata_combined, max_value=10)
    print("PCA")
    sc.tl.pca(adata_combined, svd_solver='arpack')
    print("Harmony")
    sc.external.pp.harmony_integrate(adata_combined, 'dtype')
    return adata_combined

In [None]:
adata_combined = integrate_10x_merfish(adata10x_subset, adata)

In [None]:
adata_combined

In [None]:
# copy over harmony PCA
temp = adata_combined.obsm['X_pca']
adata_combined.obsm['X_pca'] = adata_combined.obsm['X_pca_harmony']
adata_combined.obsm['X_pca_orig'] = temp

In [None]:
# first try without bbknn
sc.pl.pca(adata_combined, color=['dtype', 'age'])
sc.pp.neighbors(adata_combined, n_pcs=30)


In [None]:
sc.tl.umap(adata_combined, n_components=2)

sc.pl.umap(adata_combined, color=['dtype','age', 'batch'])

In [None]:
# then try with bbknn
import bbknn
bbknn.bbknn(adata_combined, batch_key='batch') 
sc.tl.umap(adata_combined, n_components=2)


In [None]:
sc.pl.umap(adata_combined, color=['dtype','age', 'batch','Cd3e', 'Cd74','Cx3cr1','Foxj1','Aqp4','Vtn'],size=1, cmap=plt.cm.Reds)

In [None]:
sc.tl.leiden(adata_combined, resolution=1.0)

In [None]:
sc.pl.umap(adata_combined, color=['dtype','age', 'batch','leiden'],size=5, cmap=plt.cm.Reds)

In [None]:
sc.pl.pca(adata_combined[adata_combined.obs.dtype=="scrnaseq"], color=['dtype', 'age'])

In [None]:
sc.pl.pca(adata_combined[adata_combined.obs.dtype=="merfish"], color=['dtype', 'age'])

In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_combined, color=['dtype'],ax=ax)
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_integration_umap.png",dpi=300,bbox_inches='tight')

In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_combined[adata_combined.obs.age.isin(['4wk','90wk'])], color=['age'],ax=ax)
ax.axis('off')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_integration_age_twocolor.png",dpi=300,bbox_inches='tight')

In [None]:
for i in ['4wk', '24wk', '90wk']:
    adata_combined.obs['is_'+i] = [1 if j==i else 0 for j in adata_combined.obs.age]

In [None]:
sc.tl.embedding_density(adata_combined, basis='umap', groupby='age')


In [None]:
for i in ['4wk', '24wk', '90wk']:
    f = sc.pl.embedding_density(adata_combined, basis='umap', key='umap_density_age',bg_dotsize=10,fg_dotsize=1,color_map=plt.cm.Reds,group=i,show=False, return_fig=True)
    plt.axis('off')
    f.set_size_inches((5,5))
    f.savefig(f"/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_density_{i}.png",bbox_inches='tight',dpi=300)
#cbar = plt.colorbar()
#plt.colorbar()

In [None]:
sc.pl.umap(adata_combined, color=['is_4wk'],size=0.5,cmap=plt.cm.Greys, add_outline=True)


In [None]:
sc.pl.umap(adata_combined, color=['is_24wk'],size=0.5,cmap=plt.cm.Greys, add_outline=True)


In [None]:
sc.pl.umap(adata_combined, color=['is_90wk'],size=0.5,cmap=plt.cm.Reds, add_outline=True)


In [None]:
genes_to_plot = ["Cux2","Rorb", "Vip", "Slc32a1", "Ctss",'Aqp4','Vtn', 'Cd3e','Cd74','Foxj1']
sc.pl.umap(adata_combined, color=genes_to_plot,cmap=plt.cm.Reds,size=5,use_raw=True)


In [None]:
sc.pl.umap(adata_combined[adata_combined.obs.dtype=='scrnaseq'], color=genes_to_plot,cmap=plt.cm.Reds,size=5)


In [None]:
sc.pl.umap(adata_combined[adata_combined.obs.dtype=='merfish'], color=genes_to_plot,cmap=plt.cm.Reds,size=5)


#  De novo clustering: Cluster the integrated MERFISH and scRNAseq data

In [None]:
# first leiden cluster integrated data
sc.tl.leiden(adata_combined, resolution=0.5)

In [None]:
major_celltype_markers = ['Slc17a6', 'Slc17a7', 'Gad1', 'Drd1', 'Adora2a', 'Trem2','Ctss','Aqp4','Foxj1','Vtn','Flt1','Olig1','Plp1','Pdgfra','Igf2','Cd74','Cd3e']

In [None]:
sc.pl.umap(adata_combined, color=major_celltype_markers,size=5)

In [None]:
sc.pl.umap(adata_combined, color=['age','leiden','Foxj1'],size=1)

In [None]:
sc.pl.dotplot(adata_combined, var_names=major_celltype_markers, groupby='leiden',figsize=(5,5))

In [None]:
major_cell_type_map = {
    '0' : 'Olig',
    '1' : 'ExN',
    '2' : 'Astro',
    '3' : 'MSN',
    '4' : 'ExN',
    '5' : 'Vascular',
    '6' : 'MSN',
    '7' : 'Micro',
    '8' : 'ExN',
    '9' : 'InN',
    '10' : 'InN',
    '11' : 'Vascular',
    '12' : 'OPC',
    '13' : 'ExN',
    '14' : 'Olig',
    '15' : 'Immune',
}

In [None]:
adata_combined.obs["cell_type"] = [major_cell_type_map[i] for i in adata_combined.obs.leiden]

In [None]:
f,ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_combined, color=['cell_type'],size=5,ax=ax)

In [None]:
#adata_combined.write_h5ad("/faststorage/brain_aging/merfish/exported/011722_integrated_merfish_10x_allages.h5ad")

In [None]:
adata_combined = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_integrated_merfish_10x_allages.h5ad")

# Subcluster neurons

In [None]:
def hcluster_adata(A):
    from scipy.spatial.distance import pdist
    import scipy.cluster.hierarchy as hc

    expr_mat = np.zeros((len(A.obs.leiden.unique()), A.X.shape[1]))
    for i,k in enumerate(sorted(A.obs.leiden.unique())):
        expr_mat[i,:] = A[A.obs.leiden==k].X.mean(0)


    D = pdist(expr_mat, 'cosine')
    Z = hc.linkage(D,'complete',optimal_ordering=True)
    dn = hc.dendrogram(Z,no_plot=False)
    plt.axhline(0.1, color='k')


In [None]:
# subset into inhibitory and excitatory
adata_subset_inhib = adata_combined[adata_combined.obs.cell_type=="InN"]
del adata_subset_inhib.uns['raw_scrnaseq_X']
del adata_subset_inhib.uns['raw_merfish_X']

#adata_subset_msn = adata_subset_neurons[adata_subset_neurons.obs.msn==1]

adata_subset_excite = adata_combined[adata_combined.obs.cell_type=="ExN"]
del adata_subset_excite.uns['raw_scrnaseq_X']
del adata_subset_excite.uns['raw_merfish_X']


In [None]:
def recluster_subset(A,resolution=0.6,npcs=25):
    #A = A.raw.to_adata()
    #print("Scaling")
    #sc.pp.scale(A, max_value=10)
    #print("PCA")
    #sc.tl.pca(A, svd_solver='arpack',n_comps=npcs)
    #print("Harmony")
    #sc.external.pp.harmony_integrate(A, 'dtype', max_iter_harmony=20)
    #temp = A.obsm['X_pca']
    #A.obsm['X_pca'] = A.obsm['X_pca_harmony']
    #A.obsm['X_pca_orig'] = temp
    #print("Neighbors")
    #sc.pp.neighbors(A, n_pcs=npcs)
    bbknn.bbknn(A,n_pcs=npcs,batch_key='batch')
    print("UMAP")
    sc.tl.umap(A)
    print("Leiden")
    sc.tl.leiden(A,resolution=resolution)
    return A

def crosstab_clusts(A):
    temp = pd.crosstab(index=A.obs.leiden,columns=A.obs.clust_annot, normalize=True).idxmax(axis=1)
    for i in temp.iteritems():
        print(f"\"{i[0]}\" : \"{i[1]}\",")
        
def count_ages(A, key="clust_annot",normalize=True):
    temp = pd.crosstab(index=A.obs[key],columns=A.obs.age,normalize=normalize)
    
    return temp

In [None]:
adata_subset_inhib = recluster_subset(adata_subset_inhib,resolution=1.6, npcs=30)

In [None]:
sc.pl.umap(adata_subset_inhib, color=['age','leiden','Sst','Vip','Pvalb','Th','Sncg', 'Chat','Lamp5', 'dtype', 'total_counts'])

In [None]:
def plot_clusts(curr_adata, key='leiden'):
    for i in sorted(curr_adata.obs[key].unique()):
        X = np.array(curr_adata.obs.center_x).astype(np.float64)
        Y = np.array(curr_adata.obs.center_y).astype(np.float64)

        pos = np.array([X,Y]).T
        plt.figure()
        plt.title(i)
        plt.scatter(pos[:,0], pos[:,1],s=1,c='lightgray')
        curr_cells = np.argwhere(np.array(curr_adata.obs[key]==i)).flatten()
        plt.scatter(pos[curr_cells,:][:,0], pos[curr_cells,:][:,1],s=1,c='r')


In [None]:
curr_adata = adata_subset_inhib[adata_subset_inhib.obs.batch==5]
plot_clusts(curr_adata)

In [None]:
sc.tl.rank_genes_groups(adata_subset_inhib, 'leiden', method='t-test')


In [None]:
sc.pl.rank_genes_groups(adata_subset_inhib, n_genes=3, sharey=False)


In [None]:
sc.pl.dotplot(adata_subset_inhib, ["Lhx6","Adarb2","Cnr1","Egfr","Crhr2","Prox1","Serpinf1","Clic4","Th","Crhr2", "Vip",'Sst','Pvalb','Chat','Calb2','Calb1','Lamp5','Sncg'], groupby='leiden')

In [None]:
hcluster_adata(adata_subset_inhib)

In [None]:
crosstab_clusts(adata_subset_inhib)

In [None]:
inhib_mapping = {
    "0" : "InN-Pvalb-1",
    "1" : "InN-LatSept-1",
    "2" : "InN-Sst-1",
    "3" : "InN-Sst-2",
    "4" : "InN-Pvalb-3",
    "5" : "InN-Vip",
    "6" : "InN-Lamp5",
    "7" : "InN-LatSept-2",
    "8" : "InN-Pvalb-2",
    "9" : "InN-Calb2-1",
    "10" : "InN-Calb2-2",
    "11" : "InN-Lhx6",
    "12" : "InN-Chat"
}

In [None]:
adata_subset_inhib.obs["clust_annot"] = [inhib_mapping[i] for i in adata_subset_inhib.obs.leiden]

In [None]:
sc.pl.dotplot(adata_subset_inhib, 
              ["Adarb2","Cnr1","Egfr","Crhr2","Prox1","Serpinf1","Clic4","Th","Crhr2", "Vip",'Sst','Pvalb','Chat','Calb2','Calb1','Lamp5','Sncg'], groupby='clust_annot')

In [None]:
import bbknn
adata_subset_excite = recluster_subset(adata_subset_excite, resolution=1.6, npcs=30)

In [None]:
sc.pl.umap(adata_subset_excite, color=['clust_annot','leiden','age','dtype'],size=5)

In [None]:
curr_adata = adata_subset_excite[adata_subset_excite.obs.batch==3]
for i in sorted(curr_adata.obs.leiden.unique()):
    X = np.array(curr_adata.obs.center_x).astype(np.float64)
    Y = np.array(curr_adata.obs.center_y).astype(np.float64)

    pos = np.array([X,Y]).T
    plt.figure()
    plt.title(i)
    plt.scatter(pos[:,0], pos[:,1],s=1,c='lightgray')
    curr_cells = np.argwhere(np.array(curr_adata.obs.leiden==i)).flatten()
    plt.scatter(pos[curr_cells,:][:,0], pos[curr_cells,:][:,1],s=1,c='r')


In [None]:
sc.tl.rank_genes_groups(adata_subset_excite, 'leiden', method='t-test')

sc.pl.rank_genes_groups(adata_subset_excite, n_genes=3, sharey=False)


In [None]:
sc.pl.dotplot(adata_subset_excite, ["Cux2","Lamp5",'Calb1','Gbp10',"Syt6","Foxp2", "Ptpru","Fezf2",'Rorb','Tshz2',   "Npr3",  "Nxph4",'Nr4a2', 'Scube1','Deptor','Htr4'], groupby='leiden')

In [None]:
hcluster_adata(adata_subset_excite)

In [None]:
crosstab_clusts(adata_subset_excite)

In [None]:

                                                                           
excite_map =  {
    "0" : "ExN-L6-1",
    "1" : "ExN-L6-2",
    "2" : "ExN-L2/3-1",
    "3" : "ExN-L5-2",
    "4" : "ExN-L5-1",
    "5" : "ExN-L5-2",
    "6" : "ExN-L2/3-2",
    "7" : "ExN-LatSept-1",
    "8" : "ExN-L6-2",
    "9" : "ExN-L2/3-1",
    "10" : "ExN-L5-3",
    "11" : "ExN-L6-3",
    "12" : "ExN-L6-2",
    "13" : "ExN-LatSept-2",
}


In [None]:
adata_subset_excite.obs["clust_annot"] = [excite_map[i] for i in adata_subset_excite.obs.leiden]

In [None]:
curr_adata = adata_subset_excite[adata_subset_excite.obs.batch==3]
for i in sorted(curr_adata.obs.clust_annot.unique()):
    X = np.array(curr_adata.obs.center_x).astype(np.float64)
    Y = np.array(curr_adata.obs.center_y).astype(np.float64)

    pos = np.array([X,Y]).T
    plt.figure()
    plt.title(i)
    plt.scatter(pos[:,0], pos[:,1],s=1,c='lightgray')
    curr_cells = np.argwhere(np.array(curr_adata.obs.clust_annot==i)).flatten()
    plt.scatter(pos[curr_cells,:][:,0], pos[curr_cells,:][:,1],s=1,c='r')



In [None]:
sc.pl.dotplot(adata_subset_excite, ["Aqp4","Slc17a6", "Slc17a7","Cldn5", "Nr4a2",'Cux2','Rorb', "Npr3","Fezf2", 'Syt6', 'Nxph4', "Tshz2", 'Otof'], groupby='clust_annot')

In [None]:
sc.pl.umap(adata_subset_excite, color=["clust_annot"])

In [None]:
# MSNs|
adata_subset_msn = adata_combined[adata_combined.obs.cell_type=="MSN"]
del adata_subset_msn.uns['raw_scrnaseq_X']
del adata_subset_msn.uns['raw_merfish_X']

adata_subset_msn = recluster_subset(adata_subset_msn)

In [None]:
#sc.tl.leiden(adata_subset_msn, resolution=0.6)

In [None]:
sc.pl.umap(adata_subset_msn, color=['leiden', 'age','dtype', 'Drd1', 'Adora2a'])

In [None]:
curr_adata = adata_subset_msn[adata_subset_msn.obs.batch==11]
for i in sorted(curr_adata.obs.leiden.unique()):
    
    X = np.array(curr_adata.obs.center_x).astype(np.float64)
    Y = np.array(curr_adata.obs.center_y).astype(np.float64)

    pos = np.array([X,Y]).T
    plt.figure()
    plt.title(i)
    plt.scatter(pos[:,0], pos[:,1],s=1,c='lightgray')
    curr_cells = np.argwhere(np.array(curr_adata.obs.leiden==i)).flatten()
    plt.scatter(pos[curr_cells,:][:,0], pos[curr_cells,:][:,1],s=1,c='r')


In [None]:
sc.tl.rank_genes_groups(curr_adata, 'leiden', method='t-test')

sc.pl.rank_genes_groups(curr_adata, n_genes=3, sharey=False)


In [None]:
sc.pl.dotplot(adata_subset_msn,["Adora2a", "Drd1",'Gad1', 'Cxcl9', 'Otof'],groupby='leiden')

In [None]:
hcluster_adata(adata_subset_msn)

In [None]:
crosstab_clusts(adata_subset_msn)

In [None]:
msn_mapping = {
"0" : "MSN-D1-1",
"1" : "MSN-D2",
"2" : "MSN-D1-1",
"3" : "MSN-D2",
"4" : "MSN-D1-2",
}

In [None]:
adata_subset_msn.obs["clust_annot"] = [msn_mapping[i] for i in adata_subset_msn.obs.leiden]

In [None]:
sc.pl.dotplot(adata_subset_msn,["Adora2a", "Drd1",'Gad1', 'Cxcl9', 'Otof'],groupby='clust_annot')

In [None]:
adata_subset_msn.write_h5ad("/faststorage/brain_aging/merfish/exported/011722_adata_subset_msn.h5ad")
adata_subset_inhib.write_h5ad("/faststorage/brain_aging/merfish/exported/011722_adata_subset_inhib.h5ad")
adata_subset_excite.write_h5ad("/faststorage/brain_aging/merfish/exported/011722_adata_subset_excite.h5ad")

## Subcluster non-neuronal cells

In [None]:
# don't accidentally run this!
non_neuronal_adatas = {}


In [None]:
import bbknn
non_neuronal_celltypes = ['Olig', 'Astro', 'Micro', 'OPC', 'Vascular']
for i in non_neuronal_celltypes:
    print(i)
    curr_adata = adata_combined[adata_combined.obs.cell_type==i]
    #del curr_adata.uns['raw_merfish_X']
    #del curr_adata.uns['raw_scrnaseq_X']
    non_neuronal_adatas[i] = recluster_subset(curr_adata, resolution=1.6)

In [None]:
non_neuronal_celltypes

## Subcluster astrocytes

In [None]:
import bbknn
curr_adata = adata_combined[adata_combined.obs.cell_type=="Astro"]#,npcs=20,resolution=1.6)
#sc.tl.leiden(curr_adata, resolution=1.2)


In [None]:
curr_adata = recluster_subset(curr_adata, npcs=30, resolution=0.6)

In [None]:
sc.tl.score_genes(curr_adata, gene_list=['B2m','Trem2', 'Ccl2', 'Apoe',  'Axl', 'Itgax', 'Cd9','C1qa','C1qc','Lyz2','Ctss'], score_name='activate_micro', use_raw=False)


In [None]:
sc.pl.umap(curr_adata,color=['activate_astro', 'age','leiden','dtype','Foxj1','C3','C4b','Gfap'],size=10)

In [None]:
sc.tl.rank_genes_groups(curr_adata, 'leiden', method='t-test')

sc.pl.rank_genes_groups(curr_adata, n_genes=3, sharey=False)


In [None]:
temp = curr_adata[curr_adata.obs.batch==8]
plot_clusts(temp)

In [None]:
sc.pl.dotplot(curr_adata,['activate_astro','Gfap', 'Aqp4', 'B2m','Cldn5', 'Gfap', 'Tnc','Foxj1'],groupby='leiden')

In [None]:
import pandas as pd
def assign_clusts(A, clust_labels, new_key='clust_annot'):
    A.obs[new_key] = [clust_labels[i] for i in A.obs.leiden]

In [None]:
crosstab_clusts(curr_adata)

In [None]:
astro_clusts = {
"0" : "Astro-1",
"1" : "Astro-2",
"2" : "Astro-2",
"3" : "Astro-2",
"4" : "Astro-3",
"5" : "Astro-3",
}
assign_clusts(curr_adata, astro_clusts)


In [None]:
xtab = np.array(pd.crosstab(curr_adata.obs.age, curr_adata.obs.clust_annot).values.astype(np.float))

In [None]:
for i in range(3):
    xtab[i,:] = xtab[i,:] / xtab[i,:].sum()

In [None]:
pd.DataFrame(xtab.T, index=['Astro-1','Astro-2','Astro-3'], columns=['4wk','24wk','90wk']).plot(kind='bar')

In [None]:
sc.pl.violin(curr_adata, 'activate_astro', groupby='clust_annot')

In [None]:
temp = curr_adata[curr_adata.obs.clust_annot=="Epen"]

In [None]:
temp = recluster_subset(temp)

In [None]:
sc.pl.umap(temp, color=['Foxj1','leiden'])

In [None]:
assign_clusts(curr_adata, astro_clusts)
curr_adata.obs.loc[temp[temp.obs.leiden=="1"].obs.index, "clust_annot"] = "Epen"
curr_adata.obs.loc[temp[temp.obs.leiden!="1"].obs.index, "clust_annot"] = "Astro-2"
non_neuronal_adatas['Astrocyte'] = curr_adata


In [None]:
counts = count_ages(curr_adata,normalize=False)
val_counts = counts.values.astype(np.float)
for i in range(val_counts.shape[1]):
    val_counts[:,i] = val_counts[:,i] / float(val_counts[:,i].sum())

In [None]:
pd.DataFrame(val_counts, index=counts.index, columns=counts.columns).transpose().plot(kind='bar')

In [None]:
count_ages(curr_adata).plot(kind='bar')
sc.pl.violin(curr_adata, 'activate_astro',groupby='clust_annot')

In [None]:
temp = curr_adata[curr_adata.obs.batch==3]
print(temp.obs.age[0])
plot_clusts(temp, 'clust_annot')

## Subcluster vascular cells

In [None]:
curr_adata = non_neuronal_adatas['Vascular'].copy()
#sc.tl.leiden(curr_adata, resolution=0.6)


In [None]:
tcell_markers = ["Tcrd",
"Tcrb",
"Ptprc",
"Rorc",
"Gata3",
"Foxp3",
"Tbx21",
"Il2ra",
"Il7r",
"Il2rb",
"Il2rg",
"Il15ra",
"Pdcd1",
"Ctla4",
"Cd3e"]
bcell_markers = [
    "Ms4a1",
    "Cd19",
    "Prdm1"
]

sc.tl.score_genes(curr_adata, gene_list=tcell_markers,score_name='tcell')
sc.tl.score_genes(curr_adata, gene_list=bcell_markers,score_name='bcell')

In [None]:
temp = curr_adata[curr_adata.obs.batch==7]
plot_clusts(temp)

In [None]:
sc.tl.rank_genes_groups(curr_adata, 'leiden', method='t-test')

sc.pl.rank_genes_groups(curr_adata, n_genes=3, sharey=False)


In [None]:
sc.pl.dotplot(curr_adata,["Cldn5", "Vtn", "Cspg4", "Il33", "Pdgfra"],groupby='leiden')

In [None]:
crosstab_clusts(curr_adata)

In [None]:
vascular_clusts = {
"0" : "Endo-2",
"1" : "Endo-1",
"2" : "Endo-1",
"3" : "Peri-1",
"4" : "Endo-2",
"5" : "Vlmc",
"6" : "Endo-3",
"7" : "Endo-1",
"8" : "Endo-1",
"9" : "Peri-1",
"10" : "Peri-2",
"11" : "Peri-1",
"12" : "Endo-3",
"13" : "Endo-1",
"14" : "Peri-1",
"15" : "Endo-1",
"16" : "Endo-1",
    
}
assign_clusts(curr_adata, vascular_clusts)
non_neuronal_adatas['Vascular'] = curr_adata


## Subcluster oligodendrocytes

In [None]:
curr_adata = non_neuronal_adatas['Olig']
#sc.tl.leiden(curr_adata, resolution=0.6)
sc.pl.umap(curr_adata,color=['age','leiden', 'dtype','Il33'])

In [None]:
temp = curr_adata[curr_adata.obs.batch==7]
plot_clusts(temp)

In [None]:
sc.tl.rank_genes_groups(curr_adata, 'leiden', method='t-test')

sc.pl.rank_genes_groups(curr_adata, n_genes=3, sharey=False)


In [None]:
sc.pl.dotplot(curr_adata,['Flt1', 'Foxj1', 'Atp10b','Cldn5', 'Gfap', 'Vtn', 'Olig1', 'Olig2','Plp1'],groupby='leiden')

In [None]:
crosstab_clusts(curr_adata)

In [None]:
oligo_clusts = {
"0" : "Olig-1",
"1" : "Olig-2",
"2" : "Olig-2",
"3" : "Olig-1",
"4" : "Olig-1",
"5" : "Olig-1",
"6" : "Olig-1",
"7" : "Olig-3",
"8" : "Olig-1",
"9" : "Olig-1",
"10" : "Olig-1",
"11" : "Olig-1",
"12" : "Olig-3",
"13" : "Olig-2",
"14" : "Olig-1",
"15" : "Olig-3",
"16" : "Olig-2",
"17" : "Olig-2",
"18" : "Olig-1",
"19" : "Olig-1",
}
assign_clusts(curr_adata, oligo_clusts)


In [None]:
temp = curr_adata[curr_adata.obs.batch==3]
print(temp.obs.age.unique())
plot_clusts(temp,key='clust_annot')

In [None]:
sc.pl.umap(curr_adata, color=['clust_annot'])

In [None]:
non_neuronal_adatas['Oligodendrocyte'] = curr_adata


## Subcluster microglia

In [None]:
curr_adata = adata_combined[adata_combined.obs.cell_type.isin(['Micro','Immune'])]#non_neuronal_adatas['Micro'].copy()
#sc.tl.leiden(curr_adata, resolution=1.6)
#curr_adata = curr_adata[curr_adata.obs.dtype=="merfish"]
#curr_adata
#sc.tl.pca(curr_adata)
curr_adata = recluster_subset(curr_adata)
sc.pl.umap(curr_adata,color=['age','leiden', 'dtype', 'Cd3e', 'Cd74','Trem2','Cx3cr1','Igf2','F13a1'],size=5)

In [None]:
#hcluster_adata(curr_adata)
hcluster_adata(curr_adata)
plt.axhline(0.1,color='k')

In [None]:
sc.tl.score_genes(curr_adata, gene_list=tcell_markers,score_name='tcell')
sc.tl.score_genes(curr_adata, gene_list=bcell_markers,score_name='bcell')

In [None]:
curr_adata_immune = curr_adata[curr_adata.obs.leiden.isin(['4'])]
curr_adata_micro = curr_adata[~curr_adata.obs.leiden.isin(['4'])]

In [None]:
recluster_subset(curr_adata_immune, npcs=30)
sc.pl.umap(curr_adata_immune,color=['leiden'])


In [None]:
sc.tl.rank_genes_groups(curr_adata_immune, 'leiden', method='t-test')

sc.pl.rank_genes_groups(curr_adata_immune, n_genes=3, sharey=False)


In [None]:
sc.pl.umap(curr_adata_immune,color=['age','dtype','leiden','Cd3e','Cd74','Ctss','tcell'])

In [None]:
immune_clusts = {
    '0' : "Macro",
    '1' : "Macro",
    '2' : "T cell",
    '3' : "Macro",

}
assign_clusts(curr_adata_immune, immune_clusts)
non_neuronal_adatas['Immune'] = curr_adata_immune


In [None]:
curr_adata_micro = non_neuronal_adatas['Micro'].copy()
curr_adata_micro = curr_adata_micro[~curr_adata_micro.obs.index.isin(curr_adata_immune.obs.index)]

In [None]:
curr_adata_micro = adata_combined[adata_combined.obs.cell_type=="Micro"]

In [None]:
recluster_subset(curr_adata_micro, npcs=30, resolution=1.2)


In [None]:
sc.tl.rank_genes_groups(curr_adata_micro, 'leiden', method='t-test')

sc.pl.rank_genes_groups(curr_adata_micro, n_genes=3, sharey=False)


In [None]:
sc.pl.umap(curr_adata_micro, color=['leiden','age','Abi3', 'Selplg','Sptan1','Apoe','C4b','Trem2'])

In [None]:
sc.pl.dotplot(curr_adata_micro, var_names=['Abi3', 'Selplg','Sptan1','Apoe'],groupby='leiden')

In [None]:
sc.tl.score_genes(curr_adata_micro, gene_list=['B2m','Trem2', 'Ccl2', 'Apoe',  'Axl', 'Itgax', 'Cd9','C1qa','C1qc','Lyz2','Ctss'], score_name='activate_micro', use_raw=False)


In [None]:
crosstab_clusts(curr_adata_micro)

In [None]:
micro_clusts = {
"0" : "Micro-3",
"1" : "Micro-1",
"2" : "Micro-2",
"3" : "Micro-3",
"4" : "Micro-2",
"5" : "Micro-1",
"6" : "Micro-2",
"7" : "Micro-1",
"8" : "Micro-1",
"9" : "Micro-1",
"10" : "Micro-1",
}
assign_clusts(curr_adata_micro, micro_clusts)
#non_neuronal_adatas['Microglia'] = curr_adata_micro


In [None]:
sc.pl.violin(curr_adata_micro, 'activate_micro', groupby='clust_annot')

In [None]:
adata_combined.obs.loc[curr_adata_micro.obs.index,'clust_annot'] = curr_adata_micro.obs.clust_annot

## Subcluster OPC

In [None]:
curr_adata = non_neuronal_adatas['OPC']
sc.tl.leiden(curr_adata,resolution=0.4)
sc.pl.umap(curr_adata,color=['age','leiden', 'dtype','Pdgfra','Grin2b'])

In [None]:
sc.tl.rank_genes_groups(curr_adata, 'leiden', method='t-test')

sc.pl.rank_genes_groups(curr_adata, n_genes=3, sharey=False)


In [None]:
plot_clusts(curr_adata[curr_adata.obs.batch==11])

In [None]:
hcluster_adata(curr_adata)

In [None]:
opc_clusts = {
    '0' : "OPC",
    '1' : "OPC",
}
assign_clusts(curr_adata, opc_clusts)
non_neuronal_adatas['OPC'] = curr_adata


In [None]:
non_neuronal_adatas

In [None]:
sc.tl.rank_genes_groups(curr_adata, 'leiden', method='t-test')

sc.pl.rank_genes_groups(curr_adata, n_genes=3, sharey=False)


In [None]:
for i in ['Astrocyte', 'Vascular', 'Oligodendrocyte', 'Immune', 'Microglia', 'OPC']:
    print(i)
    non_neuronal_adatas[i].write_h5ad(f"/faststorage/brain_aging/merfish/exported/011722_adata_subset_{i.lower()}.h5ad")

In [None]:
np.sum(non_neuronal_adatas["Microglia"].obs.cell_type=="")

# Transfer manual annotations to main dataset

In [None]:
adata_combined = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_integrated_merfish_10x_allages.h5ad")

In [None]:
adata_combined.obs['clust_annot'] = 'Unlabeled'

In [None]:
subsets = {}
subset_names = ["OPC","Microglia","Immune", "Oligodendrocyte","Vascular","Astrocyte","Excite","Inhib","MSN"]
for i in subset_names:
    print(i)
    curr_subset = ad.read_h5ad(f"/faststorage/brain_aging/merfish/exported/011722_adata_subset_{i.lower()}.h5ad")
    curr_subset.obs['cell_type'] = i
    subsets[i] = curr_subset#adata_subset.obs["clust_annot"] = ["NA"] * adata_subset.obs.shape[0]

In [None]:
combined_obs = adata_combined.obs.copy()
combined_obs["cell_type_annot"] = ""
combined_obs["clust_annot"] = ""


In [None]:
adata_combined[adata_combined.obs.dtype=="merfish"]

In [None]:
import pandas as pd
for i in subsets.keys():
    print(i)
    combined_obs.loc[subsets[i].obs.index,"cell_type_annot"] = i
    combined_obs.loc[subsets[i].obs.index,"clust_annot"] = subsets[i].obs.clust_annot


In [None]:
adata_combined.obs = combined_obs

In [None]:
# fix mistakes
annot_fix_map = {
    #'OPC-1' : 'OPC',
    #"OPC-2" : "OPC",
    #"MSN-D1-3" : "MSN-D1-2"
    #'ExN-LateSept' : "ExN-LatSept",
    #"ExN-L2/3-3" : "ExN-L5-1",
    #'ExN-L4-1' : 'ExN-L5-2',
    # 'ExN-L4-2' : 'ExN-L5-3',
    # 'ExN-L5-1' : 'ExN-L6-1',
    # 'ExN-L5-2' : "ExN-L6-2",
    # 'ExN-L6' : "ExN-L6-3"

}
annots = list(adata_combined.obs.clust_annot)
for i in range(len(annots)):
    if annots[i] in annot_fix_map:
        annots[i] = annot_fix_map[annots[i]]
adata_combined.obs['clust_annot'] = annots

In [None]:
adata_combined.obs['cell_type'] = [i.split("-")[0] for i in adata_combined.obs.clust_annot]

In [None]:
# fix missing labels
from sklearn.neighbors import KNeighborsClassifier
# first classify cell type
adata_unlabeled =  adata_combined[adata_combined.obs.cell_type=='']
adata_labeled = adata_combined[adata_combined.obs.cell_type!='']
# train on labeled data
clf = KNeighborsClassifier(n_jobs=-1).fit(adata_labeled.obsm['X_pca'], adata_labeled.obs.cell_type)

In [None]:
preds = clf.predict(adata_unlabeled.obsm['X_pca'])

In [None]:
adata_unlabeled.obs.cell_type = preds

In [None]:
# train classifier for each cell_type
for i in adata_unlabeled.obs.cell_type.unique():
    print(i)
    curr_adata = adata_labeled[adata_labeled.obs.cell_type==i]
    curr_unlabeled = adata_unlabeled[adata_unlabeled.obs.cell_type==i]
    curr_clf = KNeighborsClassifier(n_jobs=-1).fit(curr_adata.obsm['X_pca'], 
                                                   curr_adata.obs.cell_type)
    KNeighborsClassifier(n_jobs=-1).fit(curr_adata.obsm['X_pca'], curr_adata.obs.cell_type)
    preds = curr_clf.predict(curr_unlabeled.obsm['X_pca'])
    adata_unlabeled.obs.loc[curr_unlabeled.obs.index, "clust_annot"] = preds

In [None]:
sorted(adata_combined.obs.clust_annot.unique())

In [None]:
celltype_colors, celltype_pals, label_colors, clust_pals = generate_palettes(adata_combined)

In [None]:
temp = adata_combined[adata_combined.obs.cell_type.isin(['Astro','Epen'])]
sc.tl.leiden(temp, resolution=2)

In [None]:
sc.pl.umap(temp,color='leiden')

In [None]:
temp = temp[temp.obs.leiden=='10']
adata_combined.obs.loc[temp.obs.index, 'cell_type'] = "Epen"
adata_combined.obs.loc[temp.obs.index, 'clust_annot'] = "Epen"


In [None]:
sc.pl.umap(adata_combined,color=['clust_annot','Foxj1'],palette=clust_pals)

In [None]:
sc.pl.umap(adata_combined,color=['cell_type_annot', 'clust_annot','age','dtype'],size=1)

In [None]:
sc.pl.umap(adata_combined[adata_combined.obs.age!='24wk'],color=['cell_type_annot', 'clust_annot','age','dtype'],size=1)

In [None]:
# convert to squidpy representation
n_scrnaseq = adata_combined[adata_combined.obs.dtype=="scrnaseq"].shape[0]
coords = np.array(adata_combined[adata_combined.obs.dtype=="merfish"].obs[["center_x", "center_y"]]).astype(np.float64)
coords = np.vstack((np.zeros((n_scrnaseq, 2)), coords))
adata_combined.obsm['spatial'] = coords

In [None]:
# assign points to slices
from sklearn.cluster import KMeans
# number of slices for eachbatch
nslices = {
    0 : 1,
    1 : 1,
    2 : 2,
    3 : 2,
    4 : 3,
    5 : 3,
    6 : 3,
    7 : 3,
    8 : 3,
    9 : 3,
    10 : 4,
    11 : 3,
    12 : 2
} 
slice_labels = []
adata_combined.obs["slice"] = 0
for i in list(adata_combined.obs.batch.unique()):
    if i > 0:
        curr_adata = adata_combined[adata_combined.obs.batch==i]
        pos = curr_adata.obsm['spatial']
        lbl = KMeans(n_clusters=nslices[i]).fit_predict(pos)
        #slice_labels.extend(lbl)
        print(pos.shape, curr_adata.shape)
        adata_combined.obs.loc[curr_adata.obs.index, "slice"] = lbl
    
#    plt.figure()
#    plt.scatter(curr_adata.obs.center_x, curr_adata.obs.center_y, s=1, c=lbl)
#adata_annot.obs["slice"] = slice_labels

In [None]:
for i in list(adata_combined.obs.batch.unique()):
    curr_adata = adata_combined[adata_combined.obs.batch==i]
    pos = curr_adata.obsm['spatial']
    plt.figure()
    plt.scatter(pos[:,0], pos[:,1], s=1, c=curr_adata.obs.slice)


In [None]:
# adjust coordinates so that each brain section is far away from others 
# (a bit of a hack for neighborhood graph computation)
coords = []
index = []
n = 0
for i,b in enumerate(adata_combined.obs.batch.unique()):
    print('--')
    curr_adata = adata_combined[adata_combined.obs.batch==b]
    for j,s in enumerate(sorted(curr_adata.obs.slice.unique())):
        print(s)
        curr_slice = curr_adata[curr_adata.obs.slice==s]
        curr_coords = curr_slice.obsm['spatial']#np.vstack((curr_slice.obs.center_x, curr_slice.obs.center_y)).T
        #curr_coords = curr_slice.obsm['spatial']
        curr_coords += n*1e5
        plt.figure()
        plt.scatter(curr_coords[:,0], curr_coords[:,1], s=1)
        n += 1
        coords.append(curr_coords)
        index.extend(list(curr_slice.obs.index))
#adata_combined[index,:].obsm['spatial'] = np.vstack(coords)

In [None]:
adata_combined = adata_combined[index]
adata_combined.obsm['spatial'] = np.vstack(coords)

In [None]:
x = adata_combined.obsm['spatial'][:,0]
y = adata_combined.obsm['spatial'][:,1]
plt.plot(x,y,'k.')

In [None]:
for i in list(adata_combined.obs.batch.unique()):
    curr_adata = adata_combined[adata_combined.obs.batch==i]

    temp = curr_adata.obsm['spatial']
    plt.figure()
    plt.scatter(temp[:,0], temp[:,1],s=1)

In [None]:
# remove bad section
# remove bad section
good_cells = np.argwhere(np.array(~np.logical_and(adata_combined.obs.batch==10, adata_combined.obs.slice==3))).flatten()
adata_combined_merfish = adata_combined[adata_combined.obs.dtype=="merfish"]
good_cells_merfish = np.argwhere(np.array(~np.logical_and(adata_combined_merfish.obs.batch==10, adata_combined_merfish.obs.slice==3))).flatten()
adata_combined = adata_combined[good_cells]
adata_combined.uns['raw_merfish_X'] = adata_combined.uns['raw_merfish_X'][good_cells_merfish]

In [None]:
sc.pl.umap(adata_combined, color='clust_annot')

In [None]:
adata_combined.obs['cell_type'] = [i.split("-")[0] for i in adata_combined.obs.clust_annot]

In [None]:
sc.pl.umap(adata_combined, color=['Gfap','C4b', 'Il33', 'Il13','Il18','Tnf','C3','age'])

In [None]:
#f,ax = plt.subplots(figsize=(5,5))
#sc.pl.umap(adata_combined, color=['cell_type'], palette=celltype_pals, size=1,ax=ax)
#f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_celltype.pdf", bbox_inches='tight', dpi=200)

## Segment layers

In [None]:
adata_merfish = adata_combined[adata_combined.obs.dtype=="merfish"]

In [None]:
nbor_stats = compute_neighborhood_stats(adata_merfish.obsm['spatial'], adata_merfish.obs.clust_annot,radius=125)

In [None]:
nbor_stats[np.isnan(nbor_stats)] = 0

In [None]:
from sklearn.decomposition import PCA
xform = PCA(random_state=50).fit_transform(nbor_stats)

In [None]:
labels_quant = LabelEncoder().fit_transform(adata_merfish.obs.clust_annot)


In [None]:
plt.scatter(xform[:,0],xform[:,1],s=1, c=labels_quant, cmap=mpl.colors.ListedColormap(np.vstack(label_colors.values())))

In [None]:
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=25, random_state=42).fit_predict(xform)

In [None]:
plt.scatter(xform[:,0],xform[:,1],s=1, c=kmeans, cmap=mpl.colors.ListedColormap(np.vstack(label_colors.values())))

In [None]:
adata_merfish.obs['kmeans'] = kmeans

In [None]:
curr_adata = adata_merfish[np.logical_and(adata_merfish.obs.batch==12, adata_merfish.obs.slice==1)]
print(curr_adata.obs.age[0])
pos = curr_adata.obsm['spatial']
plt.scatter(pos[:,0], pos[:,1],s=1, c=curr_adata.obs.kmeans, cmap=plt.cm.nipy_spectral)
plt.legend()

In [None]:
def plot_clust(A,clust_name, ax,s=0.1,key='kmeans'):
    pos = curr_adata.obsm['spatial']
    ax.scatter(pos[:,0], pos[:,1],s=1, c='gray')
    ax.scatter(pos[A.obs[key]==clust_name,0], pos[A.obs[key]==clust_name,1],s=s, c='r')
    ax.axis('off')
    ax.set_title(clust_name)

In [None]:
curr_adata = adata_merfish[np.logical_and(adata_merfish.obs.batch==8, adata_merfish.obs.slice==1)]

plt.figure(figsize=(20,20))
for i in range(curr_adata.obs.kmeans.max()+1):
    ax = plt.subplot(5,5,i+1)
    plot_clust(curr_adata,i,ax,key='kmeans')

In [None]:
# count cell types per kmeans clust
clust_counts = np.vstack(adata_merfish.obs.groupby('kmeans').apply(lambda x: [np.sum(x.clust_annot==i) for i in sorted(adata_merfish.obs.clust_annot.unique())]).reset_index()[0])
clust_avgs = np.zeros((kmeans.max()+1, nbor_stats.shape[1]))
for i in sorted(np.unique(kmeans)):
    clust_avgs[i,:] = nbor_stats[kmeans==i,:].mean(0)
for i in range(clust_avgs.shape[1]):
    clust_avgs[:,i] = zscore(clust_avgs[:,i])
    
    # hierarchically cluster 
from scipy.spatial.distance import pdist
import scipy.cluster.hierarchy as hc

D = pdist(clust_avgs,'cosine')
Z = hc.linkage(D,'ward',optimal_ordering=True)
dn = hc.dendrogram(Z)
#lbl_order = [clust_ids[c] for c in dn['leaves']]

f, ax = plt.subplots(figsize=(5,2))
ax.imshow(clust_avgs[ dn['leaves']],aspect='auto',vmin=-5,vmax=5, cmap=plt.cm.seismic)
#for i in range(clust_counts.shape[0]):
    #ax.scatter(np.arange(clust_counts.shape[1]), i*np.ones(clust_counts.shape[1]), s=0.005*clust_counts[i,:],c='k')
ax.set_xticks(np.arange(clust_counts.shape[1]));
ax.set_yticks(np.arange(clust_counts.shape[0]));
ax.set_yticklabels(dn['leaves'],fontsize=6)
ax.set_xticklabels(sorted(adata_combined.obs.clust_annot.unique()),rotation=90,fontsize=6);

In [None]:
def crosstab_spatial_clusts(A):
    temp = pd.crosstab(index=A.obs.kmeans,columns=A.obs.spatial_clust_annots, normalize=True).idxmax(axis=1)
    return {k:v for k,v in temp.iteritems()}
        #print(f"{i[0]} : \"{i[1]}\",")


In [None]:
spatial_clust_annots = crosstab_spatial_clusts(adata_merfish)

In [None]:
spatial_clust_annots = {
    0: 'L5',
 1: 'Striatum',
 2: 'L6',
 3: 'L2/3',
 4: 'L6',
 5: 'L6',
 6: 'LatSept',
 7: 'Ventricle',
 8: 'LatSept',
 9: 'Pia',
 10: 'CC',
 11: 'L6',
 12: 'CC',
 13: 'L6',
 14: 'L2/3',
 15: 'L6',
 16: 'CC',
 17: 'Striatum',
 18: 'Striatum',
 19: 'L6',
 20: 'L5',
 21: 'L5',
 22: 'Pia',
 23: 'L5',
 24: 'CC'
}



In [None]:
spatial_clust_annots_values = {
    'Pia' : 0,
    'L2/3' : 1, 
    'L5' : 2,
    'L6' : 3, 
    'LatSept' : 4,
    'CC' : 5,
    'Striatum' : 6,
    'Ventricle' : 7
    }

In [None]:
adata_merfish.obs['spatial_clust_annots'] = [spatial_clust_annots[i] for i in adata_merfish.obs.kmeans]
adata_merfish.obs['spatial_clust_annots_value'] = [spatial_clust_annots_values[i] for i in adata_merfish.obs.spatial_clust_annots]

In [None]:
curr_adata = adata_merfish[np.logical_and(adata_merfish.obs.batch==3, adata_merfish.obs.slice==1)]

plt.figure(figsize=(20,20))
for i in range(curr_adata.obs.spatial_clust_annots_value.max()+1):
    ax = plt.subplot(4,5,i+1)
    plot_clust(curr_adata,i,ax,key='spatial_clust_annots_value')

In [None]:
curr_adata = adata_merfish[np.logical_and(adata_merfish.obs.batch==4, adata_merfish.obs.slice==0)]

pos = curr_adata.obsm['spatial']

plt.scatter(pos[:,0], pos[:,1],s=1, c=curr_adata.obs.spatial_clust_annots_value,cmap=plt.cm.turbo,vmin=0,vmax=9)


In [None]:
curr_adata = adata_merfish[np.logical_and(adata_merfish.obs.batch==12, adata_merfish.obs.slice==0)]

plt.figure(figsize=(20,20))
for i in range(9):
    ax = plt.subplot(4,4,i+1)
    plot_clust(curr_adata,i,ax,key='spatial_clust_annots_value')

In [None]:
adata_combined.obs['spatial_clust_annots_value'] = -1
adata_combined.obs['spatial_clust_annots'] = "NA"


In [None]:
plot_seg(curr_adata,plt.cm.rainbow)

In [None]:
adata_combined.obs.loc[adata_merfish.obs.index, "spatial_clust_annots_value"] = adata_merfish.obs.spatial_clust_annots_value
adata_combined.obs.loc[adata_merfish.obs.index, "spatial_clust_annots"] = adata_merfish.obs.spatial_clust_annots

In [None]:
adata_combined[adata_combined.obs.dtype=="merfish"].shape

# Save out data

In [None]:
#adata_combined.write_h5ad("/faststorage/brain_aging/merfish/exported/011722_adata_combined_harmony.h5ad")

In [None]:
adata_combined = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_adata_combined_harmony.h5ad")
#print(adata_combined.shape[0])

In [None]:
adata_combined = unbinarize_strings(adata_combined)

In [None]:
np.sum(adata_combined.obs.cell_type == "T cell")

In [None]:
from utils import *


In [None]:
from utils import *

In [None]:
# run this
celltype_colors, celltype_pals, label_colors, clust_pals = generate_palettes(adata_combined)

In [None]:
age_colors = ['cornflowerblue','thistle','lightcoral']
age_cmap = mpl.colors.ListedColormap(age_colors)
f,ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_combined, color=['age'], size=1, legend_loc='bottom',ax=ax, palette=sns.color_palette(age_colors))
#ax.set_rastera\ized(True)
ax.axis('off')
ax.set_title('')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_age.png",bbox_inches='tight', dpi=200)

f,ax = plt.subplots(figsize=(5,5))

dtype_colors = ['mediumslateblue', 'goldenrod']
dtype_cmap = mpl.colors.ListedColormap(dtype_colors)
sc.pl.umap(adata_combined, color='dtype', palette=sns.color_palette(dtype_colors),ax=ax, size=1)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_dtype.png", bbox_inches='tight', dpi=200)

f,ax = plt.subplots(figsize=(5,5))

dtype_colors_temp = ['mediumslateblue']
dtype_cmap_temp = mpl.colors.ListedColormap(dtype_colors)
sc.pl.umap(adata_combined[adata_combined.obs.dtype=='merfish'], color='clust_annot', palette=clust_pals,ax=ax, size=5,legend_loc='bottom')
ax.axis('off')
ax.set_title('')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_dtype_merfish.png", bbox_inches='tight', dpi=200)


f,ax = plt.subplots(figsize=(5,5))

dtype_colors = [ 'goldenrod']
_, _, _, clust_pals10x = generate_palettes(adata_combined[adata_combined.obs.dtype=='scrnaseq'])
sc.pl.umap(adata_combined[adata_combined.obs.dtype=='scrnaseq'], color='clust_annot', palette=clust_pals,ax=ax, size=5,legend_loc='bottom')
ax.axis('off')
ax.set_title('')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_dtype_scrnaseq.png", bbox_inches='tight', dpi=200)

f,ax = plt.subplots(figsize=(5,5))

sc.pl.umap(adata_combined, color=['clust_annot'], palette=clust_pals, size=1, legend_loc='bottom',ax=ax)
ax.axis('off')
ax.set_title('')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_clust.png", bbox_inches='tight', dpi=200)

In [None]:
# Cleanup all sections


In [None]:
#adata_combined.obs["smoothed_spatial_clust_annot_values"] = np.nan
for i in adata_combined.obs.batch.unique():
    print(i)
    if i > 0:
        curr_adata = adata_combined[adata_combined.obs.batch==i]
        for j in curr_adata.obs.slice.unique():
            A_section = curr_adata[np.logical_and(curr_adata.obs.slice==j, curr_adata.obs.spatial_clust_annots.isin(['L2/3','L5','L6']))]
            A_section = cleanup_section(A_section,50)
            adata_combined.obs.loc[A_section.obs.index, "smoothed_spatial_clust_annot_values"] = np.array(A_section.obs["smoothed_spatial_clust_annot_values"])
adata_combined.obs['spatial_clust_annots_value'] = adata_combined.obs.smoothed_spatial_clust_annot_values

In [None]:
curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==12, adata_combined.obs.slice==1)]
plt.scatter(curr_adata.obsm['spatial'][:,0],curr_adata.obsm['spatial'][:,1],c=curr_adata.obs.spatial_clust_annots_value,s=1,cmap=plt.cm.rainbow)

# Fig. 2: MERFISH Expression matrix

In [None]:
# make dendrogram of cell types for MERFISH
# compute average of cell type markers
import seaborn as sns
adata_subset_celltype = adata_combined#[:, celltype_markers]
clust_avg = []
clust_ids = sorted(adata_subset_celltype.obs.clust_annot.unique())
for i in clust_ids:
    clust_avg.append(adata_subset_celltype[adata_subset_celltype.obs.clust_annot == i].X.mean(0))
# make dendrogram
clust_avg = np.vstack(clust_avg)

In [None]:
from scipy.spatial.distance import pdist
import scipy.cluster.hierarchy as hc

D = pdist(clust_avg,'correlation')
Z = hc.linkage(D,'complete',optimal_ordering=True)
#label_colors['NA'] = (0,0,0)

In [None]:
dn = hc.dendrogram(Z)


In [None]:
lbl_order = [clust_ids[c] for c in dn['leaves']]

## Fraction of cell types per age

In [None]:
# compute fraction of each cluster per age and per brain area
n_bins = 100
frac_per_age = np.zeros((len(lbl_order), n_bins))
#frac4 = total_4wk/(total_90wk+total_24wk+total_4wk)
#frac24 = total_24wk/(total_90wk+total_24wk+total_4wk)
#frac90 = total_90wk/(total_90wk+total_24wk+total_4wk)

total_90wk = np.sum(adata_combined.obs.age=='90wk')
total_24wk = np.sum(adata_combined.obs.age=='24wk')
total_4wk = np.sum(adata_combined.obs.age=='4wk')

for n,c in enumerate(lbl_order):
    curr_clust = adata_combined[adata_combined.obs.clust_annot==c]
    # count fraction of total cells that are in this area for each age
    curr4 = np.sum(curr_clust.obs.age == "4wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "4wk")
    curr24 = np.sum(curr_clust.obs.age == "24wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "24wk")
    curr90 = np.sum(curr_clust.obs.age == "90wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "90wk")
    
    # scale based on the relative number of cells in each age in the total experiment
    denom = total_4wk + total_24wk + total_90wk
    curr4 /= total_4wk
    curr24 /= total_24wk
    curr90 /= total_90wk
    denom = curr4+curr24+curr90
    curr4 /= denom
    curr24 /= denom
    curr90 /= denom
    nbins90 = int(round(n_bins*curr90))
    nbins24 = int(round(n_bins*curr24))
    print(n, c, curr4, curr24, curr90)
    frac_per_age[n,:] = np.hstack([2*np.ones(nbins90),
                                   np.ones(nbins24), 
                                   np.zeros(n_bins-nbins90-nbins24)])


In [None]:
# fraction of cells in MERFISH vs scRNAseq
frac_per_dtype = np.zeros((len(lbl_order), n_bins))

for n,c in enumerate(lbl_order):
    curr_clust = adata_combined[adata_combined.obs.clust_annot==c]
    curr_merfish = np.sum(curr_clust.obs.dtype == "merfish")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "4wk")
    curr_10x = np.sum(curr_clust.obs.dtype == "scrnaseq")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "4wk")

    curr_merfish /= np.sum(adata_combined.obs.dtype=='merfish')
    curr_10x /=  np.sum(adata_combined.obs.dtype=='scrnaseq')
    denom = curr_merfish + curr_10x
    curr_merfish /= denom
    curr_10x /= denom
    print(n, c, curr_merfish, curr_10x)
    frac_per_dtype[n,:] = np.hstack([np.zeros(round(n_bins*(curr_merfish))), 
                                   np.ones(round(n_bins*(1-curr_merfish)))])


In [None]:
def unbinarize_strings(A):
    A.var_names = [i.decode('ascii') for i in A.var_names]
    A.obs.index = [i.decode('ascii') for i in A.obs.index]
    for i in A.obs.columns:
        if A.obs[i].dtype != np.dtype('bool') and \
            A.obs[i].dtype != np.dtype('int64') and \
            A.obs[i].dtype != np.dtype('int32') and \
            A.obs[i].dtype != np.dtype('object_') and \
            A.obs[i].dtype != np.dtype('float64') and A.obs[i].dtype != np.dtype('float32'):
            if A.obs[i].dtype.is_dtype('category'):
                try:
                    A.obs[i] = [i.decode('ascii') for i in A.obs[i]]
                except Exception as e:
                    pass
    return A


In [None]:
adata_raw_merfish = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_combined_merfish_with_doublet_umap_allages.h5ad")


In [None]:
adata_raw_merfish = unbinarize_strings(adata_raw_merfish)
#adata_combined = unbinarize_strings(adata_combined)

In [None]:
adata_raw_merfish = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_combined_merfish_with_doublet_umap_allages.h5ad")

adata_raw_merfish = adata_raw_merfish[adata_combined[adata_combined.obs.dtype=="merfish"].obs.index]
adata_combined_merfish = adata_combined[adata_combined.obs.dtype=='merfish']
adata_combined_merfish.X = adata_raw_merfish.X#adata_combined_merfish.uns['raw_merfish_X']

In [None]:
# make values for dotplot
dotplot_genes = [
'Slc17a7',
    'Gad2',
    'Cux2',
 'Rspo1',
 'Scube1',
 'Fezf2',
 'Ndst4',
 'Nxph4',
 'Hs3st4',
 'Tshz2',
 'Chat',
 'Ptpru',
 'Sst',
 'Pvalb',
 'Syt6',
 'Cpne7',
 'Lamp5',
 'Lhx6',
 'Vip',
 'Adarb2',
 'Calb2',
 'Otof',
 'Drd1',
 'Adora2a',
 'Pdgfra',
 'Olig1',
 'Rorb',
 'Aqp4',
 'Foxj1',
 'Cspg4',
 'Vtn',
 'Cldn5',
 'F13a1',
 'Cd3e',
 'Ctss']

from scipy.stats import zscore
#              categories_order=lbl_order, groupby='clust_label', swap_axes=True, ax=ax)
dotplot_vals = np.zeros((len(dotplot_genes), len(lbl_order)))
dotplot_frac = np.zeros((len(dotplot_genes), len(lbl_order)))

for n,i in enumerate(lbl_order):
    dotplot_vals[:,n] = np.mean(adata_combined_merfish[adata_combined_merfish.obs.clust_annot == i][:, dotplot_genes].X.toarray(),0)
    dotplot_frac[:,n] = np.sum(adata_combined_merfish[adata_combined_merfish.obs.clust_annot == i][:, dotplot_genes].X.toarray()>0,0)/np.sum(adata_combined_merfish.obs.clust_annot == i)
for n,i in enumerate(dotplot_genes):
    dotplot_vals[n,:] = zscore(dotplot_vals[n,:])
    #dotplot_vals[n,:] /= dotplot_vals[n,:].max()
max_idx = np.arange(len(dotplot_genes))

# uncomment optimize order
#from scipy.optimize import linear_sum_assignment
#_, max_idx = linear_sum_assignment(-dotplot_frac.T)
#dotplot_genes = [dotplot_genes[i] for i in max_idx]


In [None]:
from scipy.stats import ttest_ind, ranksums, fisher_exact, chisquare
from statsmodels.stats.proportion import proportions_ztest, test_proportions_2indep

In [None]:
# get significance star
adata_obs = adata_combined.obs.copy()
adata_obs = adata_obs[adata_obs.dtype=="merfish"]
clust_names = adata_obs.clust_annot.unique()
change_freq_pvals = []
change_freq_names = []
for k in clust_names:
    #curr4 = []
    #curr90 = []
    curr_obs = adata_obs[adata_obs.clust_annot==k]
    n4 = np.sum(adata_obs.age=='4wk')
    n90 = np.sum(adata_obs.age=='90wk')
    curr4 = np.sum(curr_obs[curr_obs.age=='4wk'].clust_annot==k)#/curr_obs.shape[0]
    curr90 = np.sum(curr_obs[curr_obs.age=='90wk'].clust_annot==k)#/curr_obs.shape[0]
    #for j in curr_obs.batch.unique():
    #    temp = curr_obs[curr_obs.batch == j]
    #    counts = 100*np.sum(temp.clust_annot==k)/curr_obs.shape[0]
    #    if temp.age[0] == '4wk':
    #        curr4.append(counts)
    #    elif temp.age[0] == '90wk':
    #        curr90.append(counts)
    #change_freq_pvals.append(ttest_ind(curr4, curr90)[1])
    change_freq_pvals.append(test_proportions_2indep(curr4, n4, curr90, n4)[1])
    change_freq_names.append(k)
    

In [None]:
#change_freq_pvals = multipletests(change_freq_pvals,method='fdr_by')[1]

In [None]:
change_freq_pvals = np.array(change_freq_pvals)

In [None]:
change_freq_pvals

In [None]:
change_freq_qvals = multipletests(change_freq_pvals,method='fdr_bh')[1]

In [None]:
signif_change = list(np.array(change_freq_names)[change_freq_qvals<0.05])

In [None]:
sns.set_style('white')

In [None]:
lbl_order_starred = [i  if i in signif_change else i for i in lbl_order]


In [None]:
#f,ax = plt.subplots(figsize=(20,2.5), nrows=3, ncols=1)
dotscale = 35
f = plt.figure(figsize=(8,8))
gs = plt.GridSpec(nrows=5, ncols=1, height_ratios=[5,1,30,6,6], hspace=0.1)
ax = plt.subplot(gs[0])

hc.dendrogram(Z,ax=ax,labels=clust_ids,leaf_font_size=10,color_threshold=0,above_threshold_color='k');
sns.despine(ax=ax,left=True)
ax.axis('off')
lbl_order = []
for lbl in plt.gca().get_xmajorticklabels():
    if lbl != 'NA':
        #lbl.set_color(label_colors[lbl.get_text()])
        lbl_order.append(lbl.get_text())
ax = plt.subplot(gs[1])

curr_cols = mpl.colors.ListedColormap([label_colors[c] for c in lbl_order])
ax.imshow(np.expand_dims(np.arange(len(label_colors.keys())),1).T, cmap=curr_cols,aspect='auto',interpolation='none')
#ax.axis('off')
ax.set_yticklabels([])
ax.set_xticks(np.arange(len(lbl_order)))
ax.set_xticklabels([np.sum(adata_combined.obs.clust_annot==c) for c in lbl_order],rotation=90)
sns.despine(ax=ax,left=True)
ax = plt.subplot(gs[2])

for i in range(dotplot_vals.shape[1]):
    plt.scatter( i*np.ones((dotplot_vals.shape[0])),-np.arange(dotplot_vals.shape[0]), c=dotplot_vals[max_idx,:][:,i], s=dotscale*dotplot_frac[max_idx,:][:,i], cmap=plt.cm.seismic, vmin=-5,vmax=5)
ax.set_yticks(-np.arange(len(dotplot_genes)));
ax.set_yticklabels(dotplot_genes, fontsize=8)
ax.set_xlim([-0.5, dotplot_vals.shape[1]-0.5])
ax.set_xticks([])
sns.despine(ax=ax,bottom=True)
#ax.axis('off')

# age
ax = plt.subplot(gs[3])
ax.imshow(frac_per_age.T, vmin=0,vmax=2,aspect='auto',interpolation='none', cmap=age_cmap)
ax.set_yticklabels([])
ax.set_xticks([])
#ax.set_xticklabels(lbl_order,rotation=90); #[str(np.sum(adata.obs.clust_label==i)) + " " + i for i in lbl_order])
ax.axhline(33,color='w',linestyle='--')
ax.axhline(66,color='w',linestyle='--')
sns.despine(ax=ax, left=True)

# dtype
ax = plt.subplot(gs[4])
ax.imshow(frac_per_dtype.T, vmin=0,vmax=1,aspect='auto',interpolation='none', cmap=dtype_cmap)
ax.set_yticklabels([])
ax.set_xticks(np.arange(len(lbl_order)))
ax.set_xticklabels(lbl_order,rotation=90); #[str(np.sum(adata.obs.clust_label==i)) + " " + i for i in lbl_order])
ax.axhline(50,color='w',linestyle='--')

sns.despine(ax=ax, left=True)
for lbl in plt.gca().get_xmajorticklabels():
    if lbl != 'NA':
        lbl.set_color(label_colors[lbl.get_text()])
    #if lbl.get_text() in signif_change:
    #    lbl.set_text("* " + lbl.get_text())
ax.set_xticklabels(lbl_order_starred,rotation=90); #[str(np.sum(adata.obs.clust_label==i)) + " " + i for i in lbl_order])

f.savefig('/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_cluster_heatmap.pdf',bbox_inches='tight', dpi=200)

In [None]:
age_pal = sns.color_palette(age_colors)


In [None]:
# show per-batch number of cells for certain cell types
import pandas as pd
adata_obs = adata_combined.obs.copy()
adata_obs = adata_obs[adata_obs.dtype=="merfish"]
clust_names = sorted(adata_obs.clust_annot.unique())
ct_counts = []
ages = []
clusts = []
for k in clust_names:
    for i in ['4wk','24wk','90wk']:
        curr_obs = adata_obs[adata_obs.age==i]
        for j in curr_obs.batch.unique():
            temp = curr_obs[curr_obs.batch == j]
            ct_counts.append(100*np.sum(temp.clust_annot==k)/temp.shape[0])
            ages.append(i)
            clusts.append(k)
            
counts = pd.DataFrame({'count':ct_counts, 'age': ages, 'clust': clusts})
f = plt.figure(figsize=(20,20))
gs = plt.GridSpec(nrows=7, ncols=7, wspace=0.5,hspace=0.5)
for n,i in enumerate(clust_names):
    curr_counts = counts[counts.clust==i]
    ax = plt.subplot(gs[n])
    sns.barplot(x='age',y='count',data=curr_counts,ax=ax, palette=sns.color_palette(age_colors),linewidth=0, errwidth=1,zorder=0)

    #sns.scatterplot(x='age',y='count',data=curr_counts,ax=ax,color='k',zorder=1,linewidth=1)

    sns.despine(ax=ax)
    ax.set_ylabel('')
    ax.set_title(i)
f.savefig('/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS6_perbatch_counts.pdf',bbox_inches='tight', dpi=200)

In [None]:
# show normalized fraction of cell types for each time point
# compute fraction of each cluster per age and per brain area
n_bins = 100
frac_per_age = np.zeros((len(lbl_order), n_bins))
#frac4 = total_4wk/(total_90wk+total_24wk+total_4wk)
#frac24 = total_24wk/(total_90wk+total_24wk+total_4wk)
#frac90 = total_90wk/(total_90wk+total_24wk+total_4wk)

total_90wk = np.sum(adata_combined.obs.age=='90wk')
total_24wk = np.sum(adata_combined.obs.age=='24wk')
total_4wk = np.sum(adata_combined.obs.age=='4wk')
frac_per_age = []
for n,c in enumerate(lbl_order):
    curr_clust = adata_combined[adata_combined.obs.clust_annot==c]
    # count fraction of total cells that are in this area for each age
    curr4 = np.sum(curr_clust.obs.age == "4wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "4wk")
    curr24 = np.sum(curr_clust.obs.age == "24wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "24wk")
    curr90 = np.sum(curr_clust.obs.age == "90wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "90wk")
    
    # scale based on the relative number of cells in each age in the total experiment
    denom = total_4wk + total_24wk + total_90wk
    curr4 /= total_4wk
    curr24 /= total_24wk
    curr90 /= total_90wk
    denom = curr4+curr24+curr90
    curr4 /= denom
    curr24 /= denom
    curr90 /= denom
    nbins90 = int(round(n_bins*curr90))
    nbins24 = int(round(n_bins*curr24))
    frac_per_age.append((curr4, curr24, curr90))
frac_per_age = np.vstack(frac_per_age)

In [None]:
clust_colors = [label_colors[i] for i in lbl_order]
plt.figure(figsize=(10,1))
for i in range(frac_per_age.shape[0]):
    print(clust_colors[i])
    plt.scatter(i*np.ones(3), 0.1*np.arange(3), s=100*frac_per_age[i,:], c=i*np.ones(3),cmap=mpl.colors.ListedColormap(clust_colors),vmin=0,vmax=len(clust_colors))
plt.ylim([-0.1,0.3])

In [None]:
adata_combined.obs.batch.unique()

In [None]:
## Pie charts
import pandas as pd
cell_types_young = adata_combined[adata_combined.obs.age=='4wk'].obs.cell_type_annot
cell_types_med = adata_combined[adata_combined.obs.age=='24wk'].obs.cell_type_annot
cell_types_old = adata_combined[adata_combined.obs.age=='90wk'].obs.cell_type_annot


In [None]:
def count_celltypes(A, age, key='cell_type'):
    counts = {
    "Inhibitory":0,
    "Excitatory":0,
    "MSN":0,
    "Non-neuronal":0
    }
    cell_types = A[A.obs.age==age].obs[key]
    for i in cell_types:
        if "ExN" in i:
            counts["Excitatory"] += 1
        elif "InN" in i:
            counts["Inhibitory"] += 1
        elif "MSN" in i:
            counts["MSN"] += 1
        else:
            counts["Non-neuronal"] += 1
    return pd.DataFrame({ 'counts':list(counts.values())},index=list(counts.keys()))

In [None]:
def simplify_celltypes(A, age, key='cell_type'):
    celltypes = []
    for i in A[A.obs.age==age].obs[key]:
        if "ExN" in i:
            celltypes.append("ExN")
        elif "InN" in i:
            celltypes.append("InN")
        elif "MSN" in i:
            celltypes.append("MSN")
        #else:
        #    celltypes.append("Non-neuronal")

    return pd.DataFrame({'cell_type':celltypes, 'count':np.ones(len(celltypes)), 'age':age})

def simplify_clusts(A, age, key='cell_type'):
    celltypes = list(A[A.obs.age==age].obs[key])
    return pd.DataFrame({'cell_type':celltypes, 'count':np.ones(len(celltypes)), 'age':age})

In [None]:
young_ct = simplify_celltypes(adata_combined, '4wk')
med_ct = simplify_celltypes(adata_combined, '24wk')
old_ct = simplify_celltypes(adata_combined, '90wk')
combined_ct = pd.concat([young_ct, med_ct, old_ct])

In [None]:
colors=[celltype_colors[i] for i in ["ExN","InN","MSN"]]# + [np.array([0.7, 0.7, 0.7])]

In [None]:
#f = plt.figure(figsize=(3,3))
#gs = plt.GridSpec(ncols=3, nrows=1, wspace=0.1)
#ax = plt.subplot(gs[0])
#sns.histplot(x='age',data=young_ct,multiple='stack',hue='cell_type',palette=sns.color_palette(colors), linewidth=0,hue_order=["ExN", "InN","MSN","Non-neuronal"], stat='percent',legend=None, ax=ax)
#sns.despine(ax=ax)
#ax = plt.subplot(gs[1])
#sns.histplot(x='age',data=med_ct,multiple='stack',hue='cell_type',palette=sns.color_palette(colors), linewidth=0,hue_order=["ExN", "InN","MSN","Non-neuronal"],stat='percent',legend=None, ax=ax)
#sns.despine(ax=ax,left=True)
#ax.set_yticks([])
#ax.set_ylabel('')
#ax = plt.subplot(gs[2])
#sns.histplot(x='age',data=old_ct,multiple='stack',hue='cell_type',palette=sns.color_palette(colors), linewidth=0,hue_order=["ExN", "InN","MSN","Non-neuronal"],stat='percent',legend=None, ax=ax)
#sns.despine(ax=ax, left=True)
#ax.set_yticks([])
#ax.set_ylabel('')
#f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig2_celltype_barplot.pdf",bbox_inches='tight', dpi=200)

In [None]:
f, axes = plt.subplots(figsize=(8,24), nrows=1, ncols=3, gridspec_kw={'wspace':0.1})
young_ct_agg = young_ct.value_counts().reset_index()
axes[0].pie(young_ct_agg[0],colors=[celltype_colors[i] for i in young_ct_agg.cell_type], labels=young_ct_agg.cell_type);

med_ct_agg = med_ct.value_counts().reset_index()
axes[1].pie(med_ct_agg[0],colors=[celltype_colors[i] for i in med_ct_agg.cell_type], labels=med_ct_agg.cell_type);

old_ct_agg = old_ct.value_counts().reset_index()
axes[2].pie(old_ct_agg[0],colors=[celltype_colors[i] for i in old_ct_agg.cell_type], labels=old_ct_agg.cell_type);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_celltype_comp_neuronal.pdf", bbox_inches='tight')

In [None]:
adata_nonneuronal = adata_combined[~adata_combined.obs.cell_type.isin(['ExN',"InN","MSN"])]
young_ct = simplify_clusts(adata_nonneuronal, '4wk')
med_ct = simplify_clusts(adata_nonneuronal, '24wk')
old_ct = simplify_clusts(adata_nonneuronal, '90wk')
combined_ct = pd.concat([young_ct, med_ct, old_ct])

In [None]:
#non_neuronal_celltypes = list(combined_ct.cell_type.unique())
non_neuronal_celltypes = ['Olig',
                           'OPC',
 'Astro',
 'Epen',
 'Vlmc',
 'Endo',
 'Peri',
 'Micro',
 'Macro',
 'T cell',
]

non_neuronal_celltype_colors = [celltype_colors[i] for i in non_neuronal_celltypes]

In [None]:
sc.pl.umap(adata_combined, color='cell_type',palette=celltype_pals)

In [None]:
f = plt.figure(figsize=(3,3))
gs = plt.GridSpec(ncols=3, nrows=1, wspace=0.1)
ax = plt.subplot(gs[0])
sns.histplot(x='age',data=young_ct,multiple='stack',hue='cell_type',linewidth=0,palette=sns.color_palette(non_neuronal_celltype_colors), hue_order=non_neuronal_celltypes, stat='percent',legend=None, ax=ax)
sns.despine(ax=ax)
ax = plt.subplot(gs[1])
sns.histplot(x='age',data=med_ct,multiple='stack',hue='cell_type',linewidth=0,palette=sns.color_palette(non_neuronal_celltype_colors), hue_order=non_neuronal_celltypes,stat='percent',legend=None, ax=ax)
sns.despine(ax=ax,left=True)
ax.set_yticks([])
ax.set_ylabel('')
ax = plt.subplot(gs[2])
sns.histplot(x='age',data=old_ct,multiple='stack',hue='cell_type',linewidth=0,palette=sns.color_palette(non_neuronal_celltype_colors), hue_order=non_neuronal_celltypes,stat='percent',legend=None, ax=ax)
sns.despine(ax=ax, left=True)
ax.set_yticks([])
ax.set_ylabel('')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig2_neurons_barplot_nonneuronal.pdf",bbox_inches='tight', dpi=200)

In [None]:
f, axes = plt.subplots(figsize=(8,24), nrows=1, ncols=3, gridspec_kw={'wspace':0.1})
young_ct_agg = young_ct.value_counts().reset_index()
axes[0].pie(young_ct_agg[0],colors=[celltype_colors[i] for i in young_ct_agg.cell_type], labels=young_ct_agg.cell_type);

med_ct_agg = med_ct.value_counts().reset_index()
axes[1].pie(med_ct_agg[0],colors=[celltype_colors[i] for i in med_ct_agg.cell_type], labels=med_ct_agg.cell_type);

old_ct_agg = old_ct.value_counts().reset_index()
axes[2].pie(old_ct_agg[0],colors=[celltype_colors[i] for i in old_ct_agg.cell_type], labels=old_ct_agg.cell_type);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_celltype_comp_nonneuronal.pdf", bbox_inches='tight')

In [None]:
# for each section, compute area covered by section
from scipy.spatial import ConvexHull
section_area = {}
age_total_area = {}
ages = []
cell_type = []
counts = []
all_ct = ["ExN", "InN","MSN"]#adata_combined.obs.cell_type.unique()
for i in np.arange(1, adata_combined.obs.batch.max()):
    curr_batch = adata_combined[adata_combined.obs.batch==i]
    curr_batch_area = 0
    for j in curr_batch.obs.slice.unique():
        curr_slice = curr_batch[curr_batch.obs.slice==j]
        curr_age = curr_slice.obs.age[0]
        hull = ConvexHull(curr_slice.obsm['spatial'])
        section_area[(i,j)] = hull.area
        if curr_age not in age_total_area:
            age_total_area[curr_age] = 0
        age_total_area[curr_age] += hull.area
        curr_batch_area += hull.area
    for ct in all_ct:
        counts.append(1000*curr_batch[curr_batch.obs.cell_type==ct].shape[0]/curr_batch_area)
        cell_type.append(ct)
        ages.append(curr_batch.obs.age.unique()[0])

In [None]:
area_norm_data = pd.DataFrame({'age':ages,'cell_type':cell_type,'counts':counts})


In [None]:
from scipy.stats import ttest_ind
tstats = []
for i in all_ct:
    curr_ct = area_norm_data[area_norm_data.cell_type==i]
    tstats.append(ttest_ind(curr_ct[curr_ct.age=='4wk'].counts, curr_ct[curr_ct.age=='90wk'].counts)[1])
tstats = multipletests(tstats, method='fdr_bh')[1]

In [None]:
age_pal = sns.color_palette(age_colors)

sns.barplot(x='cell_type',y='counts',data=area_norm_data,hue='age',palette=age_pal,errwidth=1,order=all_ct,hue_order=['4wk','24wk','90wk'])
sns.despine()
plt.ylabel('Cells/mm^2')
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS6_area_normalized_counts_neurons.pdf",bbox_inches='tight')

In [None]:
# for each section, compute area covered by section
from scipy.spatial import ConvexHull
section_area = {}
age_total_area = {}
ages = []
cell_type = []
counts = []
all_ct = ["ExN","InN","MSN","Olig", "Astro", "Endo", "Micro","OPC", "Peri", "Vlmc","Epen",   "Macro", "T cell"]#adata_combined.obs.cell_type.unique()
for i in np.arange(1, adata_combined.obs.batch.max()):
    curr_batch = adata_combined[adata_combined.obs.batch==i]
    curr_batch_area = 0
    for j in curr_batch.obs.slice.unique():
        curr_slice = curr_batch[curr_batch.obs.slice==j]
        curr_age = curr_slice.obs.age[0]
        hull = ConvexHull(curr_slice.obsm['spatial'])
        section_area[(i,j)] = hull.area
        if curr_age not in age_total_area:
            age_total_area[curr_age] = 0
        age_total_area[curr_age] += hull.area
        curr_batch_area += hull.area
    for ct in all_ct:
        counts.append(1000*curr_batch[curr_batch.obs.cell_type==ct].shape[0]/curr_batch_area)
        cell_type.append(ct)
        ages.append(curr_batch.obs.age.unique()[0])

In [None]:
area_norm_data = pd.DataFrame({'age':ages,'cell_type':cell_type,'counts':counts})


In [None]:
from scipy.stats import ttest_ind
tstats = []
for i in all_ct:
    curr_ct = area_norm_data[area_norm_data.cell_type==i]
    tstats.append(ttest_ind(curr_ct[curr_ct.age=='4wk'].counts, curr_ct[curr_ct.age=='90wk'].counts)[1])
tstats = multipletests(tstats, method='fdr_bh')[1]

In [None]:
np.array(all_ct)[tstats<0.05]

In [None]:
age_pal = sns.color_palette(age_colors)
#sns.scatterplot(x='cell_type',y='counts',data=area_norm_data,gr='age',palette=age_pal,hue_order=['4wk','24wk','90wk'])

sns.barplot(x='cell_type',y='counts',data=area_norm_data,hue='age',palette=age_pal,errwidth=1,order=all_ct,hue_order=['4wk','24wk','90wk'])
sns.despine()
plt.ylabel('Cells/mm^2')
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS6_area_normalized_counts.pdf",bbox_inches='tight')

In [None]:
age_pal = sns.color_palette(age_colors)

area_norm_data = pd.DataFrame({'age':ages,'cell_type':cell_type,'counts':counts})
sns.barplot(x='cell_type',y='counts',data=area_norm_data,hue='age',palette=age_pal,errwidth=1,order=['Vlmc','Epen','Macro','T cell'],hue_order=['4wk','24wk','90wk'])
sns.despine()
plt.ylabel('Cells/mm^2')
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS6_area_normalized_counts_inset.pdf",bbox_inches='tight')

In [None]:
area_norm_data_normed = area_norm_data.copy()
for i in area_norm_data_normed.cell_type.unique():
    curr_ct = area_norm_data_normed[area_norm_data_normed.cell_type==i]
    for j in ['4wk','24wk','90wk']:
        area_norm_data_normed.loc[np.logical_and(area_norm_data_normed.cell_type==i, area_norm_data_normed.age==j),'counts'] /= curr_ct[curr_ct.age=='4wk'].counts.values[0]

In [None]:
sns.barplot(x='cell_type',y='counts',data=area_norm_data_normed,hue='age',palette=age_pal)
sns.despine()
plt.ylabel('Fold change in cells/mm^2 relative to 4 wk')
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS6_fc_area_normalized_counts.pdf",bbox_inches='tight')

In [None]:
# for each section, compute area covered by section
from scipy.spatial import ConvexHull
section_area = {}
age_total_area = {}
ages = []
cell_type = []
counts = []
all_clust = adata_combined.obs.clust_annot.unique()
for i in np.arange(1, adata_combined.obs.batch.max()):
    curr_batch = adata_combined[adata_combined.obs.batch==i]
    curr_batch_area = 0
    for j in curr_batch.obs.slice.unique():
        curr_slice = curr_batch[curr_batch.obs.slice==j]
        curr_age = curr_slice.obs.age[0]
        hull = ConvexHull(curr_slice.obsm['spatial'])
        section_area[(i,j)] = hull.area
        if curr_age not in age_total_area:
            age_total_area[curr_age] = 0
        age_total_area[curr_age] += hull.area
        curr_batch_area += hull.area
    for ct in all_clust:
        counts.append(1000*curr_batch[curr_batch.obs.clust_annot==ct].shape[0]/curr_batch_area)
        cell_type.append(ct)
        ages.append(curr_batch.obs.age.unique()[0])


In [None]:
#from scipy.stats import ttest_ind
#tstats = []
#for i in all_clust:
#    curr_ct = area_norm_data[area_norm_data.cell_type==i]
#    tstats.append(ttest_ind(curr_ct[curr_ct.age=='4wk'].counts, curr_ct[curr_ct.age=='90wk'].counts)[1])
#tstats = multipletests(tstats, method='fdr_bh')[1]

In [None]:
np.cumsum()

In [None]:
f,axes = plt.subplots(nrows=4, ncols=1, figsize=(4,6), gridspec_kw={'wspace':0.1, 'hspace':0.1})
k = 0
for k,ct in enumerate(['Micro','Astro','Endo', 'Olig']):
    ax = axes[k]

    for n,i in enumerate(['4wk','24wk','90wk']):
        curr_adata = adata_combined[adata_combined.obs.age==i]
        curr_adata = curr_adata[curr_adata.obs.cell_type==ct]
        df = curr_adata.obs.clust_annot.value_counts()#.plot(kind='pie',ax=ax)
        df = df.sort_index()
        vals = df.values
        vals = vals / vals.sum()
        colors = [label_colors[c] for c in df.index]
        for j in range(len(df.index)):
        #    print(ct, i, vals[j])
            if j == 0:
                ax.bar(n+0.25, vals[j], 0.5, color=colors[j],bottom=0, align='edge')
            else:
                ax.bar(n+0.25, vals[j], 0.5, color=colors[j], bottom=np.sum(vals[:j]), align='edge')
        #ax.set_ylim([0,1])
        ax.axis('off')
        #df = df.reset_index()
        #sns.histplot(x='clust_annot',data=df,multiple='stack',hue='index',linewidth=0,stat='percent',legend=None, ax=ax)
        #ax.pie(df, colors=[label_colors[i] for i in df.index],labels=df.index)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_clust_stacked.pdf",bbox_inches='tight')

In [None]:
f,axes = plt.subplots(nrows=4, ncols=3, figsize=(12,6), gridspec_kw={'wspace':0.1, 'hspace':0.1})
k = 0
for k,ct in enumerate(['Micro','Astro','Endo', 'Olig']):
    for n,i in enumerate(['4wk','24wk','90wk']):
        ax = axes[k,n]
        curr_adata = adata_combined[adata_combined.obs.age==i]
        curr_adata = curr_adata[curr_adata.obs.cell_type==ct]
        df = curr_adata.obs.clust_annot.value_counts()#.plot(kind='pie',ax=ax)
        df = df.sort_index()
        ax.pie(df, colors=[label_colors[i] for i in df.index],labels=df.index)
#f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_clust_piechart.pdf",bbox_inches='tight')

In [None]:
f,axes = plt.subplots(nrows=4, ncols=3, figsize=(12,6), gridspec_kw={'wspace':0.1, 'hspace':0.1})
k = 0
for k,ct in enumerate(['Micro','Astro','Endo', 'Olig']):
    for n,i in enumerate(['4wk','24wk','90wk']):
        ax = axes[k,n]
        curr_adata = adata_combined[adata_combined.obs.age==i]
        curr_adata = curr_adata[curr_adata.obs.cell_type==ct]
        df = curr_adata.obs.clust_annot.value_counts()#.plot(kind='pie',ax=ax)
        df = df.sort_index()
        ax.pie(df, colors=[label_colors[i] for i in df.index],labels=df.index)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_clust_piechart.pdf",bbox_inches='tight')

In [None]:
# show per-batch number of cells for certain cell types
import pandas as pd
adata_obs = adata_combined.obs.copy()
adata_obs = adata_obs[adata_obs.dtype=="merfish"]
clust_names = ["Olig-3", "Olig-1","Olig-2",  
               'Micro-2','Micro-1','Micro-3', 
               'Astro-2','Astro-4','Astro-3',
              'Endo-3','Endo-1','Endo-2']
#for k in clust_names:
#    for i in ['4wk','24wk','90wk']:
##        curr_obs = adata_obs[adata_obs.age==i]
#        for j in curr_obs.batch.unique():
#            temp = curr_obs[curr_obs.batch == j]
##            ct_counts.append(100*np.sum(temp.clust_annot==k)/temp.shape[0])
#            ages.append(i)
#            clusts.append(k)
            
clust_counts = pd.DataFrame({'counts':counts, 'age': ages, 'clust': cell_type})
f = plt.figure(figsize=(6,6))
gs = plt.GridSpec(nrows=4, ncols=3, wspace=0.5,hspace=0.5)
for n,i in enumerate(clust_names):
    curr_counts = clust_counts[clust_counts.clust==i]
    ax = plt.subplot(gs[n])
    sns.barplot(x='age',y='counts',data=curr_counts,ax=ax,order=['4wk','24wk','90wk'],palette=sns.color_palette(age_colors),linewidth=0, errwidth=1,zorder=0)
    #sns.scatterplot(x='age',y='count',data=curr_counts,ax=ax,color='k',zorder=1,linewidth=1)

    sns.despine(ax=ax)
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.set_title(i)
f.savefig('/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_perbatch_counts.pdf',bbox_inches='tight', dpi=200)

In [None]:
# show per-batch number of cells for certain cell types
import pandas as pd
adata_obs = adata_combined.obs.copy()
adata_obs = adata_obs[adata_obs.dtype=="merfish"]
clust_names = [
               'Astro-1','Astro-2'
              ]
#for k in clust_names:
#    for i in ['4wk','24wk','90wk']:
##        curr_obs = adata_obs[adata_obs.age==i]
#        for j in curr_obs.batch.unique():
#            temp = curr_obs[curr_obs.batch == j]
##            ct_counts.append(100*np.sum(temp.clust_annot==k)/temp.shape[0])
#            ages.append(i)
#            clusts.append(k)
            
clust_counts = pd.DataFrame({'counts':counts, 'age': ages, 'clust': cell_type})
f = plt.figure(figsize=(6,6))
gs = plt.GridSpec(nrows=4, ncols=3, wspace=0.5,hspace=0.5)
for n,i in enumerate(clust_names):
    curr_counts = clust_counts[clust_counts.clust==i]
    ax = plt.subplot(gs[n])
    sns.barplot(x='age',y='counts',data=curr_counts,ax=ax,order=['4wk','24wk','90wk'],palette=sns.color_palette(age_colors),linewidth=0, errwidth=1,zorder=0)
    #sns.scatterplot(x='age',y='count',data=curr_counts,ax=ax,color='k',zorder=1,linewidth=1)

    sns.despine(ax=ax)
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.set_title(i)
f.savefig('/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_perbatch_counts.pdf',bbox_inches='tight', dpi=200)

In [None]:
from scipy.stats import ttest_ind
tstats = []
for i in all_clust:
    curr_ct = clust_counts[clust_counts.clust==i]
    tstats.append(ttest_ind(curr_ct[curr_ct.age=='4wk'].counts, curr_ct[curr_ct.age=='90wk'].counts)[1])
#tstats = multipletests(tstats, method='fdr_bh')[1]

In [None]:
pd.DataFrame({'clust':all_clust, 'pval':tstats}).sort_values("clust")

# Spatial Organization

In [None]:
adata_combined_old = adata_combined[adata_combined.obs.age=='90wk']
adata_combined_med = adata_combined[adata_combined.obs.age=='90wk']

adata_combined_young = adata_combined[adata_combined.obs.age=='4wk']

In [None]:
from matplotlib import cm


In [None]:
import seaborn as sns
import pandas as pd

from plotting import *

In [None]:
clust_encoding = {}
for i,v in enumerate(adata_combined.obs.clust_annot.unique()):
    clust_encoding[v] = i

celltype_encoding = {}
for i,v in enumerate(adata_combined.obs.cell_type.unique()):
    celltype_encoding[v] = i
    

adata_combined.obs["clust_id"] = [clust_encoding[i] for i in adata_combined.obs.clust_annot]
adata_combined.obs["celltype_id"] = [celltype_encoding[i] for i in adata_combined.obs.cell_type]
#adata_annot.obs["remapped_celltype_id"] = [remapped_celltype_encoding[i] for i in adata_annot.obs.remapped_cell_type]

In [None]:
def plot_obs_by_cells(A, obs_name, s=0.1, cmap=plt.cm.gist_rainbow, show_legend=False, vmax=None, rot=0):
    pts = A.obsm['spatial']#np.array([A.obs.center_x, A.obs.center_y]).T
    if rot != 0:
        pts = rotate(pts, degrees=rot)
    pts = pd.DataFrame({'x': pts[:,0], 'y': pts[:,1], 'obs':A.obs[obs_name]})
    if vmax is None:
        vmax = len(pts.obs.unique())
    cols = cmap(np.linspace(0,1,vmax+1))
    #for n,i in enumerate(pts.obs.unique()):
        #curr_pts = pts[pts.obs==i]
    plt.scatter(pts.x,pts.y,s=s,vmin=0,vmax=vmax,c=pts.obs,cmap=cmap)
    if show_legend:
        plt.legend(pts.obs.unique())


In [None]:
curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==8, adata_combined.obs.slice==2)]
print(curr_adata.obs.age.unique())
celltype_cmap = mpl.colors.ListedColormap([celltype_colors[i] for i in adata_combined.obs.cell_type.unique()])
plt.figure(figsize=(3,5))
plot_obs_by_cells(curr_adata, 'celltype_id',s=1,vmax=adata_combined.obs.celltype_id.max(),
                  cmap=celltype_cmap,rot=-187)
plt.axis('off')
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_majorcelltypes_young_combined.pdf",dpi=300,bbox_inches='tight')

In [None]:
curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==12, adata_combined.obs.slice==1)]
plt.figure(figsize=(3,5))
print(curr_adata.obs.age.unique())

plot_obs_by_cells(curr_adata, 'celltype_id',s=1,vmax=adata_combined.obs.celltype_id.max(),
                  cmap=celltype_cmap, rot=-15)
plt.axis('off')
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_majorcelltypes_med_combined.pdf",dpi=300,bbox_inches='tight')

In [None]:
curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==9, adata_combined.obs.slice==2)]
plt.figure(figsize=(3,5))
print(curr_adata.obs.age.unique())

plot_obs_by_cells(curr_adata, 'celltype_id',s=1,vmax=adata_combined.obs.celltype_id.max(),
                  cmap=celltype_cmap, rot=35)
plt.axis('off')
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_majorcelltypes_old_combined.pdf",dpi=300,bbox_inches='tight')

In [None]:
from plotting import *
clust_encoding = {k:i for i,k in enumerate(label_colors.keys())}
#celltype_colors, celltype_pals, label_colors, clust_pals = generate_palettes_new(adata)
curr_cmap = mpl.colors.ListedColormap([label_colors[i] for i in label_colors.keys()])
adata_combined.obs['clust_encoding'] = [clust_encoding[i] for i in adata_combined.obs.clust_annot]

In [None]:
seg_max = adata_combined.obs.spatial_clust_annots_value.max()

In [None]:
seg_cmap = mpl.colors.ListedColormap([ 'gold', 'orange', 'chocolate', 'brown', 'steelblue','gray',  'purple', 'darkkhaki'])

In [None]:
def plot_seg(A, cmap, ax=None, rot=0, s=0.1, xlim=None, ylim=None,key='spatial_clust_annots_value',vmax=7):
    if ax is None:
        f,ax = plt.subplots()
    all_pts = A.obsm['spatial'].copy()#np.array([A.obs.center_x, A.obs.center_y]).T
    # zero center all_pts
    all_pts = rotate(all_pts, degrees=rot)
    all_pts[:,0] -= all_pts[:,0].min()
    all_pts[:,1] -= all_pts[:,1].min()
    ax.scatter(all_pts[:,0], all_pts[:,1],s=s, c=A.obs[key],cmap=cmap,vmin=0,vmax=vmax)
    ax.axis('off')
    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)


In [None]:

curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==8, adata_combined.obs.slice==1)]
aspect_ratio, nx, ny = calculate_aspect_ratio(curr_adata)
curr_rot = -183
curr_size = 3
xlim = [200, 2300]
ylim = [200, 4000]
aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])
#aspect_ratio, nx, ny = calculate_aspect_ratio(curr_adata, rot=curr_rot)
print(aspect_ratio, nx, ny)
f, ax = plt.subplots(figsize=(5*7*aspect_ratio,5))
ax = plt.subplot(1,7,1)
plot_seg(curr_adata, seg_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim,vmax=seg_max)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['ExN'])].unique()
ax = plt.subplot(1,7,2)
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax, xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,3)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['InN', 'MSN'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,4)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Olig', 'OPC'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,5)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Astro'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,6)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Epen', 'Endo', 'Vlmc', 'Peri'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,7)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Micro','Macro','T cell','B cell'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_majorcelltypes_young.pdf",dpi=300,bbox_inches='tight')

In [None]:

curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==12, adata_combined.obs.slice==0)]
curr_rot = -12
aspect_ratio, nx, ny = calculate_aspect_ratio(curr_adata, rot=curr_rot)
print(aspect_ratio, nx, ny)
xlim = [1950, 1950+2100]
ylim = [200, 3700]
aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])
f, ax = plt.subplots(figsize=(5*7*aspect_ratio,5))
ax = plt.subplot(1,7,1)
plot_seg(curr_adata, seg_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim,)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['ExN'])].unique()
ax = plt.subplot(1,7,2)
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax, xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,3)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['InN', 'MSN'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,4)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Olig', 'OPC'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,5)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Astro'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,6)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Epen', 'Endo', 'Vlmc', 'Peri'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,7)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Micro','Macro','T cell','B cell'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_majorcelltypes_med.pdf",dpi=300,bbox_inches='tight')

In [None]:

curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==9, adata_combined.obs.slice==1)]
print(curr_adata.obs.age.unique())
curr_rot = 35
aspect_ratio, nx, ny = calculate_aspect_ratio(curr_adata, rot=curr_rot)
print(aspect_ratio, nx, ny)
xlim = [200, 2300]
ylim = [400, 4000]
aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])
f, ax = plt.subplots(figsize=(5*7*aspect_ratio,5))
ax = plt.subplot(1,7,1)
plot_seg(curr_adata, seg_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['ExN'])].unique()
ax = plt.subplot(1,7,2)
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax, xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,3)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['InN', 'MSN'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,4)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Olig', 'OPC'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,5)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Astro'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,6)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Epen', 'Endo', 'Vlmc', 'Peri'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

ax = plt.subplot(1,7,7)
cell_types = adata_combined.obs.clust_annot[adata_combined.obs.cell_type.isin(['Micro','Macro','T cell','B cell'])].unique()
plot_clust_subset(curr_adata, cell_types, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_majorcelltypes_old.pdf",dpi=300,bbox_inches='tight')

In [None]:
# make plots of cell type enrichment per 
# count cell types per kmeans clust
spatial_domains = ['Pia','L2/3', 'L5', 'L6','LatSept', 'CC', 'Striatum','Ventricle']
clust_order = [
 'ExN-L2/3-1',
 'ExN-L2/3-2',
 'ExN-L5-1',
 'ExN-L5-2',
 'ExN-L5-3',
 'ExN-L6-1',
 'ExN-L6-2',
 'ExN-L6-3',
 'ExN-Olf',
 'InN-Olf-1',
 'InN-Olf-2',

 'InN-Vip',

 'InN-Lamp5',

 'InN-Pvalb-1',
 'InN-Pvalb-2',
 'InN-Pvalb-3',
 'InN-Sst-1',
 'InN-Sst-2',
 'InN-Calb2-1',
 'InN-Calb2-2',
 'InN-Chat',
 'InN-Lhx6',

'MSN-D1-1',
 'MSN-D1-2',
 'MSN-D2',
 'OPC',
 'Olig-1',
 'Olig-2',
 'Olig-3',

'Astro-1',
 'Astro-2',
 'Vlmc',
 'Peri-1',
 'Peri-2',
 'Endo-1',
 'Endo-2',
 'Endo-3',
 'Epen',

 'Micro-1',
 'Micro-2',
 'Micro-3',
 'Macro',
 'T cell',
]

short_clust_order = [
 'OPC',
 'Olig',

'Astro',
 'Vlmc',
 'Peri',
 'Endo',
 'Epen',

 'Micro',
 'Macro',
 'T cell',
]




In [None]:
short_clust_annot = []
for i in list(adata_combined.obs.clust_annot):
    if "-" in i:
        split_name = i.split("-")
        if len(split_name) == 2:
            short_clust_annot.append(split_name[0])
        else:
            short_clust_annot.append(split_name[0] + "-" + split_name[1])
    else:
        short_clust_annot.append(i)
adata_combined.obs['short_clust_annot'] = short_clust_annot

In [None]:
adata_merfish = adata_combined[adata_combined.obs.dtype=="merfish"]

In [None]:
def plot_clust_spatial_enrichment(A,vmin=0,vmax=1,uniq_clusts=None,clust_key='clust_annot',label_colors=None, spatial_domains=['Pia','L2/3', 'L5','L6', 'LatSept', 'CC', 'Striatum','Ventricle'],
    seg_cmap=plt.cm.viridis):
    if uniq_clusts is None:
        uniq_clusts = sorted(A.obs[clust_key].unique())
    n_spatial_domains = int(A.obs.spatial_clust_annots_value.max() + 1)
    clust_counts = np.zeros((n_spatial_domains, len(uniq_clusts)))
    print(clust_counts.shape)
    for i in range(n_spatial_domains):
        curr_clusts = A[A.obs.spatial_clust_annots_value==i,:].obs[clust_key]
        for j,c in enumerate(uniq_clusts):
            clust_counts[i,j] = np.sum(curr_clusts==c)
    clust_avgs = clust_counts.copy()
    for i in range(clust_avgs.shape[1]):
        clust_avgs[:,i] /= clust_avgs[:,i].sum()

    f, ax = plt.subplots(figsize=(5.5,1))
    gs = plt.GridSpec(nrows=2,ncols=2,width_ratios=[0.36, 20], height_ratios=[20,2], wspace=0.01, hspace=0.05)

    ax = plt.subplot(gs[0,0])
    ax.imshow(np.expand_dims(np.arange(n_spatial_domains),1),aspect='auto',interpolation='none', cmap=seg_cmap,rasterized=True)
    sns.despine(ax=ax,bottom=True,left=True)
    ax.set_yticks(np.arange(clust_avgs.shape[0]));
    ax.set_yticklabels(spatial_domains,fontsize=6)
    ax.set_xticks([])
    ax = plt.subplot(gs[0,1])
    ax.imshow(clust_avgs,aspect='auto',vmin=vmin,vmax=vmax, cmap=plt.cm.viridis)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('off')
    #for i in range(clust_counts.shape[0]):
        #ax.scatter(np.arange(clust_counts.shape[1]), i*np.ones(clust_counts.shape[1]), s=0.005*clust_counts[i,:],c='k')
    ax = plt.subplot(gs[1,1])
    if label_colors is None:
        curr_cmap = plt.cm.viridis
    else:
        curr_cmap = mpl.colors.ListedColormap([label_colors[i] for i in uniq_clusts])
    ax.imshow(np.expand_dims(np.arange(len(uniq_clusts)),1).T,aspect='auto',interpolation='none', cmap=curr_cmap,rasterized=True)

    ax.set_xticks(np.arange(clust_avgs.shape[1]));
    ax.set_yticks([])
    ax.set_xticklabels(uniq_clusts,rotation=90,fontsize=6);
    sns.despine(ax=ax, left=True, bottom=True)
    return clust_avgs, clust_counts


In [None]:
young_clusts, young_counts = plot_clust_spatial_enrichment(adata_merfish[adata_merfish.obs.age=='4wk'],vmax=1,uniq_clusts=clust_order,seg_cmap=seg_cmap,label_colors=label_colors)
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_cellcomp_young.pdf",dpi=300,bbox_inches='tight')

In [None]:
plot_clust_spatial_enrichment(adata_merfish[adata_merfish.obs.age=='24wk'],vmax=1,uniq_clusts=clust_order,seg_cmap=seg_cmap,label_colors=label_colors);
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_cellcomp_med.pdf",dpi=300,bbox_inches='tight')

In [None]:
old_clusts, old_counts = plot_clust_spatial_enrichment(adata_merfish[adata_merfish.obs.age=='90wk'],vmax=1,uniq_clusts=clust_order, seg_cmap=seg_cmap,label_colors=label_colors);
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_cellcomp_old.pdf",dpi=300,bbox_inches='tight')

In [None]:
short_clust_order = ['Astro-1',
 'Astro-2',
 'Endo-1',
 'Endo-2',
 'Endo-3',
 'Epen',
 'OPC',
 'Olig-1',
 'Olig-2',
 'Olig-3',
 'Peri-1',
 'Peri-2',
 'Vlmc',
 'Macro',
 'Micro-1',
 'Micro-2',
 'Micro-3',

 'T cell',
                     
                
]

short_clust_order = [
    'Astro',
    'Endo',
    "Peri",
    'Vlmc',
    'Epen',
    'OPC',
    'Olig',
    'Micro',
    'Macro',
    'T cell',
]

In [None]:
sns.set_style('white')
diff = old_avgs-young_avgs#,vmin=-0.75,vmax=0.75,cmap=plt.cm.bwr)
diff[np.isinf(diff)] = 5
f, ax = plt.subplots(figsize=(3,2))
gs = plt.GridSpec(nrows=2,ncols=2,width_ratios=[0.36, 20], height_ratios=[20,2], wspace=0.01, hspace=0.05)
uniq_clusts = short_clust_order
n_spatial_domains = diff.shape[0]
#ax = plt.subplot(gs[0,0])
#ax.imshow(np.expand_dims(np.arange(n_spatial_domains),1),aspect='auto',interpolation='none', cmap=seg_cmap,rasterized=True)
#sns.despine(ax=ax,bottom=True,left=True)

ax = plt.subplot(gs[0,1])
ax.imshow(diff,cmap=plt.cm.seismic, vmin=-0.25, vmax=0.25,rasterized=True, aspect='auto',interpolation='none')
ax.set_yticks(np.arange(diff.shape[0]));
ax.set_yticklabels(spatial_domains,fontsize=6)
ax.set_xticks([])
#ax.set_xticks([])
#ax.set_yticks([])
#ax.axis('off')
#ax = plt.subplot(gs[1,1])
#curr_cmap = mpl.colors.ListedColormap([label_colors[i] for i in uniq_clusts])
#ax.imshow(np.expand_dims(np.arange(len(uniq_clusts)),1).T,aspect='auto',interpolation='none', cmap=curr_cmap,rasterized=True)

ax.set_xticks(np.arange(diff.shape[1]));
#ax.set_yticks([])
ax.set_xticklabels(uniq_clusts,rotation=90,fontsize=6);
sns.despine(ax=ax, left=True, bottom=True)

f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_enrichment_diff.pdf",bbox_inches='tight', dpi=200)

In [None]:
# show examples of cell types that are specific to one age 
age_specific_celltypes = ["Olig-1", "Olig-2","Olig-3"]
k = 0
figsize_height = 5
figsize_width = figsize_height*3*1.1*aspect_ratio
f, axes = plt.subplots(figsize=(figsize_height,figsize_width), nrows=3, ncols=3, gridspec_kw={'wspace':0.1, 'hspace':0.1})
curr_size = 0.5

for c in age_specific_celltypes:
    # young
    ax = axes[0, k]#plt.subplot(3,3,k)

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==8, adata_combined.obs.slice==1)]

    curr_rot = -183
    xlim = [200, 2300]
    ylim = [200, 4000]
    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)
    
    #k += 1
    
    # med
    ax = axes[1,k] #plt.subplot(3,3,k)

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==12, adata_combined.obs.slice==0)]

    curr_rot = -12; xlim = [1950, 1950+2100]; ylim = [200, 3700]

    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

    #k += 1
    # old
    ax = axes[2,k] #plt.subplot(3,3,k)

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==9, adata_combined.obs.slice==1)]
    curr_rot = 35; xlim = [200, 2300]; ylim = [400, 4000]

    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)
    k += 1

#f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_majorcelltypes_examples_olig.png",dpi=300,bbox_inches='tight')

In [None]:
# show examples of cell types that are specific to one age 
age_specific_celltypes = ['Micro-1','Micro-2','Micro-3']
k = 0
figsize_height = 5
figsize_width = figsize_height*3*1.1*aspect_ratio
f, axes = plt.subplots(figsize=(figsize_height,figsize_width), nrows=3, ncols=3, gridspec_kw={'wspace':0.1, 'hspace':0.1})
curr_size = 1

for c in age_specific_celltypes:
    # young
    ax = axes[0, k]#plt.subplot(3,3,k)

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==8, adata_combined.obs.slice==1)]
    print(curr_adata.obs.age[0])
    curr_rot = -183
    xlim = [200, 2300]
    ylim = [200, 4000]
    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)
    
    #k += 1
    
    # med
    ax = axes[1, k]#plt.subplot(3,3,k)

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==12, adata_combined.obs.slice==0)]

    curr_rot = -12; xlim = [1950, 1950+2100]; ylim = [200, 3700]

    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

    #k += 1
    # old
    ax = axes[2, k]#plt.subplot(3,3,k)

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==9, adata_combined.obs.slice==1)]
    curr_rot = 35; xlim = [200, 2300]; ylim = [400, 4000]

    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)
    k += 1

f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_majorcelltypes_examples_micro.png",dpi=300,bbox_inches='tight')

In [None]:
# show examples of cell types that are specific to one age 
age_specific_celltypes = ['Astro-1','Astro-2']
k = 0
figsize_height = 5
figsize_width = figsize_height*3*1.1*aspect_ratio
f, axes = plt.subplots(figsize=(figsize_height,figsize_width), nrows=3, ncols=3, gridspec_kw={'wspace':0.1, 'hspace':0.1})
curr_size = 1

for c in age_specific_celltypes:
    # young
    ax = axes[0,k]

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==8, adata_combined.obs.slice==1)]

    curr_rot = -183
    xlim = [200, 2300]
    ylim = [200, 4000]
    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)
    
    #k += 1
    
    # med
    ax =axes[1,k]#plt.subplot(4,3,k)

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==12, adata_combined.obs.slice==0)]

    curr_rot = -12; xlim = [1950, 1950+2100]; ylim = [200, 3700]

    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

    #k += 1
    # old
    ax = axes[2,k]#plt.subplot(4,3,k)

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==9, adata_combined.obs.slice==1)]
    curr_rot = 35; xlim = [200, 2300]; ylim = [400, 4000]

    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)
    k += 1
for i in range(3):
    axes[i,2].axis('off')
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_majorcelltypes_examples_astro.png",dpi=300,bbox_inches='tight')

In [None]:
# show examples of cell types that are specific to one age 
age_specific_celltypes = ['Endo-1','Endo-2','Endo-3']
k = 0
figsize_height = 5
figsize_width = figsize_height*3*1.1*aspect_ratio
f, axes = plt.subplots(figsize=(figsize_height,figsize_width), nrows=3, ncols=3, gridspec_kw={'wspace':0.1, 'hspace':0.1})
curr_size = 1

for c in age_specific_celltypes:
    # young
    ax = axes[0, k]#plt.subplot(3,3,k)

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==8, adata_combined.obs.slice==1)]
    print(curr_adata.obs.age[0])
    curr_rot = -183
    xlim = [200, 2300]
    ylim = [200, 4000]
    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)
    
    #k += 1
    
    # med
    ax = axes[1, k]#plt.subplot(3,3,k)

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==12, adata_combined.obs.slice==0)]

    curr_rot = -12; xlim = [1950, 1950+2100]; ylim = [200, 3700]

    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)

    #k += 1
    # old
    ax = axes[2, k]#plt.subplot(3,3,k)

    curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==9, adata_combined.obs.slice==1)]
    curr_rot = 35; xlim = [200, 2300]; ylim = [400, 4000]

    plot_clust_subset(curr_adata, c, curr_cmap, rot=curr_rot,s=curr_size, ax=ax,xlim=xlim, ylim=ylim)
    k += 1

f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_majorcelltypes_examples_endo.png",dpi=300,bbox_inches='tight')

# Show markers for clusters for Fig 1

In [None]:
endo_cells = adata_combined_merfish[adata_combined_merfish.obs.cell_type=="Endo"]
sc.tl.rank_genes_groups(endo_cells,groupby='clust_annot')
sc.tl.filter_rank_genes_groups(endo_cells,min_fold_change=1.5,min_in_group_fraction=0.3)
sc.pl.rank_genes_groups_stacked_violin(endo_cells, groupby='clust_annot')


In [None]:
f, ax = plt.subplots()
sc.pl.stacked_violin(endo_cells, var_names=['Cldn5','Sparc','Xdh'], groupby='clust_annot', cmap=sns.light_palette("darkkhaki", as_cmap=True),ax=ax)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_endo_cluster_expr.pdf",bbox_inches='tight')

In [None]:
micro_cells = adata_combined_merfish[adata_combined_merfish.obs.cell_type=="Micro"]
sc.pp.regress_out(micro_cells,'total_counts')
sc.tl.rank_genes_groups(micro_cells,groupby='clust_annot')
sc.tl.filter_rank_genes_groups(micro_cells,min_fold_change=1,min_in_group_fraction=0.15)
sc.pl.rank_genes_groups_stacked_violin(micro_cells, groupby='clust_annot')


In [None]:
sc.pl.rank_genes_groups(micro_cells)

In [None]:
f, ax = plt.subplots()
sc.pl.stacked_violin(micro_cells, var_names=['Selplg','Zfhx3','B2m'], groupby='clust_annot',cmap=sns.light_palette("deeppink", as_cmap=True),ax=ax)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_microg_cluster_expr_V2.pdf",bbox_inches='tight')

In [None]:
olig_cells = adata_combined_merfish[adata_combined_merfish.obs.cell_type=="Olig"]

sc.tl.rank_genes_groups(olig_cells,groupby='clust_annot')
sc.tl.filter_rank_genes_groups(olig_cells,min_fold_change=2,min_in_group_fraction=0.3)
sc.pl.rank_genes_groups_stacked_violin(olig_cells, groupby='clust_annot')

In [None]:
f,ax = plt.subplots()
sc.pl.stacked_violin(olig_cells, var_names=['Olig1','Neat1','Il33'], groupby='clust_annot',cmap=plt.cm.Greys,ax=ax)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_olig_cluster_expr.pdf",bbox_inches='tight')

In [None]:
astro_cells = adata_combined_merfish[adata_combined_merfish.obs.cell_type=="Astro"]
sc.tl.rank_genes_groups(astro_cells,groupby='clust_annot')
sc.tl.filter_rank_genes_groups(astro_cells,min_fold_change=2,min_in_group_fraction=0.3)
sc.pl.rank_genes_groups_stacked_violin(astro_cells, groupby='clust_annot')


In [None]:
f,ax = plt.subplots()
sc.pl.stacked_violin(astro_cells, var_names=['Mfge8','Gfap','C4b'], groupby='clust_annot',cmap=sns.light_palette("seagreen", as_cmap=True),ax=ax)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig1_astro_cluster_expr.pdf",bbox_inches='tight')

In [None]:
sc.pl.stacked_violin(astro_cells, var_names=['Mfge8','Gfap','C4b'], groupby='clust_annot',cmap=sns.light_palette("seagreen", as_cmap=True),standard_scale='col')


In [None]:
plot_info = {
    'young' : {
        'batch' : 8,
        'slice' : 1,
        'rot' : -183,
        'xlim' : [200, 2300],
        'ylim' : [200, 4000]
    },
    'mid' : {
        'batch' : 12,
        'slice' : 0,
        'rot' : -12,
        'xlim' : [1950, 1950+2100],
        'ylim' : [200, 3700]
    },
    'old' : {
        'batch' : 9,
        'slice' : 1,
        'rot' : 35,
        'xlim' : [200, 2300],
        'ylim' : [400, 4000]
    }
}


In [None]:
adata_combined_merfish = adata_combined[adata_combined.obs.dtype=="merfish"]
#adata_combined_merfish.X = 


In [None]:
def identify_nearest_neighbors_with_idx(X,Y,dist_thresh, min_dist_thresh=15):
    if X.shape[0] > 0 and Y.shape[0] > 0:
        kdtree = KDTree(Y)
        ind, dists = kdtree.query_radius(X, r=dist_thresh, count_only=False,return_distance=True)
        ind_X = np.hstack([[i]*len(ind[i]) for i in np.arange(len(ind)) if len(ind[i])>0])
        
        ind = np.hstack(ind)
        dists = np.hstack(dists)
        if len(ind) > 0:
            ind = ind[dists>min_dist_thresh]      
            ind_X = ind_X[dists>min_dist_thresh]
        return ind.astype(np.int), ind_X.astype(np.int)
    else:
        return np.array([])

def identify_nearest_neighbors_with_dist(X,Y):
    if X.shape[0] > 0 and Y.shape[0] > 0:
        kdtree = KDTree(Y)
        dists, ind = kdtree.query(X, k=1,return_distance=True)
        #ind_X = np.hstack([[i]*len(ind[i]) for i in np.arange(len(ind)) if len(ind[i])>0])
        return dists, ind
    else:
        return np.array([])

def compute_celltype_obs_distance_correlation(A,cell_type_X, cell_type_Y, obs_key_X, celltype_key='cell_type'):
    X = A[A.obs[celltype_key] == cell_type_X]
    Y = A[A.obs[celltype_key] == cell_type_Y]
    obs_X = X.obs[obs_key_X]
    curr_X = X.obsm['spatial']
    curr_Y = Y.obsm['spatial']
    dists_Y, ind_Y = identify_nearest_neighbors_with_dist(curr_X, curr_Y)
    return obs_X.values, dists_Y

def compute_celltype_obs_correlation(A,cell_type_X, cell_type_Y, obs_key_X, obs_key_Y, celltype_key='cell_type', radius=40, min_dist_thresh=15):
    X = A[A.obs[celltype_key] == cell_type_X]
    Y = A[A.obs[celltype_key] == cell_type_Y]
    obs_X = X.obs[obs_key_X]
    obs_Y = Y.obs[obs_key_Y]
    curr_X = X.obsm['spatial']
    curr_Y = Y.obsm['spatial']
    neighbors_X, ind_X = identify_nearest_neighbors_with_idx(curr_X, curr_Y, dist_thresh=radius, min_dist_thresh=min_dist_thresh)
    curr_expr = obs_Y[neighbors_X]
    return obs_X.values[ind_X], curr_expr.values

In [None]:

k = 1
plt.figure(figsize=(5, 6))
gs = plt.GridSpec(nrows=3, ncols=4,hspace=0.01,wspace=0.1)
for i,ct in enumerate(["Astro","Micro","Olig"]):
    score_name = f"activate_{ct.lower()}"
    age_score = adata_combined_merfish[adata_combined_merfish.obs.cell_type==ct].obs[score_name]
    age_score = age_score[~np.isnan(age_score)]
    curr_vmin = np.quantile(age_score,0.05)#.min()
    curr_vmax = np.quantile(age_score, 0.95)#.max()
    for j,age in enumerate(['young','mid','old']):
        curr_adata = adata_combined_merfish[np.logical_and(adata_combined_merfish.obs.batch==plot_info[age]['batch'], adata_combined_merfish.obs.slice==plot_info[age]['slice'])]
        ax = plt.subplot(gs[j,i])
        if age=="young":
            ax.set_title(ct)
        curr_adata_celltype = curr_adata[curr_adata.obs.cell_type==ct]
        plot_obs(curr_adata, curr_adata_celltype.obs.clust_annot.unique(), score_name,plt.cm.rainbow,s=0.5,alpha=0.1,rot=plot_info[age]['rot'], vmin=curr_vmin,vmax=curr_vmax, ax=ax, xlim=plot_info[age]['xlim'], ylim=plot_info[age]['ylim'])
        if i == 0:
            ax.set_ylabel(age)

#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig2_activation_score_spatial.pdf",bbox_inches='tight',dpi=300)

In [None]:
print(curr_adata.obs.age.unique())

k = 1
plt.figure(figsize=(5, 6))
gs = plt.GridSpec(nrows=3, ncols=4,hspace=0.01,wspace=0.1)
for i,ct in enumerate(["Astro","Micro","Endo","Olig"]):
    age_score = adata_combined[adata_combined.obs.cell_type==ct].obs.age_score
    age_score = age_score[~np.isnan(age_score)]
    curr_vmin = np.quantile(age_score,0.05)#.min()
    curr_vmax = np.quantile(age_score, 0.95)#.max()
    for j,age in enumerate(['young','mid','old']):
        curr_adata = adata_combined[np.logical_and(adata_combined.obs.batch==plot_info[age]['batch'], adata_combined.obs.slice==plot_info[age]['slice'])]
        ax = plt.subplot(gs[j,i])
        if age=="young":
            ax.set_title(ct)
        curr_adata_celltype = curr_adata[curr_adata.obs.cell_type==ct]
        plot_obs(curr_adata, curr_adata_celltype.obs.clust_annot.unique(), "age_score",plt.cm.turbo,s=0.1,alpha=0.1,rot=plot_info[age]['rot'], vmin=curr_vmin,vmax=curr_vmax, ax=ax, xlim=plot_info[age]['xlim'], ylim=plot_info[age]['ylim'])
        if i == 0:
            ax.set_ylabel(age)

#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig2_age_score_spatial.pdf",bbox_inches='tight',dpi=300)

In [None]:
# plot age score by anatomical region
spatial_regions = ['Pia', 'L2/3','L5', 'L6', 'CC','LatSept',  'Striatum',  'Ventricle']
age_score_spatial = np.zeros((6,3,len(spatial_regions)))
celltypes = ["ExN","InN", "Astro","Micro","Endo","Olig"]
for i,ct in enumerate(celltypes):
    print(i)
    for j,age in enumerate(['4wk','24wk','90wk']):
            for k,r in enumerate(spatial_regions):
                curr_adata = adata_combined[np.logical_and(adata_combined.obs.cell_type==ct, adata_combined.obs.age==age)]
                age_score_spatial[i,j,k] = np.mean(curr_adata[curr_adata.obs.spatial_clust_annots==r].obs.age_score)


In [None]:

for i in range(4):
    sns.set_style('white')
    age_score = adata_combined[adata_combined.obs.cell_type==celltypes[i]].obs.age_score
    age_score = age_score[~np.isnan(age_score)]
    curr_vmin = np.quantile(age_score,0.02)#.min()
    curr_vmax = np.quantile(age_score, 0.98)#.max()

    f,ax = plt.subplots()
    ax.imshow(age_score_spatial[i,:,:],cmap=age_score_cmap,rasterized=True)
    ax.set_xticks(np.arange(len(spatial_regions)))
    ax.set_yticks([0,1,2])
    ax.set_yticklabels(['Young',"Mid","Old"])
    ax.set_xticklabels(spatial_regions,rotation=90);
    sns.despine(ax=ax,left=True,bottom=True)
    #plt.savefig(f"/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig2_age_score_spatial_quant_{celltypes[i]}.pdf",bbox_inches='tight',dpi=300)

# Show differentially expressed genes

In [None]:
age_specific_genes = ['C4b','Il33','Gfap', 'Xdh']#'Nr6a1','Xdh','Ifit3', 'Il33', 'C4b', 'Fmo2', 'Xdh', 'Cdkn1a', 'C3', 'Serpina3n','Sparc', 'Sncg'] 
#age_specific_genes = sorted(list(np.unique([i for i in age_specific_genes if i in adata_combined.var_names])))
# sort by average expression old vs young
celltypes = adata_combined_merfish.obs.clust_annot.unique()
diffs = []
k = 1
aspect_ratio = (2500-200)/(4000-200)
f, ax = plt.subplots(figsize=(2.5*aspect_ratio*len(age_specific_genes),2.5*3), gridspec_kw={'wspace':0.1, 'hspace':0.1})
vmax = 5

for c in age_specific_genes:
    ax = plt.subplot(3,len(age_specific_genes),k)
    ax.set_title(c)

    # young
    curr_adata = adata_combined_merfish[np.logical_and(adata_combined_merfish.obs.batch==8, adata_combined_merfish.obs.slice==1)]

    curr_rot = -183
    curr_size = 1
    xlim = [200, 2500]
    ylim = [200, 4000]
    aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])
    plot_gene_expr(curr_adata, celltypes, c, plt.cm.Reds, s=0.25,vmin=0,vmax=vmax, rot=curr_rot, ax=ax, xlim=xlim, ylim=ylim)
    k += 1
    
for c in age_specific_genes:
    # med
    
    ax = plt.subplot(3,len(age_specific_genes),k)
    # old
    curr_adata = adata_combined_merfish[np.logical_and(adata_combined_merfish.obs.batch==12, adata_combined_merfish.obs.slice==0)]
    print(curr_adata.obs.age.unique())
    curr_rot = -15
    xlim = [1950, 1950+2100]
    ylim = [200, 3600]
    aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])
    plot_gene_expr(curr_adata, celltypes, c, plt.cm.Reds, s=0.25,vmin=0,vmax=vmax, rot=curr_rot, ax=ax, xlim=xlim, ylim=ylim)
    k += 1

for c in age_specific_genes:
    
    ax = plt.subplot(3,len(age_specific_genes),k)
    # old
    curr_adata = adata_combined_merfish[np.logical_and(adata_combined_merfish.obs.batch==9, adata_combined_merfish.obs.slice==1)]
    curr_rot = 35
    xlim = [200, 2300]
    ylim = [400, 4000]
    aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])
    plot_gene_expr(curr_adata, celltypes, c, plt.cm.Reds, s=0.25,vmin=0,vmax=vmax, rot=curr_rot, ax=ax, xlim=xlim, ylim=ylim)
    k += 1
plt.tight_layout()
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_aging_expr_diff.pdf", bbox_inches='tight',dpi=300)

In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.stacked_violin(adata_combined_merfish[np.logical_and(adata_combined_merfish.obs.age=='4wk',
                                                           adata_combined_merfish.obs.cell_type.isin(['Astro','Endo','Micro', 'Olig']))], var_names=age_specific_genes, groupby='cell_type',ax=ax)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/extrafig_example_expr_violinplot_young.pdf")

In [None]:
f, ax = plt.subplots(figsize=(5,5))

sc.pl.stacked_violin(adata_combined_merfish[np.logical_and(adata_combined_merfish.obs.age=='90wk',
                                                           adata_combined_merfish.obs.cell_type.isin(['Astro','Endo','Micro', 'Olig']))], var_names=age_specific_genes, groupby='cell_type',ax=ax)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/extrafig_example_expr_violinplot_old.pdf")

In [None]:
adata_combined_merfish = unbinarize_strings(adata_combined_merfish)

In [None]:
plot_gene_expr(curr_adata, celltypes, c, plt.cm.Reds, s=1,alpha=0.5, vmin=vmin,vmax=vmax, rot=curr_rot, ax=ax, xlim=xlim, ylim=ylim)


In [None]:
# plot area specific genes
area_genes = [
 'Otof',
'Cux2',
'Rorb',
  'Rspo1',

 'Scube1',
  'Fezf2',
 'Syt6',
 'Drd1',
 'Drd2',
]
curr_size = 5
celltypes = adata_combined_merfish.obs.clust_annot.unique()

curr_adata = adata_combined_merfish[np.logical_and(adata_combined_merfish.obs.batch==9, adata_combined_merfish.obs.slice==1)]
curr_rot = 35
xlim = [200, 2300]
ylim = [400, 4000]
aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])

f, axes = plt.subplots(nrows=1, ncols=len(area_genes), figsize=(5*aspect_ratio*(len(area_genes)+1),5*1))
#plot_seg(curr_adata, seg_cmap, rot=curr_rot,s=curr_size, ax=axes[0],xlim=xlim, ylim=ylim)

k = 1
for c in area_genes: 
    ax = axes[k]
    gene_expr = curr_adata[:,c].X
    vmin = np.quantile(gene_expr, 0.01)
    vmax = np.quantile(gene_expr, 0.99)
    #temp = curr_adata[gene_expr<0.5*(vmax-vmin),:]
    #plot_gene_expr(temp, celltypes, c, plt.cm.Greys, s=2.5,alpha=0.1,vmin=vmin,vmax=vmin, rot=curr_rot, ax=ax, xlim=xlim, ylim=ylim)
    plot_gene_expr(curr_adata, celltypes, c, plt.cm.Reds, s=1,alpha=0.5, vmin=vmin,vmax=vmax, rot=curr_rot, ax=ax, xlim=xlim, ylim=ylim)
    k += 1
    ax.set_title(c)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/figures/figS4_layer_markers.png",bbox_inches='tight',dpi=300)

In [None]:
pro_inflammatory = ['Il1b','Il1a','Tnf','Il6','Il18']

In [None]:
activate_endo = ["B2m", "Nfkbia", "Serinc3","Xdh", "Gfap", "Tap1"]

In [None]:
sc.tl.score_genes(adata_combined_merfish, gene_list=pro_inflammatory, score_name='pro_inflammatory', use_raw=True)


In [None]:
#sc.tl.score_genes(adata_combined_merfish, gene_list=activate_endo, score_name='activate_endo', use_raw=True)


In [None]:
sc.tl.score_genes(adata_combined_merfish, gene_list=['B2m','Trem2', 'Ccl2', 'Apoe',  'Axl', 'Itgax', 'Cd9','C1qa','C1qc','Lyz2','Ctss'], score_name='activate_micro', use_raw=False)
sc.tl.score_genes(adata_combined_merfish, gene_list=['C4b', 'C3', 'Serpina3n', 'Cxcl10', 'Gfap', 'Vim', 'Il18','Hif3a'], score_name='activate_astro', use_raw=False)
activate_endo = ["B2m", "Nfkbia", "Serinc3","Xdh", "Gfap", "Tap1"]
sc.tl.score_genes(adata_combined_merfish, gene_list=activate_endo, score_name='activate_endo',use_raw=False)
sc.tl.score_genes(adata_combined_merfish, gene_list=["Il33", "C4b","Neat1"], score_name='activate_olig',use_raw=False)

In [None]:
adata_astro = adata_combined_merfish[adata_combined_merfish.obs.cell_type=="Astro"]
adata_combined_merfish.obs.activate_astro = adata_combined_merfish.obs.activate_astro - np.mean(adata_astro[adata_astro.obs.age=='4wk'].obs.activate_astro)

In [None]:
f, ax = plt.subplots(figsize=(5,3))

sc.pl.violin(adata_combined_merfish[adata_combined_merfish.obs.cell_type=='Astro'], 'activate_astro', groupby='age', ax=ax)
sns.despine()

sns.despine(ax=ax)
ax.set_rasterized(True)


f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_astro_age_violin.pdf",bbox_inches='tight',dpi=300)

In [None]:
adata_micro = adata_combined_merfish[adata_combined_merfish.obs.cell_type=="Micro"]
adata_combined_merfish.obs.activate_micro = adata_combined_merfish.obs.activate_micro - np.mean(adata_micro[adata_micro.obs.age=='4wk'].obs.activate_micro)

In [None]:
f, ax = plt.subplots(figsize=(5,3))

sc.pl.violin(adata_combined_merfish[adata_combined_merfish.obs.cell_type=='Micro'], 'activate_micro', groupby='age', ax=ax)
sns.despine()

sns.despine(ax=ax)
ax.set_rasterized(True)


f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_micro_age_violin.pdf",bbox_inches='tight',dpi=300)

In [None]:
sc.pl.umap(adata_combined_merfish, color=['age','activate_olig','activate_micro', 'activate_astro', 'activate_endo'])

In [None]:
adata_olig = adata_combined_merfish[adata_combined_merfish.obs.cell_type=="Olig"]
adata_combined_merfish.obs.activate_olig = adata_combined_merfish.obs.activate_olig - np.mean(adata_micro[adata_micro.obs.age=='4wk'].obs.activate_olig)


In [None]:
f, ax = plt.subplots(figsize=(5,3))

sc.pl.violin(adata_combined_merfish[adata_combined_merfish.obs.clust_annot.isin(['Olig-1','Olig-2','Olig-3'])],keys=["activate_olig"],groupby='clust_annot',order=['Olig-3','Olig-1','Olig-2'],ax=ax)
sns.despine(ax=ax)
ax.set_rasterized(True)


f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_olig_violin.pdf",bbox_inches='tight',dpi=300)

In [None]:
f, ax = plt.subplots(figsize=(5,3))
astro_subtypes = ['Astro-1','Astro-2']
sc.pl.violin(adata_combined_merfish[adata_combined_merfish.obs.clust_annot.isin(astro_subtypes)],keys=["activate_astro"],groupby='clust_annot', order=astro_subtypes, ax=ax)
ax.set_rasterized(True)
sns.despine(ax=ax)
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_astro_violin.pdf",bbox_inches='tight',dpi=300)

In [None]:
f, ax = plt.subplots(figsize=(5,3))
sc.pl.violin(adata_combined_merfish[adata_combined_merfish.obs.clust_annot.isin(['Micro-1','Micro-2','Micro-3'])],keys=["activate_micro"],groupby='clust_annot', order=['Micro-1','Micro-2','Micro-3'], ax=ax)
sns.despine(ax=ax)
ax.set_rasterized(True)

f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_micro_violin.pdf",bbox_inches='tight',dpi=300)

In [None]:
spatial_order = ['Pia','L2/3','L5','L6','CC','LatSept','Striatum','Ventricle']

In [None]:
def plot_age_obs_comparison(data, x, y, cell_type, figsize=(5,3), show_pvals=False, order=None, clust_key='cell_type', age_pal=sns.color_palette(['cornflowerblue','thistle','lightcoral'])):
    f, ax = plt.subplots(figsize=(5,3))
    curr_df = data[data.obs[clust_key]==cell_type].obs
    if order is None:
        order = sorted(curr_df[x].unique())
    #sns.violinplot(x=x, y=y, data=curr_df,hue='age',fliersize=1,linewidth=1,palette=age_pal, ax=ax,inner=None,order=order,rasterized=True)
    sns.boxplot(x=x, y=y, data=curr_df,hue='age',fliersize=0,linewidth=1,palette=age_pal, ax=ax,order=order)

    sns.stripplot(data=curr_df, x=x, y=y, hue="age", ax=ax,jitter=0.15,size=0.5,dodge=True,color='k',order=order, rasterized=True)

    sns.despine()
    plt.legend([],[], frameon=False)
    if show_pvals:
        pvals = calc_pvals_for_grouping(x,y,curr_df, "age",order=order)
        plot_pvals(ax, pvals)
    return f


In [None]:
f = plot_age_obs_comparison(adata_combined_merfish, "spatial_clust_annots", "activate_olig", "Olig", show_pvals=True, order=spatial_order);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_olig_spatial_merfish.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plot_age_obs_comparison(adata_combined_merfish, "spatial_clust_annots", "activate_endo", "Endo", show_pvals=True, order=spatial_order);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_endo_spatial_merfish.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plot_age_obs_comparison(adata_combined_merfish, "spatial_clust_annots", "activate_olig", "Olig", show_pvals=True, order=spatial_order);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_olig_spatial_merfish.pdf",bbox_inches='tight',dpi=300)

In [None]:
#f = plot_age_obs_comparison(adata_combined_merfish, "spatial_clust_annots", "pro_inflammatory", "Astro", show_pvals=True, order=spatial_order);


In [None]:
f = plot_age_obs_comparison(adata_combined_merfish, "spatial_clust_annots", "activate_micro", "Micro", show_pvals=False, order=spatial_order);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_micro_spatial_merfish.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plot_age_obs_comparison(adata_combined_merfish, "spatial_clust_annots", "activate_astro", "Astro", show_pvals=False, order=spatial_order);
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_astro_spatial_merfish.pdf",bbox_inches='tight',dpi=300)

In [None]:
def plot_obs_by_conditions(A, obs_name, vmin=0,vmax=3,cmap=plt.cm.Reds, cell_types=None,key='clust_annot_preds',s=0.1):
    if cell_types is None:
        cell_types = A.obs.clust_annot_preds.unique()
    f,ax = plt.subplots(figsize=(4,10), nrows=3, ncols=2, gridspec_kw={'wspace':0.05, 'hspace':0.01})
    for i, cond in enumerate(['ctrl','lps']):
        for j, age in enumerate(['4wk', '24wk', '90wk']):
            batch, dslice, xlim, ylim, rot = get_plot_info(age, cond)
            curr_ax = ax[j][i]
            curr_adata = A[np.logical_and(A.obs.data_batch==str(batch), A.obs.slice==dslice)]
            plot_obs(curr_adata, cell_types, obs_name,rot=rot,s=s,vmin=vmin,vmax=vmax,key=key,cmap=cmap,ax=curr_ax)
            curr_ax.set_xlim(xlim)
            curr_ax.set_ylim(ylim)
    return f


In [None]:
# plot celltype specific genes
k = 1
celltype_specific_genes = ['Slc17a7', 'Gad1','Drd1', 'Aqp4',  'Olig1', 'Pdgfra', 'Vtn']
f, ax = plt.subplots(figsize=(2.5*aspect_ratio*len(celltype_specific_genes),2.5*2))
vmax = 3
for c in celltype_specific_genes:
    # young
    
    ax = plt.subplot(2,len(celltype_specific_genes),k)
    # old
    curr_adata = adata_annot[np.logical_and(adata_annot.obs.batch==7, adata_annot.obs.slice==2)]
    curr_rot = -183
    xlim = [200, 2500]
    ylim = [200, 4000]
    aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])
    plot_gene_expr(curr_adata, celltypes, c, plt.cm.Reds, vmin=1,vmax=vmax, rot=curr_rot, ax=ax, xlim=xlim, ylim=ylim)
    ax.set_title(c)
    k += 1
f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures/fig5_celltype_marker_comparison_merfish.pdf", bbox_inches='tight', dpi=300)