# Tonotopy Analysis

We would like to know if it's feasible to assign a cell to a tonotopycal region depending on its RNA expression.
[Waldhaus et al., 2015](https://www.sciencedirect.com/science/article/pii/S2211124715004945?via%3Dihub) have shown that different cell types seem to have differential expression of some genes according to their position in the Apex to Base axis.
Out first attempt at this determination is going to be:

1. Choose a cell type
2. Take the most DE genes across Apex vs Base
3. Plot pair plots showing how the expression of these genes change across the Apex to Base axis

In [None]:
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import seaborn as sns

Remember to download the data from [Waldhaus et al., 2015](https://www.cell.com/cms/10.1016/j.celrep.2015.04.062/attachment/7d19c04f-89c9-4f8d-8c7d-670de495351a/mmc2.xls) to ```../data/Waldhaus_2015```

In [None]:
DATA_DIR = Path("../data")
master_table_dir = DATA_DIR / "Waldhaus_2015/mmc2.xls"

In [None]:
df = pd.read_excel(master_table_dir, skiprows=2)
df

We can now make a dictionary of all the genes of interest per region. This list includes genes that are differentially expressed between Middle region and Apex or Base, but not necessarily between Apex and Base.

In [None]:
genes_of_interest = {
    "GER": [
        "Gapdh",
        "Egfp",
        "Acvr2b",
        "Cacna1d",
        "Cdh2",
        "Cdk7",
        "Clrn1",
        "Dach1",
        "Dkk3",
        "Egfr",
        "Fzd6",
        "Gata3",
        "Gli3",
        "Gpr98",
        "Hes5",
        "Hmga2",
        "Jag1",
        "Lamb1",
        # "Ncam",
        "Ntf3",
        "Sox2"
        ],
    "IHC": [
        "Actb",
        "Bmpr2",
        "Calb1",
        "Calb2",
        "Calm1",
        "Cdkn1a",
        "Csnk1e",
        "Dach1",
        "Dkk3",
        "Dner",
        "Eya2",
        # "Fgfr1",
        "Fzd3",
        "Gata3",
        "Grb10",
        "Isl1",
        "Lfng",
        "Ntf3",
        "Otof",
        "Pax2",
        "Pdgfra",
        "Ptprq",
        "Pvalb",
        "Sox2",
        "Sox9",
        ],
    "IPH": [
        "Cdk6",
        "Fbxo2",
        "Fgf10",
        "Fos",
        "Inhba",
        ],
    "IPC": [
        "Actb",
        "Calb2",
        "Ccnd1",
        "Cdh1",
        "Fbxo2",
        "Fgf10",
        "Gjb1",
        "Hmga2",
        "Inhba",
        "Itpr2",
        "Kcnj10",
        "Lgr5",
        "Pax2",
        "Pdgfra",
        ],
    "OPC": [
        "Actb",
        "Ai14-tdTomato",
        "Acvr2b",
        "Akt1",
        "Aqp4",
        "Bmp2",
        "Calm1",
        "Cdh2",
        "Cdkn1a",
        "Dach1",
        "Dkk3",
        "Efna1",
        "Epha7",
        "Espn",
        "Etv4",
        "Etv5",
        "Fbxo2",
        "Fgf10",
        # "Fgfr1",
        # "Fgfr3",
        "Fos",
        "Fst",
        "Fzd3",
        "Gdf10",
        "Gjb1",
        "Gli2",
        "Gli3",
        "Gpr98",
        "Hes5",
        "Hmga2",
        "Inhba",
        "Itpr2",
        "Kcnj10",
        "Lgr5",
        "Myo6",
        "Myo7a",
        "Nes",
        "Otol1",
        "Pdgfra",
        "Ptc1",
        "Six4",
        "Smo",
        "Spry1",
        "Tcf7l1",
        "Tgfbr2",
        "Vegfc",
        "Wnt2b",
        "Wnt7a",
        "Wnt7b",
    ],
    "OHC": [
        "Actb",
        "Ai14-tdTomato",
        "Acvr2b",
        "Akt1",
        "Bmpr2",
        "Cacna1d",
        "Calb1",
        "Calb2",
        "Calm1",
        "Ccne1",
        "Cdh1",
        "Cdkn1a",
        "Cdkn2d",
        "Chrd",
        "Clrn1",
        "Csnk1e",
        "Dach1",
        "Dkk3",
        "Dll1",
        "Dnmt1",
        "Dnmt3a",
        "Etv5",
        "Eya1",
        "Eya2",
        "Fbxo2",
        # "Fgfr1",
        # "Fgfr3",
        "Fgfr4",
        "Fos",
        "Fst",
        "Fzd3",
        "Gata3",
        "Grb10",
        "Hey1",
        "Hmga2",
        "Isl1",
        "Jag2",
        "Lfng",
        "Lgr5",
        "Mcam",
        "Myo15",
        "Myo6",
        "Myo7a",
        "Ntf3",
        "Ocm",
        "Otof",
        "Pax2",
        "Pcdh15",
        "Pcgf2",
        "Pdgfra",
        "Prox1",
        "Ptc1",
        "Ptprq",
        "Six4",
        "Slc26a5",
        "Sox2",
        "Sox9",
        "Spna2",
        "Spry2",
        "Vegfa",
        "Wnt2b",
        "Wnt7a",
    ],
    "DC12": [
        "Cdh1",
        "Csnk1e",
        "Egfr",
        "Gjb1",
        "Hmga2",
        "Inhba",
        "Isl1",
        "Itpr2",
        "Kcnj10",
        "Lamb1",
        "Map2k1",
        "Mcam",
        "Myo6",
        "Otor",
        "Tcf7l1",
    ],
    "DC3": [
        "Cdh1",
        "Cdk6",
        "Fgf10",
        "Fgfr4",
        "Hmga2",
        "Inhba",
        "Kcnj10",
        "Myo7a",
        "Tcf7l1",
    ]
}

Now that we have the genes of interest per cell type, we can make pairplots and try to see if gene expression correlates with tonotopy in some clear way.

Commented out genes were not found in the master table. Careful because it takes a while to generate tge following plots.

In [None]:
for cell_type, cell_type_de_genes in genes_of_interest.items():
    print(cell_type)
    cols = ["Anatomical Origin"] + cell_type_de_genes
    selected_df = df.query(f"Subpopulation == '{cell_type}'")[cols]

    sns.pairplot(data=selected_df, hue="Anatomical Origin", plot_kws={"rasterized":True}, diag_kws={"rasterized":True})
    plt.title(cell_type)
    plt.savefig(DATA_DIR.parent.parent / f"results/waldhaus_2015_analysis/{cell_type}.png", format="png")
    plt.close()

We should save a dictionary showing which genes show a higher expression in each cell type.
This dictionary we will use later to name and corroborate that each group assigned in our dataset corresponds to one of these groups and that the same pattern of expression is repeated in both groups.
We will have to skip the cells assigned to Middle as then we can't assign anything here.

In [None]:
waldhaus_high_expression_dict = dict()

for cell_type, cell_type_de_genes in genes_of_interest.items():
    print(cell_type)
    cols = ["Anatomical Origin"] + cell_type_de_genes
    selected_df = df.dropna(subset="Anatomical Origin").query(f"Subpopulation == '{cell_type}'")[cols]
    selected_df = selected_df[~selected_df["Anatomical Origin"].str.contains("Middle")]
    
    waldhaus_high_expression_dict[cell_type] = selected_df.groupby("Anatomical Origin").mean().idxmax()
    


# Load our dataset

We can load expression of landmark genes in our particular cell types and look into separating them with a Gaussian Mixture Model.
This will yield two separate populations of cells.
We can then assess which genes have higher expression in each subpopulation and then assign Apex or Base to each class.
Finally, we can obtain the probability of belonging to each of these classes and use this to classify cells.

After looking into gene expression in our datasets, we noticed that gene expression of OHC depended on collection date so we kept only datasets collected in P1.
We also noticed that some OHC cells showed a very low expression of Calm1 (<4.5) which we also filtered out. Something similar was happenning with IHC which expressed low Pvalb (<2)

In [None]:
datasets = dict()
sheets_dict = pd.read_excel(DATA_DIR / "expression_tonogenes_trex.xlsx", sheet_name=None)

for this_cell_type, this_df in sheets_dict.items():
    print(this_cell_type)
    
    this_df["dataset"] = this_df[this_df.columns[0]].apply(lambda x: int(x.split("_")[-1]))
    this_df = this_df[this_df.dataset.isin([3, 4, 5])]

    if this_cell_type == "OHC":
        this_df = this_df.query("Calm1 > 4.5")
    if this_cell_type == "IHC":
        this_df = this_df.query("Pvalb > 2")
    
    datasets[this_cell_type] = this_df
datasets

# Cluster cells into Apex and Base

The goal is to use Gaussian Mixture Models to segregate between Apex and Base for each cell type.
We will use the gene expression found in Waldhaus to initiate the GMMs near the expected expression pattern for each gene from each region for each cell type.
Briefly, for each cell type, we will look into each gene and find the minimum and maximum expression level in our dataset.
We will initialize the GMM with the means so that one group should start near what would be expected for an Apex cell expressing the maximal values and then the analogous for a Base cell.
Once GMM converges for each cell type, we will assess that the expression distribution of each cell type matches the Waldhaus paper.

In [None]:
from sklearn.mixture import GaussianMixture

We made a list of all the genes that have a clear differential expression in Waldhaus's dataset and also tagged in which region the highest expression was happenning.

In [None]:
high_expression_gene_dict = {
    "GER": {
        "Dkk3": "Base",
        "Jag1": "Apex"
    },
    "IPH": {
        "Fbxo2": "Apex",
        "Fos": "Base",
    },
    "IHC": {
        "Actb": "Apex",
        "Calb2": "Base",
        "Dach1": "Apex",
        "Gata3": "Base",
        "Fzd3": "Apex",
        "Isl1": "Apex",
        "Ntf3": "Apex",
        "Otof": "Base",
        "Pax2": "Apex",
        "Ptprq": "Base",
        "Pvalb": "Base",
        "Sox2": "Apex",
    },
    "IPC": {
        "Ccnd1": "Apex",
        "Hmga2": "Apex",
        "Inhba": "Base",
        "Kcnj10": "Base"
    },
    "OPC": {
        "Actb": "Apex",
        "Fbxo2": "Apex",
        "Hes5": "Apex",
        "Itpr2": "Base",
        "Tcf7l1": "Base",
    },
    "OHC": {
        "Actb": "Apex",
        "Cacna1d": "Base",
        "Calb1": "Base",
        "Calb2": "Base",
        "Clm1": "Base",
        "Eya1": "Apex",
        "Myo6": "Base",
        "Otof": "Base",
        "Pcdh15": "Apex",
        "Ptprq": "Base"
    },
    "DC12": {
        "Inhba": "Base",
        "Map2k1": "Apex"
    },
    "DC3": {
        "Fg10": "Apex",
        "Hmga2": "Apex",
        "Inhba": "Base",
        "Tcf7l1": "Apex",
    }
}

In [None]:
gmms = dict()

for this_cell_type, this_df in datasets.items():
    print(f"Processing cell type: {this_cell_type}")
    
    this_gene_list = this_df.columns.drop(["cellID", "dataset"]).values
    
    if this_cell_type == "GER":
        this_gene_list = ["Dach1", "Dkk3"]
    
    if this_cell_type == "IHC":
        this_gene_list = ["Actb", "Calb1", "Isl1", "Otof", "Pvalb"]
    
    if this_cell_type == "IPH":
        # Only has the 4 dark green genes
        pass
    
    if this_cell_type == "IPC":
        this_gene_list = ["Fbxo2", "Hmga2", "Inhba","Itpr2", "Kcnj10"]
    
    if this_cell_type == "OPC":
        this_gene_list = ["Fzd3", "Inhba", "Itpr2", "Tcf7l1"]
    
    if this_cell_type == "OHC":
        this_gene_list = this_df.columns.drop(["cellID", "dataset", "Fbxo2", "Gata3"]).values

    print("We will use the following genes to fit GMMs")
    print(this_gene_list)
    
    gene_min_max = this_df[this_gene_list].agg([min, max]).values

    for n, should_swap in enumerate([False  if region =="Apex" else True for region in waldhaus_high_expression_dict[this_cell_type][this_gene_list]]):
        if should_swap:
            gene_min_max[:, n] = np.flip(gene_min_max[:, n])
    
    gmms[this_cell_type] = GaussianMixture(n_components=2, means_init=gene_min_max, max_iter=1000, random_state=42)
    gmms[this_cell_type].fit(this_df[this_gene_list])
    if not gmms[this_cell_type].converged_:
        print("GMM has not converged")
    
    this_df["predicted"] = gmms[this_cell_type].predict(this_df.loc[:, this_gene_list])

    sns.pairplot(this_df, hue="predicted")
    plt.suptitle(this_cell_type)
    plt.savefig(DATA_DIR.parent / f"results/gmm_initialized_{this_cell_type}.png", format="png")
    plt.close()

Once we have GMMs trained for predicting to which region the different cell types belong to, we have to assess whether the gene expression of these predictions match the gene pattern observed in Wladhaus' datasets.

Let's make a function to parse the group number into the most likely region.
We can then check which genes seem to not follow the pattern and assess if these clearly had a differential gene expression or if they could be mislabeled.

In [None]:
def get_group_dict(assignment_counts):
    most_common = assignment_counts.most_common()[0][0]
    
    if not most_common[0] in [0, 1]:
        raise ValueError("Number not in 0 or 1")
    if not most_common[1] in ["Base", "Apex"]:
        raise ValueError("Region not in Base or Apex")
    
    other_number = 0 if most_common[0] == 1 else 1
    other_region = "Apex" if most_common[1] == "Base" else "Base"

    group_dict = {
        most_common[0]: most_common[1],
        other_number: other_region
    }

    return group_dict

In [None]:
detected_high_expression_dict = dict()

for this_cell_type, this_gmm in gmms.items():
    print(this_cell_type)
    detected_high_expression_dict[this_cell_type] = pd.Series(np.argmax(this_gmm.means_, axis=0), index=this_gmm.feature_names_in_)

detected_high_expression_dict

In [None]:
group_dicts = dict()

for this_cell_type, this_gmm in gmms.items():
    print(this_cell_type)

    this_gene_list = detected_high_expression_dict[this_cell_type].index
    assignments = [(detected, waldhaus) for detected, waldhaus in zip(detected_high_expression_dict[this_cell_type], waldhaus_high_expression_dict[this_cell_type][this_gene_list])]

    assignment_counts = Counter(assignments)
    group_dicts[this_cell_type] = get_group_dict(assignment_counts)

    landmarked_genes = list(high_expression_gene_dict[this_cell_type].keys())
    this_high_df = detected_high_expression_dict[this_cell_type].rename("detected").to_frame().join(waldhaus_high_expression_dict[this_cell_type][this_gene_list].rename("waldhaus"))
    this_high_df["detected"] = this_high_df["detected"].apply(lambda x: group_dicts[this_cell_type][x])
    this_high_df["landmarked_gene"] = this_high_df.index.map(lambda x: x in landmarked_genes)

    print(this_high_df)

From comparing these, not all genes seem to follow the same pattern in a straightforward way.
We should also take into account that not all genes show an expression different enough for us to be confident that one is higher than the other, for this reason we have taggedd some genes that should be clear landmark genes.

By looking into this list, GER is the only cell type which we cannot be sure about its assignment.
For the other cell types, there are very few (2) mismatches in expression, but these are in genes that are not really differentially expressed.

Finally, let's apply the trained GMMs to predict region and also state how certain we are of belonging to such region.

In [None]:
with pd.ExcelWriter(DATA_DIR.parent / f"results/tonotopy_initialized_prediction_2.xlsx") as writer:
    for this_cell_type, this_df in datasets.items():
        print(this_cell_type)

        this_gene_list = detected_high_expression_dict[this_cell_type].index

        this_gmm = gmms[this_cell_type]
        this_group_dict = group_dicts[this_cell_type]
        
        this_df.loc[:, "predicted"] = [this_group_dict[x] for x in gmms[this_cell_type].predict(this_df.loc[:, this_gene_list])]
        this_df[[f"proba_{this_group_dict[0]}", f"proba_{this_group_dict[1]}"]] = gmms[this_cell_type].predict_proba(this_df.loc[:, this_gene_list])

        this_df.to_excel(excel_writer=writer, sheet_name=this_cell_type)

# Plots

Having a updataed dataframes with all the predictions and probabilities, we can plot:

1. A pair plot separated by regions, so that we can vissually assess expression of each of the genes.

2. A pair plot where probability is color coded, so that we can see how fast probability changes across expression of one of the genes.

In [None]:
for this_cell_type, this_df in datasets.items():
    print(this_cell_type)
    sns.pairplot(this_df, hue="predicted")
    plt.suptitle(this_cell_type)
    plt.savefig(DATA_DIR.parent.parent / f"results/waldhaus_2015_analysis/initialized_prediction_{this_cell_type}.png", format="png")
    plt.close()

In [None]:
for this_cell_type, this_df in datasets.items():
    print(this_cell_type)
    sns.pairplot(this_df, hue="proba_Apex")
    plt.suptitle(this_cell_type)
    plt.savefig(DATA_DIR.parent.parent / f"results/waldhaus_2015_analysis/proba_initialized_prediction_{this_cell_type}.png", format="png")
    plt.close()

The question arose, if we could filter out cells which we were unsure about the region and get a very segregated gene expression distribution for each cell type.

In [None]:
probability_threshold = 0.9999

for this_cell_type, this_df in datasets.items():
    print(this_cell_type)
    sns.pairplot(this_df.query(f"proba_Apex > {probability_threshold} or proba_Base > {probability_threshold}"),
                 hue="proba_Apex",
                 plot_kws={"alpha": 0.6})
    plt.suptitle(this_cell_type)
    plt.savefig(DATA_DIR.parent.parent / f"results/waldhaus_2015_analysis/certain_proba_initialized_prediction_{this_cell_type}.png", format="png")
    plt.close()

# Validation

To validate our pipeline, we can train GMMs on Waldhaus' datasets, predict to which region each cell belongs to and assess our accurracy.

In [None]:
waldhaus_ggms = dict()
waldhaus_results = dict()

for this_cell_type, other_gmm in gmms.items():
    print(this_cell_type)
    this_gene_list = other_gmm.feature_names_in_

    cols = ["Anatomical Origin"] + this_gene_list.tolist()
    this_df = df.query(f"Subpopulation == '{this_cell_type}'")[cols].copy()
    this_df = this_df.dropna(subset="Anatomical Origin")
    this_df = this_df.query("`Anatomical Origin` != 'Middle'")
    
    if this_cell_type == "OHC":
        this_df.query("Cacna1d > 2")
    
    
    gene_min_max = this_df[this_gene_list].agg([min, max]).values

    for n, should_swap in enumerate([False  if region =="Apex" else True for region in waldhaus_high_expression_dict[this_cell_type][this_gene_list]]):
        if should_swap:
            gene_min_max[:, n] = np.flip(gene_min_max[:, n])
    
    waldhaus_ggms[this_cell_type] = GaussianMixture(n_components=2, means_init=gene_min_max, max_iter=1000, random_state=42)
    waldhaus_ggms[this_cell_type].fit(this_df[this_gene_list])
    if not waldhaus_ggms[this_cell_type].converged_:
        print("GMM has not converged")
    
    this_df["predicted"] = waldhaus_ggms[this_cell_type].predict(this_df.loc[:, this_gene_list])
    waldhaus_results[this_cell_type] = this_df


In [None]:
waldhaus_group_dicts = dict()

for this_cell_type, this_gmm in waldhaus_ggms.items():
    print(this_cell_type)

    high_expression_list = pd.Series(np.argmax(this_gmm.means_, axis=0), index=this_gmm.feature_names_in_)

    this_gene_list = this_gmm.feature_names_in_
    assignments = [(detected, waldhaus) for detected, waldhaus in zip(high_expression_list, waldhaus_high_expression_dict[this_cell_type][this_gene_list])]

    assignment_counts = Counter(assignments)
    waldhaus_group_dicts[this_cell_type] = get_group_dict(assignment_counts)

In [None]:
for this_cell_type, this_result in waldhaus_results.items():
    print(this_cell_type)

    this_group_dict = group_dicts[this_cell_type]
    this_gene_list = waldhaus_ggms[this_cell_type].feature_names_in_

    this_result.loc[:, "predicted"] = [this_group_dict[x] for x in waldhaus_ggms[this_cell_type].predict(this_result.loc[:, this_gene_list])]
    print(pd.crosstab(this_result["Anatomical Origin"], this_result["predicted"]))
    print("=========================================")

In [None]:
for this_cell_type, this_df in waldhaus_results.items():
    print(this_cell_type)
    sns.pairplot(this_df, hue="predicted")
    plt.suptitle(this_cell_type)
    plt.savefig(DATA_DIR.parent.parent / f"results/waldhaus_2015_analysis/algorithm_validation_{this_cell_type}.png", format="png")
    plt.close()