In [None]:
%load_ext autoreload
%autoreload 2
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import os
import anndata as ad
mpl.rcParams['figure.dpi'] = 150
plt.rcParams['pdf.fonttype'] = 42

import pandas as pd
import sys
from spatial_analysis import *
from plotting import *

In [None]:
adata = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_adata_combined_harmony.h5ad")
adata_merfish_raw = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_combined_merfish_with_doublet_umap_allages.h5ad")
adata = adata.raw.to_adata()

adata_merfish_scaled = adata[adata.obs.dtype=="merfish"]

adata_merfish = adata[adata.obs.dtype=="merfish"]
adata_merfish.X = adata_merfish_raw[adata_merfish.obs.index,:].X
del adata_merfish_raw


In [None]:
adata_10x_int = adata[adata.obs.dtype=="scrnaseq"]


In [None]:
adata10x = sc.read_h5ad("/faststorage/brain_aging/rna_analysis/adata_finalclusts_annot.h5ad")
adata10x = adata10x.raw.to_adata()
adata10x = adata10x[adata10x.obs.area=="PFC"]
adata10x = adata10x[adata_10x_int.obs.index]
adata10x.obs['cell_type'] = adata_10x_int.obs.cell_type
adata10x.obs['clust_annot'] = adata_10x_int.obs.clust_annot
#del adata

In [None]:
adata_merfish_lps = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_lps_ctrl_allages.h5ad")


# Show correspondence between MERFISH and 10X cell types

In [None]:
adata.obs['clust_10x'] = ''
adata.obs.loc[adata_10x_int.obs.index,'clust_10x'] = adata10x.obs.clust_label
adata_10x_int = adata[adata.obs.dtype=="scrnaseq"]


In [None]:
from sklearn.neural_network import MLPClassifier
# predict labels of 10x cell types based on integrated clustering
pred_10x_clf = MLPClassifier().fit(adata_10x_int.obsm['X_pca'], adata_10x_int.obs.clust_10x)

In [None]:
pred_10x_lbl = pred_10x_clf.predict(adata_merfish.obsm['X_pca'])

In [None]:
pred_merfish_clf = MLPClassifier().fit(adata_merfish.obsm['X_pca'], adata_merfish.obs.clust_annot)

In [None]:
pred_merfish_lbl = pred_merfish_clf.predict(adata_10x_int.obsm['X_pca'])

In [None]:
sorted(adata.obs.clust_10x.unique())

In [None]:
pred_labs_ex = [
 'FrEx1',
 'FrEx2',
 'FrEx3',
 'FrEx4',
 'FrEx5',
 'FrEx6',
 'FrEx7',
 'FrEx8',
 'FrEx9',
 'FrEx10',
 'FrEx11',
 'FrEx12',
 'FrEx13',
 'FrEx14',
 'FrEx15',
 'FrEx16',
 'FrEx17',
 'FrEx18']
pred_labs_in = [
 'FrIn10',
 'FrIn12',
 'FrIn13',
 'FrIn14',
 'FrIn15',
 'FrIn16',
 'FrIn17',
 'FrIn18',
 'FrIn7',
 'FrIn8',
 'FrIn9']
pred_labs_str = [
 'StD1M4',
 'StD1M5',
 'StD1M6',
 'StD1M7',
 'StD1M8',
 'StD2M2',
 'StD2M3',
 'StD2M4',
 'StD2M5']
pred_labs_nn = [
 'Astro1',
 'Astro2',
 'Astro3',
 'Astro4',
 'OPC',
 'Olig1',
 'Olig2',
 'Olig3',
 'Olig4',
 'Olig5',
 'Vlmc1',
 'Vlmc2',
 'Peri1',
 'Peri2',
 'Macro',
 'Micro1',
 'Micro2',
]
combined_pred_lab_names = pred_labs_ex + pred_labs_in + pred_labs_str + pred_labs_nn

In [None]:
correspondence_map = {
    'ExN' : 'FrEx',
    'InN' : 'FrIn',
    'MSN' : "StD",
    "Astro" : "Astro",
    "OPC" : "OPC",
    "Olig" : "Olig",
    "Vlmc" : "Vlmc",
    
}

In [None]:
exn_clust_labs = [
 'ExN-L2/3-1',
 'ExN-L2/3-2',
 'ExN-L5-1',
 'ExN-L5-2',
 'ExN-L5-3',
 'ExN-L6-1',
 'ExN-L6-2',
 'ExN-L6-3',
 'ExN-Olf',
]
inn_clust_labs = [
 'InN-Calb2-1',
 'InN-Calb2-2',
 'InN-Chat',
 'InN-Lamp5',
 'InN-Lhx6',
 'InN-Olf-1',
 'InN-Olf-2',
 'InN-Pvalb-1',
 'InN-Pvalb-2',
 'InN-Pvalb-3',
 'InN-Sst-1',
 'InN-Sst-2',
 'InN-Vip',
]
msn_clust_labs = [
 'MSN-D1-1',
 'MSN-D1-2',
 'MSN-D2',
]
nn_clust_labs = [
'Astro-1',
 'Astro-2',
 'Epen',

 'OPC',
 'Olig-1',
 'Olig-2',
 'Olig-3',
    'Endo-1',
 'Endo-2',
 'Endo-3',

 'Peri-1',
 'Peri-2',
 'Vlmc',
 'Micro-1',
 'Micro-2',
 'Micro-3',
 'Macro',
 'T cell',

 ]
clust_labs_names = exn_clust_labs + inn_clust_labs + msn_clust_labs + nn_clust_labs

In [None]:
from scipy.optimize import linear_sum_assignment
def make_pred_mat(clust_labs, pred_labs, clust_lab_names, pred_lab_names):
    n_annot = len(clust_lab_names)
    n_pred = len(pred_lab_names)
    print(n_annot, n_pred)
    pred_mat = np.zeros((n_annot, n_pred))
    pred_idx = {i:k for k,i in enumerate(pred_lab_names)}
    annot_idx = {i:k for k,i in enumerate(clust_lab_names)}
    # make matrix of correspondence between clust annot label and predicted label
    clust_annot_labels = list(clust_labs)
    predicted_labels = list(pred_labs)
    for i in range(len(clust_annot_labels)):
        pred_mat[annot_idx[clust_annot_labels[i]], pred_idx[predicted_labels[i]]] += 1
    #pred_mat = pred_mat.T
    for i in range(pred_mat.shape[0]):
        pred_mat[i,:] /= pred_mat[i,:].sum()
    _, max_idx = linear_sum_assignment(-pred_mat)
    pred_mat = pred_mat[:,max_idx]
    cols_to_keep = np.argwhere(np.max(pred_mat,0)>0.1).flatten()
    #pred_mat = pred_mat[:,cols_to_keep]
    return pred_mat, max_idx, cols_to_keep

def plot_pred_mat(clust_labs, pred_labs, clust_color_map, pred_color_map, pred_lab_names=None, clust_lab_names=None):
    if pred_lab_names is None:
        pred_lab_names = np.unique(pred_labs)
    if clust_lab_names is None:
        clust_lab_names = np.unique(clust_labs)
    pred_mat, max_idx, cols_to_keep = make_pred_mat(clust_labs, pred_labs, clust_lab_names, pred_lab_names)
    f, ax = plt.subplots(figsize=(4,4))
    
    gs = plt.GridSpec(nrows=2, ncols=3, width_ratios=[0.5,20,1], height_ratios=[20,0.5],wspace=0.01,hspace=0.01)
    # left colorbar
    print(pred_mat.shape, max_idx.max(), len(clust_labs_names),len(pred_lab_names))
    ax = plt.subplot(gs[0,0])
    curr_cmap = mpl.colors.ListedColormap([clust_color_map[i] for i in clust_lab_names])
    ax.imshow(np.expand_dims(np.arange(pred_mat.shape[1]),1),cmap=curr_cmap,aspect='auto',interpolation='none')
    sns.despine(ax=ax,left=True, bottom=True)
    ax.set_xticks([])
    ax.set_yticks(np.arange(len(clust_lab_names)))
    ax.set_yticklabels(np.array(clust_lab_names),fontsize=5);
    ax = plt.subplot(gs[0,1])
    s = ax.imshow(pred_mat,aspect='auto',interpolation='none',cmap=plt.cm.gist_heat_r,vmin=0,vmax=1, rasterized=True)
    sns.despine(ax=ax, left=True, bottom=True)
    ax.set_yticks([])
    ax.set_xticks([])
    ax = plt.subplot(gs[1,1])
    
    xlabs = np.array(pred_lab_names)[max_idx]
    curr_cmap = mpl.colors.ListedColormap([pred_color_map[i] for i in xlabs])
    ax.imshow(np.expand_dims(np.arange(pred_mat.shape[1]),1).T,cmap=curr_cmap,aspect='auto',interpolation='none')
    ax.set_yticks([])
    ax.set_xticks(np.arange(len(xlabs)))
    #xlabs = np.array(pred_labs)[cols_to_keep][max_idx]
    ax.set_xticklabels(xlabs,fontsize=5,rotation=90);
    sns.despine(ax=ax,bottom=True, left=True)
    ax = plt.subplot(gs[0,2])
    f.colorbar(s, ax=ax, cax=ax)
    sns.despine(ax=ax)
    return f


In [None]:
clust_colors10x = pd.read_csv("/home/user/src/tithonus/analysis/merfish/10x_clust_colors.csv")
clust_colors10x = {r.name: r.values for i,r in clust_colors10x.iteritems()}


In [None]:
f = plot_pred_mat(adata_merfish.obs.clust_annot,pred_10x_lbl, label_colors, clust_colors10x, clust_lab_names=clust_labs_names, pred_lab_names=combined_pred_lab_names);
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS3_mer_to_10x.pdf",bbox_inches='tight')

In [None]:
f = plot_pred_mat(adata_10x_int.obs.clust_annot,pred_merfish_lbl , label_colors, label_colors, clust_lab_names=clust_labs_names);
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS3_10x_int.pdf",bbox_inches='tight')

# Show correlation between MERFISH and 10X

In [None]:
adata_merfish_orig = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/102821_merged_combined_merfish_allages.h5ad")
adata_merfish_orig_lps = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/112921_merged_combined_merfish_allages_lps.h5ad")


In [None]:
adata_merfish_orig = adata_merfish_orig.concatenate(adata_merfish_orig_lps)

In [None]:
shared_cells = [i for i in adata_merfish_orig.obs.index if i in adata_merfish_lps.obs.index]

In [None]:
adata_merfish_orig.var.library = adata_merfish_orig_lps.var.library

In [None]:
cell_type_vars = adata_merfish_orig.var_names[adata_merfish_orig.var.library=="cell_type"]
aging_vars = adata_merfish_orig.var_names[adata_merfish_orig.var.library=="aging"]

In [None]:
adata_10x_orig = ad.read_h5ad("/faststorage/brain_aging/rna_analysis/adata_combined_nodoublet.h5ad")

In [None]:
adata_10x_orig = adata_10x_orig[adata_10x_orig.obs.area=='PFC']

In [None]:
adata_10x_orig = adata_10x_orig[adata10x.obs.index]

In [None]:
adata_10x_orig.obs['cell_type'] = adata10x.obs.cell_type

In [None]:
adata_merfish_orig = adata_merfish_orig[adata_merfish_lps.obs.index]

In [None]:
adata_merfish_orig.obs['cell_type'] = adata_merfish_lps.obs.cell_type
adata_merfish_orig.obs['spatial_clust_annots'] = adata_merfish_lps.obs.spatial_clust_annots
adata_merfish_orig.obs['data_batch'] = adata_merfish_lps.obs.data_batch

In [None]:
adata_merfish_orig.obs['cond'] = adata_merfish_lps.obs.cond

In [None]:
# do some basic filtering
sc.pp.calculate_qc_metrics(adata_10x_orig, percent_top=None, log1p=False, inplace=True)
sc.pp.filter_cells(adata_10x_orig, min_genes=1000)
sc.pp.filter_cells(adata_10x_orig, max_counts=100000)
#sc.pp.filter_genes(adata_10x_orig, min_cells=3)
sc.pp.filter_cells(adata_10x_orig, min_counts=5000)

In [None]:
curr_age = ['4wk','24wk','90wk']

In [None]:
def compute_corr(A_10x, A_merfish,curr_age=['90wk'],ct="",cutoff=1e-4):
    adata_10x_avg_ct =  np.array(np.mean(A_10x[A_10x.obs.age.isin(curr_age)][:,cell_type_vars].X.toarray(),0))
    adata_merfish_avg_ct = np.array(np.mean(A_merfish[A_merfish.obs.age.isin(curr_age)][:,cell_type_vars].X,0))

    adata_10x_avg_aging =  np.array(np.mean(A_10x[A_10x.obs.age.isin(curr_age)][:,aging_vars].X.toarray(),0))
    adata_merfish_avg_aging = np.array(np.mean(A_merfish[A_merfish.obs.age.isin(curr_age)][:,aging_vars].X,0))

    # subset to expressed genes in both conditions 
    good_genes_ct = np.logical_and(adata_10x_avg_ct>cutoff, adata_merfish_avg_ct>cutoff)
    ct_corr = np.corrcoef(np.log10(adata_10x_avg_ct[good_genes_ct]), np.log10(0+adata_merfish_avg_ct[good_genes_ct]))[0,1]
    good_genes_age = np.logical_and(adata_10x_avg_aging>cutoff, adata_merfish_avg_aging>cutoff)
    age_corr =  np.corrcoef(np.log10(adata_10x_avg_aging[good_genes_age]),
                            np.log10(adata_merfish_avg_aging[good_genes_age]))[0,1]
    print(ct_corr, age_corr)
    all_10x = np.hstack((adata_10x_avg_ct[good_genes_ct],adata_10x_avg_aging[good_genes_age]))
    all_merfish = np.hstack((adata_merfish_avg_ct[good_genes_ct],adata_merfish_avg_aging[good_genes_age]))
    cc = np.corrcoef(np.log10(all_10x),
                     np.log10(all_merfish))[0,1]
    print(cc)
    plt.figure(figsize=(5,5))
    plt.loglog(adata_10x_avg_ct[good_genes_ct].T,adata_merfish_avg_ct[good_genes_ct],'g.')
    plt.loglog(adata_10x_avg_aging[good_genes_age].T,adata_merfish_avg_aging[good_genes_age],'b.')
    plt.legend(["Cell type library (r=%0.02f)"%ct_corr, "Aging library (r=%0.02f)"%age_corr])
    mpl.pyplot.grid(True, which="both")
    plt.title(ct + " " + f"snRNA-seq to MERFISH")
    plt.xlabel('snRNA-seq (UMI/cell)')
    plt.ylabel('MERFISH (counts/cell)')



In [None]:
def compute_corr_bulk(bulk,A_merfish,curr_age=['4wk','90wk'],ct="",cutoff=1e-4):
    cell_type_vars = A_merfish.var_names[A_merfish.var.library=="cell_type"]
    aging_vars = A_merfish.var_names[A_merfish.var.library=="aging"]
    adata_10x_avg_ct =  np.array(bulk[bulk.genes.isin(cell_type_vars)].average.values)
    adata_merfish_avg_ct = np.array(np.mean(A_merfish[A_merfish.obs.age.isin(curr_age)][:,bulk[bulk.genes.isin(cell_type_vars)].genes].X,0))
    adata_10x_avg_aging =  np.array(bulk[bulk.genes.isin(aging_vars)].average.values)
    adata_merfish_avg_aging = np.array(np.mean(A_merfish[A_merfish.obs.age.isin(curr_age)][:,bulk[bulk.genes.isin(aging_vars)].genes].X,0))
    good_genes_ct = np.logical_and(adata_10x_avg_ct>cutoff, adata_merfish_avg_ct>cutoff)
    ct_corr = np.corrcoef(np.log10(adata_10x_avg_ct[good_genes_ct]), np.log10(0+adata_merfish_avg_ct[good_genes_ct]))[0,1]
    good_genes_age = np.logical_and(adata_10x_avg_aging>cutoff, adata_merfish_avg_aging>cutoff)

    age_corr =  np.corrcoef(np.log10(adata_10x_avg_aging[good_genes_age]),
                            np.log10(adata_merfish_avg_aging[good_genes_age]))[0,1]
    all_10x = np.hstack((adata_10x_avg_ct[good_genes_ct],adata_10x_avg_aging[good_genes_age]))
    all_merfish = np.hstack((adata_merfish_avg_ct[good_genes_ct],adata_merfish_avg_aging[good_genes_age]))
    cc = np.corrcoef(np.log10(all_10x),
                     np.log10(all_merfish))[0,1]
    plt.figure(figsize=(5,5))
    plt.loglog(adata_10x_avg_ct[good_genes_ct].T,adata_merfish_avg_ct[good_genes_ct],'b.')
    plt.loglog(adata_10x_avg_aging[good_genes_age].T,adata_merfish_avg_aging[good_genes_age],'g.')
    plt.legend(["Cell type library (r=%0.02f)"%ct_corr, "Aging library (r=%0.02f)"%age_corr])
    mpl.pyplot.grid(True, which="both")
    plt.title(ct + " " + f"Bulk RNA-seq to MERFISH")
    plt.xlabel('Bulk RNA-seq (FPKM)')
    plt.ylabel('MERFISH (counts/cell)')



In [None]:
bad_genes = ['Prom1',
 'Parp8',
 'Rbpj',
 'Skap2',
 'Ago3',
 'Cntnap3',
 'Meis2',
 'Arnt2',
 'Hivep2',
 'Foxn3',
 'Parp2',
 'Zfp608',
 'Fbxl7',
 'Htr2c',
 'Klf7',
 'Timp2',
 'Zbtb16',
 'Egflam',
 'Ikzf2',
 'Cdh13',
 'Cd63',
 'Marcks',
 'Parp11',
 'Herc6',
 'Cdh9',
 'Tsc22d1',
 'Lef1',
 'Shisa6',
 'St8sia6',
 'Trp53',
 'Plch1',
 'Cp',
 '9630014M24Rik',
 'Elf2',
 'Tafa1',
 'Ntn1',
 'Rarb',
 'Zfp462',
 'Sirt5',
 'Mamdc2',
 'Bach2']

In [None]:
m1_bulk = pd.read_csv("/home/user/Downloads/M1_bulk_all.csv")

curr_adata = adata_merfish_orig
compute_corr_bulk(m1_bulk, curr_adata[curr_adata.obs.spatial_clust_annots.isin(['Pia','L2/3','L5','L6'])],cutoff=0,curr_age=['24wk'])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/bulk_m1_to_merfish.pdf",bbox_inches='tight')

In [None]:
str_bulk = pd.read_excel("/home/user/Downloads/AllGenesExpressionTable_NAC.xlsx")
str_bulk["average"] = (str_bulk.rep1 + str_bulk.rep2)/2

curr_adata = adata_merfish_orig

compute_corr_bulk(str_bulk, curr_adata[curr_adata.obs.spatial_clust_annots.isin(['Striatum'])],cutoff=0,curr_age=['24wk'])
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/bulk_striatum_to_merfish.pdf",bbox_inches='tight')

In [None]:
m1_bulk = pd.read_csv("/home/user/Downloads/M1_bulk_all.csv")

curr_adata = adata_merfish_orig#[adata_merfish_orig.obs.batch==2]
compute_corr_bulk(m1_bulk, curr_adata[curr_adata.obs.spatial_clust_annots.isin(['Pia','L2/3','L5','L6'])],cutoff=0,curr_age=['4wk','90wk'])

In [None]:
adata_merfish_young_avg = adata_merfish_orig[adata_merfish_orig.obs.age=='4wk'].X.mean(0)
adata_merfish_old_avg = adata_merfish_orig[adata_merfish_orig.obs.age=='90wk'].X.mean(0)

In [None]:
plt.figure(figsize=(5,5))
plt.loglog(adata_merfish_young_avg,adata_merfish_old_avg,'k.')
mpl.pyplot.grid(True, which="both")
plt.title("Young MERFISH to old MERFISH, r=%0.2f"%np.corrcoef(np.log10(adata_merfish_young_avg), np.log10(adata_merfish_old_avg))[0,1])
plt.xlabel('Young MERFISH (counts/cell)')
plt.ylabel('Old MERFISH (counts/cell)')

plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS2_merfish_young_old_corr.pdf", bbox_inches='tight', dpi=200)


In [None]:
batches = []
ages = []
for i in ['4wk','24wk','90wk']:
    curr_batches = list(adata_merfish_orig[np.logical_and(adata_merfish_orig.obs.age == i,
                                                          adata_merfish_orig.obs.cond=="ctrl")].obs.data_batch.unique())
    batches.extend(curr_batches)
    for j in range(len(curr_batches)):
        ages.append(i)

In [None]:
adata_merfish_orig[adata_merfish_orig.obs.data_batch==batches[1]]

In [None]:
# young MERFISH rep 1 to rep 2
adata_merfish_batch1_avg = np.array(adata_merfish_orig[adata_merfish_orig.obs.data_batch=="10"].X.mean(0))
adata_merfish_batch2_avg = np.array(adata_merfish_orig[adata_merfish_orig.obs.data_batch=="11"].X.mean(0))

In [None]:
plt.figure(figsize=(5,5))
plt.loglog(adata_merfish_batch1_avg,adata_merfish_batch2_avg,'k.')
mpl.pyplot.grid(True, which="both")
plt.title("Batch 1 MERFISH to Batch 2 MERFISH, r=%0.2f"%np.corrcoef(np.log10(adata_merfish_batch1_avg), np.log10(adata_merfish_batch2_avg))[0,1])
plt.xlabel('MERFISH batch 1 (counts/cell)')
plt.ylabel('MERFISH batch 2 (counts/cell)')

plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS2_merfish_batch_corr_example.pdf", bbox_inches='tight', dpi=200)


In [None]:
celltype_colors, celltype_pals, label_colors, clust_pals = generate_palettes(adata10x)

# Run GLM on 10X  for all genes

In [None]:
from de import *

In [None]:
adata10x.obs['log_umi'] = np.log10(adata10x.obs.total_counts)


In [None]:
adata10x

In [None]:
glm_de = {}#glm_de_partial.copy()
for i in adata10x.obs.cell_type.unique():
    if i not in glm_de:
        print(i)
        try:
            glm_de[i] = run_glm_de_age(adata10x[adata10x.obs.cell_type==i],lognorm=True, family='ols')
        except Exception as e:
            print(e)

In [None]:
glm_de_df = []
for v in glm_de.values():
    df = list(v.values())[0]
    df['cell_type'] = list(v.keys())[0]
    glm_de_df.append(df)

In [None]:
#pd.concat(glm_de_df).to_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_major_age_V2.csv")

In [None]:
glm_de_clust = {}
for i in adata10x.obs.clust_annot.unique():
    print(i)
    try:
        if i not in glm_de_clust:
            glm_de_clust[i] = run_glm_de_age(adata10x[adata10x.obs.clust_annot==i],lognorm=True, family='ols', grouping='clust_annot')
    except Exception as e:
        print(e)

In [None]:
glm_de_clust_df = []
for v in glm_de_clust.values():
    df = list(v.values())[0]
    df['cell_type'] = list(v.keys())[0]
    glm_de_clust_df.append(df)

In [None]:
pd.concat(glm_de_clust_df).to_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_clust_age_V2.csv")

# Run GLM on MERFISH and 10X just for MERFISH genes

In [None]:
glm_de_clust_df = []
for v in glm_de_clust.values():
    df = list(v.values())[0]
    df['cell_type'] = list(v.keys())[0]
    glm_de_clust_df.append(df)
glm_de_clust_df = pd.concat(glm_de_clust_df)

In [None]:
from de import *

In [None]:
#sc.pp.scale(adata_merfish,max_value=10)

In [None]:
adata_merfish.obs['log_umi'] = np.log(adata_merfish.obs.total_counts)

In [None]:
adata10x_subset = adata10x[:,adata_merfish.var_names]

In [None]:
ttest_de_celltype_10x = {}
for i in adata10x_subset.obs.cell_type.unique():
    print(i)
    try:
        ttest_de_celltype_10x[i] = run_glm_de_age_merfish(adata10x_subset[adata10x_subset.obs.cell_type==i],family='ols',grouping='cell_type')
    except Exception as e:
        print(e)

In [None]:
ttest_de_celltype_merfish = {}
for i in adata_merfish.obs.cell_type.unique():
    print(i)
    try:
        #if i not in ttest_de_celltype_merfish:
        ttest_de_celltype_merfish[i] = run_glm_de_age_merfish(adata_merfish[adata_merfish.obs.cell_type==i],family='ols',grouping='cell_type')
    except Exception as e:
        print(e)

In [None]:
import pandas as pd
from statsmodels.stats.multitest import multipletests
def collapse_de_results(X):
    ttest_de_clust_df = []
    for v in X.values():
        df = list(v.values())[0]
        df['cell_type'] = list(v.keys())[0]
        #print(np.sum(np.isnan(df.pval)))
        df.pval[np.isnan(df.pval)] = 1
        df.qval = multipletests(df.pval, method='fdr_by')[1]
        ttest_de_clust_df.append(df)
    ttest_de_clust_df = pd.concat(ttest_de_clust_df)
    #ttest_de_clust_dfttest_de_clust_df[~np.isnan(ttest_de_clust_df.qval)]
    return ttest_de_clust_df
#ttest_de_clust_df = pd.concat(ttest_de_clust_df)
#ttest_de_clust_df.pval = [float(i) for i in ttest_de_clust_df.pval]
#ttest_de_clust_df.qval = [float(i) for i in ttest_de_clust_df.qval]


In [None]:
ttest_de_celltype_10x_df = collapse_de_results(ttest_de_celltype_10x)
ttest_de_celltype_merfish_df = collapse_de_results(ttest_de_celltype_merfish)

In [None]:
#ttest_de_celltype_10x_df.to_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_major_age_10x_V2.csv")
#ttest_de_celltype_merfish_df.to_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_major_age_merfish_V2.csv")

In [None]:


#ttest_de_clust_df = pd.concat(ttest_de_clust_df)
#ttest_de_clust_df.pval = [float(i) for i in ttest_de_clust_df.pval]
#ttest_de_clust_df.qval = [float(i) for i in ttest_de_clust_df.qval]

#ttest_de_clust_df = ttest_de_clust_df[~np.isnan(ttest_de_clust_df.qval)]
#ttest_de_clust_df.to_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_ttest_de_minor_age.csv")
ttest_de_clust_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_major_age_merfish_V2.csv")
ttest_de_celltype_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_major_age_10x_V2.csv")

In [None]:
ttest_de_clust_merfish = {}
for i in adata_merfish.obs.clust_annot.unique():
    print(i)
    try:
        if i not in ttest_de_clust_merfish:
            ttest_de_clust_merfish[i] = run_glm_de_age(adata_merfish[adata_merfish.obs.clust_annot==i],lognorm=True, family='ols',grouping='clust_annot')
    except Exception as e:
        print(e)

In [None]:

ttest_de_clust_merfish_df = []
for v in ttest_de_clust_merfish.values():
    df = list(v.values())[0]
    df['cell_type'] = list(v.keys())[0]
    ttest_de_clust_merfish_df.append(df)
ttest_de_clust_merfish_df = pd.concat(ttest_de_clust_merfish_df)
ttest_de_clust_merfish_df.pval = [float(i) for i in ttest_de_clust_merfish_df.pval]
ttest_de_clust_merfish_df.qval = [float(i) for i in ttest_de_clust_merfish_df.qval]
ttest_de_clust_merfish_df = ttest_de_clust_merfish_df[~np.isnan(ttest_de_clust_merfish_df.qval)]

In [None]:
ttest_de_clust_merfish_df.to_csv("gene_lists/ttest_de_clust_merfish_df_for_figure_nologumi.csv")

In [None]:
de_genes_age_minor_signif_merfish = ttest_de_clust_merfish_df[np.logical_and(np.log2(np.exp(np.abs(ttest_de_clust_merfish_df.coef))) > np.log2(1.75), ttest_de_clust_merfish_df.qval<0.05)]

In [None]:
merfish_de_genes = de_genes_age_minor_signif_merfish.gene.unique()

In [None]:
de_clust_order = [
 'ExN-L2/3-1',
 'ExN-L2/3-2',
 'ExN-L5-1',
 'ExN-L5-2',
 'ExN-L5-3',
 'ExN-L6-1',
 'ExN-L6-2',
 'ExN-L6-3',
 'ExN-Olf',
 'InN-Olf-1',
 'InN-Olf-2',

 'InN-Vip',

 'InN-Lamp5',

 'InN-Pvalb-1',
 'InN-Pvalb-2',
 'InN-Pvalb-3',
 'InN-Sst-1',
 'InN-Sst-2',
 'InN-Calb2-1',
 'InN-Calb2-2',
 'InN-Chat',
 'InN-Lhx6',

'MSN-D1-1',
 'MSN-D1-2',
 'MSN-D2',
 'OPC',
 'Olig-1',
 'Olig-2',
 'Olig-3',

'Astro-1',
 'Astro-2',
 'Vlmc',
 'Peri-1',
 'Peri-2',
 'Endo-1',
 'Endo-2',
 'Endo-3',
 'Epen',

 'Micro-1',
 'Micro-2',
 'Micro-3',
 'Macro',
 'T cell',
]

In [None]:
coef_mat = np.zeros((len(merfish_de_genes), len(de_clust_order)))
for i,ct in enumerate(de_clust_order):
    curr_ct = ttest_de_clust_merfish_df[ttest_de_clust_merfish_df.cell_type==ct]
    for j,g in enumerate(merfish_de_genes):
        coef = curr_ct[curr_ct.gene==g].coef.values
        if len(coef) > 0:
            coef_mat[j,i] = coef[0]

In [None]:
coef_mat = pd.DataFrame(coef_mat, columns=de_clust_order, index=merfish_de_genes)

In [None]:
#coef_mat = coef_mat.drop('T cell',axis=1)

In [None]:
sns.set_style('white')
cg = sns.clustermap(coef_mat,vmin=-3,vmax=3,cmap=plt.cm.seismic,metric='correlation')#figsize=(12,12*(coef_mat.shape[0]/coef_mat.shape[1])),metric='correlation')
row_idx = np.array(cg.dendrogram_row.reordered_ind)
col_idx = np.array(cg.dendrogram_col.reordered_ind)
#cg.ax_row_dendrogram.set_visible(False)
#cg.ax_col_dendrogram.set_visible(False)
cg.cax.set_visible(False)
cg.ax_heatmap.set_rasterized(True)
for i in cg.ax_heatmap.get_xticklabels():
    i.set_size(20)
    i.set_rasterized(False)
for i in cg.ax_heatmap.get_yticklabels():
    i.set_size(20)
    i.set_rasterized(False)

#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_expr_heatmap.pdf",dpi=300,bbox_inches='tight')

In [None]:
celltype_colors, celltype_pals, label_colors, clust_pals = generate_palettes(adata_merfish)

In [None]:
coef_mat = coef_mat.iloc[row_idx,:]
coef_mat = coef_mat.drop("T cell",axis=1)

In [None]:
clust_idx = np.array([np.argwhere(np.array(list(label_colors.keys())) == i)[0][0] for i in list(np.array(coef_mat.columns))])
clust_idx_cmap = mpl.colors.ListedColormap(list(label_colors.values()))



In [None]:
f, axes = plt.subplots(figsize=(10,10*len(col_idx)/len(row_idx)), ncols=2, gridspec_kw={'width_ratios':[0.5,20],'wspace':0.05})
ax = axes[0]
ax.imshow(np.expand_dims(clust_idx,axis=0).T,aspect='auto',interpolation='none',cmap=clust_idx_cmap)
#ax.axis('off')
ax.set_xticks([])
ax.set_yticks([])
#ax.set_yticks(np.arange(len(clust_idx)))
#ax.set_yticklabels(np.array(coef_mat.columns)[clust_idx],fontsize=6)
sns.despine(ax=ax,left=True,right=True, bottom=True)
ax = axes[1]
ax.imshow(coef_mat.values.T,cmap=plt.cm.seismic,vmin=-3,vmax=3,aspect='auto',interpolation='none')
#ax.axis('off')
sns.despine(ax=ax,left=True,right=True,bottom=True)
ax.set_yticks([])
ax.set_xticks(np.arange(len(row_idx)));
ax.set_xticklabels(np.array(coef_mat.index)[:],rotation=90,fontsize=6);
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_expr_heatmap_nologumi.pdf",dpi=300,bbox_inches='tight')

In [None]:
adata_merfish.X.shape

In [None]:
de_clust_order = [i for i in de_clust_order if i != "T cell"]

In [None]:
de_genes_age_minor_signif_merfish = ttest_de_clust_merfish_df[np.logical_and(np.log2(np.exp(np.abs(ttest_de_clust_merfish_df.coef))) > np.log2(1.75), ttest_de_clust_merfish_df.qval<0.05)]
sns.set_style('white')
f, ax = plt.subplots(figsize=(8,3), nrows=3, gridspec_kw={'hspace':0.05, 'height_ratios':[10,1,10]})
for i in [50,100,150]:
    ax[0].axhline(i,color='gray',linestyle='--',lw=1)
    ax[2].axhline(-i,color='gray',linestyle='--',lw=1)

sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': de_clust_order, 
                                    'de':[de_genes_age_minor_signif_merfish[np.logical_and(de_genes_age_minor_signif_merfish.cell_type==i, de_genes_age_minor_signif_merfish.coef>0)].shape[0] for i in de_clust_order]}),
            order=de_clust_order,color='salmon',ax=ax[0])
ax[0].set_xticks([])
ax[0].set_xlabel("")
ax[0].set_ylim([0,12.5])
for i in [2.5, 5, 7.5, 10, 12.5]:
    ax[0].axhline(i,color='gray',linestyle='--')
sns.despine(ax=ax[0],bottom=True)

ax[1].imshow(np.expand_dims(clust_idx,1).T,aspect='auto',interpolation='none',cmap=clust_idx_cmap)
ax[1].axis('off')
sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': de_clust_order, 
                                    'de':[-de_genes_age_minor_signif_merfish[np.logical_and(de_genes_age_minor_signif_merfish.cell_type==i, de_genes_age_minor_signif_merfish.coef<0)].shape[0] for i in de_clust_order]}),
            order=de_clust_order,color='skyblue',ax=ax[2])
ax[2].set_xticklabels(ax[2].get_xticklabels(),rotation = 90);
sns.despine(ax=ax[2])
ax[2].set_ylim([-12.5,0])
for i in [-2.5, -5, -7.5, -10, -12.5]:
    ax[2].axhline(i,color='gray',linestyle='--')

plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_num_de_merfish_nologumi.pdf",dpi=300,bbox_inches='tight')

In [None]:
gene_idx = {}
for i,g in enumerate(de_genes_age_minor_signif_merfish.gene.unique()):
    gene_idx[g] = i

n_genes = len(de_genes_age_minor_signif_merfish.gene.unique())
n_celltypes = len(clust_order)

de_subtypes_pos = np.zeros((n_celltypes, n_genes))#pd.DataFrame(index=de_genes_age_minor_signif.cell_type.unique(), columns=de_genes_age_minor_signif.gene.unique(),data=0)
de_subtypes_neg = np.zeros((n_celltypes, n_genes))#pd.DataFrame(index=de_genes_age_minor_signif.cell_type.unique(), columns=de_genes_age_minor_signif.gene.unique(),data=0)

for n,i in enumerate(clust_order):
    curr_type = i#.get_text()
    curr_de = de_genes_age_minor_signif_merfish[de_genes_age_minor_signif_merfish.cell_type==curr_type]
    #curr_de.at(curr_type, curr_)
    for g,c in zip(curr_de.gene, curr_de.coef):
        curr_gene_idx = gene_idx[g]
        if c > 0:
            de_subtypes_pos[n, curr_gene_idx] = 1
        else:
            de_subtypes_neg[n, curr_gene_idx]  = -1
de_subtypes = de_subtypes_pos + de_subtypes_neg#np.hstack((de_subtypes_pos, de_subtypes_neg))#pd.concat([de_subtypes_pos, de_subtypes_neg],axis=1)
pos_points = np.argwhere(de_subtypes>0)
neg_points = np.argwhere(de_subtypes<0)
sorted_gene_idx = np.argsort(np.sum(de_subtypes,0))[::-1]


In [None]:
sns.set_style('white')
f = plt.figure(figsize=(10,6))

gs = plt.GridSpec(figure=f,ncols=2, nrows=2,width_ratios=[1,10], height_ratios=[3,10],wspace=0.1,hspace=0.1)
ax = plt.subplot(gs[0])
ax.bar(np.arange(de_subtypes.shape[1]), np.sum(de_subtypes_pos[:,sorted_gene_idx],0),width=1,color='salmon',lw=0,rasterized=True)
ax.bar(np.arange(de_subtypes.shape[1]), np.sum(de_subtypes_neg[:,sorted_gene_idx],0),width=1,color='skyblue',lw=0,rasterized=True)

ax.set_xlim([0, de_subtypes.shape[1]])
ax.set_xticklabels([])
ax.set_ylim([-20,20])
sns.despine(ax=ax,left=True)

ax = plt.subplot(gs[1])
ax.imshow(de_subtypes[:,sorted_gene_idx],aspect='auto',interpolation='nearest',vmin=-1,vmax=1,cmap=mpl.colors.ListedColormap(['skyblue','white','salmon']),rasterized=True)
#ax.set_xlim([0,de_subtypes.shape[1]])
#ax.set_ylim([0,de_subtypes.shape[0]])
ax.set_yticks(np.arange(len(clust_order)))
ax.set_yticklabels(clust_order,fontsize=6);
ax.set_xlabel('Genes')
ax.set_xticks([])
sns.despine(ax=ax)

#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/fig3_shared_de_merfish.pdf",dpi=300,bbox_inches='tight')

In [None]:
ttest_de_clust = {}
for i in adata10x.obs.clust_annot.unique():
    print(i)
    try:
        if i not in ttest_de_clust:
            ttest_de_clust[i] = run_ttest_de_age(adata10x[adata10x.obs.clust_annot==i],lognorm=True, grouping='clust_annot')
    except Exception as e:
        print(e)

In [None]:
ttest_de_clust_df = []
for v in ttest_de_clust.values():
    df = list(v.values())[0]
    df['cell_type'] = list(v.keys())[0]
    ttest_de_clust_df.append(df)

ttest_de_clust_df = pd.concat(ttest_de_clust_df)
ttest_de_clust_df.pval = [float(i) for i in ttest_de_clust_df.pval]
ttest_de_clust_df.qval = [float(i) for i in ttest_de_clust_df.qval]

ttest_de_clust_df = ttest_de_clust_df[~np.isnan(ttest_de_clust_df.qval)]

In [None]:
#ttest_de_clust_df.to_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_minor_age_V3_nools.csv")

In [None]:
#glm_de_clust_df.to_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_minor_age.csv")

In [None]:
#ttest_de_clust_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_minor_age.csv")
ttest_de_celltype_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_major_age.csv")

In [None]:
ttest_de_clust_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_minor_age_V3_nools.csv")
#ttest_de_clust_df = pd.read_csv("/home/user/src/tithonus/analysis/merfish/gene_lists/glm_nb_de_minor_age.csv")


In [None]:
# for v2
de_genes_age_minor_signif = ttest_de_clust_df[np.logical_and(np.abs(ttest_de_clust_df.coef) > np.log(2.5), ttest_de_clust_df.qval<0.05)]
#de_genes_age_major_signif = ttest_de_celltype_df[np.logical_and(np.abs(ttest_de_celltype_df.coef) > np.log(1.25), ttest_de_celltype_df.qval<0.05)]

In [None]:
de_genes_age_minor_signif = de_genes_age_minor_signif[~np.isinf(de_genes_age_minor_signif.coef)]
#de_genes_age_major_signif = de_genes_age_major_signif[~np.isinf(de_genes_age_major_signif.coef)]

In [None]:
gene_idx = {}
for i,g in enumerate(de_genes_age_minor_signif.gene.unique()):
    gene_idx[g] = i

n_genes = len(de_genes_age_minor_signif.gene.unique())
n_celltypes = len(de_clust_order)

de_subtypes_pos = np.zeros((n_celltypes, n_genes))#pd.DataFrame(index=de_genes_age_minor_signif.cell_type.unique(), columns=de_genes_age_minor_signif.gene.unique(),data=0)
de_subtypes_neg = np.zeros((n_celltypes, n_genes))#pd.DataFrame(index=de_genes_age_minor_signif.cell_type.unique(), columns=de_genes_age_minor_signif.gene.unique(),data=0)

for n,i in enumerate(de_clust_order):
    curr_type = i#.get_text()
    curr_de = de_genes_age_minor_signif[de_genes_age_minor_signif.cell_type==curr_type]
    #curr_de.at(curr_type, curr_)
    for g,c in zip(curr_de.gene, curr_de.coef):
        curr_gene_idx = gene_idx[g]
        if c > 0:
            de_subtypes_pos[n, curr_gene_idx] = 1
        else:
            de_subtypes_neg[n, curr_gene_idx]  = -1
de_subtypes = de_subtypes_pos + de_subtypes_neg#np.hstack((de_subtypes_pos, de_subtypes_neg))#pd.concat([de_subtypes_pos, de_subtypes_neg],axis=1)
pos_points = np.argwhere(de_subtypes>0)
neg_points = np.argwhere(de_subtypes<0)
sorted_gene_idx = np.argsort(np.sum(de_subtypes,0))[::-1]


In [None]:
#de_genes_age_minor_signif = de_genes_age_minor[np.logical_and(de_genes_age_minor.qval<0.05,
#                                                              np.log2(np.exp(np.abs(de_genes_age_minor.coef)))>1)]
clust_idx = np.array([np.argwhere(np.array(list(label_colors.keys())) == i)[0][0] for i in de_clust_order])
clust_idx_cmap = mpl.colors.ListedColormap(list(label_colors.values()))



In [None]:
sns.set_style('white')
f, ax = plt.subplots(figsize=(8,3), nrows=3, gridspec_kw={'hspace':0.05, 'height_ratios':[10,1,10]})
for i in [50,100,150,200,250]:
    ax[0].axhline(i,color='gray',linestyle='--',lw=1)
    ax[2].axhline(-i,color='gray',linestyle='--',lw=1)

sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': de_genes_age_minor_signif.cell_type.unique(), 
                                    'de':[de_genes_age_minor_signif[np.logical_and(de_genes_age_minor_signif.cell_type==i, de_genes_age_minor_signif.coef>0)].shape[0] for i in de_genes_age_minor_signif.cell_type.unique()]}),
            order=de_clust_order,color='salmon',ax=ax[0])
ax[0].set_xticks([])
ax[0].set_xlabel("")
ax[0].set_ylim([0,200])

sns.despine(ax=ax[0],bottom=True)

ax[1].imshow(np.expand_dims(clust_idx,1).T,aspect='auto',interpolation='none',cmap=clust_idx_cmap)
ax[1].axis('off')
sns.barplot(x='gene', y='de',
            data=pd.DataFrame(data={'gene': de_genes_age_minor_signif.cell_type.unique(), 
                                    'de':[-de_genes_age_minor_signif[np.logical_and(de_genes_age_minor_signif.cell_type==i, de_genes_age_minor_signif.coef<0)].shape[0] for i in de_genes_age_minor_signif.cell_type.unique()]}),
            order=de_clust_order,color='skyblue',ax=ax[2])
ax[2].set_xticklabels(ax[2].get_xticklabels(),rotation = 90);
sns.despine(ax=ax[2])
ax[2].set_ylim([-200,0])
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS4_num_de_V3_thresh2.5.pdf",dpi=300,bbox_inches='tight')

In [None]:
for i in de_genes_age_minor_signif.cell_type.unique():
    print(i,de_genes_age_minor_signif[de_genes_age_minor_signif.cell_type==i].shape) 

In [None]:
for i in de_genes_age_major_signif.cell_type.unique():
    print(i,de_genes_age_major_signif[de_genes_age_major_signif.cell_type==i].shape)

In [None]:
# get coef for all the genes
all_de_genes = list(de_genes_age_minor_signif.gene.unique())
good_ct = []
all_de_coef = np.zeros((len(all_de_genes), len( de_genes_age_minor_signif.cell_type.unique())))
for j,ct in enumerate(de_genes_age_minor_signif.cell_type.unique()):
    curr_de = ttest_de_clust_df[ttest_de_clust_df.cell_type==ct]
    if curr_de.shape[0] > 0:
        good_ct.append(ct)
    print(ct, curr_de.shape)
    for i in tqdm(range(len(all_de_genes))):
        curr_coef = curr_de[curr_de.gene==all_de_genes[i]].coef
        if len(curr_coef) > 0:
            all_de_coef[i,j] = curr_coef.values[0]
        else:
            all_de_coef[i,j] = 0
        #all_de_coef[i,j] = ttest_de_clust_df[np.logical_and(ttest_de_clust_df.cell_type==ct, ttest_de_clust_df.gene==all_de_genes[i])].coef

In [None]:
from utils import order_values

In [None]:
all_de_coef[np.isnan(all_de_coef)] = 0
all_de_coef[np.isinf(all_de_coef)] = all_de_coef.max()

In [None]:
all_de_coef = pd.DataFrame(data=all_de_coef,index=all_de_genes, columns=de_genes_age_minor_signif.cell_type.unique())
cl = sns.clustermap(all_de_coef,cmap=plt.cm.seismic,vmin=-5,vmax=5, method='complete',
                    metric='cosine',figsize=(10,20),z_score=None)

In [None]:
f,axes = plt.subplots(figsize=(20,5),ncols=2, gridspec_kw={'width_ratios':[0.25,20],'wspace':0.01})
colnames = np.array(list(de_genes_age_minor_signif.cell_type.unique()))[cl.dendrogram_col.reordered_ind]
ax = axes[0]
ax.set_yticks(np.arange(len(colnames)));
ax.set_yticklabels(colnames);
sns.despine(ax=ax,left=True,bottom=True)
ax.set_xticks([])
ax.imshow(np.expand_dims(np.arange(len(colnames)),0).T,aspect='auto',interpolation='none',cmap=mpl.colors.ListedColormap([label_colors[i] for i in colnames]),rasterized=True)
ax = axes[1]
ax.imshow(all_de_coef.values[:,cl.dendrogram_col.reordered_ind][cl.dendrogram_row.reordered_ind,:].T,aspect='auto',interpolation='nearest',vmin=-3,vmax=3,cmap=plt.cm.seismic,rasterized=True)
ax.set_yticks([])
ax.set_xticks([])
ax.set_xlabel('Genes')
sns.despine(ax=ax,left=True,bottom=True)
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS4_shared_de.pdf",dpi=300,bbox_inches='tight')

In [None]:
bad_top_terms = ['GO:0009987', 'GO:0008150']

~np.any([True if i in bad_top_terms else False for i in temp.parents.iloc[0] ])

In [None]:
def get_genes_for_celltype(de, name, direction=None):
    clust_names = de.cell_type.unique()
    if direction == "pos":
        de = de[de.coef>=0]
    elif direction == "neg":
        de = de[de.coef<0]
    return de[de.cell_type.isin([i for i in clust_names if name in i])].gene.unique()

# code from: https://github.com/klarman-cell-observatory/inCITE-seq/blob/main/notebooks/inCITE_tools.ipynb
def parse_GO_query(gene_list, species, db_to_keep='all'): 
    if db_to_keep=='all': 
        db_to_keep = ['GO:BP', 'GO:MF', 'KEGG', 'REAC', 'TF']
    GO_df = sc.queries.enrich(list(gene_list), org=species)
    GO_df = GO_df[GO_df['significant']==True]
    GO_df = GO_df[GO_df['source'].isin(db_to_keep)]
    return GO_df

def sig_genes_GO_query(sig_genes, clust_lim=1000, source=['GO:BP']):
    bad_top_terms = ['GO:0009987', 'GO:0008150']
    GO_results = pd.DataFrame([],columns=['source','name','p_value','description','native','parents'])
    clust_ct = 0
    idx_ct = 0
    GO_df = parse_GO_query(sig_genes,'mmusculus',source)
    if len(GO_df)>0:
        for index, row in GO_df.iterrows():
            if clust_ct<clust_lim:
                if ~np.any([True if i in bad_top_terms else False for i in row['parents']]):
                    # exclude top level terms
                    GO_row = pd.DataFrame({'source':row['source'],
                                         'name':row['name'],'p_value':row['p_value'],
                                         'description':row['description'], 
                                         'native':row['native'], 'parents':[row['parents']]},
                                            index=[idx_ct])
                    clust_ct+=1
                    idx_ct+=1
                    GO_results = pd.concat([GO_results, GO_row])
    return GO_results

def plot_GO_terms(df,alpha,filename,colormap='#d3d3d3',xlims=[0,20],ax=None): 
    
    # add color column
    if colormap != '#d3d3d3': 
        df['color'] = df['cluster'].map(colormap)
        color=df['color']
    else: 
        color=colormap
    
    df = df.loc[df['p_value']<=alpha]
    
    fig_height = df.shape[0]*(1/10)
    if ax is None:
        fig, ax = plt.subplots(figsize=(3,fig_height))
    y_pos = np.arange(df.shape[0])
    log10p = -np.log10(df['p_value'].tolist())
    df['-log10p'] = log10p
    
    sns.reset_orig()
    ax.barh(y_pos, log10p, align='center', color=color)
    ax.set_yticks(y_pos)
#     ax.set_yticklabels(df['native']+':'+df['name'],fontsize=6)
    ax.set_yticklabels(df['name'],fontsize=6)
    ax.invert_yaxis()
    ax.set_xlabel('-log10(P)')
    ax.set_xlim(xlims)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_linewidth(1)
#     plt.show()
    #figname = '%s/GO_hbar_%s.pdf' %(sc.settings.figdir, filename)
    #print('Saving to %s' %figname)
    #fig.savefig(figname, bbox_inches='tight')

def get_sig_genes_for_celltype(de_df, cell_type, log2_thresh=np.log2(2), qval_thresh=0.05):
    res = de_df[de_df.cell_type==cell_type].dropna()
    #res.pval = multipletests(res.qval, method='fdr_bh')[1]
    good_genes = np.logical_and(res.qval<qval_thresh, np.abs(res.log2fc)>log2_thresh)
    return list(res.gene[good_genes])

    return list(res.gene[good_genes])
def get_sig_genes_for_aging(de_df, cell_type, log2_thresh=np.log2(2), pval_thresh=0.05):
    res = de_df[de_df.cell_type==cell_type].dropna()
    res.coef = np.log2(np.exp(res.coef))
    res.pval = multipletests(res.pval, method='fdr_bh')[1]
    res = res.sort_values('coef')
    good_genes = np.logical_and(res.pval<pval_thresh, np.abs(res.coef)>log2_thresh)
    return list(res.gene[good_genes])

In [None]:
from goatools import obo_parser

In [None]:
from goatools.base import download_ncbi_associations
gene2go = download_ncbi_associations()


In [None]:
from goatools.base import download_go_basic_obo
#obo_fname = download_go_basic_obo()
from goatools.anno.genetogo_reader import Gene2GoReader

objanno = Gene2GoReader("gene2go", taxids=[10090])
go2geneids_mus = objanno.get_id2gos(namespace='BP', go2geneids=True)


In [None]:
from goatools.go_search import GoSearch

srchhelp = GoSearch("go-basic.obo", go2items=go2geneids_human)


In [None]:
id_to_sym = {}
for r,i in pd.read_table("entrez_gene_ids.txt").iterrows():
    id_to_sym[i['GeneID']] = i['Symbol']

In [None]:
def get_genes_for_go_term(go_id):
    gos = srchhelp.add_children_gos([go_id])
    ids = srchhelp.get_items(gos)
    return [id_to_sym[geneid] for geneid in ids if geneid in id_to_sym]#{geneid: id_to_sym[geneid] for geneid in ids if geneid in id_to_sym}


In [None]:
merfish_genes = list(ttest_de_clust_df.gene.unique())

In [None]:
ttest_de_clust_df = pd.read_csv("gene_lists/glm_nb_de_minor_age_V3_nools.csv")
de_genes_age_minor_signif_high = ttest_de_clust_df[np.logical_and(np.log2(np.exp(np.abs(ttest_de_clust_df.coef))) > np.log2(2), ttest_de_clust_df.qval<0.05)]



In [None]:
cell_types = ["Endo", "Astro", "Micro","Olig", "ExN", "InN", "MSN","OPC"]


In [None]:
n_ct = len(cell_types)
f, axes = plt.subplots(nrows=1,ncols=8, figsize=(20,5),gridspec_kw={'wspace':1})
k = 0
for i in range(8):
    ax = axes[i]
    genes_pos = get_genes_for_celltype(de_genes_age_minor_signif_high, cell_types[k],"pos")
    go_terms = sig_genes_GO_query(genes_pos, source=['GO:BP','KEGG']).head(10)
    print(cell_types[i], len(genes_pos))
    plot_GO_terms(go_terms, 1, "",xlims=[0,15],ax=ax)
    ax.set_title(cell_types[k])
    for l in ax.get_yticklabels():
        print(l.get_text())
        if go_terms[go_terms.name==l.get_text()].source.iloc[0] == "KEGG":
            l.set_color('r')
    k += 1
    print("--------------")
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS5_go_pos_V3.pdf",dpi=300,bbox_inches='tight')

In [None]:
from Bio.KEGG import REST
import io
def to_df(result):
    return pd.read_table(io.StringIO(result), header=None)

def load_pathway(path_id):
    genes = []
    pathway_file = REST.kegg_get(path_id).read()  # query and read each pathway

    # iterate through each KEGG pathway file, keeping track of which section
    # of the file we're in, only read the gene in each pathway
    current_section = None
    for line in pathway_file.rstrip().split("\n"):
        section = line[:12].strip()  # section names are within 12 columns
        if not section == "":
            current_section = section

        if current_section == "GENE":
            gene_identifiers, gene_description = line[12:].split("; ")
            gene_id, gene_symbol = gene_identifiers.split()

            if not gene_symbol in genes:
                genes.append(gene_symbol)
    return genes



In [None]:
n_ct = len(cell_types)
f, axes = plt.subplots(nrows=1,ncols=7, figsize=(20,5),gridspec_kw={'wspace':1})
k = 0
for i in range(7):
    ax = axes[i]
    genes_pos = get_genes_for_celltype(de_genes_age_minor_signif_high, cell_types[k],"pos")
    go_terms = sig_genes_GO_query(genes_pos,source=['GO','KEGG']).head(10)
    plot_GO_terms(go_terms, 1, "",xlims=[0,10],ax=ax)
    ax.set_title(cell_types[k])

    k += 1
    print("--------------")
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS5_go_pos.pdf",dpi=300,bbox_inches='tight')


In [None]:
n_ct = len(cell_types)
f, axes = plt.subplots(nrows=1,ncols=7, figsize=(20,5),gridspec_kw={'wspace':1})
#plt.figure(figsize=(20,20))
k = 0
for i in range(7):
    ax = axes[i]
    genes_pos = get_genes_for_celltype(de_genes_age_minor_signif_high, cell_types[k],"pos")
    go_terms = sig_genes_GO_query(genes_pos, source=['KEGG']).head(15)

    plot_GO_terms(go_terms, 1, "",xlims=[0,20],ax=ax)
    ax.set_title(cell_types[k])
    k += 1
    
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS5_go_pos.pdf",dpi=300,bbox_inches='tight')
#genes_neg = get_genes_for_celltype(de_genes_age_minor_signif, cell_types[i],"neg")

In [None]:
# get genes for each GO term

In [None]:
n_ct = len(cell_types)
f, axes = plt.subplots(nrows=2,ncols=5, figsize=(20,10),gridspec_kw={'wspace':1})
#plt.figure(figsize=(20,20))
k = 0
for i in range(5):
    for j in range(2):
        ax = axes[j,i]
        genes_pos = get_genes_for_celltype(de_genes_age_minor_signif, cell_types[k],"pos")
        go_terms = sig_genes_GO_query(genes_pos).head(15)

        plot_GO_terms(go_terms, 1, "",xlims=[0,10],ax=ax)
        ax.set_title(cell_types[k])
        k += 1
    
#plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS5_go_pos.pdf",dpi=300,bbox_inches='tight')
#genes_neg = get_genes_for_celltype(de_genes_age_minor_signif, cell_types[i],"neg")

In [None]:
n_ct = len(cell_types)
f, axes = plt.subplots(nrows=2,ncols=5, figsize=(20,10),gridspec_kw={'wspace':1})
#plt.figure(figsize=(20,20))
k = 0
for i in range(5):
    for j in range(2):
        ax = axes[j,i]
        genes_pos = get_genes_for_celltype(de_genes_age_minor_signif, cell_types[k],"neg")
        go_terms = sig_genes_GO_query(genes_pos).head(15)

        plot_GO_terms(go_terms, 1, "",xlims=[0,20],ax=ax)
        ax.set_title(cell_types[k])
        k += 1
plt.savefig("/home/user/Dropbox/zhuang_lab/aging/aging_atlas_paper/figures_int/figS5_go_neg.pdf",dpi=300,bbox_inches='tight')
    

#genes_neg = get_genes_for_celltype(de_genes_age_minor_signif, cell_types[i],"neg")