In [1]:
import anndata as ad
import pandas as pd
import scanpy as sc
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors
import seaborn as sns

In [2]:
# Load MERFISH data
merData = ad.read_h5ad("../data/merfish_609882_AIT17.1_annotated_TH_ZI_only_2023-02-16_00-00-00/atlas_brain_609882_AIT17_1_annotated_TH_ZI_only.h5ad")

# Subset to neuronal
merData = merData[merData.obs["division_id_label"].isin(["3 PAL-sAMY-TH-HY-MB-HB neuronal","2 Subpallium GABAergic","4 CBX-MOB-other neuronal"])]

In [3]:
# Load sequencing data
seqData = ad.read_h5ad("../data/rnaseq_AIT17.2_2022-12-15_12-00-00/rnaseq/processed.U19_TH-EPI.postQC.h5ad")

## Explore sequencing data

In [4]:
for name in seqData.obs["Level2_id_label"].unique():
    print(name)

In [4]:
def getColorDict(annData,level):
    colorDict = {}
    a = 0
    for group in annData.obs[level].unique():
        colorDict[group] = matplotlib.colors.to_rgb(seqData.uns[level + "_colors"][a][0])
        a = a + 1
        
    return colorDict

In [6]:
seqPalette = getColorDict(seqData,"Level2_id_label")
sc.pl.embedding(seqData,"EMBED2",color = "Level2_id_label", palette = seqPalette)

In [7]:
# Subset to Prkcd subclass
prkcdData = seqData[seqData.obs["Level2_id_label"].isin(["69_TH Prkcd Grin2c Glut"])]

In [8]:
# Prkcd supertype
sc.pl.embedding(prkcdData,"EMBED2",color = "supertype_id_label")

In [9]:
prkcdPalette = getColorDict(prkcdData,"cluster_label")

sc.pl.embedding(prkcdData,"EMBED2",color = "cluster_label", palette = prkcdPalette)

In [10]:
sc.pl.embedding(prkcdData,"EMBED2",color = "Cbln1")

In [11]:
sc.pl.embedding(seqData,"EMBED2",color = "Cbln2")

## Explore MERFISH data

In [5]:
def countTypes(annData,label):
    # Quick function to check an AnnData object for the number of cells belonging to each of a given label
    countTypes = annData.obs[label].unique()
    countDF = pd.DataFrame([sum(annData.obs[label] == name) for name in countTypes], 
                             columns= ["Count"], index = countTypes)
    countDF = countDF.sort_values(by = "Count", ascending=False)
    return countDF

def plotGenes(genes, section = '1198980080'):
    """ Plot spatial expression of a requested list of genes for a given section of MERFISH data """
    
    # Clean inputs
    if type(genes) != list:
        genes = [genes]
    
    # Functionality to call sections by name or sequentially iterate through them
    uniqueSections = merData.obs["section"].unique()
    if section not in uniqueSections:
        if int(section) < len(uniqueSections):
            section = uniqueSections[int(section)]
        else:
            raise Exception("Unrecognized section number")
    
    # Subset to requested section
    merSection = merData[merData.obs["section"] == section]
    
    # Set up plot parameters
    nGenes = len(genes)
    nCol = int(np.floor(np.sqrt(nGenes)))
    nRow = int(np.ceil(nGenes / nCol))
    
    # Make smaller if requesting lots of plots
    if nGenes <= 4:
        plt.figure(figsize=(nRow*6.4,nCol*4.8))
    else:
        plt.figure(figsize=(nRow*3.2,nCol*2.4))
        
    # Plot expression for each gene
    for count, gene in enumerate(genes):
        
        # Set colormap
        cnorm = merSection[:,gene].X / merSection[:,gene].X.max() # Normalize to maximal gene expression
        color = plt.cm.viridis(cnorm)

        # Plot
        ax = plt.subplot(nRow,nCol,count+1)
        plt.subplots_adjust(wspace=1, hspace=0.1)
        plt.scatter(merSection.obsm["spatial_cirro"][:,0],merSection.obsm["spatial_cirro"][:,1], color = color, s = 15, edgecolors = "black", linewidth = .2)
        # plt.title(gene);

        # Adjust colorbar to show actual expression levels
        cbar = plt.colorbar()
        cbar.ax.set_yticks(np.linspace(0,1,6))
        cbar.ax.set_yticklabels(np.round(np.linspace(0,1,6) * merSection[:,gene].X.max(),1));
        # plt.clim(0,6)

        plt.tight_layout()
        plt.title(gene)
        plt.axis('off')
        
def plotAll(level = "cluster_label",group = "1180 RE-Xi Nox4 Glut_3", color = 'cyan'):
    """ Function to plot spatial arrangement of cells across all sections. """
    
    # Check that level and group pairings match, e.g. did not ask for a cluster within subclass .obs
    if group not in merData.obs[level].unique():
        raise Exception("Requested group is not within the specified annData object .obs level.")
        
    # Subset to requested section
    groupData = merData[merData.obs[level] == group]
    
    # Plot
    plt.figure(figsize = (6.4* 9,4.8 * 4))
    plt.scatter(merData.obsm["spatial_cirro"][:,0],merData.obsm["spatial_cirro"][:,1],color = 'gray', s = 15, edgecolors = "black", linewidth = .2,zorder = 0)
    plt.scatter(groupData.obsm["spatial_cirro"][:,0],groupData.obsm["spatial_cirro"][:,1],color = color, s = 15, edgecolors = "black", linewidth = .2)
    
    plt.title(group, fontsize = 48)
    plt.axis('off')
    plt.tight_layout()

In [13]:
plotGenes(["Gad2","Cbln2","Prkcd","Calb2"], 8)

In [14]:
plotAll(level = "Level2_id_label",group = "68_CM-IAD-CL-PCN Glut")

In [15]:
# Subset to particular section
merSection = merData[merData.obs["section"] == merData.obs["section"].unique()[8]]

# Count Prkcd clusters within data
countTypes(merSection,"Level2_id_label").head(10)

In [16]:
# Subset to particular section
merSection = merData[merData.obs["section"] == merData.obs["section"].unique()[8]]

# Subset to midline data
merSection = merSection[merSection.obs["Level2_id_label"] == "68_CM-IAD-CL-PCN Glut"]

# Count midline clusters within data
countTypes(merSection,"cluster_label").head(20)

In [6]:
def plotSectionClusters(subclass = "68_CM-IAD-CL-PCN Glut", sections = ['1198980086', '1198980092','1198980095', '1198980098'],palette = "Spectral_r", subset = 0):
    """ Function to plot all clusters from a specific subclass in a given slice, with a chosen color palette."""
    
    # Subset to requested slice
    merSection = merData[merData.obs["section"].isin(sections)]
    
    # Want to create a dictionary of each subclass contained within a given class
    # First create unique tuples with each class and subclass pair
    groupPairs = set(zip(merSection.obs["Level2_id_label"],merSection.obs["cluster_label"]))

    # Then initialize a dictionary where each subclass is a key corresponding to an empty list
    classDict = {pair[0]: [] for pair in groupPairs}

    # Add each cluster to the corresponding class list
    for pair in groupPairs:
        classDict[pair[0]].append(pair[1])

    # Further subset to top N clusters, if requested
    if subset:
        # Subset to only requested subclass
        onlySubclass = merSection[merSection.obs["cluster_label"].isin(classDict[subclass])]
        # Take top N clusters from that subclass
        toColor = list(countTypes(onlySubclass,"cluster_label").head(subset).index)
    else:
        toColor = classDict[subclass]

    # Create a dictionary where each cluster corresponds to a color
    colorDict = {}
    colorCount = 0

    # Create the appropriate colormap by blending the requested palette for the appropriate number of clusters
    colorBlend = sns.color_palette(palette, n_colors=len(toColor))
#     sns.palplot(colorBlend)

    # Populate color dictionary
    for pair in groupPairs:

        # Assign colors to requested subclass, otherwise color black
        if pair[1] in toColor:
            colorDict[pair[1]] = colorBlend[colorCount]
            # print(pair[1]) 
            colorCount += 1
        else:
            colorDict[pair[1]] = 'black'

    # Prepare plotting
    nSections = len(sections)
    nCol = int(np.floor(np.sqrt(nSections)))
    nRow = int(np.ceil(nSections / nCol))
    
    # Make plots smaller if lots of figures
    if nSections <= 4:
        plt.figure(figsize=(nRow*6.4,nCol*4.8))
    else:
        plt.figure(figsize=(nRow*3.2,nCol*2.4))
            
    # Plot each slice successively
    for count, tissue in enumerate(sections):
        
        ax = plt.subplot(nRow,nCol,count+1)
        plt.subplots_adjust(wspace=.1, hspace=.1)
            
        merSection = merData[merData.obs["section"] == tissue]
        for cl in merSection.obs["cluster_label"].unique():
            x = merSection[merSection.obs["cluster_label"] == cl]
            if colorDict[cl] == "black":
                plt.scatter(x.obsm["spatial_cirro"][:,0],x.obsm["spatial_cirro"][:,1],color = colorDict[cl], s = 15, edgecolors = "black", linewidth = .2,zorder = 0)
            else:
                plt.scatter(x.obsm["spatial_cirro"][:,0],x.obsm["spatial_cirro"][:,1],color = colorDict[cl], s = 15, edgecolors = "black", linewidth = .2)
        plt.title("Section#: " + tissue)
        plt.axis('off')
        plt.tight_layout()
        
    return colorDict

In [23]:
plotSectionClusters(subclass = "68_CM-IAD-CL-PCN Glut");

In [20]:
# Note clusters within these sections
clList = ["1129 CM-IAD-CL-PCN Glut_5", "1116 CM-IAD-CL-PCN Glut_1", "1123 CM-IAD-CL-PCN Glut_3", "1118 CM-IAD-CL-PCN Glut_1",
      "1119 CM-IAD-CL-PCN Glut_2", "1120 CM-IAD-CL-PCN Glut_2", "1121 CM-IAD-CL-PCN Glut_2", "1128 CM-IAD-CL-PCN Glut_5",
      "1125 CM-IAD-CL-PCN Glut_4", "1117 CM-IAD-CL-PCN Glut_1", "1115 CM-IAD-CL-PCN Glut_1", "1126 CM-IAD-CL-PCN Glut_4",
      "1124 CM-IAD-CL-PCN Glut_3", "1127 CM-IAD-CL-PCN Glut_4", "1122 CM-IAD-CL-PCN Glut_3"]

In [21]:
# Look at sequencing data for those clusters within those sections

# Note clusters within these sections
clList = ["1129 CM-IAD-CL-PCN Glut_5", "1116 CM-IAD-CL-PCN Glut_1", "1123 CM-IAD-CL-PCN Glut_3", "1118 CM-IAD-CL-PCN Glut_1",
      "1119 CM-IAD-CL-PCN Glut_2", "1120 CM-IAD-CL-PCN Glut_2", "1121 CM-IAD-CL-PCN Glut_2", "1128 CM-IAD-CL-PCN Glut_5",
      "1125 CM-IAD-CL-PCN Glut_4", "1117 CM-IAD-CL-PCN Glut_1", "1115 CM-IAD-CL-PCN Glut_1", "1126 CM-IAD-CL-PCN Glut_4",
      "1124 CM-IAD-CL-PCN Glut_3", "1127 CM-IAD-CL-PCN Glut_4", "1122 CM-IAD-CL-PCN Glut_3"]

# seqPalette = getColorDict(seqData,"Level2_id_label")
ilmData = seqData[seqData.obs["Level2_id_label"] == "68_CM-IAD-CL-PCN Glut"]
ilmPalette = getColorDict(ilmData,"cluster_label") # from previous sequencing section

sc.pl.embedding(seqData,"EMBED2",color = "Level2_id_label", palette = seqPalette)
sc.pl.embedding(ilmData,"EMBED2",color = "cluster_label", palette = ilmPalette)

## Plot spatial arrangement of clusters

In [7]:
def plotSpatial(level = "cluster_label",group = "1180 RE-Xi Nox4 Glut_3", section = False, color = 'cyan'):
    """ Function to plot the spatial layout of a requested group. """
    
    # Check that level and group pairings match, e.g. did not ask for a cluster within subclass .obs
    if group not in merData.obs[level].unique():
        raise Exception("Requested group is not within the specified annData object .obs level.")
    
    # If specific section is requested, use that. Otherwise use section with greatest number of cells from that group.
    if section:
        section = checkSectionNumber(merData,section)
    else:
        section = merData[merData.obs[level] == group].obs["section"].mode()[0]
        
    # Subset to requested section
    sectionData = merData[merData.obs["section"] == section]
    groupData = sectionData[sectionData.obs[level] == group]
    
    # Plot
    plt.figure(dpi=120)
    plt.scatter(sectionData.obsm["spatial_cirro"][:,0],sectionData.obsm["spatial_cirro"][:,1],color = 'gray', s = 15, edgecolors = "black", linewidth = .2,zorder = 0)
    plt.scatter(groupData.obsm["spatial_cirro"][:,0],groupData.obsm["spatial_cirro"][:,1],color = color, s = 15, edgecolors = "black", linewidth = .2)
    
    plt.title(group + ": section " + section)
    
def checkSectionNumber(annData,section):
    """ Function to convert indexed section calls, e.g. section = 1, 2, 3 rather than ID 1198980092 """
    
    uniqueSections = annData.obs["section"].unique()
    # If section call isn't within list of section names, assume it's an index and convert appropriately.
    if section not in uniqueSections:
        if int(section) < len(uniqueSections):
            section = uniqueSections[int(section)]
        else:
            raise Exception("Unrecognized section number")
    return section

In [25]:
x = ["69_TH Prkcd Grin2c Glut", "81_RT ZI Gnb3 Gaba", "68_CM-IAD-CL-PCN Glut","70_RE-Xi Nox4 Glut"]
z = ["cyan", "red", "orange","green"]
for i in range(len(x)):
    plotSpatial(level = "Level2_id_label",group = x[i], section = 5, color = z[i])

## Correspond MERFISH to sequencing data

In [11]:
plotSpatial(level = "Level2_id_label",group = "70_RE-Xi Nox4 Glut",color="lightgreen")

### Spatial arrangement of reuniens clusters

In [12]:
# Check for counts of clusters present within reuniens subclass
reData = merData[merData.obs["Level2_id_label"] == "70_RE-Xi Nox4 Glut"]
reDF = countTypes(reData,"cluster_label")
reDF

In [31]:
# Count subclass occurrences in each section
reTopSections = merData[merData.obs["Level2_id_label"] == "70_RE-Xi Nox4 Glut"].obs["section"].value_counts().head(9)
reTopSections = reTopSections.sort_index()

# Plot top 9 sections (ordered Anterior - Posterior)
reColorDict = plotSectionClusters(subclass = "70_RE-Xi Nox4 Glut", sections = reTopSections.index,palette = "Spectral")

In [32]:
# Look at spatial arrangement of each of these clusters

# # Generate different colors for each cluster (actually just use color dict established in the above step)
# reColor = sns.color_palette("Spectral", n_colors=len(reDF))

for cluster in reDF.index:
    plotSpatial(level = "cluster_label",group = cluster,color=reColorDict[cluster]) # By default will plot the section where each cluster is most abundant

### UMAPs

In [33]:
# Plot subclass UMAPS
subclassColorDict = {key: 'lightgray' for key in seqData.obs["Level2_id_label"].unique()}
subclassColorDict["70_RE-Xi Nox4 Glut"] = "black"
sc.pl.embedding(seqData,"EMBED2",color = "Level2_id_label", palette = subclassColorDict, size = 1)

# Plot Reuniens Clusters
seqData.obs["RE_clusters"] = seqData.obs["cluster_label"]
seqData.obs["RE_clusters"] = seqData.obs["RE_clusters"].cat.set_categories(reDF.index)
sc.pl.embedding(seqData,"EMBED2",color = "RE_clusters", palette = reColorDict, size = 1)

# Plot just Reuniens Clusters
reData = seqData[seqData.obs["Level2_id_label"] == "70_RE-Xi Nox4 Glut"]
sc.pl.embedding(reData,"EMBED2",color = "RE_clusters", palette = reColorDict)

### Differential expression

In [34]:
sc.tl.rank_genes_groups(reData,groupby="cluster_label",reference = "rest")
sc.pl.rank_genes_groups(reData, n_genes=10, fontsize = 16, sharey=False)

In [47]:
geneList = [reData.uns["rank_genes_groups"]["names"][name][0] for name in reData.uns["rank_genes_groups"]["names"].dtype.names]
geneList

In [58]:
sc.pl.stacked_violin(reData, geneList, groupby='cluster_label',figsize=[12,6])

In [62]:
sc.pl.correlation_matrix(reData,groupby="cluster_label",show_correlation_numbers=True)

### Other stuff

In [8]:
# Check for counts of clusters present within reuniens subclass
avData = merData[merData.obs["Level2_id_label"] == "66_AV Col27a1 Glut"]
avDF = countTypes(avData,"cluster_label")
avDF

In [14]:
sc.tl.rank_genes_groups(avData,groupby="cluster_label",reference = "rest",pts=True)
sc.pl.rank_genes_groups(avData, n_genes=10, fontsize = 16, sharey=False)

In [15]:
# Flatten to a single array
geneList = np.concatenate([avData.uns["rank_genes_groups"]["names"][name][0:3] for name in avData.uns["rank_genes_groups"]["names"].dtype.names])
geneList

In [129]:
sc.pl.stacked_violin(avData, geneList, groupby='cluster_label',figsize=[6,4])
sc.pl.dotplot(avData, geneList, groupby='cluster_label',figsize=[6,4])

In [19]:
avSeqData = seqData[seqData.obs["Level2_id_label"] == "66_AV Col27a1 Glut"]
sc.tl.rank_genes_groups(avSeqData,groupby="cluster_label",reference = "rest",pts=True)
sc.pl.rank_genes_groups(avSeqData, n_genes=20, fontsize = 16, sharey=False)

In [20]:
seqGeneList = np.concatenate([avSeqData.uns["rank_genes_groups"]["names"][name][0:3] for name in avSeqData.uns["rank_genes_groups"]["names"].dtype.names])
seqGeneList

In [134]:
sc.pl.stacked_violin(avSeqData, seqGeneList, groupby='cluster_label',figsize=[6,4])
sc.pl.dotplot(avSeqData, seqGeneList, groupby='cluster_label',figsize=[6,4])

In [135]:
sc.pl.stacked_violin(avSeqData, geneList, groupby='cluster_label',figsize=[6,4])
sc.pl.dotplot(avSeqData, geneList, groupby='cluster_label',figsize=[6,4])

In [9]:
def plotROCs(annData, group, geneList = "Prkcd", groupby = "Level2_id_label", density = True, xRange = 8):
    # Function to plot gene expression histograms for two clusters, and then display ROC analysis
    # for logistics regression classification accuracy based on those genes.
    from sklearn.metrics import roc_curve, roc_auc_score, auc, RocCurveDisplay
    from sklearn.linear_model import LogisticRegression
    
    if type(geneList) is not list:
        geneList = [geneList]

    # Setup plotting parameters
    plt.figure(figsize=(12.8,2.4*len(geneList)))
    plt.subplot(len(geneList),2,1)
    
    # Subset to requested group
    groupData = annData[annData.obs[groupby] == group]
    
    # Create plots for each gene
    for count, gene in enumerate(geneList,1):
        plt.subplot(len(geneList),2,(count*2 - 1))

        ### Plot whole population histogram ###
        plt.hist(annData[:,gene].X.A, bins = np.arange(0,xRange,.1), alpha = .7, color = 'black',
                label = "merData", density = density)

        # Plot requested subpopulations, from largest to smallest
        plt.hist(groupData[:,gene].X.A,bins = np.arange(0,xRange,.1), edgecolor = 'black',
                alpha = .7, label = group, density = density)

        if density:
            plt.ylim([0, 1])
        plt.legend()
        plt.xlabel('Log2p Expression')
        plt.ylabel('PDF')
        plt.title(gene)
        plt.tight_layout()
        
        ### Plot ROC analysis ###
        plt.subplot(len(geneList),2,(count*2))
        x = annData[:,gene].X.A
        y = annData.obs[groupby].to_numpy() == group

        # Fit logstic regression model
        clf = LogisticRegression().fit(x, y)
        yPred = clf.decision_function(x)
        fpr, tpr, thresholds = roc_curve(y, yPred)
        roc_auc = auc(fpr, tpr)
    
        # Display ROC curves
        plt.plot(fpr,tpr, label = "AUC = " + str(np.round(roc_auc,3)))
        plt.legend(loc = "lower right")
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title("ROC")
        plt.tight_layout()

In [144]:
plotROCs(avSeqData,"1097 AV Col27a1 Glut",list(geneList), groupby = "cluster_label",xRange = 10)

In [145]:
plotROCs(avSeqData,"1097 AV Col27a1 Glut",list(seqGeneList), groupby = "cluster_label", xRange = 14)

In [78]:
plotGenes(["Sorcs3","C1ql2","Prkcd","C1ql3"],3)

## Look into AD

In [10]:
adSeqData = seqData[seqData.obs["Level2_id_label"] == "65_AD Serpinb7 Glut"]
sc.tl.rank_genes_groups(adSeqData,groupby="cluster_label",reference = "rest", pts = True)
sc.pl.rank_genes_groups(adSeqData, n_genes=20, fontsize = 16, sharey=False)

In [158]:
adSeqGeneList = np.concatenate([adSeqData.uns["rank_genes_groups"]["names"][name][0:3] for name in adSeqData.uns["rank_genes_groups"]["names"].dtype.names])
plotROCs(adSeqData,"1095 AD Serpinb7 Glut",list(adSeqGeneList), groupby = "cluster_label",xRange = 13)

In [212]:
ptDF.loc[["Scn4b","Slc17a7","Nkain2","Zcchc12","Prkcd"]].sort_values("diff",ascending = False)

In [178]:
ptDF = adSeqData.uns["rank_genes_groups"]["pts"]
ptDF['diff'] = ptDF['1095 AD Serpinb7 Glut'] - ptDF['1096 AD Serpinb7 Glut']
ptDF.sort_values("diff")

In [154]:
adData = merData[merData.obs["Level2_id_label"] == "65_AD Serpinb7 Glut"]
sc.tl.rank_genes_groups(adData,groupby="cluster_label",reference = "rest", pts=True)
sc.pl.rank_genes_groups(adData, n_genes=10, fontsize = 16, sharey=False)

In [156]:
adGeneList = np.concatenate([adData.uns["rank_genes_groups"]["names"][name][0:3] for name in adData.uns["rank_genes_groups"]["names"].dtype.names])
plotROCs(adSeqData,"1095 AD Serpinb7 Glut",list(adGeneList), groupby = "cluster_label",xRange = 10)

In [157]:
plotGenes(["Slc17a7","C1ql2","Zcchc12","Scn4b"],3)

## Discriminate anterior nuclei

In [11]:
# List putative nuclei clusters:
anteriorNuclei = ["1095 AD Serpinb7 Glut", "1096 AD Serpinb7 Glut", # Anterodorsal nucleus
                  "1097 AV Col27a1 Glut","1098 AV Col27a1 Glut", # Anteroventral nucleus
                  "1169 TH Prkcd Grin2c Glut_9", "1171 TH Prkcd Grin2c Glut_9", # Anteromedial nucleus
                  "1153 TH Prkcd Grin2c Glut_5", # Lateral dorsal nucleus of thalamus
                 ]
antSeqData = seqData[seqData.obs["cluster_label"].isin(anteriorNuclei)]

sc.tl.rank_genes_groups(antSeqData,groupby="cluster_label",reference = "rest", pts = True)
sc.pl.rank_genes_groups(antSeqData, n_genes=10, fontsize = 16, sharey=False)

## Modality differences
Look at differences between datasets, e.g. cell counts collected

In [123]:
xDF = countTypes(seqData,"Level2_id_label")
yDF = countTypes(merData,"Level2_id_label")
zDict = {}
for group in xDF.index:
    zDict[group] = xDF.loc[group][0] / yDF.loc[group][0]
    
pd.DataFrame(data=zDict.values(),index = zDict.keys(),columns=["seq / mer"]).sort_values("seq / mer",ascending = False)

## Misc

In [131]:
# Look at LGN clusters, section #16
plotSectionClusters(subclass = "69_TH Prkcd Grin2c Glut",sections=["1198980123"]);

In [126]:
plotGenes(["Gad2","Kirrel3","Prkcd","C1ql3"],16)