# Decoding Responsivity (imbalanced learn)
**Can we predict whether a stimulus will be detected or not based on neuron's responsivity ?**

Célien Vandromme
18/04/2024

---

In [None]:
from unittest import result

import numpy as np
import pandas as pd


import percephone.core.recording as pc
import percephone.plts.stats as ppt
import os
import matplotlib
import matplotlib.pyplot as plt
from multiprocessing import Pool, cpu_count, pool
import warnings
import seaborn as sns
import copy

from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import confusion_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.dummy import DummyClassifier
from scipy.stats import mannwhitneyu
import imblearn as imb

plt.rcParams['font.size'] = 10
plt.rcParams['axes.linewidth'] = 2
plt.switch_backend("Qt5Agg")
matplotlib.use("Qt5Agg")
%matplotlib inline
warnings.filterwarnings('ignore')
fontsize = 30
%config InlineBackend.figure_format = 'retina'

In [None]:
directory = "C:/Users/cvandromme/Desktop/Data/"
roi_path = "C:/Users/cvandromme/Desktop/FmKO_ROIs&inhibitory.xlsx"
files = os.listdir(directory)
files_ = [file for file in files if file.endswith("synchro")]

def opening_rec(fil, i):
    rec = pc.RecordingAmplDet(directory + fil + "/", 0, roi_path)
    rec.peak_delay_amp()
    return rec


workers = cpu_count()
pool = pool.ThreadPool(processes=workers)
async_results = [pool.apply_async(opening_rec, args=(file, i)) for i, file in enumerate(files_)]
recs = {ar.get().filename: ar.get() for ar in async_results}

In [None]:
names = ["ID", "genotype", "stim", "stim_True", "stim_False", "neuron_exc", "neuron_inh"]
matrix = []
for rec in recs.values():
    ind = []
    ind.append(rec.filename)
    ind.append(rec.genotype)
    ind.append(len(rec.detected_stim))
    ind.append(rec.detected_stim.sum())
    ind.append(len(rec.detected_stim) - rec.detected_stim.sum())
    ind.append(rec.matrices["EXC"]["Responsivity"].shape[0])
    ind.append(rec.matrices["INH"]["Responsivity"].shape[0])
    matrix.append(ind)
summary = pd.DataFrame(np.array(matrix), columns=names)
summary

In [None]:
def resp_heatmap(record_dict, sort=False):
    if sort:
        X_train_true = record_dict["X_train"][record_dict["y_train"] == True]
        y_train_true = np.ones(X_train_true.shape[0], dtype=bool)
        X_train_false = record_dict["X_train"][record_dict["y_train"] == False]
        y_train_false = np.zeros(X_train_false.shape[0], dtype=bool)
        X_new = record_dict["X_bal"][X_train_true.shape[0] + X_train_false.shape[0]:]
        y_new = record_dict["y_bal"][X_train_true.shape[0] + X_train_false.shape[0]:]
        X = np.row_stack((X_train_true, X_train_false, X_new))
        y = np.append(y_train_true, np.append(y_train_false, y_new))
    else:
        X = record_dict["X_bal"]
        y = record_dict["y_bal"]

    plt.figure(figsize=(12, 6))
    sns.heatmap(X.T, cmap='plasma', xticklabels=y, cbar_kws={'label': 'Responsivity'})
    plt.axvline(x=record_dict["X_train"].shape[0], color="black", linewidth=1)
    if sort:
        plt.axvline(x=X_train_true.shape[0], color="black", linewidth=1, linestyle='dashed')
    plt.title(f"{record_dict["filename"]}({record_dict["genotype"]}) Neuron Responses to Stimuli")
    plt.xticks(fontsize=7)
    plt.yticks(fontsize=7)
    plt.xlabel("Detected stimulation")
    plt.ylabel("Neurons")
    plt.show()

In [None]:
def get_rec_info(rec, parameter, exc_neurons=True, inh_neurons=False):
    result = {}
    result["genotype"] = rec.genotype
    result["filename"] = rec.filename
    result["threshold"] = rec.threshold
    result["stim_ampl"] = rec.stim_ampl
    result["y"] = rec.detected_stim
    if exc_neurons and inh_neurons:
        result["X"] = np.row_stack((rec.matrices["EXC"][parameter], rec.matrices["INH"][parameter])).T
        result["neurons"] = "EXC & INH"
    elif exc_neurons:
        result["X"] = rec.matrices["EXC"][parameter].T
        result["neurons"] = "EXC"
    elif inh_neurons:
        result["X"] = rec.matrices["INH"][parameter].T
        result["neurons"] = "INH"
    return result

In [None]:
def stim_ampl_filter(record, stim_ampl="all"):
    all_ampl = np.arange(0, 14, 2)
    if stim_ampl == "supra":
        amplitudes = all_ampl[all_ampl >= record["threshold"]]
    elif stim_ampl == "sub":
        amplitudes = all_ampl[all_ampl < record["threshold"]]
    elif stim_ampl == "all":
        return record
    else:
        amplitudes = np.array(stim_ampl)
    selected_stim = np.isin(record["stim_ampl"], amplitudes)
    record["X"] = record["X"][selected_stim]
    record["y"] = record["y"][selected_stim]
    return record

In [None]:
def split_data(record_dict, train_ratio=0.8, stratify=False, seed=None):
    if stratify:
        record_dict["X_train"], record_dict["X_test"], record_dict["y_train"], record_dict["y_test"] = train_test_split(record_dict["X"], record_dict["y"], 
                                                                                                                        train_size=train_ratio,
                                                                                                                        stratify=record_dict["y"],
                                                                                                                        random_state=seed)
    else:
        record_dict["X_train"], record_dict["X_test"], record_dict["y_train"], record_dict["y_test"] = train_test_split(record_dict["X"], record_dict["y"], 
                                                                                                                        train_size=train_ratio,
                                                                                                                        stratify=None,
                                                                                                                        random_state=seed)
    return record_dict

In [None]:
def resample(record_dict, resampler):
    record_dict["X_bal"], record_dict["y_bal"] = resampler.fit_resample(record_dict["X_train"], record_dict["y_train"])
    return record_dict

In [None]:
def apply_model(model, parameter, resampler=None, exc_neurons=True, inh_neurons=True, amplitudes="all", train_ratio=0.8, stratify=True, cv=10, seed=None):
    mean_cv_scores = [[], [], []]
    accuracies = [[], [], []]
    sensitivities = [[], [], []]
    specificities = [[], [], []]
    models_dict = {}
    
    fig, ax = plt.subplots(nrows=8, ncols=2, figsize=(15, 10), sharex=True)
    i_wt = 0
    i_ko = 0
        
    for rec in recs.values():
        
        # Defining the variable gen_id according to genotype, used to group accuracies, sensitivities, ...etc by genotype
        if rec.genotype == "WT":
            genotype_id = 0
        elif rec.genotype == "KO-Hypo":
            genotype_id = 1
        elif rec.genotype == "KO":
            genotype_id = 2
        
        try:
            record = get_rec_info(rec, parameter, exc_neurons=exc_neurons, inh_neurons=inh_neurons)
            record = stim_ampl_filter(record, stim_ampl=amplitudes)
            record = split_data(record, train_ratio=train_ratio, stratify=stratify, seed=seed)
            test_size = record["y_test"].shape[0]
            
            if resampler is not None:
                record = resample(record, resampler)
                new_samples = record["y_bal"].shape[0] - record["y_train"].shape[0]
                model.fit(record["X_bal"], record["y_bal"])
                cv_scores = cross_val_score(model, record["X_bal"], record["y_bal"], cv=cv)
                # Generated label
                gen_lab = str(record["y_bal"][-1])
            else:
                new_samples = 0
                model.fit(record["X_train"], record["y_train"])
                cv_scores = cross_val_score(model, record["X_train"], record["y_train"], cv=cv)
        except ValueError:
            continue
            
        # Saving model for plotting the weights
        models_dict[f"{record["filename"]} ({record["genotype"]})"] = copy.copy(model)
        
        # Cross-validation
        mean_cv_scores[genotype_id].append(cv_scores.mean())
        
        # Metrics on test data
        y_pred = model.predict(record["X_test"])
        conf_matrix = confusion_matrix(record["y_test"], y_pred, labels=[False, True])
        TP = conf_matrix[1, 1]
        TN = conf_matrix[0, 0]
        FP = conf_matrix[0, 1]
        FN = conf_matrix[1, 0]

        sensitivity = TP / (TP + FN)
        specificity = TN / (TN + FP)
        accuracy = (TP + TN) / (TP + TN + FP + FN)
        accuracies[genotype_id].append(accuracy)
        sensitivities[genotype_id].append(sensitivity)
        specificities[genotype_id].append(specificity)
        
        # Boxplot of each recording
        if rec.genotype == "WT":
            i = copy.copy(i_wt)
            j = 0
            i_wt += 1
        else:
            i = copy.copy(i_ko)
            j = 1
            i_ko += 1
        
        ax[i, j].boxplot(cv_scores, vert=False, widths=.4)
        ax[i, j].scatter(cv_scores, np.ones(cv), s=10)
        ax[i, j].axvline(x=specificity, color="green", linewidth=1, linestyle=":")
        ax[i, j].axvline(x=sensitivity, color="orange", linewidth=1, linestyle=":")
        ax[i, j].axvline(x=accuracy, color="red", linewidth=1, linestyle=":")
        if resampler is not None:
            ax[i, j].set_title(f"{rec.filename} ({rec.genotype})[N:{new_samples}({gen_lab}) T:{test_size}] - CV({cv_scores.mean():.1%}) Ac({accuracy:.1%}) Sp({specificity:.1%}) Se({sensitivity:.1%})", size=10)
        else:
            ax[i, j].set_title(f"{rec.filename} ({rec.genotype})[T:{test_size}] - CV({cv_scores.mean():.1%}) Ac({accuracy:.1%}) Sp({specificity:.1%}) Se({sensitivity:.1%})", size=10)
        ax[i, j].spines["left"].set_visible(True)
    plt.suptitle(f"Decoding {parameter} [{resampler}/{model}] - Train size: {train_ratio:.1%} - CV: {cv} fold")
    plt.show()
    
    # Boxplots of sensitivity, specificity and accuracy by genotype
    data_wt = [sensitivities[0], specificities[0], accuracies[0]]
    data_koh = [sensitivities[1], specificities[1], accuracies[1]]
    ppt.boxplot_3_conditions(data_wt, data_koh, cond_labels=["Detected", "Undetected", "All"], y_percent=True, 
                             title=f"Decoding {parameter} [{resampler}/{model}]", filename="model_metrics", label_y="Accuracy")

In [None]:
seed = 42

# Models
dum = DummyClassifier()
lr = LogisticRegression(penalty="l2")

# Resamplers
ros = imb.over_sampling.RandomOverSampler(sampling_strategy='auto', shrinkage=None)
smote = imb.over_sampling.SMOTE(sampling_strategy='auto')
adasyn = imb.over_sampling.ADASYN(sampling_strategy='auto')

apply_model(lr, "Responsivity", resampler=ros, exc_neurons=True, inh_neurons=True, amplitudes="all", train_ratio=0.8, cv=5, stratify=True, seed=seed)

In [None]:
rec = recs[4756]
test = get_rec_info(rec, "Responsivity", exc_neurons=True, inh_neurons=False)
ros = imb.over_sampling.RandomOverSampler(sampling_strategy='auto', shrinkage=None)
bd_smote = imb.over_sampling.BorderlineSMOTE(sampling_strategy='auto')
adasyn = imb.over_sampling.ADASYN(sampling_strategy="auto")
test = split_data(test, train_ratio=0.8, stratify=True)

test = resample(test, ros)
test_2 = resample(copy.copy(test), smote)
test_3 = resample(copy.copy(test), bd_smote)
test_4 = resample(copy.copy(test), adasyn)
resp_heatmap(test, sort=True)
resp_heatmap(test_2, sort=True)
resp_heatmap(test_3, sort=True)
resp_heatmap(test_4, sort=True)

In [18]:
# from sklearn.metrics.pairwise import cosine_similarity
import percephone.plts.behavior as pbh

def neuron_mean_std_corr(array, estimator):
    if estimator == "Mean":
        return np.mean(array, axis=0)
    if estimator == "Std":
        return np.std(array, axis=0)
    

def get_zscore(rec, exc_neurons=True, inh_neurons=False):
        if exc_neurons and inh_neurons:
            zscore = np.row_stack((rec.zscore_exc, rec.zscore_inh)).T
        elif exc_neurons:
            zscore = rec.zscore_exc.T
        elif inh_neurons:
            zscore = rec.zscore_inh.T
        return zscore

def get_zscore_estimator(rec, estimator, exc_neurons=True, inh_neurons=True):
    zscore = get_zscore(rec, exc_neurons=exc_neurons, inh_neurons=inh_neurons)
    first = True
    for i in range(rec.stim_time.shape[0]):
        start = rec.stim_time[i]
        end = rec.stim_time[i] + int(rec.stim_durations[i])
        if first:
            X = neuron_mean_std_corr(zscore[start : end], estimator)
            first = False
        else:
            new_row = neuron_mean_std_corr(zscore[start : end], estimator)
            X = np.row_stack((X, new_row))
    return X

# def cosine_matrix(ax, rec, amplitude, resp_mat):
#     print("Cosine similarity computation")
# 
#     sim_mat = cosine_similarity(resp_mat)
#     ax.imshow(sim_mat, cmap="seismic", vmin=-1, vmax=+1, interpolation="none")
#     ax.set_xlabel("Trial i")
#     ax.set_ylabel("Trial j")
#     ax.set_title(str(rec.filename) + " " + rec.genotype + " " + str(amplitude))

In [19]:
y, i = 0, 0
amps = [2, 6,6, 4, 4, 4, 8, 4, 4, 12, 8, 6, 6, 12, 12]  # manual selection of the threshold amp for each animal from psychometric curves
fig, ax = plt.subplots(2, 8, figsize=(35, 20))
estimator = "Mean"
roi_info = pd.read_excel(roi_path)

for rec, amp in zip(recs.values(), amps):

        resp_m=get_zscore_estimator(rec, estimator)
        if rec.genotype == "WT":
            pbh.psycho_like_plot(rec, roi_info, ax[0, i])
            # cosine_matrix(ax[0, i], rec, rec.threshold,resp_m)
            i = i + 1
        else:
            pbh.psycho_like_plot(rec, roi_info, ax[1, y])
            # cosine_matrix(ax[1, y], rec, rec.threshold,resp_m)
            y = y + 1

# ax[2, 6].set_axis_off()
# ax[2, 7].set_axis_off()
# ax[3, 6].set_axis_off()
# ax[3, 7].set_axis_off()
fig.suptitle('Cosine similarity for all trials', fontsize=26)

Text(0.5, 0.98, 'Cosine similarity for all trials')