In [None]:
%load_ext autoreload
%autoreload 2
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import os
import anndata as ad
mpl.rcParams['figure.dpi'] = 80
plt.rcParams['pdf.fonttype'] = 42

import pandas as pd
import sys
from spatial_analysis import *
from plotting import *

# Functions

In [None]:
def save_fig(f, name, dtype="png"):
    f.savefig(f"/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/figures_imputed/{name}."+dtype,bbox_inches="tight",dpi=200)
    
from collections import Counter
def majority_vote(votes):
    return Counter(votes).most_common()[0][0]

def impute_celltype(pc_A, pc_B, celltype_B, k=10, n_per_batch=1000):
    """
    Construct a kNN lookup for pc_B, then use to compute values of celltype for each cell in pc_A
    """
    n_lookup = pc_A.shape[0]
    bt = KDTree(pc_B)
    
    # divide up into chunks of n_per_batch
    combined_ct = []
    for i in tqdm(range(0, n_lookup, n_per_batch)):
        dist,ind = bt.query(pc_A[i:(i+n_per_batch),:],k=k)
        curr_ct = []
        for idx in ind:
            curr_vals = celltype_B[idx]
            ct = majority_vote(curr_vals)
            curr_ct.append(ct)
        combined_ct.extend(curr_ct)
    return np.array(combined_ct)
                    

def unbinarize_strings(A):
    A.var_names = [i.decode('ascii') for i in A.var_names]
    A.obs.index = [i.decode('ascii') for i in A.obs.index]
    for i in A.obs.columns:
        if A.obs[i].dtype != np.dtype('bool') and \
            A.obs[i].dtype != np.dtype('int64') and \
            A.obs[i].dtype != np.dtype('int32') and \
            A.obs[i].dtype != np.dtype('object_') and \
            A.obs[i].dtype != np.dtype('float64') and A.obs[i].dtype != np.dtype('float32'):
            if A.obs[i].dtype.is_dtype('category'):
                try:
                    A.obs[i] = [i.decode('ascii') for i in A.obs[i]]
                except Exception as e:
                    pass
    return A

markers = ['Snap25', 'Aldh1l1', 'Trem2', 'Olig1', 'Olig2', 'Gad2','Gad1', 'Slc17a7', 'Sst', 'Pvalb', 'Cux2']

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
import numpy as np

def normalize_data(data):
    #return (data - np.percentile(data,5)) / (np.percentile(data,95) - np.percentile(data,5))
    return (data-np.min(data))/(np.max(data) - np.min(data))

def identify_variable_genes(adata,markers=None,theta=100,n_top_genes=2000, do_plot=False):
    sc.experimental.pp.highly_variable_genes(
        adata, flavor="pearson_residuals", n_top_genes=n_top_genes, theta=theta,
    )
    if do_plot:
        # plot gene expression
        ax = plt.axes()
        hvgs = adata.var["highly_variable"]

        ax.scatter(
            adata.var["mean_counts"], adata.var["residual_variances"], s=3, edgecolor="none"
        )
        ax.scatter(
            adata.var["mean_counts"][hvgs],
            adata.var["residual_variances"][hvgs],
            c="tab:red",
            label="selected genes",
            s=3,
            edgecolor="none",
        )
        if markers:
            ax.scatter(
                adata.var["mean_counts"][np.isin(adata.var_names, markers)],
                adata.var["residual_variances"][np.isin(adata.var_names, markers)],
                c="k",
                label="known marker genes",
                s=10,
                edgecolor="none",
            )
        ax.set_xscale("log")
        ax.set_xlabel("mean expression")
        ax.set_yscale("log")
        ax.set_ylabel("residual variance")
        #ax.set_title(adata.uns["name"])

        ax.spines["right"].set_visible(False)
        ax.spines["top"].set_visible(False)
        ax.yaxis.set_ticks_position("left")
        ax.xaxis.set_ticks_position("bottom")
        
def compute_sqrt_norm(A):
    A.layers["sqrt_norm"] = np.sqrt(sc.pp.normalize_total(A, inplace=False)["X"])
    
def normalize_10x_data(A,n_genes=2000,theta=100, identify_var_genes=True):

    A.raw = A
    
    A.layers["raw"] = A.X.copy()
    A.layers["sqrt_norm"] = np.sqrt(sc.pp.normalize_total(A, inplace=False)["X"])
    
    if identify_var_genes:
        identify_variable_genes(A, n_top_genes=n_genes, theta=theta)   
        A = A[:, A.var["highly_variable"]]
    #print("Normalizing pearson")
    sc.experimental.pp.normalize_pearson_residuals(A)
    return A

from sklearn.neighbors import BallTree, KDTree
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from scipy.sparse import csr_matrix, vstack
from tqdm import tqdm
from joblib import delayed, Parallel

def impute_expr_sparse(pc_A, pc_B, expr_B, k=10, n_per_batch=1000):
    """
    Construct a kNN lookup for pc_B, then use to compute sparse expression values for each cell in pc_A, weighted by
    """
    n_lookup = pc_A.shape[0]
    bt = KDTree(pc_B)
    
    # divide up into chunks of n_per_batch
    combined_expr = []
    for i in tqdm(range(0, n_lookup, n_per_batch)):
        dist,ind = bt.query(pc_A[i:(i+n_per_batch),:],k=k)
        curr_expr = []
        for idx in ind:
            curr_expr.append(np.mean(expr_B[idx,:],0))
        combined_expr.append(csr_matrix(np.vstack(curr_expr)))
    return vstack(combined_expr)
                    

def impute_obs(pc_A, pc_B, var_B, k=10, n_per_batch=1000):
    """
    Construct a kNN lookup for pc_B, then use to compute values of obs for each cell in pc_A
    """
    n_lookup = pc_A.shape[0]
    bt = KDTree(pc_B)
    
    # divide up into chunks of n_per_batch
    combined_expr = []
    for i in tqdm(range(0, n_lookup, n_per_batch)):
        dist,ind = bt.query(pc_A[i:(i+n_per_batch),:],k=k)
        curr_expr = []
        for idx in ind:
            curr_expr.append(np.mean(var_B[idx],0))
        combined_expr.extend(curr_expr)
    return np.array(combined_expr)
                    
# compute gene modules
from scipy.cluster.hierarchy import ward,fcluster, dendrogram, linkage
from scipy.spatial.distance import pdist
from scipy.stats import zscore
from statsmodels.stats.multitest import multipletests
from scipy.sparse.linalg import svds
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
from umap import UMAP


def cluster_vals(X,n_clusts=5):
    pd = pdist(X)
    ln = linkage(pd, method='complete')
    return fcluster(ln, 0.5 * pd.max(), 'distance')

def zscore_mat(X):
    n_x,n_y = X.shape
    return zscore(X.flatten()).reshape((n_x,n_y))

def zscore_pval_mat(X):
    """
    Take matrix of zscores and compute FDR adjusted pvals
    """
    X_Z = zscore_mat(X)
    n_x, n_y = X_Z.shape

    X_pval = scipy.stats.norm.sf(np.abs(X_Z.flatten()))*2
    X_pval = multipletests(X_pval,method='fdr_bh')[1]                                       
    return X_Z, X_pval.reshape((n_x,n_y))

def threshold_zscore_mat(X_Z, X_pval, thresh=0.1):
    n_x, n_y = X_Z.shape
    X_Z = X_Z.flatten()
    X_Z[X_pval.flatten()>thresh] = 0
    return X_Z.reshape((n_x, n_y))


from sklearn.decomposition import PCA,FactorAnalysis,NMF
from sklearn.manifold import TSNE
from sklearn.cluster import SpectralClustering, DBSCAN
from sklearn.preprocessing import scale, MinMaxScaler
from sklearn.neighbors import NearestNeighbors
from scipy.stats import ttest_ind, norm, ranksums
from scipy.stats.mstats import gmean

import pandas as pd
import scanpy as sc
def get_genes_for_celltype(de, name, direction=None):
    clust_names = de.cell_type.unique()
    if direction == "pos":
        de = de[de.coef>=0]
    elif direction == "neg":
        de = de[de.coef<0]
    return de[de.cell_type.isin([i for i in clust_names if name in i])].gene.unique()

# code from: https://github.com/klarman-cell-observatory/inCITE-seq/blob/main/notebooks/inCITE_tools.ipynb
def parse_GO_query(gene_list, species, db_to_keep='all'): 
    if db_to_keep=='all': 
        db_to_keep = ['GO:BP', 'GO:MF', 'KEGG', 'REAC', 'TF']
    GO_df = sc.queries.enrich(list(gene_list), org=species)
    GO_df = GO_df[GO_df['significant']==True]
    GO_df = GO_df[GO_df['source'].isin(db_to_keep)]
    return GO_df

def sig_genes_GO_query(sig_genes, clust_lim=1000, source=['GO:BP','KEGG']):
    bad_top_terms = ['GO:0009987', 'GO:0008150']
    GO_results = pd.DataFrame([],columns=['source','name','p_value','description','native','parents'])
    clust_ct = 0
    idx_ct = 0
    GO_df = parse_GO_query(sig_genes,'mmusculus',source)
    if len(GO_df)>0:
        for index, row in GO_df.iterrows():
            if clust_ct<clust_lim:
                if ~np.any([True if i in bad_top_terms else False for i in row['parents']]):
                    # exclude top level terms
                    GO_row = pd.DataFrame({'source':row['source'],
                                         'name':row['name'],'p_value':row['p_value'],
                                         'description':row['description'], 
                                         'native':row['native'], 'parents':[row['parents']]},
                                            index=[idx_ct])
                    clust_ct+=1
                    idx_ct+=1
                    GO_results = pd.concat([GO_results, GO_row])
    return GO_results

def plot_GO_terms(df,alpha,filename,colormap='#d3d3d3',xlims=[0,20],ax=None): 
    
    # add color column
    if colormap != '#d3d3d3': 
        df['color'] = df['cluster'].map(colormap)
        color=df['color']
    else: 
        color=colormap
    
    df = df.loc[df['p_value']<=alpha]
    
    fig_height = df.shape[0]*(1/10)
    if ax is None:
        fig, ax = plt.subplots(figsize=(3,fig_height))
    y_pos = np.arange(df.shape[0])
    log10p = -np.log10(df['p_value'].tolist())
    df['-log10p'] = log10p
    
    sns.reset_orig()
    ax.barh(y_pos, log10p, align='center', color=color)
    ax.set_yticks(y_pos)
#     ax.set_yticklabels(df['native']+':'+df['name'],fontsize=6)
    ax.set_yticklabels(df['name'],fontsize=6)
    ax.invert_yaxis()
    ax.set_xlabel('-log10(P)')
    ax.set_xlim(xlims)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_linewidth(1)
#     plt.show()
    #figname = '%s/GO_hbar_%s.pdf' %(sc.settings.figdir, filename)
    #print('Saving to %s' %figname)
    #fig.savefig(figname, bbox_inches='tight')

def unbinarize_list(s):
    try:
        s = [i.decode('ascii') for i in s]
    except Exception as e:
        pass
    return s


def identify_nearest_neighbors_with_idx(X,Y,dist_thresh, min_dist_thresh=15):
    if X.shape[0] > 0 and Y.shape[0] > 0:
        kdtree = KDTree(Y)
        ind, dists = kdtree.query_radius(X, r=dist_thresh, count_only=False,return_distance=True)
        ind_X = np.hstack([[i]*len(ind[i]) for i in np.arange(len(ind)) if len(ind[i])>0])
        
        ind = np.hstack(ind)
        dists = np.hstack(dists)
        if len(ind) > 0:
            ind = ind[dists>min_dist_thresh]      
            ind_X = ind_X[dists>min_dist_thresh]
        return ind.astype(np.int), ind_X.astype(np.int)
    else:
        return np.array([])

def count_neighbors_with_idx(X,Y,dist_thresh, ):
    if X.shape[0] > 0 and Y.shape[0] > 0:
        kdtree = KDTree(Y)
        ind, dists = kdtree.query_radius(X, r=dist_thresh, count_only=False,return_distance=True)
        counts_Y = np.array([len(i) for i in ind])
        ind_X = np.arange(len(ind))#np.array([i for i in np.arange(len(ind)) if len(ind[i])>0])
        return ind_X.astype(np.int), counts_Y.astype(np.int)
    else:
        return np.array([])

def identify_nearest_neighbors_with_dist(X,Y, min_dist=0):
    if X.shape[0] > 0 and Y.shape[0] > 0:
        kdtree = KDTree(Y)
        dists, ind = kdtree.query(X, k=2,return_distance=True)
        print(dists.shape, ind.shape)
        good_dists = np.zeros(len(dists))
        good_ind = np.zeros(len(ind))
        for i in range(dists.shape[0]):
            if dists[i,0] > 0: # remove duplicates
                good_dists[i] = dists[i,0]
                good_ind[i] = ind[i,0]
            else:
                good_dists[i] = dists[i,1]
                good_ind[i] = ind[i,1]
        #ind_X = np.hstack([[i]*len(ind[i]) for i in np.arange(len(ind)) if len(ind[i])>0])
        return good_dists, good_ind
    else:
        return np.array([])
    
def compute_celltype_obs_count_correlation(A,cell_type_X, cell_type_Y, obs_key_X, celltype_key='cell_type',radius=40, min_dist_thresh=15):
    X = A[A.obs[celltype_key] == cell_type_X]
    Y = A[A.obs[celltype_key] == cell_type_Y]
    obs_X = X.obs[obs_key_X]
    curr_X = X.obsm['spatial']
    curr_Y = Y.obsm['spatial']
    ind_X, counts_Y = count_neighbors_with_idx(curr_X, curr_Y, dist_thresh=radius)
    return obs_X.values[ind_X], ind_X, counts_Y


def compute_celltype_obs_distance_correlation(A,cell_type_X, cell_type_Y, obs_key_X, celltype_key1='cell_type', celltype_key2='cell_type'):
    X = A[A.obs[celltype_key1] == cell_type_X]
    Y = A[A.obs[celltype_key2] == cell_type_Y]
    obs_X = X.obs[obs_key_X]
    curr_X = X.obsm['spatial']
    curr_Y = Y.obsm['spatial']
    dists_Y, ind_Y = identify_nearest_neighbors_with_dist(curr_X, curr_Y)
    return obs_X.values, dists_Y

def compute_celltype_obs_correlation(A,cell_type_X, cell_type_Y, obs_key_X, obs_key_Y, celltype_key='cell_type', radius=40, min_dist_thresh=15):
    X = A[A.obs[celltype_key] == cell_type_X]
    Y = A[A.obs[celltype_key] == cell_type_Y]
    obs_X = X.obs[obs_key_X]
    obs_Y = Y.obs[obs_key_Y]
    curr_X = X.obsm['spatial']
    curr_Y = Y.obsm['spatial']
    neighbors_X, ind_X = identify_nearest_neighbors_with_idx(curr_X, curr_Y, dist_thresh=radius, min_dist_thresh=min_dist_thresh)
    curr_expr = obs_Y[neighbors_X]
    return obs_X.values[ind_X], curr_expr.values

def compute_binned_values(dists, scores, min_d=0, max_d=100, bin_size=30):
    binned_mean = np.zeros(max_d-min_d-bin_size)
    binned_std = np.zeros(max_d-min_d-bin_size)
    for i in np.arange(min_d, max_d-bin_size):
        # find distances in this bin range
        idx = np.argwhere(np.logical_and(dists>i, dists<=(i+bin_size)))
        curr_scores = scores[idx]
        binned_mean[i] = np.mean(curr_scores)#/len(idx)
        binned_std[i] = np.std(curr_scores)/np.sqrt(len(curr_scores))#/len(idx)
    binned_mean -= binned_mean.mean()
    binned_std -= binned_mean.mean()
    return binned_mean, binned_std

# Process 10X data

## Load and preprocess 10X

In [None]:
# load raw data that has been filtered and leiden clustered recursively
adata_10x = unbinarize_strings(sc.read_h5ad("/faststorage/brain_aging/aging10x/051722_aging10x_pfc_clustered.h5ad"))

In [None]:
sc.pl.umap(adata_10x, color=['age','leiden'])

## Subset 10X data to just CellphoneDB ligand-receptor pairs and convert gene symbols to human

# Integrate MERFISH and snRNA_seq

## Load MERFISH

In [None]:
adata_combined = unbinarize_strings(ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_adata_combined_harmony.h5ad"))

## Copy information from combined to 10X raw data

In [None]:
adata_combined_10x = adata_combined[adata_combined.obs.dtype=='scrnaseq']
adata_combined_mer = adata_combined[adata_combined.obs.dtype=='merfish']

In [None]:
# find just set of 10X cells that were used for integration
match_idx_10x = [i for i in adata_combined_10x.obs.index if i in adata_10x.obs.index]
adata_combined_10x = adata_combined_10x[match_idx_10x]

In [None]:
# copy UMAP and obs to adata_subset, for visualization
# adata subset is the subset of 10X cells that were used for integration, but with full transcriptome
adata_subset = adata_10x[match_idx_10x]
adata_subset.obs = adata_combined_10x.obs.copy()
adata_subset.obsm['X_umap'] = adata_combined_10x.obsm['X_umap']

In [None]:
adata_subset.obs['leiden_10x'] = adata_10x.obs.loc[adata_subset.obs.index, 'leiden']

In [None]:
adata_subset = adata_subset.raw.to_adata()
adata_subset.raw = adata_subset

In [None]:
adata_subset.var_names = unbinarize_list(adata_subset.var_names)

In [None]:
#adata_subset = adata_subset[:, [i for i in adata_subset.var_names if "Gm" not in i]]
#adata_subset = adata_subset[:, [i for i in adata_subset.var_names if "Rik" not in i]]


In [None]:
#adata_subset = normalize_10x_data(adata_subset, n_genes=4000)

## Find markers for age-dependent subtypes

In [None]:
import diffxpy.api as diffxpy

def find_de_genes_for_subtype(A, celltype, log10_fc_thresh=np.log10(10), qval_thresh=1e-6):
    A.obs['A'] = A.obs.clust_annot==celltype
    de_results = diffxpy.test.t_test(A, "A")
    de_genes = pd.DataFrame({'genes':A.var_names, 'qval' : de_results.qval, 'log10_fc' : de_results.log10_fold_change()}).sort_values('log10_fc')
    de_genes_upreg = de_genes[np.logical_and(de_genes.qval<qval_thresh, de_genes.log10_fc>log10_fc_thresh)]
    return de_genes_upreg

In [None]:
adata_de = adata_subset.copy()
adata_de.X = adata_subset.layers['sqrt_norm']

In [None]:
olig_genes = find_de_genes_for_subtype(adata_de, "Olig-3",log10_fc_thresh=np.log10(5))
sig_genes_GO_query(olig_genes.genes).head(10)

In [None]:
olig_genes = find_de_genes_for_subtype(adata_de, "Astro-2",log10_fc_thresh=np.log10(5))
sig_genes_GO_query(olig_genes.genes).head(10)

In [None]:
olig_genes = find_de_genes_for_subtype(adata_de, "Micro-3",log10_fc_thresh=np.log10(5))
sig_genes_GO_query(olig_genes.genes).head(10)

In [None]:
adata_10x.obs.clip

In [None]:
for i in ["Endo","Astro","Olig","Micro"]:
    curr_adata = adata_subset[adata_subset.obs.cell_type==i]
    sc.tl.rank_genes_groups(curr_adata, 'clust_annot',layer='sqrt_norm',use_raw=False)
    sc.tl.filter_rank_genes_groups(curr_adata, use_raw=False, min_fold_change=2, min_in_group_fraction=0.25)
    sc.pl.rank_genes_groups(curr_adata)

In [None]:
sc.tl.rank_genes_groups(adata_subset, 'clust_annot',layer='sqrt_norm',use_raw=False)


In [None]:
clusts = ['Astro-1','Astro-2','Olig-1','Olig-2','Olig-3','Endo-1','Endo-2','Endo-3','Micro-1','Micro-2','Micro-3']

In [None]:
temp = adata_subset[adata_subset.obs.clust_annot.isin(clusts)]

In [None]:
sc.tl.dendrogram(temp,groupby='clust_annot')

In [None]:
gene_names = ["Msi2", "Slc1a2", "Slc1a3", "Gfap","C4b",  "Efemp1", "Cped1","Shroom3", "Tnc", "Gpc5",
              "Flt1","Cldn5", "Xdh",  "Fmo2", "Vim", "Nr3c2", "Aff1", "Serinc3", "Nfib", "Trim30a","Parp14","Sp100", "Dlg2","Itgam","C1qa",
             "B2m","Tnfsf8","Sp100","Ctss", "Itgb2", "Lyz2", "C1qc","Trem2", 
              "Mbp", "Cd9","Fth1","Apoe","Edil3","Anln","Il33","Neat1", "Cldn14"]

In [None]:
f = sc.pl.dotplot(temp, var_names=gene_names, groupby='clust_annot',layer='sqrt_norm',figsize=(10,5),return_fig=True)
save_fig(f, "fig2_de_genes",dtype="pdf")

In [None]:
de_genes_age_major_signif[de_genes_age_major_signif.cell_type=="Olig"].sort_values('coef').tail(20)

In [None]:
sc.pl.rank_genes_groups_dotplot(temp, groupby='clust_annot',n_genes=10)

## Assign labels to leiden clusters

In [None]:
def get_base_names(class_labels):
    base_names = []
    for i in class_labels:
        fields = i.split('-')
        if fields[-1] in ['1','2','3','4','5','6','7','8','9','10']:
            if len(fields) == 2:
                base_names.append(fields[0])
            elif len(fields) == 3:
                base_names.append(fields[0]+"-"+fields[1])
        else:
            base_names.append(i)
    return base_names

In [None]:
base_names = np.array(get_base_names(adata_subset.obs['clust_annot']))
clust_map = {}
clust_counts = {}
clust_to_celltype = {}
for i in adata_subset.obs.leiden_10x.unique():
    curr_clust = majority_vote(base_names[adata_subset.obs.leiden_10x==i])
    if curr_clust not in clust_counts:
        clust_counts[curr_clust] = 1
    else:
        clust_counts[curr_clust] += 1
    clust_map[i] = curr_clust + "-" + str(clust_counts[curr_clust])
    if "ExN" in curr_clust:
        curr_ct = "ExN"
    elif "InN" in curr_clust:
        curr_ct = "InN"
    elif "MSN" in curr_clust:
        curr_ct = "MSN"
    elif "Olig" in curr_clust:
        curr_ct = "Olig"
    elif "Astro" in curr_clust:
        curr_ct = "Astro"
    elif "OPC" in curr_clust:
        curr_ct = "OPC"
    elif "Endo" in curr_clust:
        curr_ct = "Endo"
    elif "Peri" in curr_clust:
        curr_ct = "Peri"
    elif "Vlmc" in curr_clust:
        curr_ct = "Vlmc"
    elif "Micro" in curr_clust:
        curr_ct = "Micro"
    elif "Macro" in curr_clust:
        curr_ct = "Macro"
    else:
        curr_ct = "NA"
    clust_to_celltype[i] = curr_ct
adata_subset.obs['clust_annot_10x'] = [clust_map[i] for i in adata_subset.obs.leiden_10x]
adata_subset.obs['cell_type_10x'] = [clust_to_celltype[i] for i in adata_subset.obs.leiden_10x]

In [None]:
adata_subset.var_names = unbinarize_list(adata_subset.var_names)

In [None]:
# get paletes
celltype_colors, celltype_pals, label_colors, clust_pals = generate_palettes(adata_subset)

In [None]:
sc.pl.umap(adata_subset, color='cell_type_10x',palette=celltype_colors)

## Identify gene modules from DE genes

### Load DE genes

In [None]:
# compute DE with t test on normalized data
#import diffxpy
#def compute_de_genes_by_age(A, age_key='age', cell_type_key='cell_type'):
#    for i in A.obs[cell_type_key].unique():
#        curr_A = 

In [None]:
def get_upreg_with_age(df_major, df_minor):
    return list(df_minor[df_minor.coef>0].gene) + list(df_major[df_major.coef>0].gene)

def get_downreg_with_age(df_major, df_minor):
    return list(df_minor[df_minor.coef<=0].gene) + list(df_major[df_major.coef<=0].gene)

ttest_de_celltype_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_major_age_10x_V2.csv")
ttest_de_clust_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_minor_age_V3_nools.csv")
qval_thresh = 0.05
coef_thresh_major = np.log(1.25)
coef_thresh_minor = np.log(2)
de_genes_age_minor_signif = ttest_de_clust_df[np.logical_and(np.abs(ttest_de_clust_df.coef) > coef_thresh_minor, ttest_de_clust_df.qval<qval_thresh)]
de_genes_age_major_signif = ttest_de_celltype_df[np.logical_and(np.abs(ttest_de_celltype_df.coef) > coef_thresh_major, ttest_de_celltype_df.qval<qval_thresh)]
de_genes_age_minor_signif = de_genes_age_minor_signif[~np.isinf(de_genes_age_minor_signif.coef)]
de_genes_age_major_signif = de_genes_age_major_signif[~np.isinf(de_genes_age_major_signif.coef)]

In [None]:
neuronal_clusts = list(set([i for i in de_genes_age_minor_signif.cell_type if "ExN" in i or "InN" in i or "MSN" in i]))
nonneuronal_clusts = list(set([i for i in de_genes_age_minor_signif.cell_type if "ExN" not in i and "InN" not in i and "MSN" not in i]))

In [None]:
qval_thresh = 0.1
de_genes_age_minor_signif_neurons = de_genes_age_minor_signif[de_genes_age_minor_signif.cell_type.isin(neuronal_clusts)]
de_genes_age_minor_signif_nonneurons = de_genes_age_minor_signif[de_genes_age_minor_signif.cell_type.isin(nonneuronal_clusts)]
de_genes_age_major_signif_neurons = de_genes_age_major_signif[de_genes_age_major_signif.cell_type.isin(["ExN","InN","MSN"])]
de_genes_age_major_signif_nonneurons = de_genes_age_major_signif[~de_genes_age_major_signif.cell_type.isin(["ExN","InN","MSN"])]

upreg_with_age_neuron = get_upreg_with_age(de_genes_age_major_signif_neurons, de_genes_age_minor_signif_neurons)
upreg_with_age_nonneuron = get_upreg_with_age(de_genes_age_major_signif_nonneurons, de_genes_age_minor_signif_nonneurons)

downreg_with_age_neuron = get_downreg_with_age(de_genes_age_major_signif_neurons, de_genes_age_minor_signif_neurons)
downreg_with_age_nonneuron = get_downreg_with_age(de_genes_age_major_signif_nonneurons, de_genes_age_minor_signif_nonneurons)

In [None]:
#adata_subset.var_names = [i.decode('ascii') for i in adata_subset.var_names]

### Identify gene module genes

In [None]:
adata_subset_neurons_raw = adata_subset[adata_subset.obs.cell_type.isin(['ExN','InN','MSN'])].raw.to_adata()
adata_subset_nonneurons_raw = adata_subset[~adata_subset.obs.cell_type.isin(['ExN','InN','MSN'])].raw.to_adata()

In [None]:
adata_subset_neurons_raw = normalize_10x_data(adata_subset_neurons_raw, n_genes=2000)
adata_subset_nonneurons_raw = normalize_10x_data(adata_subset_nonneurons_raw, n_genes=2000)

In [None]:
def get_normalized_subset_by_genes(A, genes):
    genes = list(set(genes))
    var_names = A.var_names
    try:
        var_names = [i.decode('ascii') for i in var_names]
        A.var_names = [i.decode('ascii') for i in A.var_names]
    except Exception as e:
        pass
    genes = [i for i in genes if i in var_names]
    
    temp = A[:, genes]
    temp.X = np.sqrt(sc.pp.normalize_total(temp, inplace=False)["X"])
    return temp

In [None]:
#adata_subset_neurons_upreg = get_normalized_subset_by_genes(adata_subset_neurons_raw, upreg_with_age_neuron)
#adata_subset_neurons_downreg = get_normalized_subset_by_genes(adata_subset_neurons_raw, downreg_with_age_neuron)

#adata_subset_nonneurons_upreg = get_normalized_subset_by_genes(adata_subset_nonneurons_raw, upreg_with_age_nonneuron)
#adata_subset_nonneurons_downreg = get_normalized_subset_by_genes(adata_subset_nonneurons_raw, downreg_with_age_neuron)


In [None]:
um_neurons, clu_neurons, genes_neurons = find_gene_networks(adata_subset_neurons_raw, dbscan_eps=0.5)

In [None]:
plt.scatter(um_neurons[:,0], um_neurons[:,1], c=clu_neurons, cmap=plt.cm.gist_ncar,s=1)

In [None]:
um_nonneurons, clu_nonneurons, genes_nonneurons = find_gene_networks(adata_subset_nonneurons_raw,dbscan_eps=0.5)

In [None]:
genes_nonneurons

In [None]:
plt.scatter(um_nonneurons[:,0], um_nonneurons[:,1], c=clu_nonneurons, cmap=plt.cm.gist_ncar,s=1)

In [None]:
nonneuron_mod_df = pd.DataFrame({'cluster': clu_nonneurons, 'gene':[i.decode('ascii') for i in genes_nonneurons]})
neuron_mod_df = pd.DataFrame({'cluster': clu_neurons, 'gene':[i.decode('ascii') for i in genes_neurons]})

In [None]:
neuron_mod_df.to_csv("gene_lists/neuron_module_genes_eps0.4.csv")
nonneuron_mod_df.to_csv("gene_lists/nonneuron_module_genes_eps0.4.csv")

In [None]:
neuron_mod_df = pd.read_csv("gene_lists/neuron_module_genes_eps0.4.csv")
nonneuron_mod_df = pd.read_csv("gene_lists/nonneuron_module_genes_eps0.4.csv")

In [None]:
genes_nonneurons = list(nonneuron_mod_df.gene.unique())
genes_neurons = list(neuron_mod_df.gene.unique())

In [None]:
adata_subset_nonneurons_raw.var_names = [i.decode('ascii') for i in adata_subset_nonneurons_raw.var_names]
adata_subset_neurons_raw.var_names = [i.decode('ascii') for i in adata_subset_neurons_raw.var_names]

In [None]:
for n,i in enumerate(np.unique(clu_neurons)): 
    curr_genes = [i.decode('ascii') for i in list(genes_neurons[clu_neurons==i])]
    print("--> Module", i, curr_genes[:20])
    print(sig_genes_GO_query(curr_genes).head(5).name)
    sc.tl.score_genes(adata_subset_neurons_raw, gene_list=curr_genes, score_name='score_'+str(n), use_raw=False)
    print("--------------------")

In [None]:
for n,i in enumerate(np.unique(clu_nonneurons)): 
    curr_genes = [i.decode('ascii') for i in list(genes_nonneurons[clu_nonneurons==i])]
    print("--> Module", i, curr_genes[:20])
    print(sig_genes_GO_query(curr_genes).head(5).name)
    sc.tl.score_genes(adata_subset_nonneurons_raw, gene_list=curr_genes, score_name='score_'+str(n), use_raw=False)
    print("--------------------")

In [None]:
adata_subset_neurons_raw.obs

In [None]:
sc.pl.dotplot(adata_subset_neurons_raw,['score_'+str(i) for i in range(20)],groupby='cell_type',dendrogram=True)

In [None]:
sc.pl.dotplot(adata_subset_nonneurons_raw,['score_'+str(i) for i in range(27)],groupby='cell_type', dendrogram=True)

In [None]:
nonneuron_mod_df = pd.read_csv("gene_lists/nonneuron_module_genes_eps0.4.csv")

In [None]:
print_go_terms_for_modules(list(nonneuron_mod_df.cluster), list(nonneuron_mod_df.gene))

## Identify gene modules for each cell type

In [None]:
def test_module_correlation(corr_mat, clusters, n_repeat=1000):
    pvals = []
    clusters = np.array(clusters)
    for c in tqdm(np.unique(clusters)):
        true_corr = corr_mat[clusters==c,:][:,clusters==c].mean()
        shuffled_corr = []
        for i in range(n_repeat):
            curr_clust = clusters[np.random.permutation(len(clusters))]
            curr_shuffle = corr_mat[curr_clust==c,:][:,curr_clust==c].mean()
            shuffled_corr.append(curr_shuffle)
        shuffled_corr = np.array(shuffled_corr)
        curr_pval = np.sum(shuffled_corr >= true_corr)/len(shuffled_corr)
        pvals.append(curr_pval)
        #print(c, true_corr, curr_pval)
    pvals = multipletests(np.array(pvals), method='fdr_bh')[1]
    corrected_pvals = {}
    for i,c in enumerate(np.unique(clusters)):
        corrected_pvals[c] = pvals[i]
    return corrected_pvals

def find_gene_networks(A, n_pcs=50, dbscan_eps=0.5, use_neighbors=True, use_dbscan=True, filter_modules=True, pval_thresh=0.1):
    import igraph as ig

    u,s,vt = svds(A.X, k=n_pcs)
    corr_mat = np.dot(vt.T,vt)#np.corrcoef(adata_subset_micro.layers['sqrt_norm'].toarray().T)

    corr_mat, corr_pval = zscore_pval_mat(corr_mat)
    corr_mat = threshold_zscore_mat(corr_mat, corr_pval)
    if use_neighbors:
        nbrs = NearestNeighbors(n_neighbors=10, algorithm='kd_tree').fit(corr_mat)
        neighbor_graph = nbrs.kneighbors_graph(corr_mat)
        umap = UMAP().fit(neighbor_graph)
    else:
        umap = UMAP().fit(corr_mat)
    if use_dbscan:
        dbs = DBSCAN(eps=dbscan_eps).fit_predict(umap.embedding_)
    else:
        # cluster nearest neighbors graph
        g = ig.Graph()
        g = ig.GraphBase.Adjacency(neighbor_graph.toarray().tolist(), mode=ig.ADJ_UNDIRECTED)
        sim = np.array(g.similarity_jaccard())
        g = ig.GraphBase.Weighted_Adjacency(sim.tolist(), mode=ig.ADJ_UNDIRECTED)
        dbs = np.array(g.community_multilevel(weights="weight", return_levels=False))
    if filter_modules:
        pvals = test_module_correlation(corr_mat, dbs)
        signif_idx = np.array([pvals[i]<pval_thresh for i in dbs])
        corr_mat = corr_mat[signif_idx,:][:,signif_idx]
        umap = UMAP().fit(corr_mat)
        # redo dbscan
        dbs = DBSCAN(eps=dbscan_eps).fit_predict(umap.embedding_)
        #dbs = dbs[signif_idx]
        gene_names = A.var_names[signif_idx]
        corr_mat = corr_mat[np.argsort(dbs),:][:,np.argsort(dbs)]
        
        print(np.sum(np.array(list(pvals.values()))<pval_thresh), "significant clusters")
    else:
        gene_names = A.var_names
    return umap.embedding_, dbs, gene_names, corr_mat


In [None]:
# initial run that looked good eps = 0.5, pcs=30, pval=1e-6
#
gene_modules = {}
for i in adata_subset.obs.cell_type.unique():
    print(i)
    temp = adata_subset[adata_subset.obs.cell_type==i].raw.to_adata()
    #temp.X = temp.layers['raw']
    temp = normalize_10x_data(temp)
    um, clu, genes, corr_mat = find_gene_networks(temp, dbscan_eps=0.25, n_pcs=50, pval_thresh=1e-6)
    genes = unbinarize_list(genes)
    gene_modules[i] = (um, clu, genes, corr_mat)

In [None]:
def save_go_terms_for_modules(clu, genes, fpath):
    genes = np.array(genes)
    with open(fpath,'w') as f:
        for n,i in enumerate(np.unique(clu)): 
            curr_genes = [i for i in list(genes[clu==i])]
            f.write(f"--> Module  {i}:  {curr_genes[:20]}\n")
            for i in sig_genes_GO_query(curr_genes).head(5).name:
                f.write(i + "\n")
            f.write("--------------------\n")


In [None]:
gene_modules

In [None]:
base_path = "/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/module_spatial_plots"
for mod in gene_modules.keys():
    print(mod)
    curr_path = os.path.join(base_path, mod)
    if not os.path.exists(curr_path):
        os.mkdir(curr_path)
        
    um, clu, genes, corr_mat = gene_modules[mod]
    
    f, ax =plt.subplots(figsize=(5,5))
    ax.scatter(um[:,0], um[:,1], c=clu, s=1, cmap=plt.cm.gist_ncar)
    f.savefig(os.path.join(curr_path, "gene_umap.pdf"), dpi=300, bbox_inches="tight")
    
    f, ax = plt.subplots(figsize=(5,5), ncols=2, nrows=2, gridspec_kw={'wspace':0.01,'hspace':0.01, 'width_ratios':[0.5,10], 'height_ratios':[0.5,10]})
    ax[0][0].axis('off')
    ax[0][1].imshow(np.expand_dims(np.sort(clu),1).T,aspect='auto',cmap=plt.cm.gist_ncar,interpolation='none', rasterized=True)
    ax[0][1].axis('off')
    ax[1][1].imshow(corr_mat,vmin=-1,vmax=1,cmap=plt.cm.bwr,aspect='auto',interpolation='bilinear')
    ax[1][1].axis('off')
    ax[1][0].imshow(np.expand_dims(np.sort(clu),1),aspect='auto',cmap=plt.cm.gist_ncar,interpolation='none', rasterized=True)
    ax[1][0].axis('off')
    sorted_mod = np.sort(clu)
    for i in np.arange(1,len(clu)):
        if sorted_mod[i-1] != sorted_mod[i]:
            ax[1][1].axvline(i,color='k',lw=1)
            ax[0][1].axvline(i,color='k',lw=1)

            ax[1][1].axhline(i,color='k',lw=1)
            ax[1][0].axhline(i,color='k',lw=1)
    f.savefig(os.path.join(curr_path, "gene_gene_corr.pdf"), dpi=300, bbox_inches="tight")
    
    save_go_terms_for_modules(clu, genes, os.path.join(curr_path, 'go_terms.txt'))

In [None]:
# make dataframe for gene modules
umap_coords = []
clusters = []
genes = []
ct = []
for cell_type, (um, clu, g, corr_mat) in gene_modules.items():
    umap_coords.extend(um)
    clusters.extend(clu)
    genes.extend(g)
    ct.extend([cell_type]*len(clu))
umap_coords = np.vstack(umap_coords)

In [None]:
module_df = pd.DataFrame({'cell_type':ct, 'umap0' : umap_coords[:,0], 'umap1': umap_coords[:,1], 'cluster': clusters, 'gene' : genes})
module_df

In [None]:
def print_go_terms_for_modules(clu, genes):
    genes = np.array(genes)
    for n,i in enumerate(np.unique(clu)): 
        curr_genes = [i for i in list(genes[clu==i])]
        print("--> Module", i, curr_genes[:20])
        print(sig_genes_GO_query(curr_genes).head(5).name)
        print("--------------------")
    

In [None]:
# do t-test across ages and save violin plots
from scipy.stats import ttest_ind,ranksums
ttest_mod_dfs = []
for i in ["Olig"]:#adata_subset.obs.cell_type.unique():
    print(i)
    curr_adata = adata_subset[adata_subset.obs.cell_type==i]        
    curr_adata.X = np.sqrt(sc.pp.normalize_total(curr_adata,inplace=False)["X"])
    curr_mods_df = module_df[module_df.cell_type==i]
    pvals = []
    diff = []
    mods = []
   # pseudoage_corr = []
    for j in tqdm(curr_mods_df.cluster.unique()):
        sc.settings.figdir = f"/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/module_spatial_plots/{i}/"
        genes = list(curr_mods_df[curr_mods_df.cluster==j].gene)
        genes = [i for i in genes if i in curr_adata.var_names]
        #curr_expr_young = curr_adata[curr_adata.obs.age=='4wk',:][:, genes].X
        #curr_expr_old = curr_adata[curr_adata.obs.age=='90wk',:][:, genes].X

        sc.tl.score_genes(curr_adata, gene_list=list(curr_mods_df[curr_mods_df.cluster==j].gene), score_name=f'module_{j}',use_raw=False)
       # ttest = de.test.two_sample(data=curr_adata,grouping='age',test='t_test')
        #score_young = np.array(curr_expr_young.sum(1)).flatten()
        #score_old = np.array(curr_expr_old.sum(1)).flatten()
        #pvals.append(ranksums(score_young, score_old)[1])
        #diff.append(np.median(score_old)-np.median(score_young))
        #curr_adata.obs[f"module_{j}"] = np.array(curr_adata[:,genes].X.sum(1)).flatten()
        #pseudoage_corr.append(np.corrcoef(curr_adata.obs[f"module_{j}"], curr_adata.obs[f"pseudoage"])[0,1])
        #mods.append(j)
    #mod_names_sorted = [f"module_{mods[j]}" for j in np.argsort(diff)]
        #f,ax = pl
        sc.pl.violin(curr_adata, f"module_{j}", groupby='age',show=False, save=f"_aging_module_10x_{j}_score.pdf")
    #ttest_mod_dfs.append(pd.DataFrame({'cell_type':[i]*len(pvals), 'diff':diff, 'pval':pvals, 'module':mods}))


In [None]:
ttest_mod_dfs = pd.concat(ttest_mod_dfs)

In [None]:
ttest_mod_dfs['qval'] = multipletests(ttest_mod_dfs['pval'],method='fdr_bh')[1]

In [None]:
ttest_mod_dfs[ttest_mod_dfs.cell_type=="Astro"].sort_values('diff')

In [None]:
temp = ttest_mod_dfs.copy()

In [None]:
for i in temp.cell_type.unique():
    if i != "ExN":
        print(i)
        temp.loc[temp.cell_type==i, 'sorted_clust'] = np.argsort(temp.loc[temp.cell_type==i,'diff'])
    else:
        temp.loc[temp.cell_type==i, 'sorted_clust'] = temp.loc[temp.cell_type==i,'module']

In [None]:
ttest_mod_dfs = temp

In [None]:
for i in ttest_mod_dfs.cell_type.unique():
    print(i)
    mod_map = {r.module:int(r.sorted_clust) for i,r in ttest_mod_dfs[ttest_mod_dfs.cell_type==i].iterrows()}
    curr_mods = module_df[module_df.cell_type==i]
    module_df.loc[module_df.cell_type==i, "sorted_module"] = [int(mod_map[i]) for i in curr_mods.cluster]

In [None]:
# save out modules
module_df.to_csv("gene_lists/062422_gene_modules_eps0.5_50pcs.csv")

In [None]:
module_df = pd.read_csv("gene_lists/062422_gene_modules_eps0.5_50pcs.csv")

### Save out module activation plots and GO terms for Fig S5

In [None]:
# quantify module scores for each module across both ages
for i in ["Olig","Astro", "Micro"]:
    print(i)
    curr_adata = adata_subset[adata_subset.obs.cell_type==i].raw.to_adata()
    curr_adata.var_names = unbinarize_list(curr_adata.var_names)
    curr_mod_df = module_df[module_df.cell_type==i]
    score_names = []
    scores = []
    ages = []
    for j in tqdm(curr_mod_df.cluster.unique()):
        if int(j) != -1:
            sc.tl.score_genes(curr_adata, gene_list=curr_mod_df[curr_mod_df.cluster==j].gene,use_raw=False, score_name=f"module_score")
            score_names.extend([j]*curr_adata.shape[0])
            scores.extend(list(curr_adata.obs.module_score.values))
            ages.extend(list(curr_adata.obs.age.values))
    df = pd.DataFrame({'module':score_names, 'module_score':scores, 'age':ages})
    mean_diff = df[df.age=='90wk'].groupby('module')['module_score'].mean()  - df[df.age=='4wk'].groupby('module')['module_score'].mean() 
    f,ax = plt.subplots(figsize=(5,10))
    sns.boxplot(y='module',x='module_score',data=df, hue='age', order=mean_diff.sort_values().index,fliersize=1, orient='h',ax=ax, palette=sns.color_palette(age_colors[::2]))
    sns.despine(ax=ax)
    ax.set_xlabel('Module Score')
    ax.set_ylabel('Module')
    if i == "Olig":
        ax.set_xlim([-2, 25])
    else:
        ax.set_xlim([-2, 20])
    #ax.set_xlim([np.quantile(df.module_score,0.001), np.quantile(df.module_score,99)])
    save_fig(f, f"figS5_{i}_quant",dtype="pdf")

In [None]:
def save_go_terms_for_modules_table(clu, genes, fpath):
    genes = np.array(genes)
    with open(fpath,'w') as f:
        for n,i in enumerate(np.unique(clu)): 
            curr_genes = [i for i in list(genes[clu==i])]
            f.write(f"{i}\t")
            go_terms = sig_genes_GO_query(curr_genes).head(3).name
            if len(go_terms) > 0:
                for i in sig_genes_GO_query(curr_genes).head(3).name:
                    print(i)
                    f.write('\t' + i)
            else:
                f.write("NA")
            f.write("\n")

In [None]:
for i in ["Olig"]:
    print(i)
    curr_mod_df = module_df[module_df.cell_type==i]
    clu = np.array(curr_mod_df.cluster)
    genes = np.array(curr_mod_df.gene)
    save_go_terms_for_modules_table(clu, genes, f"/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/module_spatial_plots/{i}/go_term_table.csv")

In [None]:
for i in ["Olig"]:
    print(i)
    curr_mod_df = module_df[module_df.cell_type==i]
    clu = np.array(curr_mod_df.cluster)
    genes = np.array(curr_mod_df.gene)
    save_go_terms_for_modules_table(clu, genes, f"/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/module_spatial_plots/{i}/go_term_table.csv")

In [None]:
clu

## show gene-gene correlations for a few non-neuronal cell types

In [None]:
curr_adata.var_names = unbinarize_list(curr_adata.var_names)

In [None]:
plt.figure(figsize=(5,5))
plt.imshow(corr_mat[:, np.argsort(curr_df.cluster)][np.argsort(curr_df.cluster),:],vmin=-1,vmax=1, cmap=plt.cm.bwr,aspect='auto',interpolation='none')

In [None]:
def make_color_index(sort_clusts):
    idx = 0
    color_idx = np.zeros_like(sort_clusts)
    for i in range(1,len(sort_clusts)):
        if sort_clusts[i] != sort_clusts[i-1]:
            idx += 1
    

In [None]:
np.sort(curr_df.cluster)

In [None]:
base_path = "/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/module_spatial_plots"
for mod in ["Astro"]:
    print(mod)
    curr_df = module_df[module_df.cell_type==mod]
    curr_path = os.path.join(base_path, mod)
    if not os.path.exists(curr_path):
        os.mkdir(curr_path)
        
    #um, clu, genes, corr_mat = gene_modules[mod]
    curr_adata = adata_subset[adata_subset.obs.cell_type==mod]
    curr_adata = curr_adata.raw.to_adata()
    curr_adata.var_names = unbinarize_list(curr_adata.var_names)
    curr_adata = normalize_10x_data(curr_adata)
    curr_adata = curr_adata[:, curr_df.gene]
    u,s,vt = svds(curr_adata.X, k=50)
    corr_mat = np.dot(vt.T,vt)#np.corrcoef(adata_subset_micro.layers['sqrt_norm'].toarray().T)

    corr_mat, corr_pval = zscore_pval_mat(corr_mat)
    corr_mat = threshold_zscore_mat(corr_mat, corr_pval)
    
    corr_mat = corr_mat[:,np.argsort(curr_df.cluster)][np.argsort(curr_df.cluster),:]
    
    cmap = plt.cm.get_cmap('gist_ncar', len(curr_df.cluster.unique()))
    cmap 
   # cmap = cmap[np.random.permutation(cmap.shape[0]),:]
   # cmap = mpl.colors.ListedColormap(list(cmap))
    f, ax = plt.subplots(figsize=(5,5), ncols=2, nrows=2, gridspec_kw={'wspace':0.01,'hspace':0.01, 'width_ratios':[0.5,10], 'height_ratios':[0.5,10]})
    ax[0][0].axis('off')
    ax[0][1].imshow(np.expand_dims(np.sort(curr_df.cluster),1).T,aspect='auto',cmap=cmap,interpolation='none', rasterized=True,vmax=curr_df.cluster.max(),vmin=-1)
    ax[0][1].axis('off')
    ax[1][1].imshow(corr_mat,vmin=-1,vmax=1,cmap=plt.cm.bwr,aspect='auto',interpolation='bilinear')
    ax[1][1].axis('off')
    ax[1][0].imshow(np.expand_dims(np.sort(curr_df.cluster),1),aspect='auto',cmap=cmap,interpolation='none', rasterized=True,vmin=-1,vmax=curr_df.cluster.max())
    ax[1][0].axis('off')
    sorted_mod = np.sort(clu)

    #f.savefig(os.path.join(curr_path, "gene_gene_corr.pdf"), dpi=300, bbox_inches="tight")
    
    #save_go_terms_for_modules(clu, genes, os.path.join(curr_path, 'go_terms.txt'))

In [None]:
#temp = adata_subset[adata_subset.obs.cell_type==i].raw.to_adata()
#temp.X = temp.layers['raw']
temp = normalize_10x_data(temp)


In [None]:
expr = temp.X

In [None]:
n_pcs = 30
u,s,vt = svds(expr, k=n_pcs)
corr_mat = np.dot(vt.T,vt)#np.corrcoef(adata_subset_micro.layers['sqrt_norm'].toarray().T)


In [None]:
modules = module_df[np.logical_and(module_df.cluster != -1, module_df.cell_type==celltype)]
good_modules = modules[modules.gene.isin(curr_genes)].cluster

In [None]:
#pvals = test_module_correlation(corr_mat, modules)

In [None]:
#good_mods = np.unique(modules)[pvals<0.1]

In [None]:
#good_modules = modules[modules.isin(good_mods)]


In [None]:
#good_gene_idx = np.array([i for i in np.arange(len(modules)) if modules.values[i] in good_mods])

In [None]:
#temp = module_df[np.logical_and(module_df.gene.isin(np.array(curr_genes)[good_gene_idx]), module_df.cell_type==celltype)]

In [None]:
#plt.scatter(temp.umap0, temp.umap1, c=temp.cluster,cmap=plt.cm.gist_ncar)

In [None]:
corr_mat_subset = corr_mat#[good_gene_idx,:][:,good_gene_idx]

In [None]:
f,ax = plt.subplots(figsize=(10,10), ncols=2, nrows=2, gridspec_kw={'wspace':0.01,'hspace':0.01, 'width_ratios':[0.5,10], 'height_ratios':[0.5,10]})
ax[0][0].axis('off')
ax[0][1].imshow(np.expand_dims(np.sort(good_modules),1).T,aspect='auto',cmap=plt.cm.gist_ncar,interpolation='none')
ax[0][1].axis('off')
ax[1][1].imshow(zscore_mat(corr_mat_subset[np.argsort(good_modules),:][:, np.argsort(good_modules)]),vmin=-2.5,vmax=2.5,cmap=plt.cm.seismic,aspect='auto',interpolation='bilinear')
ax[1][1].axis('off')
ax[1][0].imshow(np.expand_dims(np.sort(good_modules),1),aspect='auto',cmap=plt.cm.gist_ncar,interpolation='none')
ax[1][0].axis('off')
sorted_mod = np.sort(good_modules)
for i in np.arange(1,len(sorted_mod)):
    if sorted_mod[i-1] != sorted_mod[i]:
        ax[1][1].axvline(i,color='k',lw=1)
        ax[0][1].axvline(i,color='k',lw=1)

        ax[1][1].axhline(i,color='k',lw=1)
        ax[1][0].axhline(i,color='k',lw=1)


## Load ligand/receptor pairs

In [None]:
# load cellchatdb 
cellchat = pd.read_csv("gene_lists/cellchatdb_interactions.csv")
# load celltalkdb databases
celltalk = pd.read_table("gene_lists/mouse_lr_pair.txt")

cellchat_genes = list(set(list(cellchat['receptor']) + list(cellchat['ligand'])))
celltalk_genes = list(set(list(celltalk['ligand_gene_symbol']) + list(celltalk['receptor_gene_symbol'])))
cellchat_genes = [i for i in cellchat_genes if i in unbinarize_list(adata_10x.raw.var_names)]

## Make list of genes for imputation

In [None]:
# genes to use for imputation
genes_for_imputation = unbinarize_list(list(genes_nonneurons)) + \
    unbinarize_list(list(genes_neurons)) + list(de_genes_age_minor_signif.gene) + list(de_genes_age_major_signif.gene) \
    + cellchat_genes + celltalk_genes + list(module_df.gene.unique())

In [None]:
genes_for_imputation = sorted(list(set(genes_for_imputation))) + ["Fezf2"]

## Compute pseudo-age score for each celltype

In [None]:
from sklearn.decomposition import PCA, TruncatedSVD
def compute_pseudoage_score(A, renormalize=True, densify=False):
    """
    Identifies the top PC correlated with the plane separating young from old
    """
    if renormalize:
        A = A.raw.to_adata()
        A = normalize_10x_data(A)
    if densify:
        pca = PCA(n_components=5).fit_transform(A.X.toarray())
    else:
        pca = TruncatedSVD(n_components=5).fit_transform(A.X)
    coef = np.mean(A[A.obs.age=='90wk'].X,0)-np.mean(A[A.obs.age=='4wk'].X,0)
    proj = np.dot(coef, A.X.T).T
    pc_corrs = [np.corrcoef(pca[:,i].T,proj)[0,1] for i in range(pca.shape[1])]
    max_pc_corr = np.argsort(pc_corrs)[0]   
    pseudoage = pca[:,np.argsort(pc_corrs)[0]]
    # invert sign so older age is greater
    if np.mean(pseudoage[A.obs.age=='90wk']) < np.mean(pseudoage[A.obs.age=='4wk']):
        pseudoage = -pseudoage        
    return pseudoage


In [None]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
def compute_lda_pseudoage_score(A,renormalize=True):
    if renormalize:
        A = A.raw.to_adata()
        A = normalize_10x_data(A)
    lda = LinearDiscriminantAnalysis().fit_transform(A.X, A.obs.age)
    pseudoage = normalize_data(lda)
    return pseudoage


In [None]:
# actually compute pseudoage scores
adata_subset.obs['pseudoage'] = 0
adata_subset.obs['pseudoage_lda'] = 0

pseudoage_scores = {}
pseudoage_lda_scores = {}
for i in adata_subset.obs.cell_type.unique():
    print(i)
    curr_A = adata_subset[adata_subset.obs.cell_type==i]
    pseudoage_scores[i] = compute_pseudoage_score(curr_A)    
    pseudoage_lda_scores[i] = compute_lda_pseudoage_score(curr_A)

In [None]:
for i in adata_subset.obs.cell_type.unique():
    adata_subset.obs.loc[adata_subset.obs.cell_type==i,"pseudoage"] = pseudoage_scores[i]
    adata_subset.obs.loc[adata_subset.obs.cell_type==i,"pseudoage_lda"] = pseudoage_lda_scores[i]

In [None]:
sc.pl.umap(adata_subset, color=['age','pseudoage', 'pseudoage_lda'],cmap=plt.cm.seismic,vmin=0,vmax=1,size=1)

In [None]:
sc.pl.umap(adata_subset, color=['age','leiden_10x'])

In [None]:
# find major cell types for each 
#leiden_clusts = adata_subset.obs.leiden_10x
#int_clusts = adata_subset.obs.cell_type
#clust_map = {}
#xtab = pd.crosstab(leiden_clusts, int_clusts)
#for i,r in xtab.iterrows():
#    clust_map[i] = xtab.columns[r.argmax()]


## Run imputation

In [None]:
adata_subset = adata_subset.raw.to_adata()


In [None]:
adata_subset.shape

In [None]:
adata_subset.var_names = unbinarize_list(adata_subset.var_names)

In [None]:
adata_subset.var_names

In [None]:
good_genes_for_imputation = [i for i in genes_for_imputation if i in list(adata_subset.var_names)]

In [None]:
len(np.unique(genes_for_imputation))

In [None]:
len(good_genes_for_imputation)

In [None]:
#good_genes_for_imputation =adata_imputed.var_names

In [None]:
adata_subset = adata_subset[:, good_genes_for_imputation]

In [None]:
adata_subset.layers['sqrt_norm'] = np.sqrt(sc.pp.normalize_total(adata_subset, inplace=False)["X"])

In [None]:
adata_subset.layers['raw'] = adata_subset.X.copy()

In [None]:
adata_subset = normalize_10x_data(adata_subset, identify_var_genes=False)

In [None]:
# set up imputation for expression
npcs = 30
pc_A = adata_combined_mer.obsm['X_pca_harmony'][:,:npcs]
pc_B = adata_combined_10x.obsm['X_pca_harmony'][:,:npcs]
#expr_B = adata_subset.X

In [None]:
# actually run imputation
imputed = impute_expr_sparse(pc_A, pc_B, adata_subset[:,good_genes_for_imputation].X, n_per_batch=10000)
imputed_sqrt = impute_expr_sparse(pc_A, pc_B, adata_subset[:,good_genes_for_imputation].layers['sqrt_norm'],n_per_batch=10000)
imputed_raw = impute_expr_sparse(pc_A, pc_B, adata_subset[:,good_genes_for_imputation].layers['raw'],n_per_batch=10000)

In [None]:
# impute leiden cluster identity from 10X
#imputed_leiden = impute_celltype(pc_A, pc_B, adata_subset.obs.leiden_10x)

In [None]:
# impute top 30 PCs
#imputed_pcs = impute_obs(pc_A, pc_B, adata_subset.obsm['X_pca'])

In [None]:
# impute pseudoage
#imputed_pseudoage = impute_obs(pc_A, pc_B, adata_subset.obs.pseudoage, n_per_batch=10000)

In [None]:
#imputed_pseudoage_lda = impute_obs(pc_A, pc_B, adata_subset.obs.pseudoage_lda, n_per_batch=10000)

In [None]:
#adata_combined_mer.obs['pseudoage'] = imputed_pseudoage

In [None]:
#adata_combined_mer.obs['pseudoage_lda'] = imputed_pseudoage_lda

In [None]:
#adata_combined_mer.obs['leiden_10x'] = imputed_leiden

In [None]:
sc.pl.umap(adata_combined_mer, color=['age','pseudoage_lda'],cmap=plt.cm.coolwarm, vmin=0,vmax=1)

In [None]:
adata_imputed = ad.AnnData(
    X=imputed,
    obs=adata_combined_mer.obs,
    var=adata_subset.var,
    obsm=adata_combined_mer.obsm,
   # varm=adata_combined_mer.varm,
)

In [None]:
adata_imputed.layers['sqrt_norm'] = imputed_sqrt
adata_imputed.layers['raw'] = imputed_raw

In [None]:
#adata_imputed.obsm['X_pca_10x'] = imputed_pcs

In [None]:
#adata_imputed.obs['clust_annot_10x'] = [clust_map[i] for i in adata_imputed.obs.leiden_10x]
#adata_imputed.obs['cell_type_10x'] = [clust_to_celltype[i] for i in adata_imputed.obs.leiden_10x]

In [None]:
#adata_imputed.obs['pseudoage_lda'] = imputed_pseudoage_lda

## Save out imputed data

In [None]:
adata_imputed.write_h5ad("/faststorage/brain_aging/merfish/exported/062422_merfish_combined_imputed.h5ad")

# Load imputed data

In [None]:
# load imputed data
adata_imputed = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/062422_merfish_combined_imputed.h5ad")

In [None]:
adata_imputed = unbinarize_strings(adata_imputed)

In [None]:
sc.pl.umap(adata_imputed, color='age')

# QC imputed clusters

## Check discriminabilty of leiden clusters with MERFISH

In [None]:
vc = adata_imputed.obs['leiden_10x'].value_counts()
good_clusts = list(vc[vc>100].index)

In [None]:
adata_imputed = adata_imputed[adata_imputed.obs.leiden_10x.isin(good_clusts)]

In [None]:
# subsample at most 1000 cells from each subtype
class_labels = adata_imputed.obs['leiden_10x']
#class_labels = class_labels[class_labels!='42023350']
#class_labels = class_labels[class_labels!='35390410']
class_X = adata_imputed.obsm['X_pca_10x'][:,:25]
pd.DataFrame(class_labels).value_counts()

In [None]:
class_labels = LabelEncoder().fit_transform(class_labels)
#len(class_labels)

In [None]:

class_idx = []
n_to_take = 1000
for i in np.unique(class_labels):
    curr_class = np.argwhere(class_labels==i).flatten()
    #
    class_idx.extend(curr_class[np.random.permutation(len(curr_class))][:n_to_take])
class_idx = np.array(class_idx)
class_X = class_X[class_idx,:]
class_labels = class_labels[class_idx]

In [None]:
# make 80/20 train-test split
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.ensemble import RandomForestClassifier
#from xgboost import XGBClassifier
from sklearn.preprocessing import LabelEncoder
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=0)
for train_idx, test_idx in sss.split(class_X, class_labels):
    X_train, X_test = class_X[train_idx,:], class_X[test_idx,:]
    y_train, y_test = class_labels[train_idx], class_labels[test_idx]
    
    klass = MLPClassifier(random_state=42).fit(X_train, y_train)
    preds = klass.predict(X_test[:,:25])

In [None]:
from sklearn.metrics import confusion_matrix, plot_confusion_matrix
cmat = confusion_matrix(y_test, preds).astype(np.float64).T
for i in range(cmat.shape[0]):
    if cmat[i,:].sum() > 0:
        cmat[i,:] = cmat[i,:]/cmat[i,:].sum()


In [None]:
plt.figure(figsize=(10,10))
plt.imshow(cmat,cmap=plt.cm.Reds,vmin=0,vmax=1)


### train classifier on non-imputed 10X and apply to imputed

In [None]:
klass = MLPClassifier(random_state=42).fit(adata_subset.obsm['X_pca'][:,:25], adata_subset.obs.leiden_10x)


In [None]:
preds = klass.predict(adata_imputed.obsm['X_pca_10x'])

In [None]:
from sklearn.metrics import confusion_matrix, plot_confusion_matrix
cmat = confusion_matrix(adata_imputed.obs['leiden_10x'], preds).astype(np.float64).T
for i in range(cmat.shape[0]):
    if cmat[i,:].sum() > 0:
        cmat[i,:] = cmat[i,:]/cmat[i,:].sum()


In [None]:
plt.figure(figsize=(10,10))
plt.imshow(cmat,cmap=plt.cm.Reds,vmin=0,vmax=1)
plt.xlabel('Predicted cluster')
plt.ylabel('Actual cluster')

## Plot marker genes, etc between raw, imputed, and MERFISH

In [None]:
sc.pl.umap(adata_subset, color=['age','Il33', 'Gfap','B2m','C4b','Lyz2'],layer='sqrt_norm',cmap=plt.cm.Reds)

In [None]:
sc.pl.umap(adata_imputed, color=['age','Il33', 'Gfap','B2m','C4b','Lyz2'],layer='sqrt_norm',cmap=plt.cm.Reds)

# Fig 1: Basic data UMAP

In [None]:
# plot integrated stuff

In [None]:
#adata_imputed = unbinarize_strings(adata_imputed)

In [None]:
import seaborn as sns

In [None]:
temp = adata_imputed[adata_imputed.obs.cell_type_10x=="Astro"]
sc.pl.violin(temp, ['pseudoage'], groupby='clust_annot_10x')

In [None]:
temp = adata_imputed[adata_imputed.obs.cell_type_10x=="Olig"]
sc.pl.violin(temp, ['pseudoage'], groupby='clust_annot_10x')

In [None]:
sorted(list(adata_imputed.obs.clust_annot_10x.unique()))

In [None]:
temp = adata_imputed[adata_imputed.obs.cell_type_10x=="Micro"]
sc.pl.violin(temp, ['pseudoage'], groupby='clust_annot_10x')

# Fig 2: Cell typing

## Define colors

In [None]:
import seaborn as sns
from cycler import cycler
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as R
from statsmodels.stats.multitest import multipletests
import matplotlib as mpl
def gen_light_palette(prefix, color_name, uniq_clusts):
    n = np.sum([1 if prefix in i else 0 for i in uniq_clusts])
    return sns.light_palette(color_name, n_colors=n+2)[2:]

def gen_dark_palette(prefix, color_name, uniq_clusts):
    n = np.sum([1 if prefix in i else 0 for i in uniq_clusts])
    return sns.dark_palette(color_name, n_colors=n+2)[2:]
clust_cell_types_10x = {
# Astro -- green
"Astro" : "seagreen",
# Excitatory -- red/orange
"ExN-L2/3" : "darkorange",
"ExN-L5" : "lightsalmon",
"ExN-L6" : "maroon",
"ExN-Olf" : "firebrick",
# inhibitory -- blue/purple
'InN-Olf' : "cornflowerblue",
#'InN-Adarb2' : "lightsteelblue",
'InN-Chat' : "lavender",
#'InN-Egfr' : "turquoise",
#'InN-Calb' : "teal",
'InN-Lhx6':'lightsteelblue',

'InN-Calb2' : "navy",
'InN-Lamp5' : "royalblue",
'InN-Pvalb' : "steelblue",
'InN-Sst' : "dodgerblue",
'InN-Vip' : "deepskyblue",
"MSN-D1" : "mediumslateblue",
"MSN-D2" : "rebeccapurple",
# immune cells + microglia -- pink
"Micro" : "deeppink",
"Macro" : "hotpink",

# Endothelial/vasculure -- gold/tan
"Vlmc" : "olive",
"Endo" : "khaki",
"Peri" : "goldenrod",

# Oligodendrocytes
"Olig" : "slategrey",
"OPC" : "black"
}
 

major_cell_types_10x = {
# Astro -- green
"Astro" : "seagreen",
# Excitatory -- red/orange
"ExN" : "lightcoral",
# inhibitory -- blue/purple
"InN" : "cornflowerblue",
"MSN" : "mediumpurple",

# immune cells + microglia -- pink
"Micro" : "pink",
"Macro" : "violet",

# Endothelial/vasculure -- gold/tan
"Vlmc" : "gold",
"Endo" : "khaki",
"Peri" : "goldenrod",

# Oligodendrocytes
"Olig" : "slategrey",
"OPC" : "black"
}

def generate_palettes_10x(A,clust_key="clust_annot_10x", cell_type_key="cell_type_10x"):
    print("Updated")
    uniq_celltypes = np.sort(np.unique(A.obs[cell_type_key]))
    uniq_clusts = np.sort(A.obs[clust_key].unique())

    celltype_pals = []
    for i in uniq_celltypes:
        pal = gen_dark_palette(i, major_cell_types_10x[i], uniq_celltypes)
        celltype_pals.append(pal)
    celltype_pals = cycler(color=np.vstack(celltype_pals))
 
    celltype_colors = {}
    for i,c in enumerate(iter(celltype_pals)):
        celltype_colors[uniq_celltypes[i]] = c['color']
 
    clust_pals = []
    label_colors = {}
    for i in sorted(clust_cell_types.keys()):
        n = np.sum([1 if i in j else 0 for j in uniq_clusts])
        if n > 0:
            pal = gen_dark_palette(i, clust_cell_types_10x[i], uniq_clusts)
            print(i,pal)
            clust_pals.append(pal)
            # find palettes for cell types
            curr_clusts = sorted([k for k in uniq_clusts if i in k])
            for n,p in enumerate(pal):
                label_colors[curr_clusts[n]] = p
        else:
            print("Couldn't find clust", i)
    clust_pals = cycler(color=np.vstack(clust_pals))
     #label_colors = {}
     #for i, c in enumerate(iter(clust_pals)):
     #    label_colors[valid_clusts[i]] = c['color']
 
    return celltype_colors, celltype_pals, label_colors, clust_pals


In [None]:
age_colors = ['cornflowerblue','thistle','lightcoral']
age_cmap = mpl.colors.ListedColormap(age_colors)
dtype_colors = ['mediumslateblue', 'goldenrod']
dtype_cmap = mpl.colors.ListedColormap(dtype_colors)
celltype_colors, celltype_pals, label_colors, clust_pals = generate_palettes(adata_imputed)

In [None]:
sc.pl.umap(adata_imputed, color=['cell_type'], palette=celltype_pals)

In [None]:
sc.pl.umap(adata_imputed, color=['clust_annot'], palette=clust_pals)

In [None]:
sc.pl.umap(adata_imputed, color=['age'], palette=sns.color_palette(age_colors))

In [None]:
sc.pl.umap(adata_imputed, color=['dtype'], palette=sns.color_palette(dtype_colors))

## Make heatmap of celltype markers

In [None]:
adata_imputed.raw = adata_imputed


In [None]:
# subset of 10X cells with only clusters in MERFISH also
adata_subset_shared = adata_subset[adata_subset.obs.clust_annot.isin(adata_imputed.obs.clust_annot.unique())]

In [None]:
import seaborn as sns
clust_avg = []
clust_ids = sorted(adata_imputed.obs.clust_annot.unique())
for i in clust_ids:
    print(i)
    clust_avg.append(adata_imputed[adata_imputed.obs.clust_annot == i].X.mean(0))
# make dendrogram
clust_avg = np.vstack(clust_avg)

In [None]:
from scipy.spatial.distance import pdist
import scipy.cluster.hierarchy as hc

D = pdist(clust_avg,'correlation')
Z = hc.linkage(D,'complete',optimal_ordering=True)
#label_colors['NA'] = (0,0,0)
dn = hc.dendrogram(Z)
lbl_order = [clust_ids[c] for c in dn['leaves']]

In [None]:
# compute fraction of each cluster per age and per brain area
n_bins = 100
frac_per_age = np.zeros((len(lbl_order), n_bins))
#frac4 = total_4wk/(total_90wk+total_24wk+total_4wk)
#frac24 = total_24wk/(total_90wk+total_24wk+total_4wk)
#frac90 = total_90wk/(total_90wk+total_24wk+total_4wk)

total_90wk = np.sum(adata_imputed.obs.age=='90wk')
total_24wk = np.sum(adata_imputed.obs.age=='24wk')
total_4wk = np.sum(adata_imputed.obs.age=='4wk')

for n,c in enumerate(lbl_order):
    curr_clust = adata_imputed[adata_imputed.obs.clust_annot==c]
    # count fraction of total cells that are in this area for each age
    curr4 = np.sum(curr_clust.obs.age == "4wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "4wk")
    curr24 = np.sum(curr_clust.obs.age == "24wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "24wk")
    curr90 = np.sum(curr_clust.obs.age == "90wk")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "90wk")
    
    # scale based on the relative number of cells in each age in the total experiment
    denom = total_4wk + total_24wk + total_90wk
    #denom = total_4wk + total_90wk
    curr4 /= total_4wk
    curr24 /= total_24wk
    curr90 /= total_90wk
    denom = curr4+curr24+curr90
    #denom = curr4+ curr90
    curr4 /= denom
    curr24 /= denom
    curr90 /= denom
    nbins90 = int(round(n_bins*curr90))
    nbins24 = int(round(n_bins*curr24))
    print(n, c, curr4, curr90)
    #frac_per_age[n,:] = np.hstack([np.ones(nbins90),
    #                               np.zeros(n_bins-nbins90)])

    frac_per_age[n,:] = np.hstack([2*np.ones(nbins90),
                                   np.ones(nbins24), 
                                   np.zeros(n_bins-nbins90-nbins24)])


In [None]:
# fraction of cells in MERFISH vs scRNAseq
frac_per_dtype = np.zeros((len(lbl_order), n_bins))

for n,c in enumerate(lbl_order):
    curr_clust = adata_combined[adata_combined.obs.clust_annot==c]
    curr_merfish = np.sum(curr_clust.obs.dtype == "merfish")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "4wk")
    curr_10x = np.sum(curr_clust.obs.dtype == "scrnaseq")/curr_clust.shape[0] #np.sum(adata_combined.obs.age == "4wk")

    curr_merfish /= np.sum(adata_combined.obs.dtype=='merfish')
    curr_10x /=  np.sum(adata_combined.obs.dtype=='scrnaseq')
    denom = curr_merfish + curr_10x
    curr_merfish /= denom
    curr_10x /= denom
    print(n, c, curr_merfish, curr_10x)
    frac_per_dtype[n,:] = np.hstack([np.zeros(round(n_bins*(curr_merfish))), 
                                   np.ones(round(n_bins*(1-curr_merfish)))])


In [None]:
# make values for dotplot
dotplot_genes =[
 'Slc17a7',

     'Gad2',

    'Olig1',
 'Cspg4',
 'Aqp4',
 'Rorb',
 'Tshz2',
 'Cldn5',
 'Vtn',
 'Pdgfra',
 'F13a1',
 'Cd3e',
 'Ctss',
 'Adora2a',
 'Drd1',
 'Otof',
 'Calb2',
 'Cpne7',
 'Chat',
 'Vip',
 'Adarb2',
 'Lamp5',
 'Cux2',
 'Pvalb',
 'Sst',
 'Lhx6',
 'Hs3st4',
 'Syt6',
 'Ndst4',
 'Ptpru',
 'Rspo1',
 'Scube1']


dotplot_genes = [i for i in dotplot_genes if i in adata_imputed.var_names]
from scipy.stats import zscore
#["Slc17a7", 'Gad1','Npr3', "Ndst4","Ptpru", "Scube1", "Cux2", "Fezf2",'Chat','Pvalb','Sst','Calb2','Lamp5','Adarb2','Vip',"Drd1","Drd2","Cd3e","Cd74","Igf2","Cx3cr1","Cpne7","Tshz2","Flt1","Vtn", "Aqp4","Foxj1", "Pdgfra", "Plp1","F13a1"]
#sc.pl.dotplot(adata, ["Gad1", "Slc17a7", "Slc17a6", "Adora2a", "Drd1", "Olig1", "Pdgfra", "Aqp4", "Csf1r"],
#              categories_order=lbl_order, groupby='clust_label', swap_axes=True, ax=ax)
dotplot_vals = np.zeros((len(dotplot_genes), len(lbl_order)))
dotplot_frac = np.zeros((len(dotplot_genes), len(lbl_order)))

for n,i in enumerate(lbl_order):
    dotplot_vals[:,n] = np.mean(adata_imputed[adata_imputed.obs.clust_annot == i][:, dotplot_genes].X.toarray(),0)
    dotplot_frac[:,n] = np.sum(adata_imputed[adata_imputed.obs.clust_annot == i][:, dotplot_genes].X.toarray()>0,0)/np.sum(adata_imputed.obs.clust_annot == i)
for n,i in enumerate(dotplot_genes):
    dotplot_vals[n,:] = zscore(dotplot_vals[n,:])
    #dotplot_vals[n,:] /= dotplot_vals[n,:].max()
max_idx = np.arange(len(dotplot_genes))

# uncomment optimize order
from scipy.optimize import linear_sum_assignment
#_, max_idx = linear_sum_assignment(-dotplot_frac.T)
#dotplot_genes = [dotplot_genes[i] for i in max_idx]


In [None]:
from scipy.stats import ttest_ind, ranksums, fisher_exact, chisquare
from statsmodels.stats.proportion import proportions_ztest, test_proportions_2indep

In [None]:
sns.set_style('white')

In [None]:
#f,ax = plt.subplots(figsize=(20,2.5), nrows=3, ncols=1)
dotscale = 35
f = plt.figure(figsize=(8,8))
gs = plt.GridSpec(nrows=5, ncols=1, height_ratios=[5,1,30,6,6], hspace=0.1)
ax = plt.subplot(gs[0])

hc.dendrogram(Z,ax=ax,labels=clust_ids,leaf_font_size=10,color_threshold=0,above_threshold_color='k');
sns.despine(ax=ax,left=True)
ax.axis('off')
lbl_order = []
for lbl in plt.gca().get_xmajorticklabels():
    if lbl != 'NA':
        #lbl.set_color(label_colors[lbl.get_text()])
        lbl_order.append(lbl.get_text())
ax = plt.subplot(gs[1])

curr_cols = mpl.colors.ListedColormap([label_colors[c] for c in lbl_order])
ax.imshow(np.expand_dims(np.arange(len(label_colors.keys())),1).T, cmap=curr_cols,aspect='auto',interpolation='none')
#ax.axis('off')
ax.set_yticklabels([])
ax.set_xticks(np.arange(len(lbl_order)))
ax.set_xticklabels([np.sum(adata_imputed.obs.clust_annot_10x==c) for c in lbl_order],rotation=90)
sns.despine(ax=ax,left=True)
ax = plt.subplot(gs[2])

for i in range(dotplot_vals.shape[1]):
    plt.scatter( i*np.ones((dotplot_vals.shape[0])),-np.arange(dotplot_vals.shape[0]), c=dotplot_vals[max_idx,:][:,i], s=dotscale*dotplot_frac[max_idx,:][:,i], cmap=plt.cm.seismic, vmin=-5,vmax=5)
ax.set_yticks(-np.arange(len(dotplot_genes)));
ax.set_yticklabels(dotplot_genes, fontsize=8)
ax.set_xlim([-0.5, dotplot_vals.shape[1]-0.5])
ax.set_xticks([])
sns.despine(ax=ax,bottom=True)
#ax.axis('off')

# age
ax = plt.subplot(gs[3])
ax.imshow(frac_per_age.T, vmin=0,vmax=2,aspect='auto',interpolation='none', cmap=age_cmap)
ax.set_yticklabels([])
ax.set_xticks([])
#ax.set_xticklabels(lbl_order,rotation=90); #[str(np.sum(adata.obs.clust_label==i)) + " " + i for i in lbl_order])
ax.axhline(50,color='w',linestyle='--')
#ax.axhline(66,color='w',linestyle='--')
sns.despine(ax=ax, left=True)

# dtype
ax = plt.subplot(gs[4])
ax.imshow(frac_per_dtype.T, vmin=0,vmax=1,aspect='auto',interpolation='none', cmap=dtype_cmap)
ax.set_yticklabels([])
ax.set_xticks(np.arange(len(lbl_order)))
ax.set_xticklabels(lbl_order,rotation=90); #[str(np.sum(adata.obs.clust_label==i)) + " " + i for i in lbl_order])
ax.axhline(50,color='w',linestyle='--')

sns.despine(ax=ax, left=True)
for lbl in plt.gca().get_xmajorticklabels():
    if lbl != 'NA':
        lbl.set_color(label_colors[lbl.get_text()])
    #if lbl.get_text() in signif_change:
    #    lbl.set_text("* " + lbl.get_text())
#ax.set_xticklabels(lbl_order_starred,rotation=90); #[str(np.sum(adata.obs.clust_label==i)) + " " + i for i in lbl_order])

#save_fig(f,"fig1_cluster_heatmap",dtype="pdf")

## Make violin plot of marker genes for each subtype

In [None]:
genes_to_show = {
    "Endo" : ["Ly6a","Apoe","B2m","Rbm20","Tmem209", "Xdh","Fmo2","Bgn"],
    "Micro" : ["Dpp10","Meg3","Ctss","Chrm3","Trem2","Trim30a","Itgb2","B2m"],
    "Astro" : ["Lsamp","Gpc5","Slc1a2","Luzp2","Trpm3","Brinp3","Gfap","C4b"],
    "Olig" : ["Ank2","Trim2","Robo1","Spock1","C4b","Neat1","Dgki","Il33"]
}

In [None]:
for i in genes_to_show.values():
    print(len(i))

In [None]:
cmaps = {'Endo' : plt.cm.YlOrBr,
         'Micro' : plt.cm.Reds,
         'Astro' : plt.cm.YlGn,
         'Olig' : plt.cm.Greys}
for i in ["Endo","Micro","Astro","Olig"]:
    curr_adata = adata_sqrt_norm[adata_sqrt_norm.obs.cell_type==i]
    f, ax = plt.subplots(figsize=(5,5))
    sc.pl.stacked_violin(curr_adata, genes_to_show[i], groupby='clust_annot',ax=ax,cmap=cmaps[i])
    save_fig(f, f"fig2_{i}_violin","pdf")

In [None]:
for i in ["Endo","Micro","Astro","Olig"]:
    curr_adata = adata_sqrt_norm[adata_sqrt_norm.obs.cell_type==i]
    sc.tl.rank_genes_groups(curr_adata, 'clust_annot', use_raw=False, method='t-test_overestim_var')
    #sc.tl.filter_rank_genes_groups(curr_adata, min_fold_change=1.5, min_in_group_fraction=0.2)

    sc.pl.rank_genes_groups(curr_adata, key='rank_genes_groups')

## Compute pseudoage based on imputed data

In [None]:
# actually compute pseudoage scores
from sklearn.decomposition import TruncatedSVD
adata_imputed.obs['pseudoage_imputed'] = 0
adata_imputed.obs['pseudoage_pca'] = 0
pseudoage_ct = ['Astro','Endo','Olig','Micro','OPC']
pseudoage_scores = {}
pseudoage_pca_scores = {}
for i in pseudoage_ct:
    print(i)
    curr_A = adata_imputed[adata_imputed.obs.cell_type==i]
    pca = TruncatedSVD(n_components=30).fit_transform(curr_A.X)
    diff = pca[curr_A.obs.age=='90wk',:].mean(0) - pca[curr_A.obs.age=='4wk',:].mean(0)
    proj = np.dot(diff, pca.T)
    pseudoage_scores[i] = proj#compute_pseudoage_score(curr_A, renormalize=False, densify=False)    
    pca_youngold = TruncatedSVD(n_components=30).fit(curr_A[curr_A.obs.age.isin(['4wk','90wk'])].X)
    pseudoage_pca_scores[i] = pca_youngold.transform(curr_A.X)[:,0]

In [None]:
for i in pseudoage_ct:
    adata_imputed.obs.loc[adata_imputed.obs.cell_type==i,'pseudoage_imputed'] = pseudoage_scores[i]
    adata_imputed.obs.loc[adata_imputed.obs.cell_type==i,'pseudoage_pca'] = pseudoage_pca_scores[i]

In [None]:
def normalize_data_quantile(x):
    return (x-np.quantile(x,0.05))/(np.quantile(x, 0.95) - np.quantile(x, 0.05))
adata_imputed.obs['pseudoage_norm'] = 0
adata_imputed.obs['pseudoage_imputed_norm'] = 0
adata_imputed.obs['pseudoage_pca_norm'] = 0

for i in pseudoage_ct:
#    # normalize
#    adata_imputed.obs.loc[adata_imputed.obs.cell_type==i,"pseudoage_imputed"] = normalize_data(pseudoage_scores[i])
    temp = adata_imputed.obs.loc[adata_imputed.obs.cell_type==i,"pseudoage"]
    adata_imputed.obs.loc[adata_imputed.obs.cell_type==i,"pseudoage_norm"] = normalize_data_quantile(temp)
    temp2 = adata_imputed.obs.loc[adata_imputed.obs.cell_type==i,"pseudoage_imputed"]
    adata_imputed.obs.loc[adata_imputed.obs.cell_type==i,"pseudoage_imputed_norm"] = normalize_data_quantile(temp2)
    temp3 = adata_imputed.obs.loc[adata_imputed.obs.cell_type==i,"pseudoage_pca"]
    adata_imputed.obs.loc[adata_imputed.obs.cell_type==i,"pseudoage_pca_norm"] = normalize_data_quantile(temp3)

In [None]:
sc.pl.umap(adata_imputed, color=['age','pseudoage_imputed_norm','pseudoage_pca_norm', 'pseudoage_norm'], vmin=0,vmax=1,cmap=plt.cm.seismic)

## Plot pseudoage distribution scores

In [None]:
sns.set_style('white')

In [None]:
ylim = [-0.5, 1.75]
key = 'pseudoage_imputed_norm'
fig_size = [3,5]
#f,axes = plt.subplots(nrows=2,ncols=2,figsize=(4,8),gridspec_kw={'wspace':0.75,'hspace':0.25})
#ax = axes[0][0]
#plt.figure(figsize=fig_size)
sns.displot(x=key, data=adata_imputed[adata_imputed.obs.cell_type=="Astro"].obs, aspect=2, common_norm=False,hue='age',kind='kde',palette=sns.color_palette(age_colors),ax=ax,fill=True)
sns.despine()
plt.xlim(ylim)
plt.title('Astro')
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/figures_imputed/fig1_pseudoage_astro.pdf",bbox_inches="tight",dpi=300)

#ax = axes[0][1]
#plt.figure(figsize=fig_size)
sns.displot(x=key, data=adata_imputed[adata_imputed.obs.cell_type=="Micro"].obs, aspect=2, common_norm=False,hue='age',kind='kde',palette=sns.color_palette(age_colors),ax=ax,fill=True)
sns.despine()
plt.xlim(ylim)
plt.title('Micro')
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/figures_imputed/fig1_pseudoage_micro.pdf",bbox_inches="tight",dpi=300)

#ax = axes[1][0]
#plt.figure(figsize=fig_size)
sns.displot(x=key, data=adata_imputed[adata_imputed.obs.cell_type=="Olig"].obs, aspect=2, common_norm=False,hue='age',kind='kde',palette=sns.color_palette(age_colors),ax=ax,fill=True)
plt.xlim(ylim)
plt.title('Olig')
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/figures_imputed/fig1_pseudoage_olig.pdf",bbox_inches="tight",dpi=300)

#ax = axes[1][1]
#plt.figure(figsize=fig_size)
sns.displot(x=key, data=adata_imputed[adata_imputed.obs.cell_type=="Endo"].obs, aspect=2, common_norm=False,hue='age',kind='kde',palette=sns.color_palette(age_colors),ax=ax,fill=True)
sns.despine()
plt.title('Endo')
plt.xlim(ylim)
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/figures_imputed/fig1_pseudoage_endo.pdf",bbox_inches="tight")


In [None]:
ylim = [-0.5, 1.75]
key = 'pseudoage_imputed_norm'
f,axes = plt.subplots(nrows=2,ncols=2,figsize=(4,8),gridspec_kw={'wspace':0.75,'hspace':0.25})
ax = axes[0][0]
sns.boxplot(x='age',y=key, data=adata_imputed[adata_imputed.obs.cell_type=="Astro"].obs, palette=sns.color_palette(age_colors),ax=ax,showfliers=False)
sns.despine()
ax.set_ylim(ylim)
ax.set_title('Astro')

ax = axes[0][1]
sns.boxplot(x='age',y=key, data=adata_imputed[adata_imputed.obs.cell_type=="Micro"].obs, palette=sns.color_palette(age_colors),ax=ax,showfliers=False)
sns.despine()
ax.set_ylim(ylim)
ax.set_title('Micro')

ax = axes[1][0]
sns.boxplot(x='age',y=key, data=adata_imputed[adata_imputed.obs.cell_type=="Olig"].obs, palette=sns.color_palette(age_colors),ax=ax,showfliers=False)
ax.set_ylim([0,1])
ax.set_ylim(ylim)
ax.set_title('Olig')

ax = axes[1][1]
sns.boxplot(x='age',y=key, data=adata_imputed[adata_imputed.obs.cell_type=="Endo"].obs, palette=sns.color_palette(age_colors),ax=ax,showfliers=False)
sns.despine()
ax.set_title('Endo')
ax.set_ylim(ylim)
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/figures_imputed/fig1_pseudoage.pdf",bbox_inches="tight")

In [None]:
ylim = [-0.5, 1.75]
key = 'pseudoage_imputed_norm'
f,axes = plt.subplots(nrows=2,ncols=2,figsize=(4,8),gridspec_kw={'wspace':0.75,'hspace':0.25})
ax = axes[0][0]
sns.boxplot(x='clust_annot_10x',y=key, data=adata_imputed[adata_imputed.obs.cell_type_10x=="Astro"].obs, palette=sns.color_palette(age_colors),ax=ax,showfliers=False)
sns.despine()
#ax.set_ylim(ylim)
#ax.set_title('Astro')

ax = axes[0][1]
sns.boxplot(x='clust_annot_10x',y=key, data=adata_imputed[adata_imputed.obs.cell_type_10x=="Micro"].obs, palette=sns.color_palette(age_colors),ax=ax,showfliers=False)
sns.despine()
#ax.set_ylim(ylim)
#ax.set_title('Micro')

ax = axes[1][0]
sns.boxplot(x='clust_annot_10x',y=key, data=adata_imputed[adata_imputed.obs.cell_type_10x=="Olig"].obs, palette=sns.color_palette(age_colors),ax=ax,showfliers=False)
#ax.set_ylim([0,1])
#ax.set_ylim(ylim)
#ax.set_title('Olig')

ax = axes[1][1]
sns.boxplot(x='clust_annot_10x',y=key, data=adata_imputed[adata_imputed.obs.cell_type_10x=="Endo"].obs, palette=sns.color_palette(age_colors),ax=ax,showfliers=False)
sns.despine()
ax.set_title('Endo')
#ax.set_ylim(ylim)
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/figures_imputed/fig1_pseudoage_celltype.pdf",bbox_inches="tight")

# Fig 3: Spatial Analysis

## Functions for spatial plotting

In [None]:
class SpatialPlotter(object):
    def __init__(self, A):
        self.A = A
        self.plot_info = {
            '4wk' : {
                'batch' : 8,
                'slice' : 1,
                'rot' : -183,
                'xlim' : [200, 2300],
                'ylim' : [200, 3800]
            },
            '24wk' : {
                'batch' : 12,
                'slice' : 0,
                'rot' : -12,
                'xlim' : [1950, 1950+2100],
                'ylim' : [200, 3700]
            },
            '90wk' : {
                'batch' : 9,
                'slice' : 1,
                'rot' : 35,
                'xlim' : [200, 2300],
                'ylim' : [400, 3900]
        }}

    def get_info_for_age(self, age):
        curr_adata = self.A[np.logical_and(
                self.A.obs.batch==self.plot_info[age]['batch'], 
                self.A.obs.slice==self.plot_info[age]['slice'])]
        curr_rot = self.plot_info[age]['rot']
        aspect_ratio, nx, ny = calculate_aspect_ratio(curr_adata, rot=curr_rot)
        xlim = self.plot_info[age]['xlim']
        ylim = self.plot_info[age]['ylim']
        aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])
        return curr_adata, curr_rot, xlim, ylim
    
    def get_plot_info(self, age):
        info = self.plot_info[age]
        return info['batch'], info['slice'], info['xlim'], info['ylim'], info['rot']

    def plot_gene_by_age(self, gene_name, celltype=None, vmin=0,vmax=3, s=1,celltype_key='cell_type',figsize=None,cmap=plt.cm.Reds,alpha=1):
        if figsize is None:
            figsize=(4,10)
        f,ax = plt.subplots(figsize=figsize, nrows=3, ncols=1, gridspec_kw={'wspace':0.05, 'hspace':0.01})
        for j, age in enumerate(['4wk', '24wk', '90wk']):
            batch, dslice, xlim, ylim, rot = self.get_plot_info(age)
            curr_ax = ax[j]
            curr_adata = self.A[np.logical_and(self.A.obs.batch==batch, self.A.obs.slice==dslice)]
            #if celltype is not None:
            #    curr_clusts = [i for i in curr_adata.obs[celltype_key].unique() if i != celltype]
            #    self.plot_celltypes_by_age(curr_adata, curr_clusts, s=s, ax=curr_ax, cmap=mpl.colors.ListedColormap(['lightgray']*len(curr_clusts)))
            #if celltype is not None:
            #    curr_adata = curr_adata[curr_adata.obs[celltype_key]==celltype]
            #curr_clusts = curr_adata.obs['clust_annot'].unique()
            plot_gene_expr(curr_adata, celltype, gene_name,rot=rot,s=s,vmin=vmin,vmax=vmax,use_raw=False,key=celltype_key,cmap=cmap,ax=curr_ax,alpha=alpha)
            curr_ax.set_xlim(xlim)
            curr_ax.set_ylim(ylim)
        return f

    def plot_obs_by_age(self, obs_name, figsize=None,clust_annot_key='clust_annot',vmin=0,vmax=3,cmap=plt.cm.Reds, cell_types=None,key='clust_annot_preds',s=0.1):
        if cell_types is None:
            cell_types = self.A.obs[clust_annot_key].unique()
            
        if figsize is None:
            figsize=(4,10)

        f,ax = plt.subplots(figsize=figsize, nrows=3, ncols=1, gridspec_kw={'wspace':0.05, 'hspace':0.01})
        for j, age in enumerate(['4wk', '24wk', '90wk']):
            batch, dslice, xlim, ylim, rot = self.get_plot_info(age)
            curr_ax = ax[j]
            curr_adata = self.A[np.logical_and(self.A.obs.batch==batch, self.A.obs.slice==dslice)]
            plot_obs(curr_adata, cell_types, obs_name,rot=rot,s=s,vmin=vmin,vmax=vmax,key=clust_annot_key,cmap=cmap,ax=curr_ax,)
            curr_ax.set_xlim(xlim)
            curr_ax.set_ylim(ylim)
        return f
    
    def plot_celltypes_by_age(self, age, cell_types, s=1, clust_key='clust_annot', ax=None, cmap=plt.cm.gist_ncar):
        if ax is None:
            f, ax = plt.subplots()
        curr_adata, curr_rot, xlim, ylim = self.get_info_for_age(age)
        plot_clust_subset(curr_adata, cell_types, cmap, rot=curr_rot,s=s, ax=ax, xlim=xlim, ylim=ylim,clust_key=clust_key)


    def plot_obs_by_age_for_celltype(self, obs_key, cell_type, show_background=False, figsize=None, s=1,vmin=None, vmax=None,cell_type_key='cell_type', curr_size=2.5, cmap=plt.cm.rainbow,alpha=0.25):
        #curr_adata, curr_rot, xlim, ylim = self.get_info_for_age(age)
        if vmin is None:
            vmin = np.quantile(self.A[self.A.obs[cell_type_key]==cell_type].obs[obs_key],0.01) #np.quantile(adata.obs.activate_micro,0.00001)
        if vmax is None:
            vmax = np.quantile(self.A[self.A.obs[cell_type_key]==cell_type].obs[obs_key],0.99)#np.quantile(adata.obs.activate_micro,0.99999)
            
        if figsize is None:
            figsize=(4,10)
        if not show_background:
            curr_A = self.A[self.A.obs[cell_type_key]==cell_type]
        else:
            curr_A = self.A
        f,ax = plt.subplots(figsize=figsize, nrows=3, ncols=1, gridspec_kw={'wspace':0.05, 'hspace':0.01})
        for j, age in enumerate(['4wk', '24wk', '90wk']):
            batch, dslice, xlim, ylim, rot = self.get_plot_info(age)
            curr_ax = ax[j]
            curr_adata = curr_A[np.logical_and(curr_A.obs.batch==batch, curr_A.obs.slice==dslice)]
            plot_obs(curr_adata, cell_type, obs_key,rot=rot,s=s,vmin=vmin,vmax=vmax,key=cell_type_key,cmap=cmap,ax=curr_ax,alpha=alpha)
            curr_ax.set_xlim(xlim)
            curr_ax.set_ylim(ylim)
        return f
    
    def plot_celltype_by_age(self, cell_type, figsize=None, s=1,vmin=None, vmax=None,cell_type_key='clust_annot', curr_size=2.5, cmap=plt.cm.gist_ncar):
        #curr_adata, curr_rot, xlim, ylim = self.get_info_for_age(age)
            
        if figsize is None:
            figsize=(4,10)
        f,ax = plt.subplots(figsize=(4,10), nrows=3, ncols=1, gridspec_kw={'wspace':0.05, 'hspace':0.01})
        for j, age in enumerate(['4wk', '24wk', '90wk']):
            batch, dslice, xlim, ylim, rot = self.get_plot_info(age)
            curr_ax = ax[j]
            curr_adata = self.A[np.logical_and(self.A.obs.batch==batch, self.A.obs.slice==dslice)]
            plot_clust_subset(curr_adata, cell_type, cmap, rot=rot,s=curr_size, ax=curr_ax, xlim=xlim, ylim=ylim,clust_key=cell_type_key)
            curr_ax.set_xlim(xlim)
            curr_ax.set_ylim(ylim)
        return f
            


In [None]:
sp = SpatialPlotter(adata_imputed)

In [None]:
curr_cmap = mpl.colors.ListedColormap([label_colors[i] for i in label_colors.keys()])


In [None]:
f, ax = plt.subplots(figsize=(4,7))
sp.plot_celltypes_by_age('24wk', adata_imputed.obs.clust_annot.unique(), cmap=curr_cmap,ax=ax,s=2)
f.set_facecolor('white')
save_fig(f, "fig3_allcelltypes")

In [None]:
neuronal_subtypes = adata_imputed.obs.loc[adata_imputed.obs.cell_type.isin(["ExN","InN","MSN"]), "clust_annot"].unique()
nonneuronal_subtypes = adata_imputed.obs.loc[~adata_imputed.obs.cell_type.isin(["ExN","InN","MSN"]), "clust_annot"].unique()

In [None]:
f, ax = plt.subplots(figsize=(4,7))
sp.plot_celltypes_by_age('24wk', neuronal_subtypes, cmap=curr_cmap,ax=ax,s=2)
f.set_facecolor('white')

save_fig(f, "fig3_neuronal")

In [None]:
f, ax = plt.subplots(figsize=(4,7))
sp.plot_celltypes_by_age('24wk', nonneuronal_subtypes, cmap=curr_cmap,ax=ax,s=2)
f.set_facecolor('white')

save_fig(f, "fig3_nonneuronal")

In [None]:
f, ax = plt.subplots(figsize=(4,7))
sp.plot_celltypes_by_age('24wk', ["Olig",'Astro','Micro'], cmap=curr_cmap,ax=ax,s=2,clust_key='cell_type')
f.set_facecolor('white')

save_fig(f, "fig3_olig_astro_micro")

In [None]:
sp.plot_celltype_by_age('Endo-3', cell_type_key='clust_annot');

In [None]:
for i in ['Endo','Astro','Micro','Olig']:
    f = sp.plot_obs_by_age_for_celltype("pseudoage_imputed_norm",i,s=10,vmin=0,vmax=1,cmap=plt.cm.rainbow,figsize=(2.5,12));
    f.savefig(f"/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/figures_imputed/fig3_pseudoage_spatial_{i}.pdf",bbox_inches="tight",dpi=200)

In [None]:
sp.plot_obs_by_age("pseudoage",vmin=-10,vmax=10,cmap=plt.cm.seismic);

In [None]:
f = sp.plot_gene_by_age('Ephb1','OPC',figsize=(4,16),vmax=1)

In [None]:
gene_module = ['Ikzf2', 'Ifi206', 'Ifi214', 'Ifi209', 'Ifi208', 'Smyd3', 'Fam107b', 'Prkcq', 'Bcl2l11', 'S100a4', 'Ifi44', 'Cd52', 'Runx3', 'Mlxip', 'Zc3hav1', 'Gimap4', 'Herc6', 'Klrk1', 'Styk1', 'Rinl']

In [None]:
f = sp.plot_obs_by_age_for_celltype('module_score',"Olig")

In [None]:
f = sp.plot_gene_by_age('Tnf', figsize=(4,16),vmax=10)

In [None]:
f = sp.plot_gene_by_age('C4b', figsize=(4,16),vmax=10)

## Spatial plots of example imputed genes

In [None]:

ttest_de_celltype_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_minor_age_V3_nools.csv")
qval_thresh = 0.05
coef_thresh_major = np.log(1.25)
coef_thresh_minor = np.log(2)
de_genes_age_major_signif = ttest_de_celltype_df[np.logical_and(np.abs(ttest_de_celltype_df.coef) > coef_thresh_major, ttest_de_celltype_df.qval<qval_thresh)]
de_genes_age_major_signif = de_genes_age_major_signif[~np.isinf(de_genes_age_major_signif.coef)]

In [None]:
de_genes_age_major_signif[de_genes_age_major_signif.cell_type=="Astro-1"].sort_values('coef').tail(30)

In [None]:
#adata_imputed_sqrt = adata_imputed.copy()
#adata_imputed_sqrt.X = adata_imputed_sqrt.layers['sqrt_norm']

In [None]:
sp_sqrt = SpatialPlotter(adata_sqrt_norm)

In [None]:
f = sp_sqrt.plot_gene_by_age("Bmp6","Astro",vmax=1.5,s=5,celltype_key='cell_type',cmap=plt.cm.Reds, figsize=(3,12));
save_fig(f, "fig4_imputed_Astro_Bmp6",dtype="png")

In [None]:
f = sp_sqrt.plot_gene_by_age("Ighm","ExN",vmax=1.25,s=5,celltype_key='cell_type',cmap=plt.cm.Reds,figsize=(3,12));
save_fig(f, "fig4_imputed_ExN_Ighm",dtype="png")

In [None]:
f = sp.plot_gene_by_age("Efemp1","Astro",s=5,vmax=6,celltype_key='cell_type',cmap=plt.cm.Reds,figsize=(3,12));
save_fig(f, "fig4_imputed_Astro_Efemp1",dtype="png")

In [None]:
f = sp.plot_gene_by_age("Pakap","Olig",s=5,vmax=15,celltype_key='cell_type',cmap=plt.cm.Reds,figsize=(3,12));
save_fig(f, "fig4_imputed_Olig_Pakap",dtype="png")

In [None]:
f = sp.plot_gene_by_age("Mpp7","MSN",vmax=5,s=5,celltype_key='cell_type',cmap=plt.cm.Reds,figsize=(3,12));
save_fig(f, "fig4_imputed_MSN_Mpp7",dtype="png")

## Plot comparison between MERFISH and imputed

In [None]:
adata_combined = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_adata_combined_harmony.h5ad")
adata_combined = unbinarize_strings(adata_combined)
adata_raw_merfish = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_combined_merfish_with_doublet_umap_allages.h5ad")
adata_raw_merfish = unbinarize_strings(adata_raw_merfish)

adata_raw_merfish = adata_raw_merfish[adata_combined[adata_combined.obs.dtype=="merfish"].obs.index]
adata_combined_merfish = adata_combined[adata_combined.obs.dtype=='merfish']
adata_combined_merfish.X = adata_raw_merfish.X#adata_combined_merfish.uns['raw_merfish_X']

### Plot example MERFISH genes

In [None]:
ttest_de_celltype_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_major_age_merfish_V2.csv")
qval_thresh = 0.05
coef_thresh_major = np.log(1.25)
coef_thresh_minor = np.log(2)
de_genes_age_major_signif = ttest_de_celltype_df[np.logical_and(np.abs(ttest_de_celltype_df.coef) > coef_thresh_major, ttest_de_celltype_df.qval<qval_thresh)]
de_genes_age_major_signif = de_genes_age_major_signif[~np.isinf(de_genes_age_major_signif.coef)]

In [None]:
de_genes_age_major_signif[de_genes_age_major_signif.cell_type=="MSN"].sort_values('coef').tail(20)

In [None]:
sp_mer = SpatialPlotter(adata_combined)

In [None]:
f = sp_mer.plot_gene_by_age("Onecut2","MSN",vmax=10, s=5, figsize=(3,12));
save_fig(f, "fig4_merfish_MSN_Onecut2",dtype="png")

In [None]:
f = sp_mer.plot_gene_by_age("Xdh","Endo",vmax=7,figsize=(3,12),s=5);
save_fig(f, "fig4_merfish_Endo_Xdh",dtype="png")

In [None]:
f = sp_mer.plot_gene_by_age("C4b","Astro",vmax=5,figsize=(3,12),s=5);
save_fig(f, "fig4_imputed_Astro_C4b",dtype="png")

In [None]:
f = sp_mer.plot_gene_by_age("Il33","Olig",vmax=6,s=5,figsize=(3,12));
save_fig(f, "fig4_imputed_Olig_Il33",dtype="png")

In [None]:
adata_combined_merfish.raw = adata_combined_merfish

In [None]:
curr_size = 5
celltypes = adata_combined_merfish.obs.clust_annot.unique()
#curr_adata = adata_combined_merfish[np.logical_and(adata_combined_merfish.obs.batch==9, adata_combined_merfish.obs.slice==1)]
curr_adata = adata_combined_merfish[np.logical_and(adata_combined_merfish.obs.batch==8, adata_combined_merfish.obs.slice==1)]


In [None]:
def plot_example_genes(A, genes, save_name=None):
    # 4 wk
    curr_rot = -183
    xlim = [200, 2300]
    ylim =[200, 3800]
    # 90 wk
    #curr_rot = 35
    #xlim = [200, 2300]
    #ylim = [400, 4000]
    aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])

    f, axes = plt.subplots(nrows=1, ncols=len(genes), figsize=(5*aspect_ratio*(len(genes)+1),5*1))
    #plot_seg(curr_adata, seg_cmap, rot=curr_rot,s=curr_size, ax=axes[0],xlim=xlim, ylim=ylim)
    k = 0
    for c in genes: 
        print(c)
        ax = axes[k]
        gene_expr = np.array(A[:,c].X.toarray()).flatten()
        vmin = np.quantile(gene_expr, 0.001)
        vmax = np.quantile(gene_expr, 0.999)
        plot_gene_expr(A, celltypes, c, plt.cm.Purples, s=5,alpha=0.5, vmin=vmin,vmax=vmax, rot=curr_rot, ax=ax, xlim=xlim, ylim=ylim,use_raw=False)
        k += 1
        ax.set_title(c)
    if save_name is not None:
        f.savefig(f"/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/figures_imputed/{save_name}.png",bbox_inches='tight',dpi=300)

In [None]:
# plot area specific genes
area_genes_excite = [
"Slc17a7",
 'Otof',
#"Lamp5",
'Cux2',
 'Scube1',
#    "Ptpru",
  'Fezf2',
 'Syt6',
]
area_genes_inhib = [
    "Gad2",
    "Sst",
    "Pvalb",
    "Vip",
    "Drd1",
    "Drd2"
]
area_genes_nonneuronal = [
"Aqp4", "Pdgfra", "Mbp", "Ctss", "Cldn5", "Vtn"
]
common_genes = [i for i in area_genes_excite + area_genes_inhib + area_genes_nonneuronal if i in adata_imputed.var_names]


In [None]:
adata_imputed_common = adata_imputed[:, common_genes]

In [None]:
adata_imputed_common = adata_imputed_common[adata_combined_merfish.obs.index]

In [None]:
#adata_imputed_common.X = adata_imputed_common.layers['sqrt_norm']

In [None]:
def compute_gene_corr(A_mer, A_imputed, gene):
    expr_mer = A_mer[:,gene].X.toarray().flatten()
    expr_imputed = A_imputed[:,gene].layers['sqrt_norm'].toarray().flatten()
    return np.corrcoef(expr_mer, expr_imputed)[0,1]

In [None]:
for i in area_genes_excite:
    print(i, compute_gene_corr(adata_combined_merfish, adata_imputed_common, i))

In [None]:
curr_adata = adata_combined_merfish[np.logical_and(adata_imputed.obs.batch==8, adata_imputed.obs.slice==1)]

plot_example_genes(curr_adata, area_genes_nonneuronal, "figS4_nonneuronal_mer")

plot_example_genes(curr_adata, area_genes_excite, "figS4_excite_mer")
plot_example_genes(curr_adata, area_genes_inhib, "figS4_inhib_mer")

In [None]:
curr_adata = adata_imputed_common[np.logical_and(adata_imputed_common.obs.batch==8, adata_imputed_common.obs.slice==1)]


In [None]:
curr_adata.X = np.array(curr_adata.layers['sqrt_norm'].toarray()).copy()

In [None]:

plot_example_genes(curr_adata, area_genes_nonneuronal, "figS4_nonneuronal_imputed")

plot_example_genes(curr_adata, area_genes_excite, "figS4_excite_imputed")
plot_example_genes(curr_adata, area_genes_inhib, "figS4_inhib_imputed")

In [None]:
curr_adata = adata_imputed[np.logical_and(adata_imputed_common.obs.batch==8, adata_imputed_common.obs.slice==1)]

imputed_genes = [
#"Col5a1", #1
"Col12a1", #2
"Cnp", #3 
"Chrm3",#4
"Camk4", #5
"Camk2a", #6
"Exph5", #7
"Fat3", #8
"Ephb6", #9 
"Gabra4", # 10
    "Fstl4", # 11
    "Frmpd4", #12
    "Fhod3", # 13
    "Dpp10", # 14
    "Alcam", # 15
    "Cacna1e", # 16
    "Bdnf", #17
]
plot_example_genes(curr_adata, sorted(imputed_genes), "figS4_extra_imputed")


In [None]:
curr_size = 2
celltypes = adata_imputed.obs.clust_annot.unique()

curr_adata = adata_imputed[np.logical_and(adata_imputed.obs.batch==9, adata_imputed.obs.slice==1)]
curr_rot = 35
xlim = [200, 2300]
ylim = [400, 4000]
aspect_ratio = (xlim[1]-xlim[0])/(ylim[1]-ylim[0])

f, axes = plt.subplots(nrows=1, ncols=1, figsize=(5*aspect_ratio*1,5*1))
save_path = "/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/individual_gene_plots"

for i,c in enumerate(list(adata_imputed.var_names)):
    #print(c)
    curr_path = os.path.join(save_path, c + ".png")
    if "Rik" not in c:
        if not os.path.exists(curr_path):
            print("Saving", c, "(",i,"/",len(adata_imputed.var_names), ")")
            ax = axes
            gene_expr = curr_adata[:,c].X.toarray()
            vmin = np.quantile(gene_expr, 0.01)
            vmax = np.quantile(gene_expr, 0.99)
            plot_gene_expr(curr_adata, celltypes, c, plt.cm.Reds, s=1,alpha=0.5, vmin=vmin,vmax=vmax, rot=curr_rot, ax=ax, xlim=xlim, ylim=ylim, use_raw=False)

            f.savefig(curr_path, dpi=200, bbox_inches='tight',facecolor='white', transparent=False)

## Spatial pseudoage distribution

## Spatial gene module plots

In [None]:
module_df

In [None]:
adata_imputed.obs.module_score

In [None]:
curr_adata.obs['module_score']

In [None]:
base_path = "/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/module_spatial_plots"
for i in ["Olig", "Astro","Micro"]:#module_df.cell_type.unique():
    curr_path = os.path.join(base_path, i)
    curr_module_df = module_df[module_df.cell_type==i]
    curr_adata = adata_imputed[adata_imputed.obs.cell_type==i]
    #curr_sp = SpatialPlotter(curr_adata)
    print(i)
    os.system(f"rm {curr_path}/{i}_module*")

    for j in curr_module_df.sorted_module.unique():
        try:
            temp = curr_module_df[curr_module_df.sorted_module==j]
            sc.tl.score_genes(curr_adata, gene_list=list(temp.gene), score_name='module_score',use_raw=False)
            adata_imputed.obs['module_score'] = 0
            adata_imputed.obs.loc[curr_adata.obs.index, "module_score"] = curr_adata.obs['module_score']
            
            f = plot_age_obs_comparison(adata_imputed, "spatial_clust_annots", "module_score", i, show_pvals=False, order=spatial_order,);
            f.savefig(os.path.join(curr_path, f"{i}_module_spatialdist_{temp.cluster.unique()[0]}_.pdf"),bbox_inches='tight', dpi=200)
            
            f = sp.plot_obs_by_age_for_celltype("module_score", i,show_background=True,figsize=(2,8), curr_size=10)
            f.set_facecolor('white')

            f.savefig(os.path.join(curr_path, f"{i}_module_{temp.cluster.unique()[0]}_.png"),bbox_inches='tight', dpi=200)
        except Exception as e:
            print(e)

# Fig 4: Differentially expressed genes/module

In [None]:
#ttest_de_clust_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_minor_age.csv")


## Compute DE genes for imputed

In [None]:
ttest_de_celltype_df[ttest_de_celltype_df.gene=="C1qc"]

In [None]:
de_clust_order = [
 'ExN-L2/3-1',
 'ExN-L2/3-2',
 'ExN-L5-1',
 'ExN-L5-2',
 'ExN-L5-3',
 'ExN-L6-1',
 'ExN-L6-2',
 'ExN-L6-3',
 'ExN-Olf',
 'InN-Olf-1',
 'InN-Olf-2',

 'InN-Vip',

 'InN-Lamp5',

 'InN-Pvalb-1',
 'InN-Pvalb-2',
 'InN-Pvalb-3',
 'InN-Sst-1',
 'InN-Sst-2',
 'InN-Calb2-1',
 'InN-Calb2-2',
 'InN-Chat',
 'InN-Lhx6',

'MSN-D1-1',
 'MSN-D1-2',
 'MSN-D2',
 'OPC',
 'Olig-1',
 'Olig-2',
 'Olig-3',

'Astro-1',
 'Astro-2',
 'Vlmc',
 'Peri-1',
 'Peri-2',
 'Endo-1',
 'Endo-2',
 'Endo-3',
 'Epen',

 'Micro-1',
 'Micro-2',
 'Micro-3',
 'Macro',
]

In [None]:
import diffxpy.api as de

de_by_celltype = []
for i in adata_sqrt_norm.obs.cell_type.unique():
    print(i)
    curr_adata = adata_sqrt_norm[adata_sqrt_norm.obs.cell_type==i]
    #curr_adata = curr_adata.raw.to_adata()
    curr_adata = adata_sqrt_norm[adata_sqrt_norm.obs.age.isin(['4wk','90wk'])]
    yng = curr_adata[curr_adata.obs.age=='4wk']
    old = curr_adata[curr_adata.obs.age=='90wk']
    expr_thresh = 0.0
    frac_yng = np.array((yng.X>expr_thresh).sum(0)/yng.shape[0]).flatten()
    frac_old = np.array((old.X>expr_thresh).sum(0)/old.shape[0]).flatten()
    #curr_adata = curr_adata[:,good_genes]    
    de_res = de.test.t_test(curr_adata, 'age')
    
    de_by_celltype.append(pd.DataFrame({'cell_type': [i]*len(de_res.gene_ids), 'gene' : de_res.gene_ids, 'log10_fc':de_res.log10_fold_change(), 'log10_qval': de_res.log10_qval_clean(), 'frac_yng':frac_yng, 'frac_old':frac_old }))

In [None]:
de_by_celltype_df = pd.concat(de_by_celltype)

In [None]:
#de_by_celltype_df.to_csv("gene_lists/0710122_ttest_diffxpy_imputed_celltype.csv")

In [None]:
de_by_celltype_df = de_by_celltype_df[de_by_celltype_df.log10_fc.abs() < 1]

In [None]:
import diffxpy.api as de

de_by_clust = []
for i in adata_subset.obs.clust_annot.unique():
    print(i)
    curr_adata = adata_sqrt_norm[adata_sqrt_norm.obs.clust_annot==i]
    #curr_adata = curr_adata.raw.to_adata()
    curr_adata = adata_sqrt_norm[adata_sqrt_norm.obs.age.isin(['4wk','90wk'])]
    yng = curr_adata[curr_adata.obs.age=='4wk']
    old = curr_adata[curr_adata.obs.age=='90wk']
    expr_thresh = 0.1
    frac_yng = np.array((yng.X>expr_thresh).sum(0)/yng.shape[0]).flatten()
    frac_old = np.array((old.X>expr_thresh).sum(0)/old.shape[0]).flatten()
    #curr_adata = curr_adata[:,good_genes]    
    de_res = de.test.t_test(curr_adata, 'age')
    
    de_by_clust.append(pd.DataFrame({'cell_type': [i]*len(de_res.gene_ids), 'gene' : de_res.gene_ids, 'log10_fc':de_res.log10_fold_change(), 'log10_qval': de_res.log10_qval_clean(), 'frac_yng':frac_yng, 'frac_old':frac_old }))

In [None]:
de_by_clust_df = pd.concat(de_by_clust)

In [None]:
#de_by_clust_df.to_csv("gene_lists/0710122_ttest_diffxpy_imputed.csv")

In [None]:
#de_by_clust_df = ttest_de_clust_df 

In [None]:
de_by_clust_df = de_by_clust_df[de_by_clust_df.cell_type!='T cell']

In [None]:
de_by_clust_df

In [None]:
expr_frac_thresh = 0.2
ttest_de_clust_df_signif_thresh2 = de_by_clust_df[np.logical_and(np.abs(de_by_clust_df.log10_fc) > np.log10(2), de_by_clust_df.log10_qval<np.log10(0.01))]
ttest_de_clust_df_signif_thresh2 = ttest_de_clust_df_signif_thresh2[np.logical_and(ttest_de_clust_df_signif_thresh2.frac_yng > expr_frac_thresh, ttest_de_clust_df_signif_thresh2.frac_old > expr_frac_thresh)]
ttest_de_clust_df_signif_thresh1_5 = de_by_clust_df[np.logical_and(np.abs(de_by_clust_df.log10_fc) > np.log10(1.5), de_by_clust_df.log10_qval<np.log10(0.01))]
ttest_de_clust_df_signif_thresh1_5 = ttest_de_clust_df_signif_thresh1_5[np.logical_and(ttest_de_clust_df_signif_thresh1_5.frac_yng > expr_frac_thresh, ttest_de_clust_df_signif_thresh1_5.frac_old > expr_frac_thresh)]


In [None]:
#de_genes_age_minor_signif = de_genes_age_minor[np.logical_and(de_genes_age_minor.qval<0.05,
#                                                              np.log2(np.exp(np.abs(de_genes_age_minor.coef)))>1)]
clust_idx = np.array([np.argwhere(np.array(list(label_colors.keys())) == i)[0][0] for i in de_clust_order])
clust_idx_cmap = mpl.colors.ListedColormap(list(label_colors.values()))



In [None]:
sns.set_style('white')
f, ax = plt.subplots(figsize=(8,3), nrows=3, gridspec_kw={'hspace':0.05, 'height_ratios':[10,1,10]})
for i in [100,200,300,400]:
    ax[0].axhline(i,color='gray',linestyle='--',lw=1)
    ax[2].axhline(-i,color='gray',linestyle='--',lw=1)

sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': ttest_de_clust_df_signif_thresh1_5.cell_type.unique(), 
                                    'de':[ttest_de_clust_df_signif_thresh1_5[np.logical_and(ttest_de_clust_df_signif_thresh1_5.cell_type==i, ttest_de_clust_df_signif_thresh1_5.log10_fc>0)].shape[0] for i in ttest_de_clust_df_signif_thresh1_5.cell_type.unique()]}),
            order=de_clust_order,color='salmon',ax=ax[0])
sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': ttest_de_clust_df_signif_thresh2.cell_type.unique(), 
                                    'de':[ttest_de_clust_df_signif_thresh2[np.logical_and(ttest_de_clust_df_signif_thresh2.cell_type==i, ttest_de_clust_df_signif_thresh2.log10_fc>0)].shape[0] for i in ttest_de_clust_df_signif_thresh2.cell_type.unique()]}),
            order=de_clust_order,color='red',ax=ax[0])

ax[0].set_xticks([])
ax[0].set_xlabel("")
ax[0].set_ylim([0,500])

sns.despine(ax=ax[0],bottom=True)

ax[1].imshow(np.expand_dims(clust_idx,1).T,aspect='auto',interpolation='none',cmap=clust_idx_cmap)
ax[1].axis('off')
sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': ttest_de_clust_df_signif_thresh1_5.cell_type.unique(), 
                                    'de':[-ttest_de_clust_df_signif_thresh1_5[np.logical_and(ttest_de_clust_df_signif_thresh1_5.cell_type==i, ttest_de_clust_df_signif_thresh1_5.log10_fc<0)].shape[0] for i in ttest_de_clust_df_signif_thresh1_5.cell_type.unique()]}),
            order=de_clust_order,color='skyblue',ax=ax[2])
sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': ttest_de_clust_df_signif_thresh2.cell_type.unique(), 
                                    'de':[-ttest_de_clust_df_signif_thresh2[np.logical_and(ttest_de_clust_df_signif_thresh2.cell_type==i, ttest_de_clust_df_signif_thresh2.log10_fc<0)].shape[0] for i in ttest_de_clust_df_signif_thresh2.cell_type.unique()]}),
            order=de_clust_order,color='darkblue',ax=ax[2])
ax[2].set_xticklabels(ax[2].get_xticklabels(),rotation = 90);
sns.despine(ax=ax[2])
ax[2].set_ylim([-500,0])
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_num_de_imputed.pdf",dpi=300,bbox_inches='tight')

## Use 10x DE genes for heatmap

In [None]:
ttest_de_clust_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_minor_age_V3_nools.csv")
ttest_de_celltype_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_major_age_V2.csv")


In [None]:
ttest_de_clust_df_signif = ttest_de_clust_df[np.logical_and(np.abs(ttest_de_clust_df.coef) > np.log(2), ttest_de_clust_df.qval<0.05)]
ttest_de_celltype_df_signif = ttest_de_celltype_df[np.logical_and(np.abs(ttest_de_celltype_df.coef) > np.log(1.25), ttest_de_celltype_df.qval<0.05)]

In [None]:
ttest_de_clust_df_signif_thresh1_5 = ttest_de_clust_df[np.logical_and(np.abs(ttest_de_clust_df.coef) > np.log(2), ttest_de_clust_df.qval<0.05)]
ttest_de_clust_df_signif_thresh2 = ttest_de_clust_df[np.logical_and(np.abs(ttest_de_clust_df.coef) > np.log(2.5), ttest_de_clust_df.qval<0.05)]


In [None]:
sns.set_style('white')
f, ax = plt.subplots(figsize=(8,3), nrows=3, gridspec_kw={'hspace':0.05, 'height_ratios':[10,1,10]})
for i in [100,200,300,400]:
    ax[0].axhline(i,color='gray',linestyle='--',lw=1)
    ax[2].axhline(-i,color='gray',linestyle='--',lw=1)

sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': ttest_de_clust_df_signif_thresh1_5.cell_type.unique(), 
                                    'de':[ttest_de_clust_df_signif_thresh1_5[np.logical_and(ttest_de_clust_df_signif_thresh1_5.cell_type==i, ttest_de_clust_df_signif_thresh1_5.coef>0)].shape[0] for i in ttest_de_clust_df_signif_thresh1_5.cell_type.unique()]}),
            order=de_clust_order,color='salmon',ax=ax[0])
sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': ttest_de_clust_df_signif_thresh2.cell_type.unique(), 
                                    'de':[ttest_de_clust_df_signif_thresh2[np.logical_and(ttest_de_clust_df_signif_thresh2.cell_type==i, ttest_de_clust_df_signif_thresh2.coef>0)].shape[0] for i in ttest_de_clust_df_signif_thresh2.cell_type.unique()]}),
            order=de_clust_order,color='red',ax=ax[0])

ax[0].set_xticks([])
ax[0].set_xlabel("")
ax[0].set_ylim([0,500])

sns.despine(ax=ax[0],bottom=True)

ax[1].imshow(np.expand_dims(clust_idx,1).T,aspect='auto',interpolation='none',cmap=clust_idx_cmap)
ax[1].axis('off')
sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': ttest_de_clust_df_signif_thresh1_5.cell_type.unique(), 
                                    'de':[-ttest_de_clust_df_signif_thresh1_5[np.logical_and(ttest_de_clust_df_signif_thresh1_5.cell_type==i, ttest_de_clust_df_signif_thresh1_5.coef<0)].shape[0] for i in ttest_de_clust_df_signif_thresh1_5.cell_type.unique()]}),
            order=de_clust_order,color='skyblue',ax=ax[2])
sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': ttest_de_clust_df_signif_thresh2.cell_type.unique(), 
                                    'de':[-ttest_de_clust_df_signif_thresh2[np.logical_and(ttest_de_clust_df_signif_thresh2.cell_type==i, ttest_de_clust_df_signif_thresh2.coef<0)].shape[0] for i in ttest_de_clust_df_signif_thresh2.cell_type.unique()]}),
            order=de_clust_order,color='darkblue',ax=ax[2])
ax[2].set_xticklabels(ax[2].get_xticklabels(),rotation = 90);
sns.despine(ax=ax[2])
ax[2].set_ylim([-500,0])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_num_de_10x.pdf",dpi=300,bbox_inches='tight')

In [None]:
de_genes = ttest_de_clust_df_signif.gene.unique()

In [None]:
len(de_genes)

In [None]:
for i in sorted(ttest_de_clust_df_signif.cell_type.unique()):
    print(i, ttest_de_clust_df_signif[ttest_de_clust_df_signif.cell_type==i].shape)

In [None]:
ttest_de_clust_df

In [None]:
# cap number of displayed genes per cluster
#max_genes = 200
#de_genes = list(np.unique([list(ttest_de_clust_df[ttest_de_clust_df.cell_type==i].gene)[:max_genes]  for i in ttest_de_clust_df.cell_type.unique() if i != "T cell"]))

In [None]:
coef_mat = np.zeros((len(de_genes), len(de_clust_order)))
for i,ct in enumerate(de_clust_order):
    curr_ct = ttest_de_clust_df[ttest_de_clust_df.cell_type==ct]
    for j,g in enumerate(de_genes):
        coef = curr_ct[curr_ct.gene==g].coef.values
        if len(coef) > 0:
            coef_mat[j,i] = coef[0]

In [None]:
coef_mat_df = pd.DataFrame(coef_mat, columns=np.array(de_clust_order),  index=de_genes)
#coef_mat_df = coef_mat_df.loc[:, coef_mat_df.var(0)>0]


In [None]:
import scipy.spatial as sp, scipy.cluster.hierarchy as hc

row_dism = 1 - np.corrcoef(coef_mat_df.values.T)#.corr()
row_dism[np.isnan(row_dism)] = 0
row_linkage = hc.linkage(row_dism, method='complete')
row_den = hc.dendrogram(row_linkage,no_plot=True)
col_dism = 1 - np.corrcoef(coef_mat_df.values)#corr()
col_dism[np.isnan(col_dism)] = 0
col_linkage = hc.linkage(col_dism, method='complete')
col_den = hc.dendrogram(col_linkage,no_plot=True)


In [None]:
col_order = np.array(col_den['leaves'])
row_order = np.array(row_den['leaves'])

In [None]:
coef_mat_sorted = coef_mat_df#.iloc[col_order,:]

In [None]:
clust_idx = np.array([np.argwhere(np.array(list(label_colors.keys())) == i)[0][0] for i in list(np.array(coef_mat_sorted.columns))])
clust_idx_cmap = mpl.colors.ListedColormap(list(label_colors.values()))



In [None]:
heatmap_markers = ['Bmp6','C4b','Il18','Onecut2','Il33','Pakap','Mpp7', 'Ighm','Ildr2','Ifit3','Ifitm3', 'Xdh','Lyz2','Serpina3n', 'C1qa','C1qc','B2m','Gfap','Nfkbia','Fmo2','Hexb','Hif3a', 'Nfkbib','Tnfsf8','Tnfsf13b']

In [None]:
heatmap_markers = [i for i in heatmap_markers if i in coef_mat_sorted.index]

In [None]:
coef_mat_sorted

In [None]:
for i in heatmap_markers:
    print(i, list(ttest_de_celltype_df_signif[ttest_de_celltype_df_signif.gene==i].cell_type.unique()))

In [None]:
for i in heatmap_markers:
    print(i, list(ttest_de_clust_df_signif[ttest_de_clust_df_signif.gene==i].cell_type.unique()))

In [None]:
heatmap_markers

In [None]:
f, axes = plt.subplots(figsize=(5,15), ncols=2, nrows=2, gridspec_kw={'wspace':0.1,'width_ratios':[10,3],'height_ratios':[0.5,20],'hspace':0.01})
ax = axes[0,0]
ax.imshow(np.expand_dims(clust_idx,axis=0),aspect='auto',interpolation='none',cmap=clust_idx_cmap)
#ax.axis('off')
ax.set_xticks([])
ax.set_yticks([])
#ax.set_yticks(np.arange(len(clust_idx)))
#ax.set_yticklabels(np.array(coef_mat.columns)[clust_idx],fontsize=6)
sns.despine(ax=ax,left=True,right=True, bottom=True)
ax = axes[1,0]
#ax.imshow(np.flipud(coef_mat_df.values[col_order,:][:,row_order]),cmap=plt.cm.seismic,vmin=-2.5,vmax=2.5,aspect='auto',interpolation='none',rasterized=True)
ax.imshow(np.flipud(coef_mat_sorted.values),cmap=plt.cm.seismic,vmin=-2.5,vmax=2.5,aspect='auto',interpolation='none',rasterized=True)

row_names = coef_mat_sorted.index[::-1]#np.array(coef_mat_df.index)[col_order][::-1]
ax.set_yticks(np.arange(len(row_names)))
ax.set_yticklabels([i if i in heatmap_markers else '' for i in row_names])
sns.despine(ax=ax, left=True, bottom=True)
ax.set_xticks([])
#ax.axis('off')
#plt.imshow(coef_mat_df.values,aspect='auto',interpolation='none',cmap=plt.cm.seismic,vmin=-2.5,vmax=2.5)
ax = axes[1,1]

hc.dendrogram(col_linkage, ax=ax,orientation='right',above_threshold_color='k',color_threshold=0)
ax.axis('off')
ax = axes[0,1]
ax.axis('off')

#plt.axis('off')
save_fig(f, "fig4_de_gene_heatmap", dtype="pdf")


In [None]:
celltype_colors, celltype_pals, label_colors, clust_pals = generate_palettes(adaa)

In [None]:
# do t-test across ages 
import diffxpy.api as diffxpy
from de import filter_2group 
ttest_dfs = []
for i in adata_subset.obs.cell_type.unique():
    print(i)
    curr_adata = adata_subset[adata_subset.obs.cell_type==i]        
    curr_adata, _ = filter_2group(curr_adata, "age", "4wk", min_pct=0.1, logfc_thresh=np.log(1))
    curr_adata.X = np.sqrt(sc.pp.normalize_total(curr_adata,inplace=False)["X"])
    ttest = diffxpy.test.two_sample(data=curr_adata,grouping='age',test='t_test')
    qvals = ttest.qval
    fc = ttest.log10_fold_change()
    ttest_dfs.append(pd.DataFrame({'cell_type':[i]*len(qvals), 'fc':fc, 'qval':qvals, 'gene':list(curr_adata.var_names)}))
    #ttest_dfs[i] = ttest

## Quantify genes spatially upreg or downreg per cell type

In [None]:
# make plot of differentially expressed genes across major cell types
def get_upreg_with_age(df_major, df_minor):
    return list(df_minor[df_minor.coef>0].gene) + list(df_major[df_major.coef>0].gene)

def get_downreg_with_age(df_major, df_minor):
    return list(df_minor[df_minor.coef<=0].gene) + list(df_major[df_major.coef<=0].gene)

ttest_de_celltype_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_major_age_V2.csv")

qval_thresh = 0.01
coef_thresh_major = np.log(1.5)
de_genes_age_major_signif = ttest_de_celltype_df[np.logical_and(np.abs(ttest_de_celltype_df.coef) > coef_thresh_major, ttest_de_celltype_df.qval<qval_thresh)]
de_genes_age_major_signif = de_genes_age_major_signif[~np.isinf(de_genes_age_major_signif.coef)]

In [None]:
de_clust_df_signif = ttest_de_clust_df[np.logical_and(np.abs(ttest_de_clust_df.coef) > np.log(2), ttest_de_clust_df.qval<0.1)]


In [None]:
de_clust_df_signif["real_cell_type"] = [i.split("-")[0] for i in de_clust_df_signif.cell_type]

In [None]:
# count numbers for major cell types upregulated and downregulated
upreg_counts = {}
downreg_counts = {}
upreg_genes = {}
downreg_genes = {}
for i in de_genes_age_major_signif.cell_type.unique():
    upreg_counts[i] = de_clust_df_signif[np.logical_and(de_clust_df_signif.real_cell_type==i, de_clust_df_signif.coef>0)].shape[0]
    downreg_counts[i] = de_clust_df_signif[np.logical_and(de_clust_df_signif.real_cell_type==i, de_clust_df_signif.coef<=0)].shape[0]
    upreg_genes[i] = list(set(de_clust_df_signif[np.logical_and(de_clust_df_signif.real_cell_type==i, de_clust_df_signif.coef>0)].gene))
    downreg_genes[i] = list(set(de_clust_df_signif[np.logical_and(de_clust_df_signif.real_cell_type==i, de_clust_df_signif.coef<=0)].gene))
    

In [None]:
upreg_genes['Astro']

In [None]:
astro_genes = de_genes_age_major_signif[de_genes_age_major_signif.cell_type=="Astro"].gene

In [None]:
from goatools.base import download_ncbi_associations
gene2go = download_ncbi_associations()


In [None]:
from goatools.base import download_go_basic_obo
#obo_fname = download_go_basic_obo()
from goatools.anno.genetogo_reader import Gene2GoReader

objanno = Gene2GoReader("gene2go", taxids=[10090])
go2geneids_mus = objanno.get_id2gos(namespace='BP', go2geneids=True)
from goatools.go_search import GoSearch

srchhelp = GoSearch("go-basic.obo", go2items=go2geneids_mus)


In [None]:
id_to_sym = {}
for r,i in pd.read_table("/home/user/src/tithonus/analysis/merfish/entrez_gene_ids.txt").iterrows():
    id_to_sym[i['GeneID']] = i['Symbol']

In [None]:
def get_genes_for_go_term(go_id):
    gos = srchhelp.add_children_gos([go_id])
    ids = srchhelp.get_items(gos)
    return [id_to_sym[geneid] for geneid in ids if geneid in id_to_sym]#{geneid: id_to_sym[geneid] for geneid in ids if geneid in id_to_sym}


In [None]:
            curr_syms = get_genes_for_go_term(g)# [id_to_sym[i] for i in go2geneids_mus[g] if i in id_to_sym]

In [None]:
score_names

In [None]:
sp.plot_obs_by_age_for_celltype('Micro_GO:0001816', 'Astro', figsize=(5,15));

In [None]:
go_terms

In [None]:
score_names = []
for ct in ["Astro","Micro","Endo","Olig"]:
    genes_pos = upreg_genes[ct]#get_genes_for_celltype(upreg_genes, cell_types[k],"pos")
    print(len(genes_pos))
    go_terms = sig_genes_GO_query(genes_pos, source=['GO:BP','KEGG']).head(10)
    for i,r in go_terms.iterrows():
        g = r['native']
        name = r['name']
        if g in go2geneids_mus:
            curr_syms = get_genes_for_go_term(g)# [id_to_sym[i] for i in go2geneids_mus[g] if i in id_to_sym]
            syms_to_use = np.intersect1d(curr_syms, genes_pos)
            print(g, syms_to_use, name)
            if len(syms_to_use) > 0:
                curr_score_name = ct + "_" + g
                sc.tl.score_genes(adata_imputed, gene_list=syms_to_use, score_name=curr_score_name, use_raw=False)
                score_names.append(curr_score_name)

In [None]:
from scipy.stats import zscore
curr_A = adata_imputed[adata_imputed.obs.cell_type=="Micro"]
for a in astro_genes:
    if a in curr_A.var_names:
        gene_idx = np.argwhere(curr_A.var_names == a)[0][0]
        curr_A.obs['expr'] = curr_A.X[:,gene_idx].toarray().flatten()
        area_avgs = zscore(curr_A.obs.groupby('spatial_clust_annots')['expr'].mean())
        print(a, np.max(area_avgs))

In [None]:
sc.tl.score_genes(adata_imputed, gene_list=["Rsad2","Ifit3", "Oas2", "Ifi204", "Usp18","Isg15","Stat2"], score_name="interferon_score")

In [None]:
de_genes_age_major_signif[de_genes_age_major_signif.cell_type=="Astro"].sort_values('coef').tail(10)

In [None]:
genes_to_plot = ["C4b", "Il33","Gfap", "Xdh"]

In [None]:
sp.plot_gene_by_age("B2m", celltype=None,vmin=0,vmax=15, figsize=(3,15),s=0.5);

In [None]:
sp.plot_gene_by_age("Il18", celltype=None,vmin=0,vmax=15, figsize=(3,15));

In [None]:
for i in genes_to_plot:
    f = sp.plot_gene_by_age(i, celltype=None,vmin=0,vmax=10, figsize=(3,15),s=2,alpha=1);
    save_fig(f, "fig4_"+i)

In [None]:
f = sp.plot_gene_by_age("B2m", celltype=None,vmin=0,vmax=10, figsize=(3,15));

In [None]:
f = sp.plot_gene_by_age("Gfap", celltype=None,vmin=0,vmax=10, figsize=(3,15));

In [None]:
f = sp.plot_gene_by_age("Xdh", celltype=None,vmin=0,vmax=10, figsize=(3,15));

In [None]:
f = sp.plot_gene_by_age("Il33", celltype=None,vmin=0,vmax=10, figsize=(3,15));

In [None]:
sns.set_style('white')
ct_list = ['ExN','InN','MSN','Olig', 'Astro','Micro','OPC','Vlmc', 'Endo','Peri','Micro','Macro']
sns.barplot(x='cell_type',y='count',data=pd.DataFrame({'cell_type':upreg_counts.keys(), 'count':upreg_counts.values()}),order=ct_list,color='blue')

sns.barplot(x='cell_type',y='count',data=pd.DataFrame({'cell_type':downreg_counts.keys(), 'count':-np.array(list(downreg_counts.values()))}),order=ct_list,color='r')
plt.ylim([-500,500])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper_v2/figures_imputed/fig4_de_major_ct.pdf",bbox_inches='tight',dpi=200)

In [None]:
sig_genes_GO_query(de_genes_age_major_signif[de_genes_age_major_signif.cell_type=="Micro"].gene)

In [None]:
# Make GO plots of DE genes for each, for supplement
adata_subset.layers['sqrt_norm'] =  np.sqrt(sc.pp.normalize_total(adata_subset,inplace=False)["X"])

## Quantify number of DE genes per anatomical region by cell type

In [None]:
adata_sqrt_norm = adata_imputed

In [None]:
adata_sqrt_norm.X = adata_sqrt_norm.layers['sqrt_norm']

In [None]:
import diffxpy.api as de

In [None]:
upreg_by_ct = {}
downreg_by_ct = {}
qval_thresh = 0.05
for ct in ["ExN", "InN", "MSN", "Olig", "Micro", "Astro", "Endo","Vlmc","Peri","OPC"]:
    adata_ct = adata_sqrt_norm[adata_sqrt_norm.obs.cell_type==ct]    
    upreg_by_area = {}
    downreg_by_area = {}
#    adata_ct.X = adata_ct.layers['sqrt_norm']
    for i in adata_ct.obs.spatial_clust_annots.unique():

        curr_adata = adata_ct[adata_ct.obs.spatial_clust_annots==i]
        curr_adata = curr_adata[curr_adata.obs.age.isin(['4wk','90wk'])]
        yng = curr_adata[curr_adata.obs.age=='4wk']
        old = curr_adata[curr_adata.obs.age=='90wk']
        expr_thresh = 0.1
        good_genes = np.logical_or(((yng.X>0).sum(0)/yng.shape[0]).flatten() > 0.1,
                       ((old.X>0).sum(0)/old.shape[0]).flatten() > 0.1)
        curr_adata = curr_adata[:,good_genes]
        # filter to genes expressed in at least 10% of either population
        #curr_adata[
        thresh = np.log10(2)
        print(ct, i, curr_adata.shape)
        if curr_adata.shape[0] > (0.05*adata_ct.shape[0]):            
            de_res = de.test.t_test(curr_adata, "age")
            fc = de_res.log10_fold_change()
            qvals = de_res.log10_qval_clean()
            qvals = qvals[fc<=1]
            fc = fc[fc<1]
            
            #if i == "CC":
            #    print(de_res.gene_ids[np.argsort(fc)[-100:]])
            #    print(10**np.sort(fc)[-100:])
            upreg_by_area[i] = np.sum(np.logical_and(de_res.log10_fold_change() > thresh, 
                                                    de_res.log10_qval_clean() < qval_thresh))
            downreg_by_area[i] = np.sum(np.logical_and(de_res.log10_fold_change() <= -thresh, 
                                                    de_res.log10_qval_clean() < qval_thresh))    
        else:
            upreg_by_area[i] = 0
            downreg_by_area[i] = 0
    upreg_by_ct[ct] = upreg_by_area
    downreg_by_ct[ct] = downreg_by_area

In [None]:
sns.set_style('white')

In [None]:
celltypes = ["ExN", "InN", "MSN", "Olig", "Micro", "Astro", "Endo","OPC"]
areas = ['Pia','L2/3','L5','L6','CC','Striatum','LatSept','Ventricle']
f,ax = plt.subplots(figsize=(6,4.8))
for i, ct in enumerate(celltypes):
    curr_upreg = np.array([upreg_by_ct[ct][area] for j, area in enumerate(areas)])
    print(ct,curr_upreg)
    curr_upreg_frac = curr_upreg/curr_upreg.max()
    vmin = curr_upreg_frac[curr_upreg_frac>0].min()
    vmax = curr_upreg_frac[curr_upreg_frac>0].max()
    if vmin == vmax:
        vmin = 0
    if curr_upreg[curr_upreg>0].min() < cell_min:
        cell_min = curr_upreg[curr_upreg>0].min()
    if curr_upreg[curr_upreg>0].max() > cell_max:
        cell_max = curr_upreg[curr_upreg>0].max()
    print(cell_max)
#        print(upreg_by_ct[ct])
    ax.scatter([i]*len(areas),np.arange(len(areas)),s=2*curr_upreg,c=curr_upreg_frac,cmap=plt.cm.Reds,vmin=vmin,vmax=curr_upreg_frac.max(),edgecolor='k',linewidth=1)
        
ax.set_xticks(np.arange(len(celltypes)))
ax.set_yticks(np.arange(-1,len(areas)))
ax.set_xticklabels(celltypes)
ax.set_yticklabels(['']+['Pia','L2/3','L5','L6','CC','Striatum','Olf','Ventricle'])
save_fig(f, "fig4_upreg_spatial",dtype='pdf')

In [None]:
#celltypes = ["ExN", "InN", "MSN", "Olig", "Micro", "Astro", "Endo"]
areas = ['Pia','L2/3','L5','L6','CC','Striatum','LatSept','Ventricle']
cell_min = 1e6
cell_max = 0
f,ax = plt.subplots(figsize=(6,4.8))
for i, ct in enumerate(celltypes):
    curr_upreg = np.array([downreg_by_ct[ct][area] for j, area in enumerate(areas)])
    print(curr_upreg)
    curr_upreg_frac = curr_upreg/curr_upreg.max()
#        print(upreg_by_ct[ct])
    vmin = curr_upreg_frac[curr_upreg_frac>0].min()
    vmax = curr_upreg_frac[curr_upreg_frac>0].max()
    if curr_upreg[curr_upreg>0].min() < cell_min:
        cell_min = curr_upreg[curr_upreg>0].min()
    if curr_upreg[curr_upreg>0].max() > cell_max:
        cell_max = curr_upreg[curr_upreg>0].max()
    if vmin == vmax:
        vmin = 0
    
    ax.scatter([i]*len(areas),np.arange(len(areas)),s=2*curr_upreg,c=curr_upreg_frac,cmap=plt.cm.Blues,vmin=vmin,vmax=curr_upreg_frac.max(),edgecolor='k',linewidth=1)
        
ax.set_xticks(np.arange(len(celltypes)))
ax.set_yticks(np.arange(-1,len(areas)))
ax.set_xticklabels(celltypes)
ax.set_yticklabels(['']+['Pia','L2/3','L5','L6','CC','Striatum','Olf','Ventricle'])
save_fig(f, "fig4_downreg_spatial",dtype='pdf')

In [None]:
np.arange(20, 1200, 200)

In [None]:
f,ax = plt.subplots(figsize=(6,4.8))

ax.scatter(np.arange(6), np.zeros(6), s=np.arange(0, 1200, 200),color='k',linewidth=1,edgecolor='k')
ax.axis('off')
ax.set_xlim([-1,7])
save_fig(f, 'fig4_updownreg_scale',dtype='pdf')


In [None]:
len(np.arange(0, 1000, 200))

# Fig. 5: Ligand receptor analysis

In [None]:
sys.path.append("/home/user/src/tithonus/analysis/tithonus_analysis/")
from spatial_analysis import *


In [None]:
# load ligand receptor pairs
# load cellchatdb 
#adata_subset.var_names = unbinarize_list(adata_subset.var_names)
cellchat = pd.read_csv("gene_lists/cellchatdb_interactions.csv")
celltalk = pd.read_table("gene_lists/mouse_lr_pair.txt")
cellchat_genes = list(set(list(cellchat['receptor']) + list(cellchat['ligand'])))
celltalk_genes = list(set(list(celltalk['ligand_gene_symbol']) + list(celltalk['receptor_gene_symbol'])))
print(len(celltalk_genes))

celltalk_genes = [i for i in celltalk_genes if i in adata_imputed.var_names]


In [None]:
adata_imputed_lr = adata_imputed[:, celltalk_genes]

In [None]:
adata_imputed_lr.layers['sqrt_norm']

In [None]:
celltalk_filt = celltalk[np.logical_and(celltalk.ligand_gene_symbol.isin(adata_imputed_lr.var_names),
                        celltalk.receptor_gene_symbol.isin(adata_imputed_lr.var_names))] 

In [None]:
# add index column to celltalkdb
celltalk_filt["ligand_idx"] = [np.argwhere(adata_imputed_lr.var_names == i)[0][0] for i in celltalk_filt.ligand_gene_symbol]
celltalk_filt["receptor_idx"] = [np.argwhere(adata_imputed_lr.var_names == i)[0][0] for i in celltalk_filt.receptor_gene_symbol]

In [None]:
nbors, nbor_zscore, nbor_pvals = compute_celltype_neighborhood(adata_imputed, 'cell_type', niter=500, radius=50)

In [None]:
log_nbor_pvals = -np.log10(nbor_pvals)
log_nbor_pvals[np.isinf(log_nbor_pvals)] = -np.log10(1e-6)

In [None]:
f, ax = plt.subplots()
ax.imshow(nbors,cmap=plt.cm.viridis,vmin=0,vmax=100000)
ax.set_xticks(np.arange(len(adata_imputed.obs.cell_type.unique())))
ax.set_xticklabels(list(sorted(adata_imputed.obs.cell_type.unique())),rotation=90)
ax.set_yticks(np.arange(len(adata_imputed.obs.cell_type.unique())))
ax.set_yticklabels(list(sorted(adata_imputed.obs.cell_type.unique())))

In [None]:
# for LR analysis, what we want to know is which pairs of celltypes have a large fraction of cells within a radius that would actually allow communication
# we need to define what a "large fraction" is, in absolute terms -- not relative to other celltype interactions
# because oligodendrocytes, for example, could interact with many different celltypes
# what is the expected frequency if there was no interaction? There are N cells of one cell type, and K cells of another cell type, so there are N*K possible interactions
# can compute the frequency of 

In [None]:
def _compute_neighborhood_frac(pos, labels, celltypes, radius, min_count=100):
    """
    Compute the fraction of cells out of all pairs that are within a neighborhood
    """
    neighbors = np.zeros((len(celltypes), len(celltypes)))

    for i, c1 in enumerate(celltypes):
        curr_X = pos[labels==c1]
        if curr_X.shape[0] > min_count:
            for j, c2 in enumerate(celltypes):
                curr_Y = pos[labels==c2]
                #if i <= j:
                # count the fraction of celltype c1 that have celltype c2 nearby
                neighbors[i,j] = np.sum(count_nearest_neighbors(curr_X, curr_Y, dist_thresh=radius))#/(curr_X.shape[0]*curr_Y.shape[0])
                 #   neighbors[j,i] = neighbors[i,j]
    return neighbors

def _compute_neighbor_shuffled(pos, labels, celltypes, radius):
    labels = np.array([labels[i] for i in np.random.choice(len(labels),len(labels))])#labels[np.random.permutation(len(labels))]#
    return _compute_neighborhood_frac(pos, labels, celltypes, radius)

def compute_celltype_enrichment(A, celltype_key, celltypes=None, radius=150, niter=10):
    # compute the fraction of cells that are within a radius, out of the total pairs of that cell type
    if celltypes is None:
        celltypes = list(sorted(A.obs[celltype_key].unique()))
    pos = A.obsm['spatial']
    labels = A.obs[celltype_key]
    neighbors = _compute_neighborhood_frac(pos, labels, celltypes, radius)
    #iterations = tqdm(range(niter))
    # for each iteration, shuffle celltype labels
    #num_cores = multiprocessing.cpu_count()
    #random_freq = Parallel(n_jobs=num_cores)(delayed(_compute_neighbor_shuffled)(pos, labels, celltypes, radius) for i in iterations)    
    #print(len(random_freq))
    # z score
    #zs = np.zeros_like(neighbors)
    #pval = np.zeros_like(neighbors)

    #shuffled_mean = np.dstack(random_freq).mean(2)
    #shuffled_std = np.std(np.dstack(random_freq),2)
    #for i in range(neighbors.shape[0]):
    #    for j in range(neighbors.shape[1]):
    #        zs[i,j] = (neighbors[i,j] - shuffled_mean[i,j])/shuffled_std[i,j]
    #        pval[i,j] = calc_pval(neighbors[i,j],  np.dstack(random_freq)[i,j,:])#np.sum(neighbors[i,j] <= np.dstack(random_freq)[i,j,:])/niter#np.sum(neighbors[i,j] <= np.dstack(random_freq)[i,j,:])/niter #calc_pval(neighbors[i,j],  np.dstack(random_freq)[i,j,:])#np.sum(neighbors[i,j] <= np.dstack(random_freq)[i,j,:])/niter
    return neighbors

def get_interacting_celltypes(A, neighbors, thresh=0.5):
    celltypes = list(sorted(A.obs['cell_type'].unique()))
    good_interactions = neighbors>thresh
    ct_interactions = []
    for i in range(neighbors.shape[0]):
        for j in range(neighbors.shape[1]):
            if good_interactions[i,j]:
                ct_interactions.append(celltypes[i]+"_"+celltypes[j])
    return ct_interactions

In [None]:
nbors_cc = compute_celltype_enrichment(adata_imputed, 'cell_type', niter=500, radius=20) #[adata_imputed.obs.spatial_clust_annots=="CC"],

In [None]:
nbors_zscore_cc = zscore_mat(nbors_cc)

In [None]:
get_interacting_celltypes(adata_imputed, nbors_zscore_cc,0.1)

In [None]:
f, ax = plt.subplots()
ax.imshow(nbors_zscore_cc,cmap=plt.cm.bwr,vmin=-1,vmax=1)
ax.set_xticks(np.arange(len(adata_imputed.obs.cell_type.unique())))
ax.set_xticklabels(list(sorted(adata_imputed.obs.cell_type.unique())),rotation=90);
ax.set_yticks(np.arange(len(adata_imputed.obs.cell_type.unique())))
ax.set_yticklabels(list(sorted(adata_imputed.obs.cell_type.unique())));

In [None]:
nbors_cc, nbor_zscore_cc, nbor_pvals_cc = compute_celltype_neighborhood(adata_imputed[adata_imputed.obs.spatial_clust_annots=="CC"], 'cell_type', niter=500, radius=20)

In [None]:
f, ax = plt.subplots()
ax.imshow(nbor_zscore_cc,cmap=plt.cm.bwr,vmin=-5,vmax=5)
ax.set_xticks(np.arange(len(adata_imputed.obs.cell_type.unique())))
ax.set_xticklabels(list(sorted(adata_imputed.obs.cell_type.unique())),rotation=90)
ax.set_yticks(np.arange(len(adata_imputed.obs.cell_type.unique())))
ax.set_yticklabels(list(sorted(adata_imputed.obs.cell_type.unique())))

# Fig 6: White matter/oligo analysis analysis

## Score cells for activated astro/micro

In [None]:
sc.tl.score_genes(adata_imputed, gene_list=['B2m','Trem2', 'Ccl2', 'Apoe',  'Axl', 'Itgax', 'Cd9','C1qa','C1qc','Lyz2','Ctss'], score_name='activate_micro', use_raw=False)
sc.tl.score_genes(adata_imputed, gene_list=['C4b', 'C3', 'Serpina3n', 'Cxcl10', 'Gfap', 'Vim', 'Il18','Hif3a'], score_name='activate_astro', use_raw=False)
activate_endo = ["B2m", "Nfkbia", "Serinc3","Xdh", "Gfap", "Tap1"]
sc.tl.score_genes(adata_imputed, gene_list=activate_endo, score_name='activate_endo',use_raw=False)
sc.tl.score_genes(adata_imputed, gene_list=[ "C4b", "Il18", "Il33"], score_name='activate_olig',use_raw=False)

In [None]:
adata_astro = adata_imputed[adata_imputed.obs.cell_type=="Astro"]
adata_imputed.obs.activate_astro = adata_imputed.obs.activate_astro - np.mean(adata_astro[adata_astro.obs.age=='4wk'].obs.activate_astro)



In [None]:
adata_micro = adata_imputed[adata_imputed.obs.cell_type=="Micro"]
adata_imputed.obs.activate_micro = adata_imputed.obs.activate_micro - np.mean(adata_micro[adata_micro.obs.age=='4wk'].obs.activate_micro)

In [None]:
adata_olig = adata_imputed[adata_imputed.obs.cell_type=="Olig"]
adata_imputed.obs.activate_olig = adata_imputed.obs.activate_olig - np.mean(adata_olig[adata_olig.obs.age=='4wk'].obs.activate_olig)

## Analyze microglial activation states from literatuer

In [None]:
microglial_activation_genes = pd.read_excel('safraiyan_neuron_2021_microglial_activation_genes.xlsx')

In [None]:
adata_micro = adata_imputed[adata_imputed.obs.cell_type=="Micro"]
adata_micro.obs['spatial_location'] = [i if i == "CC" else "Other" for i in adata_micro.obs.spatial_clust_annots ]

In [None]:
f, ax = plt.subplots(figsize=(5,5))

sc.pl.umap(adata_micro, color='age',s=20,ax=ax,palette=age_colors)
save_fig(f, "figS7_micro_age",dtype="png")

In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_micro, color='spatial_location',s=20,palette=sns.color_palette([ 'black','lightgrey']),ax=ax)
save_fig(f, "figS7_micro_spatial",dtype="png")

In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_micro, color='activate_micro',s=20,cmap=plt.cm.Reds,ax=ax)
save_fig(f, "figS7_micro_activation",dtype="png")

In [None]:
#adata_micro = adata_sqrt_norm[adata_sqrt_norm.obs.cell_type=="Micro"]

In [None]:
# identify activated astros in CC or pia
activate_thresh = np.std(adata_micro.obs.activate_micro)
adata_micro.obs["activate_subset"] = "Other"
adata_micro.obs.loc[np.logical_and(adata_micro.obs.activate_micro>activate_thresh, adata_micro.obs.spatial_clust_annots=="CC"), "activate_subset"] = "Activated CC"
adata_micro.obs.loc[np.logical_and(adata_micro.obs.activate_micro>activate_thresh, adata_micro.obs.spatial_clust_annots!="CC"), "activate_subset"] = "Activated Non-CC"
#adata_micro.obs.loc[np.logical_and(adata_micro.obs.activate_micro>activate_thresh, adata_micro.obs.spatial_clust_annots=="CC"), "activate_subset"] = "Activated non-CC"


In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_micro, color='activate_subset',s=20,palette=sns.color_palette([ 'blue','red','lightgrey']),ax=ax)
save_fig(f, "figS7_micro_spatial_activated",dtype="png")

In [None]:
sc.tl.rank_genes_groups(adata_micro, use_raw=False, reference='Other', groups=['CC','Other'], groupby='spatial_location',method='t-test_overestim_var')
#sc.tl.filter_rank_genes_groups(adata_astro, min_fold_change=0.25, min_in_group_fraction=0, max_out_group_fraction=1, use_raw=False,key_added='rank_genes_groups_filtered')
f, ax = plt.subplots(figsize=(6,3))
sc.pl.rank_genes_groups_violin(adata_micro,ax=ax,n_genes=20,key='rank_genes_groups')
save_fig(f, "figS7_micro_all_cc_vs_all_non",dtype="png")

In [None]:
sc.tl.rank_genes_groups(adata_micro, use_raw=False, reference='Activated CC', groups=['Activated CC','Activated Non-CC'], groupby='activate_subset',method='t-test_overestim_var')
#sc.tl.filter_rank_genes_groups(adata_astro, min_fold_change=0.25, min_in_group_fraction=0, max_out_group_fraction=1, use_raw=False,key_added='rank_genes_groups_filtered')
f, ax = plt.subplots(figsize=(6,3))
sc.pl.rank_genes_groups_violin(adata_micro,ax=ax,n_genes=20,key='rank_genes_groups')
save_fig(f, "figS7_micro_activated_other_vs_activated_non",dtype="png")

In [None]:
sc.tl.rank_genes_groups(adata_micro, use_raw=False, reference='Other', groups=['CC','Other'], groupby='activate_subset',method='t-test_overestim_var')
#sc.tl.filter_rank_genes_groups(adata_astro, min_fold_change=0.25, min_in_group_fraction=0, max_out_group_fraction=1, use_raw=False,key_added='rank_genes_groups_filtered')
f, ax = plt.subplots(figsize=(6,3))
sc.pl.rank_genes_groups_violin(adata_micro,ax=ax,n_genes=20,key='rank_genes_groups')
save_fig(f, "figS7_micro_activated_cc_vs_other_non",dtype="png")

In [None]:
sc.tl.rank_genes_groups(adata_micro, use_raw=False, reference='Activated Non-CC', groups=['Activated CC','Activated Non-CC'], groupby='activate_subset',method='t-test_overestim_var')
#sc.tl.filter_rank_genes_groups(adata_astro, min_fold_change=0.25, min_in_group_fraction=0, max_out_group_fraction=1, use_raw=False,key_added='rank_genes_groups_filtered')
f, ax = plt.subplots(figsize=(6,3))
sc.pl.rank_genes_groups_violin(adata_micro,ax=ax,n_genes=20,key='rank_genes_groups')
save_fig(f, "figS7_micro_activated_cc_vs_activated_other",dtype="png")

In [None]:
f,  = plt.su
sc.pl.violin(adata_micro[adata_micro.obs.activate_subset.isin(['Activated CC','Activated Non-CC'])],
             keys=['Trem2','Ctsb','Apoe','Fth1','Cd63'],cmap=plt.cm.Reds, groupby='activate_subset')
save_fig(f, "figS7_wam_genes",dtype="png")

## Separately analyze astros in pia vs CC

In [None]:
adata_astro = adata_imputed[adata_imputed.obs.cell_type=="Astro"]
adata_astro.obs['spatial_location'] = [i if i == "CC" or i == "Pia" else "Other" for i in adata_astro.obs.spatial_clust_annots ]

In [None]:
f, ax = plt.subplots(figsize=(5,5))

sc.pl.umap(adata_astro, color='age',s=20,ax=ax,palette=age_colors)
save_fig(f, "figS7_astro_age",dtype="png")

In [None]:
f, ax = plt.subplots(figsize=(5,5))

sc.pl.umap(adata_astro, color='spatial_clust_annots',s=20,palette=sns.color_palette([ 'gray', 'orange', 'chocolate', 'brown', 'steelblue','gold',  'purple', 'darkkhaki']),ax=ax)

In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_astro, color='spatial_location',s=20,palette=sns.color_palette([ 'black', 'lightgray' ,'gold']),ax=ax)
save_fig(f, "figS7_astro_spatial_loc",dtype="png")

In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_astro, color='activate_astro',s=20,cmap=plt.cm.Reds,ax=ax)
save_fig(f, "figS7_astro_activation",dtype="png")

In [None]:
# identify activated astros in CC or pia
activate_thresh = np.std(adata_astro.obs.activate_astro)
adata_astro.obs["activate_subset"] = "Other"
adata_astro.obs.loc[np.logical_and(adata_astro.obs.activate_astro>activate_thresh, adata_astro.obs.spatial_clust_annots=="CC"), "activate_subset"] = "CC"
adata_astro.obs.loc[np.logical_and(adata_astro.obs.activate_astro>activate_thresh, adata_astro.obs.spatial_clust_annots=="Pia"), "activate_subset"] = "Pia"


In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_astro, color='activate_subset',s=20,palette=sns.color_palette([ 'black', 'lightgray' ,'gold']),ax=ax)
save_fig(f, "figS7_astro_spatial_loc_activated",dtype="png")

In [None]:
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_astro, color='spatial_location',s=20,palette=sns.color_palette([ 'black', 'lightgray' ,'gold']),ax=ax)
save_fig(f, "figS7_astro_spatial_loc",dtype="png")

In [None]:
sc.tl.rank_genes_groups(adata_astro, groupby='clust_annot',use_raw=False,n_genes=200)#, use_raw=False, reference='Pia', groups=['Pia','CC'], groupby='activate_subset',method='t-test_overestim_var')
sc.tl.filter_rank_genes_groups(adata_astro, min_fold_change=1)


In [None]:
adata_astro_subset = adata_astro[adata_astro.obs.activate_subset != "Other"]

In [None]:
sc.tl.rank_genes_groups(adata_astro_subset, groupby='activate_subset', use_raw=False, reference='Pia', groups=['Pia','CC'], method='t-test_overestim_var')
#sc.tl.filter_rank_genes_groups(adata_astro_subset,use_raw=False,min_in_group_fraction=0,max_out_group_fraction=1)


In [None]:
temp = adata_astro[adata_astro.obs.spatial_clust_annots=="CC"]
sc.tl.rank_genes_groups(temp, use_raw=False, reference='Other', groups=['CC'], groupby='activate_subset',method='t-test_overestim_var')
#sc.tl.filter_rank_genes_groups(adata_astro, min_fold_change=0.25, min_in_group_fraction=0, max_out_group_fraction=1, use_raw=False,key_added='rank_genes_groups_filtered')
f, ax = plt.subplots(figsize=(6,3))
sc.pl.rank_genes_groups_violin(temp,ax=ax,n_genes=20,key='rank_genes_groups')
save_fig(f, "figS7_activated_cc_vs_non",dtype="png")

In [None]:
temp = adata_astro[adata_astro.obs.spatial_clust_annots=="Pia"]
sc.tl.rank_genes_groups(temp, use_raw=False, reference='Other', groups=['Pia'], groupby='activate_subset',method='t-test_overestim_var')
#sc.tl.filter_rank_genes_groups(adata_astro, min_fold_change=0.25, min_in_group_fraction=0, max_out_group_fraction=1, use_raw=False,key_added='rank_genes_groups_filtered')
f, ax = plt.subplots(figsize=(6,3))
sc.pl.rank_genes_groups_violin(temp,ax=ax,n_genes=20,key='rank_genes_groups')
save_fig(f, "figS7_activated_pia_vs_non",dtype="png")

In [None]:
sc.tl.rank_genes_groups(adata_astro_subset, use_raw=False, reference='CC', groups=['Pia','CC'], groupby='activate_subset',method='t-test_overestim_var')
#sc.tl.filter_rank_genes_groups(adata_astro, min_fold_change=0.25, min_in_group_fraction=0, max_out_group_fraction=1, use_raw=False,key_added='rank_genes_groups_filtered')
f, ax = plt.subplots(figsize=(6,3))
sc.pl.rank_genes_groups_violin(adata_astro_subset,ax=ax,n_genes=20,key='rank_genes_groups')
save_fig(f, "figS7_pia_vs_cc",dtype="png")

In [None]:
sc.tl.rank_genes_groups(adata_astro_subset, use_raw=False, reference='Pia', groups=['Pia','CC'], groupby='activate_subset',method='t-test_overestim_var')
#sc.tl.filter_rank_genes_groups(adata_astro, min_fold_change=0.25, min_in_group_fraction=0, max_out_group_fraction=1, use_raw=False,key_added='rank_genes_groups_filtered')
f, ax = plt.subplots(figsize=(6,3))
sc.pl.rank_genes_groups_violin(adata_astro_subset,ax=ax,n_genes=20,key='rank_genes_groups')
save_fig(f, "figS7_cc_vs_pia",dtype="png")

In [None]:
sp = SpatialPlotter(adata_imputed)

In [None]:
sp.plot_gene_by_age('Thbs4','Astro',vmin=0,vmax=40);

In [None]:
sp.plot_gene_by_age('Adcy8','Astro',vmin=-1,vmax=1);

In [None]:
sp.plot_gene_by_age('Nrp2','Astro',vmin=-1,vmax=5);

In [None]:
sc.tl.rank_genes_groups(adata_astro, use_raw=False, reference='Other', groups=['Pia'], groupby='activate_subset',method='t-test_overestim_var')
#sc.tl.filter_rank_genes_groups(adata_astro, min_fold_change=0.25, min_in_group_fraction=0, max_out_group_fraction=1, use_raw=False,key_added='rank_genes_groups_filtered')
f, ax = plt.subplots(figsize=(6,3))
sc.pl.rank_genes_groups_violin(adata_astro,ax=ax,n_genes=20,key='rank_genes_groups')
save_fig(f, "figS7_pia_vs_other",dtype="png")

In [None]:
sc.tl.rank_genes_groups(adata_astro, use_raw=False, reference='Other', groups=['CC'], groupby='activate_subset', method='t-test_overestim_var')
sc.tl.filter_rank_genes_groups(adata_astro, min_fold_change=2)
f, ax = plt.subplots(figsize=(6,3))
sc.pl.rank_genes_groups_violin(adata_astro,ax=ax)
save_fig(f, "figS7_cc_vs_other",dtype="png")

In [None]:
f, ax = plt.subplots(figsize=(10,5))
sc.pl.rank_genes_groups_violin(adata_astro,ax=ax)

In [None]:
sc.pl.umap(adata_astro, color=['Il18','Il33','Trpm3','Slc1a3'],cmap=plt.cm.Reds)

In [None]:
# do DE for these two populations
f, ax = plt.subplots(figsize=(5,5))
sc.pl.umap(adata_astro, color="activate_subset",ax=ax)

In [None]:
temp = adata_imputed[adata_imputed.obs.cell_type=='Astro']
plt.scatter(temp.obs.pseudoage_imputed_norm, temp.obs.activate_astro,s=1)

In [None]:
f, ax = plt.subplots(figsize=(5,3))

sc.pl.violin(adata_imputed,keys=["activate_olig"],groupby='clust_annot',order=['Olig-1','Olig-2','Olig-3'],ax=ax)
sns.despine(ax=ax)
ax.set_rasterized(True)


#f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_olig_violin.pdf",bbox_inches='tight',dpi=300)

In [None]:
f, ax = plt.subplots(figsize=(5,3))
astro_subtypes = ['Astro-1','Astro-2']
sc.pl.violin(adata_imputed[adata_imputed.obs.clust_annot.isin(astro_subtypes)],keys=["activate_astro"],groupby='clust_annot', order=astro_subtypes, ax=ax)
ax.set_rasterized(True)
sns.despine(ax=ax)
save_fig(f, "fig5_activate_astro_subtype", dtype="pdf")

#f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_astro_violin.pdf",bbox_inches='tight',dpi=300)

In [None]:
f, ax = plt.subplots(figsize=(5,3))
sc.pl.violin(adata_imputed[adata_imputed.obs.clust_annot.isin(['Micro-1','Micro-2','Micro-3'])],keys=["activate_micro"],groupby='clust_annot', order=['Micro-1','Micro-2','Micro-3'], ax=ax)
sns.despine(ax=ax)
ax.set_rasterized(True)
save_fig(f, "fig5_activate_micro_subtype", dtype="pdf")

#f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_micro_violin.pdf",bbox_inches='tight',dpi=300)

In [None]:
spatial_order = ['Pia','L2/3','L5','L6','CC','LatSept','Striatum','Ventricle']

def plot_age_obs_comparison(data, x, y, cell_type, figsize=(5,3), show_pvals=False, order=None, clust_key='cell_type', age_pal=sns.color_palette(['cornflowerblue','thistle','lightcoral'])):
    f, ax = plt.subplots(figsize=(5,3))
    curr_df = data[data.obs[clust_key]==cell_type].obs
    if order is None:
        order = sorted(curr_df[x].unique())
    #sns.violinplot(x=x, y=y, data=curr_df,hue='age',fliersize=1,linewidth=1,palette=age_pal, ax=ax,inner=None,order=order,rasterized=True)
    sns.boxplot(x=x, y=y, data=curr_df,hue='age',fliersize=0,linewidth=1,palette=age_pal, ax=ax,order=order, )

    sns.stripplot(data=curr_df, x=x, y=y, hue="age", ax=ax,jitter=0.15,size=0.5,dodge=True,color='k',order=order, rasterized=True)
    ax.set_ylim([np.quantile(curr_df[y],0.001), np.quantile(curr_df[y], 0.999)])
    sns.despine()
    plt.legend([],[], frameon=False)
    
    if show_pvals:
        pvals = calc_pvals_for_grouping(x,y,curr_df, "age",order=order)
        plot_pvals(ax, pvals)
    return f


In [None]:
f = plot_age_obs_comparison(adata_imputed, "spatial_clust_annots", "activate_astro", "Astro", show_pvals=False, order=spatial_order);
save_fig(f,"fig5_activate_astro_spatial_quant", dtype="pdf")
#f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_olig_spatial_merfish.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plot_age_obs_comparison(adata_imputed, "spatial_clust_annots", "activate_micro", "Micro", show_pvals=False, order=spatial_order);
save_fig(f,"fig5_activate_micro_spatial_quant", dtype="pdf")

#f.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig4_active_olig_spatial_merfish.pdf",bbox_inches='tight',dpi=300)

## Spatial plots of activation

In [None]:
pl = SpatialPlotter(adata_imputed)

In [None]:
f = pl.plot_obs_by_age_for_celltype("activate_astro",'Astro',s=5,figsize=(3,15), cmap=plt.cm.turbo, show_background=True);
f.set_facecolor('white')
save_fig(f,"fig5_activate_astro_spatial", dtype="pdf")


In [None]:
f = pl.plot_obs_by_age_for_celltype("activate_micro",'Micro',s=5,figsize=(3,15), cmap=plt.cm.turbo, show_background=True);
f.set_facecolor('white')
#f.set_
save_fig(f,"fig5_activate_micro_spatial", dtype="pdf")


## Cell-cell interactions 

In [None]:
bin_size = 50


In [None]:
# astro to peri-1/peri-2
plt.figure(figsize=(3,3))
celltypes = ["Peri-1","Peri-2"]
for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata_imputed[adata_imputed.obs.age=='4wk'], "Astro", i, "activate_astro", celltype_key2='clust_annot')
    binned_mean, binned_std = compute_binned_values(dists, scores,bin_size=bin_size,max_d=100)
    x = np.arange(len(binned_mean))+bin_size
    plt.plot(x,binned_mean,color=label_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=label_colors[i])
#plt.legend( celltypes)
plt.ylim([-0.2, 0.3])
sns.despine()
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_distance_peri_score_4wk.pdf",bbox_inches='tight',dpi=300)

In [None]:
# astro to peri-1/peri-2
plt.figure(figsize=(3,3))
celltypes = ["Peri-1","Peri-2"]
for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata_imputed[adata_imputed.obs.age=='90wk'], "Astro", i, "activate_astro", celltype_key2='clust_annot')
    binned_mean, binned_std = compute_binned_values(dists, scores,bin_size=bin_size,max_d=100)
    x = np.arange(len(binned_mean))+bin_size
    plt.plot(x,binned_mean,color=label_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=label_colors[i])
#plt.legend( celltypes)
plt.ylim([-0.2, 0.4])
sns.despine()b
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_distance_peri_score_90wk.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plt.figure(figsize=(3,3))
celltypes = ["Endo","Vlmc", "Olig"]
for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata_imputed[adata_imputed.obs.age=='4wk'], "Micro", i, "activate_micro")
    binned_mean, binned_std = compute_binned_values(dists, scores,bin_size=bin_size, max_d=100)
    x = np.arange(len(binned_mean))+bin_size
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])
#plt.legend( celltypes)
plt.ylim([-0.05, 0.12])
sns.despine()
save_fig(f, "fig5_distance_score_micro_4wk", dtype="pdf")

#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_distance_micro_score_4wk.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plt.figure(figsize=(3,3))
for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata_imputed[adata_imputed.obs.age=='90wk'], "Micro", i, "activate_micro")
    binned_mean, binned_std = compute_binned_values(dists, scores,bin_size=bin_size, max_d=100)
    x = np.arange(len(binned_mean))+bin_size
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])

#plt.legend( celltypes)
plt.ylim([-0.05, 0.12])
sns.despine()
save_fig(f, "fig5_distance_score_micro_90wk", dtype="pdf")
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_distance_micro_score_90wk.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plt.figure(figsize=(3,3))
#celltypes = ["Endo","Vlmc", "Olig", "Micro"]
for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata_imputed[adata_imputed.obs.age=='4wk'], "Astro", i, "activate_astro")
    binned_mean, binned_std = compute_binned_values(dists, scores,bin_size=bin_size,max_d=100)
    x = np.arange(len(binned_mean))+bin_size
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])
#plt.legend( celltypes )
sns.despine()
plt.ylim([-0.2, 0.5])
save_fig(f, "fig5_distance_score_astro_4wk", dtype="pdf")

#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_distance_astro_score_4wk.pdf",bbox_inches='tight',dpi=300)

In [None]:
f = plt.figure(figsize=(3,3))
#celltypes = ["Endo","Vlmc", "Olig", "Micro"]
for i in celltypes:
    scores, dists = compute_celltype_obs_distance_correlation(adata_imputed[adata_imputed.obs.age=='90wk'], "Astro", i, "activate_astro")
    binned_mean, binned_std = compute_binned_values(dists, scores,bin_size=bin_size,max_d=100)
    x = np.arange(len(binned_mean))+bin_size
    plt.plot(x,binned_mean,color=celltype_colors[i])
    plt.fill_between(x,binned_mean-binned_std, binned_mean+binned_std,alpha=0.1,color=celltype_colors[i])
#plt.legend( celltypes )
sns.despine()
plt.ylim([-0.2, 0.5])
save_fig(f, "fig5_distance_score_astro_90wk", dtype="pdf")

#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_distance_astro_score_4wk.pdf",bbox_inches='tight',dpi=300)

In [None]:
adata_olig = adata_imputed[adata_imputed.obs.cell_type=="Olig"]
adata_imputed.obs.activate_olig = adata_imputed.obs.activate_olig - np.mean(adata_olig[adata_olig.obs.age=='4wk'].obs.activate_olig)

In [None]:
# look at correlation between Il33 and Activated Micro/Astro
x,y = compute_celltype_obs_correlation(adata_imputed[adata_imputed.obs.spatial_clust_annots=="CC"],  "Olig","Micro", f"activate_olig",f"activate_micro",   radius=50)
f = plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Olig -> Micro (R={np.corrcoef(x,y)[0,1]})")
sns.kdeplot(x=x,y=y,fill=True)
#plt.xlim([0,5])
#plt.axis('off')
sns.despine()
save_fig(f, "fig5_olig_to_micro_corr",dtype="pdf")
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_activation_corr_olig_micro.pdf",bbox_inches='tight',dpi=300)

In [None]:
# look at correlation between Il33 and Activated Micro/Astro
x,y = compute_celltype_obs_correlation(adata_imputed[adata_imputed.obs.spatial_clust_annots=="CC"],  "Olig","Astro", f"activate_olig",f"activate_astro",   radius=50)
f = plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Olig -> Astro (R={np.corrcoef(x,y)[0,1]})")
sns.kdeplot(x=x,y=y,fill=True)
#plt.xlim([0,5])
#plt.axis('off')
sns.despine()
save_fig(f, "fig5_olig_to_astro_corr",dtype="pdf")

#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_activation_corr_olig_astro.pdf",bbox_inches='tight',dpi=300)

In [None]:
# look at correlation between Il33 and Activated Micro/Astro
x,y = compute_celltype_obs_correlation(adata_imputed[adata_imputed.obs.spatial_clust_annots=="CC"],  "Astro","Micro", f"activate_astro",f"activate_micro",   radius=50)
plt.figure(figsize=(5,5))
#plt.scatter(x,y,s=1)
plt.title(f"Astro -> Micro (R={np.corrcoef(x,y)[0,1]})")
sns.kdeplot(x=x,y=y,fill=True)
#plt.xlim([0,5])
#plt.axis('off')
sns.despine()
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_activation_corr_olig_astro.pdf",bbox_inches='tight',dpi=300)