# Bouts ratser

For each unit in a session make raster plots aligned to locomotion bouts

In [1]:
# imports
import sys
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from myterial import grey_dark
from mpl_toolkits.axes_grid1 import make_axes_locatable

sys.path.append("./")
sys.path.append(r"C:\Users\Federico\Documents\GitHub\pysical_locomotion")

from fcutils.plot.elements import plot_mean_and_error
from fcutils.plot.figure import clean_axes

from analysis.ephys.utils import get_recording_names, get_data, get_session_bouts, curves
from analysis.ephys.viz import bouts_raster

save_folder = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys")

# print all available recordings
print(get_recording_names())


Connecting root@127.0.0.1:3306


['FC_220408_BAA1101192_hairpin' 'FC_220409_BAA1101192_hairpin'
 'FC_220410_BAA1101192_hairpin' 'FC_220411_BAA1101192_hairpin'
 'FC_220412_BAA1101192_hairpin' 'FC_220413_BAA1101192_hairpin'
 'FC_220414_BAA1101192_hairpin' 'FC_220415_BAA1101192_hairpin'
 'FC_220432_BAA1101192_hairpin' 'FC_220433_BAA1101192_hairpin'
 'FC_220434_BAA1101192_hairpin' 'FC_220435_BAA1101192_hairpin']


## get data

In [2]:
REC = "FC_220410_BAA1101192_hairpin"

units, left_fl, right_fl, left_hl, right_hl, body = get_data(REC)
out_bouts = get_session_bouts(REC)
in_bouts = get_session_bouts(REC, direction="inbound")
incomplete_bouts = get_session_bouts(REC, complete="false")
in_incomplete_bouts = get_session_bouts(REC, direction="inbound", complete="false")

## Plot

In [3]:

def bouts_raster(ax, unit, bouts, color="k", ds=5, label=""):
    """
        Plot a unit's spikes aligned to bouts. Unlike time_aligned_raster, this function 
        plots spikes as a function of track progression, not time!
    """
    # make frate ax
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("top", size="50%", pad=0.05)

    # get and plot the position (s) at each spike
    S, Y = [], []
    n = len(bouts)
    if n == 0:
        return
        
    h = 1 / n
    frates = {s:[] for s in np.arange(0, 260+ds, step=ds)}
    for i, bout in bouts.iterrows():
        # get spikes during bout (based on frames)
        trial_spikes = unit.spikes[
            (unit.spikes >= bout.start_frame) & (unit.spikes < bout.end_frame)
        ]

        spikes_s = np.array(bout.s)[trial_spikes - bout.start_frame]
        S.extend(spikes_s)
        Y.extend(np.zeros_like(trial_spikes) + (i * h))

        # mark the start/end of the bout
        ax.scatter(
            [
                bout.s[0],
                bout.s[-1],
            ],
            [i * h, i * h],
            s=24,
            color="red",
            alpha=1,
            marker="|",
        )

        # get firing rate
        for s in frates.keys():
            d = (bout.s - s)**2
            idx = np.argmin(d)
            if d[idx] < 1:
                frates[s].append(unit.firing_rate[idx])


    ax.scatter(S, Y, s=4, color=color, alpha=1, marker="|")
    mean_frate = np.array([np.mean(v) if v else 0.0 for v in frates.values()])
    std_frate = np.array([np.std(v) if v else 0.0 for v in frates.values()])

    plot_mean_and_error(mean_frate, std_frate, cax, x=np.array(list(frates.keys())), color=unit.color)


    # style axes
    ax.set(
        yticks=np.arange(0, 1, 10 / n),
        yticklabels=(np.arange(0, 1, 10 / n) * n).astype(int),
        xlim=[0, 260],
        ylabel=label,
    )
    cax.set(ylabel="Firing rate", xticks=[], xlim=[0, 260])


In [4]:
def make_figure(unit, out_bouts, incomplete_bouts, in_bouts, in_incomplete_bouts):
    fig = plt.figure(figsize=(22, 16))
    axes = fig.subplot_mosaic(
    """
        AAAGGGG
        AAABBBB
        AAACCCC
        AAADDDD
        AAAEEEE
    """
    )
    fig.tight_layout()


    # plot complete out bounds tracking data on 'A'
    for i, bout in out_bouts.iterrows():
        axes["A"].plot(
            body.x[bout.start_frame : bout.end_frame],
            body.y[bout.start_frame : bout.end_frame],
            "-", color="k",
            alpha=1, lw=.5, zorder=-1
        )

        # mark position at spike times
        trial_spikes = unit.spikes[
                (unit.spikes >= bout.start_frame) & (unit.spikes < bout.end_frame)
            ]
        axes["A"].scatter(
            body.x[trial_spikes],
            body.y[trial_spikes],
            s=50,
            color=unit.color,
            alpha=.5,
        )

    # mark position of curves
    for curve in curves.values():
        axes["G"].plot(
            [curve.s0, curve.sf], [0, 0],
            color=curve.color,
            alpha=1,
            lw=6,
            zorder=100,
        )


    # plot gcoord for each bout in 'G'
    for bouts, color, sign in zip([out_bouts, in_bouts], ["black", "red"], [1, -1]):
        for i, bout in bouts.iterrows():
            axes["G"].plot(
                bout.s,
                sign * np.array(bout.speed),
                color=color,
                alpha=.25, lw=.5, zorder=-1
            )


    bouts_raster(axes["B"], unit, out_bouts, color="k", ds=2, label="outbound")
    bouts_raster(axes["C"], unit, incomplete_bouts, color=[.4, .4, .4], ds=2, label="incomplete out")
    bouts_raster(axes["D"], unit, in_bouts, color="r", ds=2, label="inbound")
    bouts_raster(axes["E"], unit, in_incomplete_bouts, color=[.8, .3, .3], ds=2, label="incomplete in")


    axes['G'].axhline(0, color='k', lw=2, zorder=-1)


    for ax in "GBCD":
        axes[ax].set(xlim=[0, 260], xticks=[])
    axes['E'].set(xlim=[0, 260], xlabel="track position (cm)")

    axes["A"].axis("equal")
    _ = axes["A"].axis("off")
    clean_axes(fig)
    return fig

### Make a figure for every recording/unit

In [5]:
skipped = []

for REC in get_recording_names():
    dest = save_folder / REC / "bouts_rasters"
    dest.mkdir(exist_ok=True, parents=True)

    units, left_fl, right_fl, left_hl, right_hl, body = get_data(REC)
    if units is None:
        skipped.append(REC)
        continue

    out_bouts = get_session_bouts(REC)
    in_bouts = get_session_bouts(REC, direction="inbound")
    incomplete_bouts = get_session_bouts(REC, complete="false")
    in_incomplete_bouts = get_session_bouts(REC, direction="inbound", complete="false")

    for (i, unit) in units.iterrows():
        region = unit.brain_region[:3]
        savepath= dest / f"{region}_{unit.unit_id}.png"
        if savepath.exists():
            continue

        fig = make_figure(unit, out_bouts, incomplete_bouts, in_bouts, in_incomplete_bouts)
        fig.suptitle(f"{REC} {unit.name} {len(out_bouts)} bouts")

        
        fig.savefig(savepath, dpi=300, transparent=False, facecolor='w', edgecolor='w')

        # close figure 
        plt.close(fig)
    #     break
    # break

In [6]:
skipped