In [None]:
%load_ext autoreload
%autoreload 2
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import os
import anndata as ad
mpl.rcParams['figure.dpi'] = 150
plt.rcParams['pdf.fonttype'] = 42

import sys
from spatial_analysis import *
from plotting import *
from utils import *

In [None]:
adata = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/112921_merged_combined_merfish_allages_lps.h5ad")

In [None]:
sc.pp.calculate_qc_metrics(adata, qc_vars=[], percent_top=None, log1p=False, inplace=True)


In [None]:
plt.scatter(adata.obs.total_counts, adata.obs.n_genes_by_counts,s=0.1,alpha=0.1,c=np.array([int(i[:-2]) for i in adata.obs.age]))
plt.xlabel('Counts')
plt.ylabel('Genes')
plt.axvline(20,color='k')
plt.axhline(5,color='k')

In [None]:
# use scrublet
import scrublet as scr

all_doublet_scores = []
for i in adata.obs.batch.unique():
    print("Doubleting", i)
    curr_adata = adata[adata.obs.batch==i]
    scrub = scr.Scrublet(curr_adata.X)
    doublet_scores, predicted_doublets = scrub.scrub_doublets()
    all_doublet_scores.append(doublet_scores)
    scrub.plot_histogram()

In [None]:
adata.obs["doublet_scores"] = np.hstack(all_doublet_scores)

In [None]:
adata = adata[adata.obs.doublet_scores<0.2]

In [None]:
# remove cells < 100 um in volume or > 3 x median of all cells
median_vol = np.median(adata.obs.volume)
adata = adata[np.logical_and(adata.obs.volume >= 100, adata.obs.volume < 3*median_vol)]

In [None]:
sc.pp.filter_cells(adata, min_genes=5)
sc.pp.filter_cells(adata, min_counts=20)


In [None]:
# normalize counts by volume of cell
for i in range(adata.shape[0]):
    adata.X[i,:] /= adata.obs.volume[i]

    # We removed the cells that had total RNA counts lower than 2% quantile or higher than 98% quantile
norm_rna_counts = adata.X.sum(1)
quantile2 = np.quantile(norm_rna_counts, 0.02)
quantile98 = np.quantile(norm_rna_counts, 0.98)
adata = adata[np.logical_and(norm_rna_counts>=quantile2, norm_rna_counts<=quantile98)]
# then by sum
sc.pp.normalize_total(adata, target_sum=250)


In [None]:
sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts'],
             jitter=0.4, multi_panel=True)


In [None]:
print(np.median(adata.obs.n_genes_by_counts), np.median(adata.obs.total_counts))

In [None]:
sc.pp.log1p(adata)


In [None]:
adata = adata.raw.to_adata()

In [None]:
adata.raw = adata
#sc.pp.regress_out(adata, ['total_counts'])

sc.pp.scale(adata, max_value=10)

sc.tl.pca(adata, svd_solver='arpack', n_comps=25)


In [None]:
sc.pl.pca(adata, color=['total_counts','Vtn','Csf1r','Adora2a','Slc17a7','Slc32a1','Mbp','Cx3cr1','Gfap','C3', 'age', 'batch'],use_raw=True,cmap=plt.cm.Reds)

In [None]:
#sc.pp.neighbors(adata, n_neighbors=5)
import bbknn
bbknn.bbknn(adata, 'batch')

In [None]:
sc.tl.umap(adata)


In [None]:
sc.pl.umap(adata,color=['total_counts','Vtn','Csf1r','Adora2a','Slc17a7','Slc32a1','Mbp','Cx3cr1', 'age', 'batch', 'Il1b'])

In [None]:
tcell_markers = ["Tcrd",
"Tcrb",
"Ptprc",
"Rorc",
"Gata3",
"Foxp3",
"Tbx21",
"Il2ra",
"Il7r",
"Il2rb",
"Il2rg",
"Il15ra",
"Pdcd1",
"Ctla4",
"Cd3e"]
bcell_markers = [
    "Ms4a1",
    "Cd19",
    "Prdm1"
]

sc.tl.score_genes(adata, gene_list=tcell_markers,score_name='tcell')
sc.tl.score_genes(adata, gene_list=bcell_markers,score_name='bcell')

In [None]:
sc.pl.umap(adata[adata.obs.age=='24wk'],color=['Cd3e','tcell','bcell'],size=5, cmap=plt.cm.bwr,vmin=-1,vmax=1)

In [None]:
sc.pl.umap(adata[adata.obs.age=='90wk'],color=['Cd3e','tcell','bcell'],size=5, cmap=plt.cm.bwr,vmin=-1,vmax=1)

In [None]:
sc.pl.umap(adata[adata.obs.age=='90wk'],color=['total_counts','Cx3cr1', 'Cd3e', 'Il1b','Tnf','Cxcl10','Il6','Il33','Gfap','Serpina3n','C4b','C3','Foxj1','Ctss','Aqp4','C1qa','C1qc','Vtn','Flt1'],size=5, cmap=plt.cm.Reds)

In [None]:
sc.pl.umap(adata,color=['age','total_counts','Cx3cr1', 'Cd3e', 'Il1b','Tnf','Cxcl10','Il6','Il33','Gfap','Serpina3n','C4b','C3','Foxj1','Ctss','Aqp4','C1qa','C1qc','Vtn','Flt1'],size=1, cmap=plt.cm.Reds,use_raw=True)

In [None]:
#adata.write_h5ad("/faststorage/brain_aging/merfish/exported/112921_merged_lps_merfish_with_doublet_umap_allages.h5ad")

In [None]:
adata = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/112921_merged_lps_merfish_with_doublet_umap_allages.h5ad")
adata = adata.raw.to_adata()

In [None]:
# integrate with existing MERFISH data


In [None]:
#adata_combined = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/103121_adata_combined_harmony.h5ad")
adata_merfish = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_combined_merfish_with_doublet_umap_allages.h5ad")#adata_combined[adata_combined.obs.dtype=="merfish"]

In [None]:
# load labels for MERFISH data
adata_labeled = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/011722_adata_combined_harmony.h5ad")


In [None]:
adata_merfish = adata_merfish[adata_labeled[adata_labeled.obs.dtype=="merfish"].obs.index]

In [None]:
adata_merfish.shape

In [None]:
adata_merfish = adata_merfish.raw.to_adata()

In [None]:
sc.pl.umap(adata_merfish,color=['age','total_counts','Cx3cr1', 'Cd3e', 'Il1b','Tnf','Cxcl10','Il6','Il33','Gfap','Serpina3n','C4b','C3','Foxj1','Ctss','Aqp4','C1qa','C1qc','Vtn','Flt1'],size=1, cmap=plt.cm.Reds)

In [None]:
adata_merfish.obs['cond'] = 'ctrl'
adata.obs['cond'] = 'lps'


In [None]:
adata.obs['data_batch'] = adata.obs.batch + adata_merfish.obs.batch.max() + 1

In [None]:
adata_merfish.obs['data_batch'] = adata_merfish.obs.batch

In [None]:
adata_combined = adata_merfish.concatenate(adata)

In [None]:
adata_combined.raw = adata_combined
#sc.pp.regress_out(adata, ['total_counts', 'volume'])

sc.pp.scale(adata_combined, max_value=10)

sc.tl.pca(adata_combined, svd_solver='arpack', n_comps=25)


In [None]:
sc.pl.pca(adata_combined, color=['age','cond','data_batch'])

In [None]:
adata_combined.obs.data_batch = [str(i) for i in adata_combined.obs.data_batch]

In [None]:
sc.external.pp.harmony_integrate(adata_combined, 'data_batch')

In [None]:
adata_combined.obsm['X_pca_orig'] = adata_combined.obsm['X_pca']
adata_combined.obsm['X_pca'] = adata_combined.obsm['X_pca_harmony']

In [None]:
sc.pl.pca(adata_combined, color=['data_batch', 'batch','cond'])

In [None]:
import bbknn
#bbknn.bbknn(adata_combined, 'data_batch')
sc.pp.neighbors(adata_combined)

In [None]:
sc.tl.umap(adata_combined)

In [None]:
adata_combined

In [None]:
sc.pl.umap(adata_combined, color=['age','cond',],size=1, cmap=plt.cm.Reds, use_raw=False)

In [None]:
sc.pl.umap(adata_combined[adata_combined.obs.cond=='ctrl'], color=['age',],size=1, cmap=plt.cm.Reds, use_raw=False)

In [None]:
sc.pl.umap(adata_combined[adata_combined.obs.cond=='lps'], color=['age',],size=1, cmap=plt.cm.Reds, use_raw=False)

In [None]:
sc.pl.umap(adata_combined[adata_combined.obs.age=='90wk'], color=['age','cond','Cxcl10','Il33'],size=2, cmap=plt.cm.Reds, use_raw=False)

In [None]:
sc.tl.score_genes(adata_combined, gene_list=['B2m','Trem2', 'Ccl2', 'Apoe',  'Axl', 'Itgax', 'Cd9','C1qa','C1qc','Lyz2','Ctss'], score_name='activate_micro', use_raw=False)
sc.tl.score_genes(adata_combined, gene_list=['C4b', 'C3', 'Serpina3n', 'Cxcl10', 'Gfap', 'Vim', 'Il18','Hif3a'], score_name='activate_astro', use_raw=False)


In [None]:
sc.pl.umap(adata_combined[np.logical_and(adata_combined.obs.cond=='ctrl',adata_combined.obs.age=='90wk')], color=['age','cond','activate_micro','activate_astro', 'C3','tcell','Tnf'],size=5, cmap=plt.cm.Reds, use_raw=False)

In [None]:
sc.pl.umap(adata_combined[np.logical_and(adata_combined.obs.cond=='lps',adata_combined.obs.age=='24wk')], color=['age','cond','activate_micro','activate_astro','tcell'],size=5, cmap=plt.cm.Reds, use_raw=True)

In [None]:
clust_annots = {i:k for i,k in adata_labeled.obs.clust_annot.items()}
cell_types = {i:k for i,k in adata_labeled.obs.cell_type.items()}
adata_combined.obs['clust_annot'] = [clust_annots[i[:-2]] if i[:-2] in clust_annots else "Unlabeled" for i in adata_combined.obs.index]
adata_combined.obs['cell_type'] = [cell_types[i[:-2]] if i[:-2] in cell_types else "Unlabeled" for i in adata_combined.obs.index]

In [None]:
# train classifier in PCA space to transfer labels
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier

In [None]:
adata_ctrl = adata_combined[adata_combined.obs.cond == "ctrl"]
print(adata_ctrl.obs.clust_annot.unique())
mdl_cell_type = MLPClassifier().fit(adata_ctrl.obsm["X_pca"], adata_ctrl.obs.cell_type)
mdl_clust_annot = MLPClassifier().fit(adata_ctrl.obsm["X_pca"], adata_ctrl.obs.clust_annot)

In [None]:
clust_preds_proba = mdl_clust_annot.predict_proba(adata_combined[adata_combined.obs.cond=="lps"].obsm['X_pca'])
#clust_preds = mdl_clust_annot.predict(adata_combined.obsm['X_pca'])

In [None]:
cell_type_preds_proba = mdl_cell_type.predict_proba(adata_combined[adata_combined.obs.cond=="lps"].obsm['X_pca'])
#cell_type_preds = mdl_cell_type.predict(adata_combined.obsm['X_pca'])

In [None]:
adata_combined.obs['cell_type_preds'] = ['Unlabeled']*adata_combined.shape[0]
adata_combined.obs['clust_annot_preds'] = ['Unlabeled']*adata_combined.shape[0]

In [None]:
adata_combined.obs.loc[adata_combined.obs.cond == "ctrl",'cell_type_preds'] = adata_combined.obs.loc[adata_combined.obs.cond == "ctrl",'cell_type']
adata_combined.obs.loc[adata_combined.obs.cond == "ctrl",'clust_annot_preds'] = adata_combined.obs.loc[adata_combined.obs.cond == "ctrl",'clust_annot']


In [None]:
adata_combined.obs.loc[adata_combined.obs.cond == "lps",'cell_type_preds'] = [mdl_cell_type.classes_[i] for i in np.argmax(cell_type_preds_proba,1)]
adata_combined.obs.loc[adata_combined.obs.cond == "lps",'clust_annot_preds'] = [mdl_clust_annot.classes_[i] for i in np.argmax(clust_preds_proba,1)]
adata_combined.obs.loc[adata_combined.obs.cond == "lps",'cell_type_preds_prob'] = cell_type_preds_proba.max(axis=1)
adata_combined.obs.loc[adata_combined.obs.cond == "lps",'clust_annot_preds_prob'] = clust_preds_proba.max(axis=1)


In [None]:
celltype_colors, celltype_pals, label_colors, clust_pals = generate_palettes(adata_combined,clust_key='clust_annot_preds',cell_type_key='cell_type_preds')

In [None]:
sc.pl.umap(adata_combined[adata_combined.obs.cond=="ctrl"], color=['age','clust_annot_preds'],size=1, palette=clust_pals, use_raw=True)

In [None]:
sc.pl.umap(adata_combined[adata_combined.obs.cond=="lps"], color=['age','clust_annot_preds'],size=1, palette=clust_pals, use_raw=True)

In [None]:
sc.pl.umap(adata_combined[adata_combined.obs.cond=="ctrl"], color=['age','cond','cell_type_preds','cell_type_preds_prob','clust_annot_preds','clust_annot_preds_prob'],size=1, cmap=plt.cm.Reds, use_raw=True,vmin=0.8)

In [None]:
sc.pl.umap(adata_combined[adata_combined.obs.cond=="lps"], color=['age','cond','cell_type_preds','cell_type_preds_prob','clust_annot_preds','clust_annot_preds_prob'],size=1, cmap=plt.cm.Reds, use_raw=True)

In [None]:
#adata_combined.write_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_lps_ctrl_allages.h5ad")

# Segment neighborhoods

In [None]:
#adata_combined = ad.read_h5ad("/faststorage/brain_aging/merfish/exported/11291_merged_lps_ctrl_allages.h5ad")

In [None]:
celltype_colors, celltype_pals, label_colors, clust_pals = generate_palettes(adata_combined,clust_key='clust_annot_preds',cell_type_key='cell_type_preds')

In [None]:
# quick check of clel type numbers
for i in adata_combined.obs.cell_type_preds.unique():
    print(i, 100*np.sum(adata_combined[adata_combined.obs.cond=='ctrl'].obs.cell_type_preds==i)/adata_combined[adata_combined.obs.cond=='ctrl'].obs.shape[0],
          100*np.sum(adata_combined[adata_combined.obs.cond=='lps'].obs.cell_type_preds==i)/adata_combined[adata_combined.obs.cond=='lps'].obs.shape[0])

In [None]:
# deal with spatial info
coords = np.array(adata_combined.obs[["center_x", "center_y"]]).astype(np.float64)
adata_combined.obsm['spatial'] = coords

In [None]:
adata_combined.obs.data_batch.unique()

In [None]:
# assign points to slices
from sklearn.cluster import KMeans
# number of slices for eachbatch
nslices = {
    0 : 1,
    1 : 2,
    2 : 2,
    3 : 3,
    4 : 3,
    5 : 3,
    6 : 3,
    7 : 3,
    8 : 3,
    9 : 4,
    10 : 3,
    11 : 2,
    12 : 2,
    # LPS
    13 : 3,
    14 : 2, 
    15 : 2,
    16 : 2,
    17 : 2,
    18 : 2,
    19 : 3
} 
slice_labels = []
adata_combined.obs["slice"] = 0
for i in list(adata_combined.obs.data_batch.unique()):
    curr_adata = adata_combined[adata_combined.obs.data_batch==str(i)]
    pos = curr_adata.obsm['spatial']
    lbl = KMeans(random_state=42, n_clusters=nslices[int(i)]).fit_predict(pos)
    #slice_labels.extend(lbl)
    print(pos.shape, curr_adata.shape)
    adata_combined.obs.loc[curr_adata.obs.index, "slice"] = lbl
    
#    plt.figure()
#    plt.scatter(curr_adata.obs.center_x, curr_adata.obs.center_y, s=1, c=lbl)
#adata_annot.obs["slice"] = slice_labels

In [None]:
for i in list(adata_combined.obs.data_batch.unique()):
    curr_adata = adata_combined[adata_combined.obs.data_batch==str(i)]
    pos = curr_adata.obsm['spatial']
    plt.figure(figsize=(10,10))
    plt.title(i)
    plt.scatter(pos[:,0], pos[:,1], s=0.1, c=curr_adata.obs.slice)


In [None]:
# adjust coordinates so that each brain section is far away from others 
# (a bit of a hack for neighborhood graph computation)
coords = []
index = []
n = 0
for i,b in enumerate(adata_combined.obs.data_batch.unique()):
    print('--')
    curr_adata = adata_combined[adata_combined.obs.data_batch==b]
    for j,s in enumerate(sorted(curr_adata.obs.slice.unique())):
        print(s)
        curr_slice = curr_adata[curr_adata.obs.slice==s]
        curr_coords = curr_slice.obsm['spatial']#np.vstack((curr_slice.obs.center_x, curr_slice.obs.center_y)).T
        #curr_coords = curr_slice.obsm['spatial']
        curr_coords += n*1e5
        plt.figure()
        plt.scatter(curr_coords[:,0], curr_coords[:,1], s=1)
        n += 1
        coords.append(curr_coords)
        index.extend(list(curr_slice.obs.index))
#adata_combined[index,:].obsm['spatial'] = np.vstack(coords)

In [None]:
adata_combined = adata_combined[index]
adata_combined.obsm['spatial'] = np.vstack(coords)

In [None]:
x = adata_combined.obsm['spatial'][:,0]
y = adata_combined.obsm['spatial'][:,1]
plt.plot(x,y,'k.')

In [None]:
# cluster layers

In [None]:
from spatial_analysis import compute_neighborhood_stats

In [None]:
nbor_stats = compute_neighborhood_stats(adata_combined.obsm['spatial'], adata_combined.obs.clust_annot_preds,radius=100)

In [None]:
nbor_stats[np.isnan(nbor_stats)] = 0

In [None]:
from sklearn.decomposition import PCA
xform = PCA(random_state=50).fit_transform(nbor_stats)

In [None]:
# get spatial clust annots
#adata_combined.obs["spatial_clust_annots"] = "Unlabeled"
#adata_combined.obs.loc[adata_combined.obs.cond=="ctrl", "spatial_clust_annots"] = list(adata_labeled[adata_labeled.obs.dtype=="merfish"].obs.spatial_clust_annots)
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=20, random_state=42).fit_predict(xform)
adata_combined.obs['kmeans'] = kmeans

In [None]:
plt.scatter(xform[:,0],xform[:,1],s=1, c=kmeans, cmap=mpl.colors.ListedColormap(np.vstack(label_colors.values())))

In [None]:
curr_adata = adata_combined[np.logical_and(adata_combined.obs.data_batch=='16', adata_combined.obs.slice==0)]
pos = curr_adata.obsm['spatial']
plt.scatter(pos[:,0], pos[:,1],s=1, c=curr_adata.obs.kmeans, cmap=plt.cm.nipy_spectral)
plt.legend()

In [None]:
def plot_clust(A,clust_name, ax,s=0.1,key='kmeans'):
    pos = curr_adata.obsm['spatial']
    ax.scatter(pos[:,0], pos[:,1],s=1, c='gray')
    ax.scatter(pos[A.obs[key]==clust_name,0], pos[A.obs[key]==clust_name,1],s=s, c='r')
    ax.axis('off')
    ax.set_title(clust_name)

In [None]:
# count cell types per kmeans clust
clust_counts = np.vstack(adata_combined.obs.groupby('kmeans').apply(lambda x: [np.sum(x.clust_annot_preds==i) for i in sorted(adata_combined.obs.clust_annot_preds.unique())]).reset_index()[0])
clust_avgs = np.zeros((kmeans.max()+1, nbor_stats.shape[1]))
for i in sorted(np.unique(kmeans)):
    clust_avgs[i,:] = nbor_stats[kmeans==i,:].mean(0)
for i in range(clust_avgs.shape[1]):
    clust_avgs[:,i] = zscore(clust_avgs[:,i])
    
    # hierarchically cluster 
from scipy.spatial.distance import pdist
import scipy.cluster.hierarchy as hc

D = pdist(clust_avgs,'correlation')
Z = hc.linkage(D,'complete',optimal_ordering=True)
dn = hc.dendrogram(Z)
#lbl_order = [clust_ids[c] for c in dn['leaves']]

f, ax = plt.subplots(figsize=(5,2))
ax.imshow(clust_avgs[ dn['leaves']],aspect='auto',vmin=-5,vmax=5, cmap=plt.cm.seismic)
#for i in range(clust_counts.shape[0]):
    #ax.scatter(np.arange(clust_counts.shape[1]), i*np.ones(clust_counts.shape[1]), s=0.005*clust_counts[i,:],c='k')
ax.set_xticks(np.arange(clust_counts.shape[1]));
ax.set_yticks(np.arange(clust_counts.shape[0]));
ax.set_yticklabels(dn['leaves'],fontsize=6)
ax.set_xticklabels(sorted(adata_combined.obs.clust_annot_preds.unique()),rotation=90,fontsize=6);

In [None]:
def crosstab_spatial_clusts(A):
    temp = pd.crosstab(index=A.obs.kmeans,columns=A.obs.spatial_clust_annots, normalize=True).idxmax(axis=1)
    for i in temp.iteritems():
        print(f"{i[0]} : \"{i[1]}\",")


In [None]:
idx = [i+"-0" for i in adata_labeled[adata_labeled.obs.dtype=='merfish'].obs.index]

In [None]:
adata_combined.obs.loc[idx, "spatial_clust_annots"] = list(adata_labeled[adata_labeled.obs.dtype=='merfish'].obs.spatial_clust_annots)

In [None]:
import pandas as pd
crosstab_spatial_clusts(adata_combined[adata_combined.obs.spatial_clust_annots!=""])

In [None]:
spatial_clust_annots = {
0 : "L5",
1 : "L6",
2 : "LatSept",
3 : "Striatum",
4 : "L2/3",
5 : "L5",
6 : "Striatum",
7 : "CC",
8 : "Pia",
9 : "CC",
10 : "L2/3",
11 : "Ventricle",
12 : "LatSept",
13 : "L6",
14 : "L6",
15 : "L6",
16 : "Pia",
17 : "Striatum",
18 : "CC",
19 : "L5",
    
}
spatial_clust_annots_values = {
    'Pia' : 0,
    'L2/3' : 1, 
    'L5' : 2,
    'L6' : 3, 
    'LatSept' : 4,
    'CC' : 5,
    'Striatum' : 6,
    'Ventricle' : 7
    }

In [None]:
adata_combined.obs['spatial_clust_annots'] = [spatial_clust_annots[i] for i in adata_combined.obs.kmeans]
adata_combined.obs['spatial_clust_annots_value'] = [spatial_clust_annots_values[i] for i in adata_combined.obs.spatial_clust_annots]

In [None]:
curr_adata = adata_combined[np.logical_and(adata_combined.obs.data_batch=='19', adata_combined.obs.slice==1)]

plt.figure(figsize=(20,20))
for i in range(curr_adata.obs.spatial_clust_annots_value.max()+1):
    ax = plt.subplot(4,5,i+1)
    plot_clust(curr_adata,i,ax,key='spatial_clust_annots_value')

In [None]:
adata_combined.write_h5ad("/faststorage/brain_aging/merfish/exported/011722_merged_lps_ctrl_allages.h5ad")

In [None]:
sc.pl.dotplot(adata_combined[np.logical_and(adata_combined.obs.age=="90wk", adata_combined.obs.cond=="ctrl")],["Cxcl10","Tnf","Il1b","Il6","Ifng","C4b","C3","Gfap","Il33",'Serpina3n','Ifit3','Xdh'],groupby='cell_type_preds')

In [None]:
sc.pl.dotplot(adata_combined[np.logical_and(adata_combined.obs.age=="90wk", adata_combined.obs.cond=="lps")],["Cxcl10","Tnf","Il1b","Il6","Ifng","C4b","C3","Gfap","Il33",'Serpina3n','Ifit3','Xdh'],groupby='cell_type_preds')

In [None]:
gene_name = 'Il33'
curr_adata = adata_combined[np.logical_and(adata_combined.obs.data_batch=='15', adata_combined.obs.slice==1)]
pos = curr_adata.obsm['spatial']
print(curr_adata.obs.age.unique())
print(curr_adata.obs.cond.unique())

expr = curr_adata[:,gene_name].X.toarray()
plt.scatter(pos[:,0], pos[:,1],s=0.1, c=expr, cmap=plt.cm.Reds,vmin=0,vmax=3)
plt.legend()

In [None]:
clust_encoding = {}
for i,v in enumerate(adata_combined.obs.clust_annot_preds.unique()):
    clust_encoding[v] = i

celltype_encoding = {}
for i,v in enumerate(adata_combined.obs.cell_type_preds.unique()):
    celltype_encoding[v] = i
    

adata_combined.obs["clust_id"] = [clust_encoding[i] for i in adata_combined.obs.clust_annot_preds]
adata_combined.obs["celltype_id"] = [celltype_encoding[i] for i in adata_combined.obs.cell_type_preds]
clust_encoding = {k:i for i,k in enumerate(label_colors.keys())}
adata_combined.obs['clust_encoding'] = [clust_encoding[i] for i in adata_combined.obs.clust_annot_preds]

In [None]:
curr_adata = adata_combined[np.logical_and(adata_combined.obs.data_batch=='15', adata_combined.obs.slice==1)]
print(curr_adata.obs.age.unique())
curr_cmap = mpl.colors.ListedColormap([celltype_colors[i] for i in celltype_colors.keys()])

plot_clust_subset(curr_adata, ["Micro"], curr_cmap, clust_key="cell_type_preds",s=1)

# Cell-cell interactions

In [None]:
niter = 1000
perturb_max = 100
dist_thresh = 20
#celltypes = adata_annot.obs.remapped_cell_type.unique()
celltypes = [
    'InN',
 'ExN',
 'MSN',
 'Astro',
 'OPC',
 'Olig',
 'Epen',
 'Endo',
 'Vlmc',
 'Peri',
 'Macro',
 'Micro',
]


In [None]:
#celltypes = sorted(adata_annot.obs.cell_type.unique())
adata_lps = adata_combined[adata_combined.obs.cond=='lps']


In [None]:
young_interactions_clust, young_pvals_clust, young_qvals_clust = compute_celltype_interactions(adata_lps[adata_lps.obs.age=='4wk'], 
                                                                'cell_type_preds', celltypes,niter=niter,dist_thresh=dist_thresh,perturb_max=perturb_max)
med_interactions_clust, med_pvals_clust, med_qvals_clust = compute_celltype_interactions(adata_lps[adata_lps.obs.age=='24wk'], 
                                                                'cell_type_preds', celltypes,niter=niter,dist_thresh=dist_thresh,perturb_max=perturb_max)
old_interactions_clust, old_pvals_clust, old_qvals_clust = compute_celltype_interactions(adata_lps[adata_lps.obs.age=='90wk'], 
                                                                'cell_type_preds', celltypes,niter=niter,dist_thresh=dist_thresh,perturb_max=perturb_max)


In [None]:
sns.set_style('white')

In [None]:
young_qvals_clust = fdr_correct(young_pvals_clust.copy())
med_qvals_clust = fdr_correct(med_pvals_clust.copy())
old_qvals_clust = fdr_correct(old_pvals_clust.copy())


In [None]:
young_qvals_clust[np.isnan(young_qvals_clust)] = 1
med_qvals_clust[np.isnan(med_qvals_clust)] = 1
old_qvals_clust[np.isnan(old_qvals_clust)] = 1

In [None]:
f = plot_interactions(young_qvals_clust, young_interactions_clust, celltypes,celltype_colors,cmap=plt.cm.seismic,vmax=1.5, vmin=-1.5)


In [None]:
f = plot_interactions(old_qvals_clust, old_interactions_clust, celltypes,celltype_colors,cmap=plt.cm.seismic,vmax=1.5, vmin=-1.5)
