# Tuning curves

Plot tuning curves wrt `s`, `v` and `omega` (and derivatives) for all cells.

## Collect data
Across all bouts in both directions. 

In [None]:
# imports
import sys
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import pandas as pd
from myterial import amber_dark, blue_dark, green_dark

sys.path.append("./")
sys.path.append(r"C:\Users\Federico\Documents\GitHub\pysical_locomotion")

from fcutils.plot.figure import clean_axes
from fcutils.plot.elements import plot_mean_and_error
from fcutils.progress import track
from analysis.ephys.utils import get_recording_names, get_data, get_session_bouts

save_folder = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys")



In [None]:
def trim_bouts(bouts:pd.DataFrame):
    """ 
        Trim bouts to times where the speed is high enough.
    """

    starts, ends= [], []
    for i, bout in bouts.iterrows():
        start = np.where(bout["speed"] > 10)[0][0]
        end = np.where(bout["speed"][start:] < 10)[0][0] + start
        starts.append(start)
        ends.append(end)
    
    bouts["trim_start"] = np.array(starts) + bouts["session_start_frame"].values
    bouts["trim_end"] = np.array(ends) + bouts["session_start_frame"].values
    return bouts





In [None]:
def get_rec_data(recording:str, bins:dict) -> pd.DataFrame:
    """
        Get all data for a recording.

        Returns a dataframe with speed, angvel... and firing rate of
        each unit at every frame from all locomotion bout.
    """
    units, left_fl, right_fl, left_hl, right_hl, body = get_data(REC)
    units = units.loc[units.brain_region.isin(["MOs", "MOs1", "MOs2/3", "MOs5", "MOs6a", "MOs6b"])]

    out_bouts = trim_bouts(get_session_bouts(REC, complete=None))
    in_bouts = trim_bouts(get_session_bouts(REC, direction="inbound", complete=None))

    rec_data = {
        **{k:[] for k in bins.keys()}
        **{unit.unit_id:[] for i, unit in units.iterrows()}
    }

    # get data into a single big dataframe
    for bouts in (out_bouts, in_bouts):
        for i, bout in bouts.iterrows():
            speed = body.speed[bout.trim_start : bout.trim_end]
            acceleration = body.acceleration[bout.trim_start : bout.trim_end]
            angvel = body.angular_velocity[bout.trim_start : bout.trim_end]
            angular_acceleration = body.angular_acceleration[bout.trim_start : bout.trim_end]
            s = bout.s[bout.trim_start - bout.session_start_frame : bout.trim_end - bout.session_start_frame]

            for n, unit in units.iterrows():
                frate = unit.firing_rate[bout.trim_start : bout.trim_end]
                rec_data[unit.unit_id].extend(frate)
                rec_data[s].extend(s)
                rec_data[speed].extend(speed)
                rec_data[angvel].extend(angvel)
                rec_data[acceleration].extend(acceleration)
                rec_data[angular_acceleration].extend(angular_acceleration)

    return pd.DataFrame(rec_data), units

In [None]:
"""
Store a dictionary with a dataframe for each of the variables. 
In these datafranes each row is a unit and each column is the firing rate
at each column is a bin with associated mean/std of firing rate.

"""

bins = dict(
    s = np.arange(0, 260, step=10),
    speed = np.arange(10, 100, step=10),
    angular_velocity = np.arange(-250, 250, step=25),
    acceleration = np.arange(-100, 100, step=20),
    angular_acceleration = np.arange(-500, 500, step=50),
)

bins_centers = {
    key: (bins[key][1:] + bins[key][:-1])/2 for key in bins
}

results = {
    v : {
        **dict(unit_id = []),
        **{f"{b}_mean":[] for b in bins_centers[v]},
        **{f"{b}_std":[] for b in bins_centers[v]},
    } for v in bins.keys()
} 

for REC in get_recording_names():
    data, units = get_rec_data(REC, bins)

    # _res = {REC+"_"+unit.unit_id: {v:[] for v in bins.keys()} for i, unit in units.iterrows()}
    for var in ("s", "speed", "angular_velocity"):
        varbins = pd.cut(data[var], bins=bins[var])
        data["bins"] = varbins

        # group by bin and get average/std
        data = data.groupby("bins")
        mu = data.mean()
        sigma = data.std()

        for i, unit in units.iterrows():
            results[var]["unit_id"].append(REC+"_"+unit.unit_id)

            for b in bins_centers[var]:
                results[var][f"{b}_mean"].append(mu[b][unit.unit_id])
                results[var][f"{b}_std"].append(sigma[b][unit.unit_id])
    break

results = {k:pd.DataFrame(v) for k, v in results.items()}

results['speed'].head()
    


## Plot

for each unit, plot tuning curve for each variable.

In [None]:

all_units = results["speed"]["unit_id"].unique()
colors = dict(
    s = amber_dark, speed = blue_dark, angular_velocity = green_dark
)
for unit in track(all_units):
    # get unit data for each variable
    unit_data = {v:results[v][results[v]["unit_id"] == unit] for v in results.keys()}

    # plot tuning curves
    f, axes = plt.subplots(len(bins.keys()), 1 , figsize=(16, 10))
    for ax, var in zip(axes, bins.keys()):
        mu = unit_data[var][f"{b}_mean"].values
        sigma = unit_data[var][f"{b}_std"].values
        plot_mean_and_error(mu, sigma, x=bins_centers[var], ax=ax)
        
        ax.set(title=var, ylabel="Firing rate (Hz)", xticks=bins_centers[var])

    axes[0].set(xlabel="S position (cm)")
    axes[1].set(xlabel="Speed (cm/s)")
    axes[2].set(xlabel="Angular velocity (deg/s)")
    axes[4].set(xlabel="Acceleration (cm/s^2)")
    axes[5].set(xlabel="Angular acceleration (deg/s^2)")
    clean_axes(ax)
    break

