# GLM data prep

Create a table of data for each recording.
Each row is a millisecond (data only from bouts).
Variables include speeds + shifted speeds, curvature of the track, firing rate...


In [None]:
# imports
import sys
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import pandas as pd
from scipy import interpolate
from fcutils.progress import track
from fcutils.maths import derivative
from scipy import stats
import warnings
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

sys.path.append("./")
sys.path.append(r"C:\Users\Federico\Documents\GitHub\pysical_locomotion")


from analysis.ephys.utils import get_recording_names, get_data, get_session_bouts

save_folder = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys")

cache = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys\GLM\data")
recordings = get_recording_names()


## Parameters

In [None]:
curvature_horizon = 30
curvature_sampling_spacing = 5
curv_sample_points = np.arange(0, curvature_horizon+curvature_sampling_spacing, curvature_sampling_spacing)

track_downsample_factor = 25

firing_rate_gaussian = 250 # width in ms

### Track curvature
Sample the track curvature for N future positions given each track position

In [None]:
track_data = pd.read_json(r"C:\Users\Federico\Documents\GitHub\pysical_locomotion\analysis\ephys\track.json").iloc[::track_downsample_factor]
track_data = track_data.reset_index(drop=True)
S_f = track_data.S.values[-1]
track_data

In [None]:
# load track from json
k_shifts = np.arange(curvature_horizon+1)
curv_shifted = {
    **{f"k_{k}":[] for k in k_shifts},
    **{f"idx_{k}":[] for k in k_shifts},
}
for i, s in enumerate(track_data.S):
    for k in k_shifts:
        if s + k < S_f:
            select = track_data.loc[track_data.S >= s + k]
            curv_shifted[f"idx_{k}"].append(select.index[0])
            curv_shifted[f"k_{k}"].append(select["curvature"].iloc[0])
        else:
            curv_shifted[f"k_{k}"].append(np.nan)
            curv_shifted[f"idx_{k}"].append(np.nan)

    # break

for k,v in curv_shifted.items():
    track_data.insert(2, k, v)
track_data.head()

## Process data

In [None]:

def upsample_frames_to_ms(var):
    """
        Interpolates the values of a variable expressed in frams (60 fps)
        to values expressed in milliseconds.
    """
    t_60fps = np.arange(len(var)) / 60
    f = interpolate.interp1d(t_60fps, var)

    # t_1000fps = np.arange(0, t_60fps[-1], step=1/1000)
    t_200fps = np.arange(0, t_60fps[-1], step=1/200)
    interpolated_variable_values = f(t_200fps)
    return interpolated_variable_values


In [None]:

def gaussian(x, s):
    return (
        1.0
        / np.sqrt(2.0 * np.pi * s ** 2)
        * np.exp(-(x ** 2) / (2.0 * s ** 2))
    )


def calc_firing_rate(spikes_train: np.ndarray, dt: int = 10):
    """
        Computes the firing rate given a spikes train (wether there is a spike or not at each ms).
        Using a gaussian kernel with standard deviation = dt/2 [dt is in ms]
    """
    # create kernel & get area under the curve
    k = np.array([gaussian(x, dt / 2) for x in np.linspace(-2*dt, 2*dt, dt)])
    auc = np.trapz(k)

    # get firing rate
    frate = (
        np.convolve(spikes_train, k, mode="same") / auc * 1000
    )  # times 1000 to go from ms to seconds
    return frate[::5]  # sample every 5 ms -> 200 fps
    # return frate


def make_shuffled_units(units):
    """
        For each unit make shuffled copies in which 
        the firing rate is offset by some ammount looping
        around the start/end of the session
    """
    N = 100
    shuffled_units = dict(unit_id=[], firing_rate_ms=[])
    for i, unit in units.iterrows():
        for n in range(N):
            shuffle = np.random.randint(10 * 200, 30 * 200)  # shuffle between 10 and 30 seconds
            shuffled_units["unit_id"].append(f"{unit.unit_id}_shuffle_{n}")

            frate = unit.firing_rate_ms
            frate = np.hstack([frate[shuffle:], frate[:shuffle]])
            shuffled_units["firing_rate_ms"].append(frate)

    # merge units and shuffle units in a single dataframe
    shuffled_units = pd.DataFrame(shuffled_units)
    units = pd.concat([units, shuffled_units], ignore_index=True)
    return units

#### main data loader

In [None]:
# upsample
def load_get_recording_data(REC): 
    # load data
    units, left_fl, right_fl, left_hl, right_hl, body = get_data(REC)


    out_bouts = get_session_bouts(REC, complete=None)
    in_bouts = get_session_bouts(REC, direction="inbound", complete=None)

    v = upsample_frames_to_ms(body.speed)
    omega = upsample_frames_to_ms(body.thetadot)

    v_250ms = np.hstack([v[250:], v[250] * np.ones(250)]) - v
    v_500ms = np.hstack([v[500:], v[500] * np.ones(500)]) - v
    v_1000ms = np.hstack([v[1000:], v[1000] * np.ones(1000)]) - v

    omega_500ms = np.hstack([omega[500:], omega[500] * np.ones(500)]) - omega
    omega_250ms = np.hstack([omega[250:], omega[250] * np.ones(250)]) - omega
    omega_1000ms = np.hstack([omega[1000:], omega[1000] * np.ones(1000)]) - omega


    # get unit firing rate in milliseconds
    units = units.loc[units.brain_region.isin(["MOs", "MOs1", "MOs2/3", "MOs5", "MOs6a", "MOs6b"])]
    frates = []
    for i, unit in units.iterrows():
        time = np.zeros(len(v) * 5)  # time in milliseconds
        spikes_times = np.int64(np.round(unit.spikes_ms))
        spikes_times = spikes_times[spikes_times < len(time)]
        time[spikes_times] = 1
        frates.append(calc_firing_rate(time, dt=firing_rate_gaussian))  # firing rate at 200fps
    units["firing_rate_ms"] = frates
    units = units[["unit_id", "firing_rate_ms"]]  # discard unnecessary columns

    # add shuffled units
    units = make_shuffled_units(units)

    return units, body, pd.concat([out_bouts, in_bouts]).reset_index(), v, omega, v_250ms, omega_250ms, v_500ms, omega_500ms, v_1000ms, omega_1000ms

### Collect data for all bouts

In [None]:
for REC in recordings:
    savepath = cache / f"{REC}_bouts.h5"
    if savepath.exists():
        print(f"{REC}_bouts.h5 already exists")
        continue
    
    print(f"Doing   {REC}")
    out_bouts = get_session_bouts(REC, complete=None)
    in_bouts = get_session_bouts(REC, direction="inbound", complete=None)
    allbouts = pd.concat([out_bouts, in_bouts]).reset_index()

    bouts_files = list(cache.glob(f"{REC}_bout_*.h5"))
    if len(bouts_files) < len(allbouts):
        print(f"    Not all bouts were saved for {REC}")
        continue

    bouts_data = []
    for i, bout in allbouts.iterrows():
        f = cache / f"{REC}_bout_{bout.start_frame}.h5"
        if not f.exists():
            break
        bouts_data.append(pd.read_hdf(f, key="data"))
    bouts_data = pd.concat(bouts_data)

    print(f" Got all data ({bouts_data.shape}), removing outliers")
    bouts_data[(np.abs(stats.zscore(bouts_data)) < 3).all(axis=1)]

    print(" Saving data")
    bouts_data.to_hdf(savepath, key="data")
    print(" Saved all data")


In [None]:
# first, for each recording and each bout save a .h5 file. 
# Then load it back and merge them toghether. This is to avoid memory issues.

for REC in recordings:
    if (cache / f"{REC}_bouts.h5").exists():
        print(f"{REC}_bouts.h5 already exists")
        continue
    print(f"Processing {REC}")

    units, body, bouts, v, omega, v_250ms, omega_250ms, v_500ms, omega_500ms, v_1000ms, omega_1000ms = load_get_recording_data(REC)
    print("     got all data")

    for i, bout in track(bouts.iterrows(), total=len(bouts), description=REC):
        bout_savepath = cache / f"{REC}_bout_{bout.start_frame}.h5"
        if bout_savepath.exists():
            print(f"{REC}_bout_{bout.start_frame}.h5 already exists")
            continue

        data = {
            **dict(
                s=[],
                sdot=[],
                v=[],
                dv_250ms=[],
                dv_500ms=[],
                dv_1000ms=[],
                omega=[],
                domega_250ms=[],
                domega_500ms=[],
                domega_1000ms=[],
            ),
            **{f"curv_{k}cm":[] for k in curv_sample_points},
            **{unit:[] for unit in units.unit_id.values},
        }


        S = upsample_frames_to_ms(bout.s)
        data['s'].extend(S)
        data['sdot'].extend(derivative(S) * 60)

        start_ms = int(bout.start_frame / 60 * 200)
        end_ms = start_ms + len(S)
        data['v'].extend(v[start_ms : end_ms])
        data['dv_250ms'].extend(v_250ms[start_ms : end_ms])
        data['dv_500ms'].extend(v_500ms[start_ms : end_ms])
        data['dv_1000ms'].extend(v_1000ms[start_ms : end_ms])
        data['omega'].extend(omega[start_ms : end_ms])
        data['domega_250ms'].extend(omega_250ms[start_ms : end_ms])
        data['domega_500ms'].extend(omega_500ms[start_ms : end_ms])
        data['domega_1000ms'].extend(omega_1000ms[start_ms : end_ms])

        # get firing rate
        for i, unit in units.iterrows():
            data[unit.unit_id].extend(unit.firing_rate_ms[start_ms : end_ms])

        # get curvature
        for k_cm in curv_sample_points:
            for s in S:
                idx = np.argmin((track_data.S - s)**2)
                data[f"curv_{k_cm}cm"].append(track_data[f"k_{k_cm}"][idx])


        # ensure all entries have the same number of samples
        lengths = set([len(v) for v in data.values()])
        if len(lengths) > 1:
            lns = {k:len(v) for k,v in data.items()}        
            raise ValueError(f"Lengths of data are not the same:\n{lns}")

    
        pd.DataFrame(data).to_hdf(bout_savepath, key="data")
        del data

    del units
    del body
    del bouts

    

