# Tuning curves

Plot tuning curves wrt `s`, `v` and `omega` (and derivatives) for all cells.

## Collect data
Across all bouts in both directions. 

In [1]:
# imports
import sys
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import pandas as pd
from myterial import amber_dark, blue_dark, green_dark, indigo_dark, teal_dark

sys.path.append("./")
sys.path.append(r"C:\Users\Federico\Documents\GitHub\pysical_locomotion")

from fcutils.plot.figure import clean_axes
from fcutils.plot.elements import plot_mean_and_error
from fcutils.progress import track
from analysis.ephys.utils import get_recording_names, get_data, get_session_bouts, trim_bouts
from fcutils.maths import derivative

save_folder = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys")



Connecting root@127.0.0.1:3306


In [2]:
def get_rec_data(recording:str, bins:dict) -> pd.DataFrame:
    """
        Get all data for a recording.

        Returns a dataframe with speed, angvel... and firing rate of
        each unit at every frame from all locomotion bout.
    """
    units, left_fl, right_fl, left_hl, right_hl, body = get_data(recording)
    units = units.loc[units.brain_region.isin(["MOs", "MOs1", "MOs2/3", "MOs5", "MOs6a", "MOs6b"])]

    out_bouts = trim_bouts(get_session_bouts(recording, complete=None))
    in_bouts = trim_bouts(get_session_bouts(recording, direction="inbound", complete=None))

    rec_data = {
        **{k:[] for k in bins.keys()},
        **{unit.unit_id:[] for i, unit in units.iterrows()}
    }
    for i, unit in units.iterrows():
        for rep in range(100):
            rec_data[f"{unit.unit_id}_shuffle_{rep}"] = []

    # get data into a single big dataframe
    for bouts in (out_bouts, in_bouts):
        for i, bout in bouts.iterrows():
            speed = body.speed[bout.trim_start : bout.trim_end]
            acceleration = body.acceleration[bout.trim_start : bout.trim_end] * 60
            angvel = body.angular_velocity[bout.trim_start : bout.trim_end] / 60
            angular_acceleration = derivative(body.angular_velocity)[bout.trim_start : bout.trim_end] / 60
            s = bout.s[bout.trim_start - bout.start_frame : bout.trim_end - bout.start_frame]

            rec_data['s'].extend(s)
            rec_data['speed'].extend(speed)
            rec_data['angular_velocity'].extend(angvel)
            rec_data['acceleration'].extend(acceleration)
            rec_data['angular_acceleration'].extend(angular_acceleration)

            for n, unit in units.iterrows():
                frate = unit.firing_rate[bout.trim_start : bout.trim_end]
                rec_data[unit.unit_id].extend(frate)

                for rep in range(100):
                    shift = np.random.randint(1200, len(body.speed)-1200) - bout.trim_start
                    rec_data[f"{unit.unit_id}_shuffle_{rep}"].extend(unit.firing_rate[bout.trim_start + shift : bout.trim_end + shift])

    # print({k:len(v) for k, v in rec_data.items() if len(v) != 8202})
    return pd.DataFrame(rec_data), units

In [3]:
"""
Store a dictionary with a dataframe for each of the variables. 
In these datafranes each row is a unit and each column is the firing rate
at each column is a bin with associated mean/std of firing rate.

"""
bins = dict(
    s = np.linspace(0, 260, 21),
    speed = np.linspace(10, 80, 21),
    angular_velocity = np.linspace(-175, 175, 21),
    acceleration = np.linspace(-175, 175, 21),
    angular_acceleration = np.linspace(-250, 250, 21),
)

bins_centers = {
    key: (bins[key][1:] + bins[key][:-1])/2 for key in bins
}


recs_data = {REC: get_rec_data(REC, bins) for REC in get_recording_names()}

In [4]:
def split(a):
    k, m = divmod(len(a), 10)
    return list(a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(10))

def chunks_means(a):
    return [np.mean(x) if len(x) else np.nan for x in a]

def chunks_stds(a):
    return [np.std(x) if len(x) else np.nan for x in a]



results = {
    **{f"{v}_grouped" : [] for v in bins.keys()},  # stores a dataframe for each variable and each recording
    **{f"{v}_means" : [] for v in bins.keys()},
    **{f"{v}_stds" : [] for v in bins.keys()},
} 
cache = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys\tuning_curves\cache")

for REC in get_recording_names():
    print(f"Processing {REC}")
    data, units = recs_data[REC]


    units_names = data.columns[5:]
    unit_names_clean = [REC + "_" + str(unit) for unit in units_names]
    colnames = {o:n for o,n in zip(units_names, unit_names_clean)}
    for var in bins.keys():
        savepath = cache /  f"{REC}_{var}_grouped_chunks_stds.csv"
        if not savepath.exists():
            print(f"    doing {var}")
            # group entries based on variable bin, get a list of firing rates for all units at each bin
            _bins = pd.cut(data[var], bins[var])
            grouped = data.groupby(_bins).agg(list)[units_names].rename(columns=colnames, inplace=False)

            # split lists into chunks of equal length and get the average an std of each chunk
            grouped_chunks = grouped.applymap(split)
            grouped_chunks_means = grouped_chunks.applymap(chunks_means)
            grouped_chunks_stds = grouped_chunks.applymap(chunks_stds)

            # save file
            print("         saving...")
            grouped_chunks.to_csv(cache / f"{REC}_{var}_grouped_chunks.csv")
            grouped_chunks_means.to_csv(cache / f"{REC}_{var}_grouped_chunks_means.csv",)
            grouped_chunks_stds.to_csv(cache / f"{REC}_{var}_grouped_chunks_stds.csv",)
        # else:
        #     grouped_chunks = pd.read_csv(cache / f"{REC}_{var}_grouped_chunks.csv")
        #     grouped_chunks_means = pd.read_csv(cache / f"{REC}_{var}_grouped_chunks_means.csv")
        #     grouped_chunks_stds = pd.read_csv(cache / f"{REC}_{var}_grouped_chunks_stds.csv")


        # results[f"{var}_grouped"].append(grouped_chunks)
        # results[f"{var}_means"].append(grouped_chunks_means)
        # results[f"{var}_stds"].append(grouped_chunks_stds)
    # break

# results = {k:pd.concat(v) for k, v in results.items()}
# print(results["speed"].head())


Processing FC_220408_BAA1101192_hairpin
Processing FC_220409_BAA1101192_hairpin
Processing FC_220410_BAA1101192_hairpin
Processing FC_220411_BAA1101192_hairpin
Processing FC_220412_BAA1101192_hairpin
Processing FC_220413_BAA1101192_hairpin
    doing angular_velocity
         saving...
    doing acceleration
         saving...
    doing angular_acceleration
         saving...
Processing FC_220414_BAA1101192_hairpin
    doing s
         saving...
    doing speed
         saving...
    doing angular_velocity
         saving...
    doing acceleration
         saving...
    doing angular_acceleration
         saving...
Processing FC_220415_BAA1101192_hairpin
    doing s
         saving...
    doing speed
         saving...
    doing angular_velocity
         saving...
    doing acceleration
         saving...
    doing angular_acceleration
         saving...
Processing FC_220432_BAA1101192_hairpin
    doing s
         saving...
    doing speed
         saving...
    doing angular_velocity
 

In [None]:
# for k, d in results.items():
#     d.to_hdf(Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys\tuning_curves\cache") / f"{k}.h5", key="hdf")

## Evaluate tuning
For each unit, measure the variance in the tuning curve for each variable and compare it to the variance of shuffled data from the same unit. 

In [None]:
from scipy.optimize import curve_fit

# code to fit lines, polinomials and double exponentials & plot
def r_squared(y, y_hat):
    y_bar = y.mean()
    ss_tot = ((y-y_bar)**2).sum()
    ss_res = ((y-y_hat)**2).sum()
    return 1 - (ss_res/ss_tot)


def fit_line(unit_data:pd.Series,  ax, bins, degree=1):
    X, Y = [], []
    for b, y in zip(bins, unit_data.values):
        if np.any(np.isnan(y)):
            continue
        X.extend(np.ones_like(y) * b)
        Y.extend(y)
    X = np.array(X)
    Y = np.array(Y)

    if degree == 1:
        a, b = np.polyfit(X, Y, degree)
        yhat = a*X+b
        r2 = r_squared(Y, yhat)
        ax.plot(X, yhat, lw=3, color=[.6, .6, .6], label=f"{degree} degree, R2={r2:.2f}")
    else:
        p = np.polyfit(X,Y,degree)
        yhat = np.polyval(p,X)
        r2 = r_squared(Y, yhat)
        ax.plot(X, yhat, lw=3,  color=[.4, .4, .4], label=f"{degree} degree, R2={r2:.2f}")


def double_exp(x, a, b, c, d):
    return a * np.exp(b * x) + c * np.exp(d * x)



def fit_double_exp(unit_data:pd.Series,  ax, bins):
    X, Y = [], []
    for b, y in zip(bins, unit_data.values):
        if np.any(np.isnan(y)):
            continue
        X.extend(np.ones_like(y) * b)
        Y.extend(y)
    X = np.array(X)
    Y = np.array(Y)

    try:
        popt, pcov = curve_fit(double_exp,X, Y,p0=(1.0,1.0,1.0,1.0), maxfev=5000)
    except:
        return
    curvey = double_exp(X,*popt)

    r2 = r_squared(Y, curvey)
    if np.abs(r2) > 2:
        r2 = np.nan
    else:
        ax.plot(X, curvey, lw=3, color=[0, 0, 0], label=f"double exp, R2={r2:.2f}")

## Plot

for each unit, plot tuning curve for each variable.

In [None]:
save_dir = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys\tuning_curves")
all_units = [c for c in results["speed_chunks"].columns[1:]  if "shuffle" not in str(c)]
colors = dict(
    s = amber_dark, speed = blue_dark, angular_velocity = green_dark,
        acceleration = indigo_dark, angular_acceleration = teal_dark,
)

for unit in track(all_units):

    # plot tuning curves
    f, axes = plt.subplots(2, 3 , figsize=(19, 12), sharey=True)
    f.suptitle(unit)
    axes = axes.flatten()
    for axid, var in zip((0, 1, 2, 4, 5), bins.keys()):
        ax = axes[axid]
        bs = bins_centers[var]

        # # plot suffled data
        # shuffled = [c for c in results[var].columns if "shuffle" in str(c) and str(unit) in c]
        # X, Y = [], []    
        # for shuff in shuffled:
        #     unit_data = results[var][unit]
        #     for b, y in zip(bs, unit_data.values):
        #         X.extend(np.ones_like(y) * b)
        #         Y.extend(y)
        # ax.scatter(X, Y, color=[.6, .6, .6], alpha=.2, label="shuffled")


        # unit_data = results[var][unit]


        mus = results[f"{var}_means"][unit].values
        sigmas = results[f"{var}_stds"][unit].values

        # plot mean and error
        mu = [np.mean(x) for x in mus]
        sigma = [np.std(x) for s in sigmas]
        
        # plot all data
        X, Y = [], []
        for i, b in enumerate(bs):
            y = mus[i]
            x = b * np.ones_like(y)
            X.extend(x)
            Y.extend(y)
        ax.scatter(X, Y, color=colors[var], label=f"{var}")

        # mean_frate = np.array([np.mean(x) if not np.any(np.isnan(x)) and len(x) > 3 else np.nan for x in unit_data.values])
        # std_frate = np.array([np.std(x)  if not np.any(np.isnan(x)) and len(x) > 3 else np.nan for x in unit_data.values])

        # plot_mean_and_error(
        #     mean_frate, std_frate, ax, x=bs, color=colors[var], alpha=.3, err_alpha=.1
        # )

        # for b, y in zip(bs, unit_data.values):
        #     x  = np.ones_like(y) * b
        #     ax.scatter(x, y, color=colors[var], s=25, alpha=.5)

        # # fit and plot curves
        # fit_line(unit_data, ax, bs, degree=1)
        # fit_line(unit_data, ax, bs, degree=2)
        # fit_double_exp(unit_data, ax, bs)
        # ax.legend()


    for ax in axes:
        ax.yaxis.set_tick_params(labelleft=True)
        for tk in ax.get_yticklabels():
            tk.set_visible(True)

    axes[0].set(xlabel="S position (cm)", ylabel="Firing rate (Hz)", xlim=[bins['s'][0], bins['s'][-1]])
    axes[1].set(xlabel="Speed (cm/s)", xlim=[bins['speed'][0], bins['speed'][-1]])
    axes[2].set(xlabel="Angular velocity (deg/s)", xlim=[bins['angular_velocity'][0], bins['angular_velocity'][-1]])
    axes[4].set(xlabel="Acceleration (cm/s^2)", xlim=[bins['acceleration'][0], bins['acceleration'][-1]], ylabel="Firing rate (Hz)")
    axes[5].set(xlabel="Angular acceleration (deg/s^2)", xlim=[bins['angular_acceleration'][0], bins['angular_acceleration'][-1]])
    axes[3].axis("off")
    clean_axes(f)
    break

    f.savefig(save_dir / f"{unit}_MOs.png",  dpi=150, transparent=False, facecolor='w', edgecolor='w')
    plt.close(f)

