# Curves frate
Plot average frate as a function of s relative to apex position for each curve and cell

In [1]:
# imports
import sys
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np

sys.path.append("./")
sys.path.append(r"C:\Users\Federico\Documents\GitHub\pysical_locomotion")

from fcutils.plot.figure import clean_axes
from fcutils.plot.elements import plot_mean_and_error
from fcutils.maths.signals import rolling_mean
from analysis.ephys.utils import get_recording_names, get_data, get_session_bouts, curves, get_roi_crossings

save_folder = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys")

# print all available recordings
print(get_recording_names())

Connecting root@127.0.0.1:3306


['FC_220408_BAA1101192_hairpin' 'FC_220409_BAA1101192_hairpin'
 'FC_220410_BAA1101192_hairpin' 'FC_220411_BAA1101192_hairpin'
 'FC_220412_BAA1101192_hairpin' 'FC_220413_BAA1101192_hairpin'
 'FC_220414_BAA1101192_hairpin' 'FC_220415_BAA1101192_hairpin'
 'FC_220432_BAA1101192_hairpin' 'FC_220433_BAA1101192_hairpin'
 'FC_220434_BAA1101192_hairpin' 'FC_220435_BAA1101192_hairpin']


In [2]:
REC = "FC_220410_BAA1101192_hairpin"

units, left_fl, right_fl, left_hl, right_hl, body = get_data(REC)
out_bouts = get_session_bouts(REC, complete=None)
in_bouts = get_session_bouts(REC, direction="inbound", complete=None)


In [3]:
def frate(spikes, s, bins):
    """
        Given the s position at each frame at the frame at which spikes occurred during a ROI crossing,
        get firing rate wrt position relative to the apex.
    """
    s = np.round(s).astype(np.int64)
    counts = np.histogram(s, bins=bins)[0]
    nspikes_per_bin = np.histogram(s[spikes], bins=bins)[0]
    return nspikes_per_bin / counts * 60

In [4]:
def plot_crossings(axes, crossings, curve, bouts, unit, frate_ax_id, raster_ax_id, color, yshift=0, sign=-1):
    bins = np.arange(-30, 30+4, step=1)
    _frates = {b:[] for b in bins}

    for n, cross in crossings.iterrows():
        s = sign * (curve.s - np.array(bouts.iloc[cross.bout_idx].s[cross.enter_frame:cross.exit_frame]))
        spikes = unit.spikes[(unit.spikes > cross.session_start_frame)&(unit.spikes < cross.session_end_frame)] - cross.session_start_frame
        axes[str(frate_ax_id)].plot(
            s[spikes],
            np.ones_like(spikes)*(n+yshift),
            ".",
            color=color
        )

        cross_frate = rolling_mean(frate(spikes, s, bins), 9)
        for b, fr in zip(bins, cross_frate):
            _frates[b].append(fr)

    if len(crossings) > 5:
        frates = np.vstack(list(_frates.values())[:-1])
        mean_frate = frates.mean(axis=1)
        plot_mean_and_error(
            mean_frate, frates.std(axis=1)/np.sqrt(len(crossings)), axes[str(raster_ax_id)], x=bins[:-1], color=color
        )


def plot_avg_speed_wrt_s(axes, crossings, bouts, curve, ax_id, color, sign=-1, angvel = False):
    bins = np.arange(-30, 30+1, step=1)
    speeds = {b:[] for b in bins}

    for n, cross in crossings.iterrows():
        bout = bouts.iloc[cross.bout_idx]
        s = sign * (curve.s - np.array(bout.s[cross.enter_frame:cross.exit_frame]))
        s = np.round(s).astype(np.int64)

        if angvel:
            speed = rolling_mean(bout.angvel[cross.enter_frame:cross.exit_frame], 11)
        else:
            speed = rolling_mean(bout.speed[cross.enter_frame:cross.exit_frame], 11)
        
        for (_s, S) in zip(s, speed):
            if _s in speeds.keys():
                speeds[_s].append(S)

    mean_speed = [np.mean(v) if len(v) else np.nan for v in speeds.values()]
    sem_speed = np.array([np.std(v) if len(v) else np.nan for v in speeds.values()]) / np.sqrt(len(crossings))

    plot_mean_and_error(
            mean_speed, sem_speed, axes[str(ax_id)], x=bins, color=color
        )


In [5]:
def make_figure(out_bouts, in_bouts, unit):
    fig = plt.figure(figsize=(20, 14))
    axes = fig.subplot_mosaic(
    """
        OMNP
        1234
        5678
        ABCD
        EFGH
    """
    )

    speeds_axes = "ABCD"
    avel_axes = "EFGH"
    xy_axes = "OMNP"
    x_shift = [15, 22, 15, 5]
    y_shift = [0, 0, 0, 10]

    for i, curve in enumerate(curves.keys()):
        crossings = get_roi_crossings(out_bouts, curve, direction="out")
        in_crossings = get_roi_crossings(in_bouts, curve, direction="in")
        curve = curves[curve]


        plot_crossings(axes, crossings, curve, out_bouts, unit, i+5, i+1, "k", yshift=0, sign=-1)
        plot_crossings(axes, in_crossings, curve, in_bouts, unit, i+5, i+1, "r", yshift=len(crossings), sign=-1)

        plot_avg_speed_wrt_s(axes, crossings, out_bouts, curve, speeds_axes[i], "k", sign=-1)
        plot_avg_speed_wrt_s(axes, in_crossings, in_bouts, curve, speeds_axes[i], "r", sign=-1)
        plot_avg_speed_wrt_s(axes, crossings, out_bouts, curve,  avel_axes[i], "k", sign=-1, angvel=True)
        plot_avg_speed_wrt_s(axes, in_crossings, in_bouts, curve, avel_axes[i], "r", sign=-1, angvel=True)

        for n, cross in crossings.iterrows():
            bout = out_bouts.iloc[cross.bout_idx]
            axes[xy_axes[i]].scatter(
                (np.array(bout.x[cross.enter_frame:cross.exit_frame]) - x_shift[i]),
                (np.array(bout.y[cross.enter_frame:cross.exit_frame]) - y_shift[i]),
                c=np.arange(cross.exit_frame - cross.enter_frame), cmap="Greys", s=5, vmin=-5, vmax=100
            )

        for n,  cross in in_crossings.iterrows():
            bout = in_bouts.iloc[cross.bout_idx]
            axes[xy_axes[i]].scatter(
                np.array(bout.x[cross.enter_frame:cross.exit_frame]),
                np.array(bout.y[cross.enter_frame:cross.exit_frame]),
                c=np.arange(cross.exit_frame - cross.enter_frame), cmap="Reds", s=5, vmin=-5, vmax=100
            )


    for ax in "1234":
        axes[ax].set(xlim=[-30, 30], xticks=[],  ylabel="frate (Hz)")
        axes[ax].axvline(0, color="k", lw=2, alpha=.8)

    for ax in "5678":
        axes[ax].set(xlim=[-30, 30], xticks=[], ylabel="crossing #")
        axes[ax].axvline(0, color="k", lw=2, alpha=.8)

    for ax in speeds_axes:
        axes[ax].set(xlim=[-30, 30], ylabel="speed (cm/s)")
        axes[ax].axvline(0, color="k", lw=2, alpha=.8)


    for ax in avel_axes:
        axes[ax].set(xlim=[-30, 30], xlabel="delta S (cm)", ylabel="ang. vel. (def/s)")
        axes[ax].axvline(0, color="k", lw=2, alpha=.8)

    for ax in xy_axes:
        axes[ax].axis("equal")
        axes[ax].axis("off")

    clean_axes(fig)
    fig.suptitle(f"{unit.unit_id} - {unit.brain_region}")
    return fig

plot all sessions/units

In [6]:
for REC in get_recording_names():
    dest = save_folder / REC / "curves_rasters"
    dest.mkdir(exist_ok=True, parents=True)

    units, left_fl, right_fl, left_hl, right_hl, body = get_data(REC)
    if units is None:
        continue

    out_bouts = get_session_bouts(REC, complete=None)
    in_bouts = get_session_bouts(REC, direction="inbound", complete=None)

    for (i, unit) in units.iterrows():
        region = unit.brain_region[:3]
        savepath= dest / f"{region}_{unit.unit_id}.png"
        # if savepath.exists():
        #     continue

        fig = make_figure(out_bouts, in_bouts, unit)
        fig.savefig(savepath, dpi=200, transparent=False, facecolor='w', edgecolor='w')

        # close figure 
        plt.close(fig)
    #     break
    # break
        

In [7]:
print(out_bouts.iloc[0])

mouse_id                                              BAA1101192
name                                FC_220435_BAA1101192_hairpin
start_frame                                                 4167
end_frame                                                   4544
duration                                                 6.28333
direction                                               outbound
complete                                                    true
s              [4.605061300745047, 4.743562081698704, 4.83088...
x              [21.953143118212125, 22.07040525670981, 22.231...
y              [28.402032563898022, 28.27276050166269, 28.194...
speed          [0.0, 5.235982016743666, 7.06930054863444, 13....
angvel         [-1.676693317527222, -1.6582054323921707, -1.6...
Name: 0, dtype: object
