# Tuning curves analysis

Analysis inspired by: http://dx.doi.org/10.1371/journal.pone.0203900
With data pre-processed in `tuning_curves_generate_data.ipynb`.

In [12]:
# imports
import sys
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import pandas as pd
from myterial import amber_dark, blue_dark, green_dark, indigo_light, teal_light, salmon, salmon_darker, orange, cyan_light, deep_purple_light
import warnings
warnings.filterwarnings(action='ignore', message="This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.")


sys.path.append("./")
sys.path.append(r"C:\Users\Federico\Documents\GitHub\pysical_locomotion")

from fcutils.plot.figure import clean_axes
from analysis.ephys.utils import get_recording_names
from analysis.ephys.viz import outline

save_folder = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys")



## Curve fitting functions

In [13]:
from scipy.optimize import curve_fit

# code to fit lines, polinomials and double exponentials & plot
def r_squared(y, y_hat):
    y_bar = y.mean()
    ss_tot = ((y-y_bar)**2).sum()
    ss_res = ((y-y_hat)**2).sum()
    return 1 - (ss_res/ss_tot)


def adjusted_r_squared(y, y_hat, degree):
    r2 = r_squared(y, y_hat)
    return 1 - (1 - r2) * (len(y) - 1) / (len(y) - degree - 2)


def polynomial_variance(y, y_hat, degree):
    """
        See: https://stats.stackexchange.com/questions/261537/how-to-chose-the-order-for-polynomial-regression
    """
    ss_res = ((y-y_hat)**2).sum()
    return round(ss_res / (len(y) - degree - 1), 2)

# def polynomial_variance_on_other_data(y, p, X, degree):
#     if degree == 1:
#         yhat = p[0]*X+p[1]
#     elif degree == 2:
#         yhat = np.polyval(p,X)
#     else:
#         yhat = double_exp(X, *p)
#     return polynomial_variance(y, yhat, degree)
    



def fit_line(X, Y,  ax=None, Y_var=None, color=None,  lw=4, degree=1, label=True, **kwargs):

    idx = np.isfinite(X) & np.isfinite(Y) 
    if Y_var is not None:
        idx = idx & np.isfinite(Y_var) & (Y_var > 0)
        Y_var = Y_var[idx]
    X = X[idx]
    Y = Y[idx]
    

    if not len(X):
        return ax.plot([0, 0], [0, 0], color=salmon, linewidth=2)[0], 1000


    if degree == 1:
        p = np.polyfit(X, Y, degree, w=1/Y_var if Y_var is not None else None)
        yhat = p[0]*X + p[1]
        color=salmon if color is None else color
    else:
        p = np.polyfit(X,Y,degree, w=1/Y_var if Y_var is not None else None)
        yhat = np.polyval(p,X)
        color=salmon_darker if color is None else color

    r2 = r_squared(Y, yhat)
    var = polynomial_variance(Y, yhat, degree)
    if ax is None:
        return p, r2, var

    line = ax.plot(X, yhat, lw=lw, color=color, label=f"{degree} degree, R2={r2:.2f} | var: {var:.2f}" if label else None, **kwargs)[0]
    return line, var



def double_exp(x, a, k1, b, k2, c):
    return a*np.exp(x*k1) + b*np.exp(x*k2) + c


def fit_double_exp(X, Y,  ax=None, color="r",  lw=4, label=True, **kwargs):
    idx = np.isfinite(X) & np.isfinite(Y) 
    X = X[idx]
    Y = Y[idx]
    

    try:
        popt, pcov = curve_fit(double_exp, X, Y, p0=(1.0,1.0,1.0,1.0), maxfev=5000)
    except:
        if ax is None:
            return None, None, 100000
        else:
            return None, 1000
    yhat = double_exp(X, *popt)

    r2 = r_squared(Y, yhat)
    var = polynomial_variance(Y, yhat, 5)
    if ax is None:
        return popt, r2, var
        
    if np.abs(r2) > 2:
        var = 10000
        line = None
    else:
        line = ax.plot(X, yhat, lw=lw, color=color, label=f"double exp, R2={r2:.2f} | var: {var:.2f}" if label else None, **kwargs)[0]
    return line, var

## Tuning analysis

In [14]:
cache = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys\tuning_curves\cache")
save_dir = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys\tuning_curves")

colors = dict(
    s = amber_dark, 
    sdot=orange,
    speed = blue_dark, 
    dspeed_250ms = indigo_light,
    dspeed_500ms = deep_purple_light,
    angular_velocity = green_dark,
    dangvel_250ms = teal_light,
    dangvel_500ms = cyan_light,
)

# bins = dict(
#     s = np.linspace(0, 260, 21),
#     sdot = np.linspace(-80, 80, 21),
#     speed = np.linspace(10, 80, 21),
#     dspeed_250ms = np.linspace(-80, 80, 41),
#     dspeed_500ms = np.linspace(-80, 80, 41),
#     angular_velocity = np.linspace(-175, 175, 21),
#     dangvel_250ms = np.linspace(-300, 300, 21),
#     dangvel_500ms = np.linspace(-300, 300, 21),
# )
variables = list(colors.keys())


In [15]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def make_significance_plot(X, Y, parent_ax):
    # fit and plot curves
    p_2, _, var2 = fit_line(X, Y, degree=2)
    p_1, _, var1 = fit_line(X, Y, degree=1)

    if var2 == var1:
        var2 += 1
    best = np.argmin([var1, var2]) + 1
    pbest = p_1[0] if best == 1 else p_2[0]
    
    # fit on shuffled data
    suffled_fits = []
    for shuffled_unit in shuffles:
        y = binned_mu[var][shuffled_unit].values        
        p, _, _ = fit_line(X, y)

        suffled_fits.append(p[0] if best == 1 else p[0])

    p_ci_low, p_ci_high = np.percentile(suffled_fits, .5), np.percentile(suffled_fits, 99.5)
    significant =  pbest < p_ci_low or pbest > p_ci_high
        

    inset_ax = inset_axes(parent_ax,
                        width = "30%", height="20%",
                        loc="upper left",
                        bbox_to_anchor=(0.05, .05, 1, 1), bbox_transform=ax.transAxes,
                        borderpad=1,
                        axes_kwargs=dict(fc=[.95, .95, .95] if significant else "w"))
                        
    inset_ax.hist(suffled_fits, bins=20, color=colors[var], alpha=0.5)
    inset_ax.plot([p_ci_low, p_ci_high], [-1, -1], color=colors[var], lw=5)
    inset_ax.axvline(pbest, color="k" if significant else [.6, .6, .6], lw=4)


## Plotting

For each unit plot tuning curves and fitted curves.

In [16]:
# TODO better quantification of significance for quadratic
# TODO get which units are tuned to wat
# TODO get summary numbers and plots of tuning > tuned to what, kind of tuning curve, directionality...

In [18]:
for rec in get_recording_names():
    # bin the recording data for each variable
    binned_mu = {k: pd.read_hdf(cache /  (f"{rec}_{k}_mu.h5"), key="hdf") for k in variables}
    binned_var = {k: pd.read_hdf(cache /  (f"{rec}_{k}_sigma.h5"), key="hdf") for k in variables}


    units = [c for c in binned_mu['s'].columns[len(variables)+1:] if "shuffle" not in str(c) and isinstance(c, int)]

    for unit in units:
        savepath = save_dir / f"{rec}_{unit}_MOs.png"
        if savepath.exists():
            continue
        shuffles = [c for c in binned_mu['s'].columns[len(variables)+2:] if "shuffle" in str(c) and str(unit) in str(c)]

        f, axes = plt.subplots(3, 3 , figsize=(22, 15), sharey=True)
        f.suptitle(f"{rec}_{unit}")
        axes = axes.flatten()
        
        for axid, var in zip((0, 1, 3, 4, 5, 6, 7, 8), variables):
            ax = axes[axid]

            # plot firing rate of unit
            X, Y = binned_mu[var].bin, binned_mu[var][unit].values
            Y = Y[np.argsort(X)]
            X = np.sort(X)
            
            ax.scatter(X, Y, color="k", s=100)
            ax.scatter(X, Y, color=colors[var], s=80)
            
            # fit and plot curves
            # l3, v3 = fit_double_exp(X, Y, ax)
            l2, v2 = fit_line(X, Y, ax, degree=2)
            l1, v1 = fit_line(X, Y, ax, degree=1)
            best = np.argmin([v1, v2])
            outline([l1, l2][best], color="k", lw=8)

            # plot firing rate of shuffled units
            shuff_X, shuff_Y = [], []
            for shuffled_unit in shuffles[::1]:
                y = binned_mu[var][shuffled_unit].values
                shuff_X.extend(X) 
                shuff_Y.extend(y)
            ax.scatter(shuff_X, shuff_Y, alpha=.15, color=["k"], zorder=-100)

            # mark significance of best fit parameters
            make_significance_plot(X, Y, ax)

            ax.legend(loc="upper right")
            ax.set(ylim=[max(-1, np.nanmin(Y) - 5), np.nanmax(Y)+5])
        
        # style axes
        for ax in axes:
            ax.yaxis.set_tick_params(labelleft=True)
            for tk in ax.get_yticklabels():
                tk.set_visible(True)

        axes[0].set(title = "S", xlabel="S position (cm)", ylabel="Firing rate (Hz)")
        axes[1].set(title = "S dot", xlabel="S speed (cm/s)", ylabel="Firing rate (Hz)")
        axes[2].axis("off")
        
        axes[3].set(title = "Speed", xlabel="Speed (cm/s)", ylabel="Firing rate (Hz)")
        axes[4].set(title = "250ms", xlabel="speed change (cm)")
        axes[5].set(title = "500ms", xlabel="speed change (cm)")

        axes[6].set(title = "Angular velocity", xlabel="Angular velocity (deg/s)", ylabel="Firing rate (Hz)")
        axes[7].set(title = "250ms", xlabel="angular velocity change (deg/s)")
        axes[8].set(title = "500ms", xlabel="angular velocity change (deg/s)")

        clean_axes(f)
        f.tight_layout()
            
        f.savefig(savepath,  dpi=150, transparent=False, facecolor='w', edgecolor='w')
        plt.close(f)
        # break
    break