# Tuning curves analysis

Analysis inspired by: http://dx.doi.org/10.1371/journal.pone.0203900
With data pre-processed in `tuning_curves_generate_data.ipynb`.

In [164]:
# imports
import sys
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import pandas as pd
from myterial import amber_dark, blue_dark, green_dark, indigo_light, teal_light, salmon, salmon_darker, red

sys.path.append("./")
sys.path.append(r"C:\Users\Federico\Documents\GitHub\pysical_locomotion")

from fcutils.plot.figure import clean_axes
from fcutils.plot.elements import plot_mean_and_error
from fcutils.progress import track
from analysis.ephys.utils import get_recording_names, get_data, get_session_bouts, trim_bouts
from analysis.ephys.viz import outline
from fcutils.maths import derivative

save_folder = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys")



In [5]:
cache = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys\tuning_curves\cache")
variables = ("s", "speed", "angular_velocity", "acceleration", "angular_acceleration")

def load_rec_preprocessed_data(rec) -> dict:
    rec_data = {
        "grouped_chunks": {var: pd.read_csv(cache / f"{rec}_{var}_grouped_chunks.csv") for var in variables},
        "grouped_chunks_means": {var: pd.read_csv(cache / f"{rec}_{var}_grouped_chunks_means.csv") for var in variables},
        "grouped_chunks_stds": {var: pd.read_csv(cache / f"{rec}_{var}_grouped_chunks_stds.csv") for var in variables},
    }

    return rec_data

## Curve fitting functions

In [158]:
from scipy.optimize import curve_fit

# code to fit lines, polinomials and double exponentials & plot
def r_squared(y, y_hat):
    y_bar = y.mean()
    ss_tot = ((y-y_bar)**2).sum()
    ss_res = ((y-y_hat)**2).sum()
    return 1 - (ss_res/ss_tot)


def fit_line(X, Y,  ax, bins, degree=1):
    if degree == 1:
        a, b = np.polyfit(X, Y, degree)
        yhat = a*X+b
        r2 = r_squared(Y, yhat)
        return ax.plot(X, yhat, lw=3, color=salmon, label=f"{degree} degree, R2={r2:.2f}")[0]
    else:
        p = np.polyfit(X,Y,degree)
        yhat = np.polyval(p,X)
        r2 = r_squared(Y, yhat)
        return ax.plot(X, yhat, lw=3,  color=salmon_darker, label=f"{degree} degree, R2={r2:.2f}")[0]


def double_exp(x, a, b, c, d):
    return a * np.exp(b * x) + c * np.exp(d * x)



def fit_double_exp(X, Y,  ax, bins):
    try:
        popt, pcov = curve_fit(double_exp, X, Y, maxfev=5000)
    except:
        return
    curvey = double_exp(X,*popt)

    r2 = r_squared(Y, curvey)
    if np.abs(r2) > 2:
        r2 = np.nan
    else:
        ax.plot(X, curvey, lw=3, color=red, label=f"double exp, R2={r2:.2f}")

## Plotting

For each unit plot tuning curves and fitted curves.

In [137]:
def get_intervals_mipoints(data:pd.Series):
    intervals = data.values
    low = [float(i[1:].split(",")[0]) for i in intervals]
    high = [float(i.split(",")[1][:-1]) for i in intervals]
    mid = [(low[i] + high[i]) / 2 for i in range(len(low))]
    return np.array(mid)


def list_from_str(X):
    return [float(x) for x in X[1:-1].split(", ")]


def parse_unit_data(unit_chunks_means, unit_chunks_stds):
    unit_chunks_means = unit_chunks_means.apply(list_from_str)
    unit_chunks_stds = unit_chunks_stds.apply(list_from_str)
    return unit_chunks_means, unit_chunks_stds


def get_unit_XY(unit_chunks_means, bs):
    """
        Get the bin value and the chunk's mean value for each chunk
        in a unit's binned firing rate
    """
    X, Y = [], []
    for i, b in enumerate(bs):
        y = unit_chunks_means.iloc[i]
        try:
            X.extend(np.ones_like(y) * b)
        except:
            print(y)
            print(y[0])
            raise ValueError(b)
        Y.extend(y)
    X, Y = np.array(X), np.array(Y)

    # remove nans
    firstnan = np.where(np.isnan(Y))[0]
    if len(firstnan):
        X = X[:firstnan[0]]
        Y = Y[:firstnan[0]]

    return X, Y

In [195]:
save_dir = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys\tuning_curves")

colors = dict(
    s = amber_dark, speed = blue_dark, angular_velocity = green_dark,
        acceleration = indigo_light, angular_acceleration = teal_light,
)


for rec in get_recording_names():
    rec_data = load_rec_preprocessed_data(rec)
    print("     got recording data")

    units = [c for c in rec_data["grouped_chunks_means"]["s"].columns if "shuffle" not in c and "hairpin" in c]
    bins = {
        v:get_intervals_mipoints(rec_data["grouped_chunks_means"][v][v]) for v in variables
    }
    
    for unit in units:
        shuffled_units = [c for c in rec_data["grouped_chunks_means"]["s"].columns if "shuffle" in c and unit in c]

        f, axes = plt.subplots(2, 3 , figsize=(19, 12), sharey=True)
        f.suptitle(unit)
        axes = axes.flatten()
        
        for axid, var in zip((0, 1, 2, 4, 5), variables):
            unit_chunks_means, unit_chunks_stds = parse_unit_data(rec_data["grouped_chunks_means"][var][unit], rec_data["grouped_chunks_stds"][var][unit])
            bs = bins[var]
            ax = axes[axid]

            # plot chunks of shuffled units
            shuff_X, shuff_Y = [], []
            for shuffled_unit in shuffled_units[::10]:
                shuff_unit_chunks_means, shuff_unit_chunks_stds = parse_unit_data(rec_data["grouped_chunks_means"][var][shuffled_unit], rec_data["grouped_chunks_stds"][var][shuffled_unit])
                X, Y = get_unit_XY(shuff_unit_chunks_means, bs)
                shuff_X.extend(X)
                shuff_Y.extend(Y)
            ax.scatter(shuff_X +  np.random.normal(0, np.diff(bs)[0]/10, size=len(shuff_X)), shuff_Y, alpha=.1, color="k")
            # ax.hexbin(shuff_X + np.random.normal(0, np.diff(bs)[0]/10, size=len(shuff_X)), shuff_Y, gridsize=(25, 25), mincnt=1)


            # plot chunks
            X, Y = get_unit_XY(unit_chunks_means, bs)  # mean frate in each chunk and the corresponding bin
            s = ax.scatter(X, Y, color=colors[var], s=50)
            outline(s, lw=2, color="k")

            # fit and plot curves
            outline(fit_line(X, Y, ax, bs, degree=1), lw=4, color="k")
            outline(fit_line(X, Y, ax, bs, degree=2), lw=4, color="k")
            fit_double_exp(X, Y, ax, bs)
            ax.legend()



        for ax in axes:
            ax.yaxis.set_tick_params(labelleft=True)
            for tk in ax.get_yticklabels():
                tk.set_visible(True)

        axes[0].set(title = "S", xlabel="S position (cm)", ylabel="Firing rate (Hz)", xlim=[bins['s'][0]- 10, bins['s'][-1]+10], ylim=[np.min(Y)-5, np.max(Y)+10])
        axes[1].set(title = "Speed", xlabel="Speed (cm/s)", xlim=[bins['speed'][0]- 2, bins['speed'][-1]+2])
        axes[2].set(title = "Ang. Vel.", xlabel="Angular velocity (deg/s)", xlim=[bins['angular_velocity'][0]- 10, bins['angular_velocity'][-1]+10], ylim=[np.min(Y)-5, np.max(Y)+10])
        axes[4].set(title = "Accel.", xlabel="Acceleration (cm/s^2)", xlim=[bins['acceleration'][0]- 10, bins['acceleration'][-1]+10], ylim=[np.min(Y)-5, np.max(Y)+10], ylabel="Firing rate (Hz)")
        axes[5].set(title = "Ang. Acc.", xlabel="Angular acceleration (deg/s^2)", xlim=[bins['angular_acceleration'][0]- 10, bins['angular_acceleration'][-1]+10], ylim=[np.min(Y)-5, np.max(Y)+10])
        axes[3].axis("off")
        clean_axes(f)
            
        f.savefig(save_dir / f"{unit}_MOs.png",  dpi=150, transparent=False, facecolor='w', edgecolor='w')
        plt.close(f)
        
    # break

     got recording data




     got recording data




     got recording data




     got recording data




     got recording data




     got recording data




     got recording data




     got recording data




     got recording data




     got recording data




     got recording data




     got recording data




In [187]:
np.diff(bs)[0]/10