In [1]:
# imports
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.signal import medfilt

sys.path.append("./")
sys.path.append(r"C:\Users\Federico\Documents\GitHub\pysical_locomotion")

from data.dbase.db_tables import Probe, Unit, Session, ValidatedSession, Recording, Tracking
from data.data_utils import convolve_with_gaussian
from data.dbase import db_tables

from fcutils.plot.elements import plot_mean_and_error
from scipy.stats import ttest_ind as ttest
from fcutils.maths.signals import get_onset_offset
from fcutils.plot.figure import clean_axes, calc_nrows_ncols
from myterial import amber_darker, green_dark, grey_darker, grey_dark


save_folder = Path(r"D:\Dropbox (UCL)\Rotation_vte\Locomotion\analysis\ephys")


# print all available recordings
print(*zip(*(Recording - "mouse_id='BAA1101192'").fetch("name", "recording_probe_configuration")), sep="\n")

Connecting root@127.0.0.1:3306


('FC_210715_AAA1110750_r5_hairpin', 'longcolumn')
('FC_210716_AAA1110750_r6_hairpin', 'longcolumn')
('FC_210720_AAA1110750_hairpin', 'longcolumn')
('FC_210721_AAA1110750_hairpin', 'longcolumn')
('FC_210722_AAA1110750_hairpin', 'longcolumn')
('FC_211022_BAA110516_hairpin', 'longcolumn')
('FC_211027_BAA110516_hairpin', 'longcolumn')
('FC_211214_BAA110517_hairpin', 'b0')
('FC_220114_BAA110517_hairpin', 'b0')
('FC_220117_BAA110517_hairpin', 'b0')
('FC_220119_BAA110517_hairpin', 'b0')
('FC_220120_BAA110517_hairpin', 'b0')
('FC_210917_BAA1110279_hairpin', 'b0')
('FC_210820_BAA1110281_hairpin', 'longcolumn')
('FC_210829_BAA1110281_hairpin', 'longcolumn')
('FC_210830_BAA1110281_hairpin', 'longcolumn')
('FC_210831_BAA1110281_hairpin', 'longcolumn')
('FC_210901_BAA1110281_hairpin', 'longcolumn')
('FC_210906_BAA1110281_hairpin', 'b0')
('FC_210917_BAA1110281_hairpin', 'b0')


In [15]:
# get the speed of each limb (somehow its not in the database)
def get_speed(x, y):
    """
        Compute speed at each frame from XY coordinates
    """
    rawspeed = np.hstack([[0], np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2)]) * 60
    return convolve_with_gaussian(rawspeed, 9)

def get_data(recording):
    tracking = Tracking.get_session_tracking(recording, body_only=False)

    # units = pd.DataFrame(Unit * Unit.Spikes * Probe.RecordingSite & f'name="{recording}"')
    recording = (db_tables.Recording & f"name='{recording}'").fetch1()
    cf = recording["recording_probe_configuration"]
    units = db_tables.Unit.get_session_units(
        recording["name"],
        cf,
        spikes=True,
        firing_rate=False,
        frate_window=100,
    )
    if len(units):
        units = units.sort_values("brain_region", inplace=False).reset_index()

    left_fl = tracking.loc[tracking.bpname == "left_fl"].iloc[0]
    right_fl = tracking.loc[tracking.bpname == "right_fl"].iloc[0]
    left_hl = tracking.loc[tracking.bpname == "left_hl"].iloc[0]
    right_hl = tracking.loc[tracking.bpname == "right_hl"].iloc[0]
    body = tracking.loc[tracking.bpname == "body"].iloc[0]

    for limb in (left_fl, right_fl, left_hl, right_hl):
        limb.speed = get_speed(limb.x, limb.y)

    return units, left_fl, right_fl, left_hl, right_hl, body




## Get walking onsets

Different functions to get walking onset, based on either the paws movements or the body movement

get when all the paws are not moving: the animal is stationary. 

Get the time points at which the left/right paws start moving as walking onsets. 

In [3]:
def get_stationary(left_fl, right_fl, left_hl, right_hl, body, SPEED_TH):
    """ get when none of the paws is moving """
    limbs_moving = {l.bpname: l.speed > SPEED_TH for l in (left_fl, right_fl, left_hl, right_hl, body)}

    stationary = np.sum(
            np.vstack(list(limbs_moving.values())), axis=0
        )
    stationary = stationary == 0
    return stationary


def get_walking_from_paws(paw, SPEED_TH):
    """
        "walking" is whenthe paw is moving fast enough
    """
    above = np.argwhere(paw.speed > SPEED_TH).flatten()
    walking, _ = get_onset_offset(np.diff(above), 2)
    walking = np.where(np.diff(above) > 2)[0]
    return above[walking+1]
    # return above


def get_walking_and_stationary(left_fl, right_fl, left_hl, right_hl, body, SPEED_TH):
        

    right_walking  = get_walking_from_paws(right_fl, SPEED_TH)
    left_walking  = get_walking_from_paws(left_fl, SPEED_TH)
    stationary = get_stationary(left_fl, left_fl, left_hl, right_hl, body, SPEED_TH)

    stationary_off, stationary_on = get_onset_offset(stationary, .5)
    return stationary_off, stationary_on, left_walking, right_walking, stationary

# stationary_off, stationary_on, left_walking, right_walking, stationary = get_walking_and_stationary(left_fl, right_fl, left_hl, right_hl, body, SPEED_TH)


# T = 300
# scaled_stationary = (stationary * np.max(left_fl.speed))  # for plotting

# f, ax = plt.subplots(figsize=(20, 10))

# plt.scatter(left_walking, left_fl.speed[left_walking], color="b", s=60)
# plt.scatter(right_walking, right_fl.speed[right_walking], color="r", s=60)
# plt.scatter(stationary_on[:-2], scaled_stationary[stationary_on[:-2]+1], color="k", s=60)

# plt.plot(left_fl.speed)
# plt.plot(right_fl.speed, alpha=.4)
# plt.plot( scaled_stationary, color="k", alpha=.2)
# ax.axhline(SPEED_TH, color="k", ls="--")
# ax.set(xlim=[0, T])

Get all the times the mouse has a L/R walking onset as long as the other paw is stationary and there was a long enough period of stationary before + the bout is long enough and the mouse moves fast enough.

In [13]:
def get_trials(left_walking, right_walking, stationary_on, stationary_off):
    """
        get the trials from the walking bouts
    """
    
    trials = dict(
        onset=[],
        paw = [],
        offset = [],
    )

    for (side, walking_onset) in zip(("left", "right"), (left_walking, right_walking)):
        if side == "left":
            other = right_walking
            this = left_fl
        else:
            other = left_walking
            this = right_fl

        for onset in walking_onset:
            # get the last movement onset for othe paw 
            other_start_moving = [o for o in other if o < onset]
            other_start_moving = other_start_moving[-1] if other_start_moving else 0

            # get last stationary offset (startmoving)
            start_moving = [o for o in stationary_off if o < onset]
            if start_moving:
                start_moving = start_moving[-1]
            else:
                continue
            

            # get the next stationary onset (stop moving)
            stop_moving = [o for o in stationary_on if o > onset]
            if stop_moving:
                stop_moving = stop_moving[0]
            else:
                continue

            # get the last offset before the stationary onset
            last_stop_moving = [o for o in stationary_on if o < start_moving]
            if last_stop_moving:
                last_stop_moving = last_stop_moving[-1]
            else:
                continue
            

            assert stop_moving > start_moving
            assert last_stop_moving < start_moving

            if start_moving - other_start_moving < 15:
                continue

            # # check that the statiionary duration is long enough
            # if start_moving - last_stop_moving < MIN_PAUSE_DURATION * 60:
            #     continue

            # check that bout is long enough
            if (stop_moving - start_moving) < (MIN_WAKING_DURATION * 60):
                continue

            # check that the speed is high enough
            if not np.any(body.speed[start_moving:stop_moving] >= MIN_WALKING_SPEED):
                continue

            if np.any(this.speed[onset-int(MIN_PAUSE_DURATION*60):onset-5] > SPEED_TH):
                continue    


            # all checks are passed, keep it
            trials["onset"].append(onset - 1)
            trials["paw"].append(side)
            trials["offset"].append(stop_moving)

    
    trials = pd.DataFrame(trials)
    return trials

# trials = get_trials(left_walking, right_walking, stationary_on, stationary_off)

In [5]:
# # basic sanity checks plots

# f, ax = plt.subplots(figsize=(20, 10), nrows=2, sharex=True)

# l = trials.loc[trials.paw == "left"]
# r = trials.loc[trials.paw == "right"]

# # for i, trial in l.iterrows():
# #     ax[0].plot(left_fl.speed[trial.onset - 30:trial.onset+30])

# # for i, trial in r.iterrows():
# #     ax[1].plot(right_fl.speed[trial.onset - 30:trial.onset+30])

# ax[0].plot(left_fl.speed)
# ax[0].plot(body.speed, lw=2, color="k", alpha=.5)
# ax[0].scatter(left_walking, left_fl.speed[left_walking], color="k", s=80, alpha=.5)
# ax[0].scatter(l.onset, left_fl.speed[l.onset], color="r", s=60)
# ax[0].axhline(SPEED_TH, color="k", ls="--")
# ax[0].plot(stationary * 100, lw=4, alpha=.2, color="g")


# ax[1].plot(right_fl.speed)
# ax[1].plot(body.speed, lw=2, color="k", alpha=.5)
# ax[1].scatter(right_walking, right_fl.speed[right_walking], color="k", s=80, alpha=.5)
# ax[1].scatter(r.onset, right_fl.speed[r.onset], color="r", s=60)
# ax[1].axhline(SPEED_TH, color="k", ls="--")
# ax[1].plot(stationary * 100, lw=4, alpha=.2, color="g")

# ax[0].set(xlim=[13000, 13500])

In [6]:
# f, ax = plt.subplots(figsize=(20, 10), nrows=2, sharex=True)

# l = trials.loc[trials.paw == "left"]
# r = trials.loc[trials.paw == "right"]

# # for i, trial in l.iterrows():
#     ax[0].plot(left_fl.speed[trial.onset - 30:trial.onset+30])

# for i, trial in r.iterrows():
#     ax[1].plot(right_fl.speed[trial.onset - 30:trial.onset+30])

## Raster plots

In [7]:
def raster(ax, unit, timestamps, t_before=1, t_after=1, dt=.1):
    """
        Plot a unit spikes aligned to timestamps (in seconds).
        it also adds a firing rate visualization
    """
    ax.plot([0, 0], [0,1], lw=3, color="k", alpha=.3)
    n = len(timestamps)
    h = 1/n

    spikes = unit.spikes_ms / 1000
    perievent_spikes = []
    X, Y = [], []
    for i,t in enumerate(timestamps):
        trial_spikes = spikes[(spikes > t-t_before) & (spikes < t+t_after)]
        y = np.zeros_like(trial_spikes) + (i * h)

        Y.extend(y)
        perievent_spikes.extend(trial_spikes-t)
    ax.scatter(perievent_spikes, Y, s=4, color=grey_dark, alpha=1, marker=7)

    # add horizontal cax to axis
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('top', size='30%', pad=0.05)
    cax.axvline(0, lw=3, color="k", alpha=.3, zorder=-1)
    cax.hist(perievent_spikes, bins=np.arange(-t_before, t_after+dt, step=dt), color=unit.color, alpha=1, density=False)

    ax.set(
        xlabel="time (s)",
        ylabel="trial",
        yticks=np.arange(0, 1, 10/n),
        yticklabels=(np.arange(0, 1, 10/n) * n).astype(int),
        xlim=[-t_before, t_after],
        
    )
    cax.set(ylabel="Spike counts", xticks=[], title = f"Unit {unit.unit_id} - {unit.brain_region}", xlim=[-t_before, t_after],)
    return ax, cax


In [8]:
# f, axes = plt.subplots(np.ceil(len(units)/5).astype(int), 5, figsize=(25, 50))
# # f.suptitle(f"{REC} - locomotion onset")

# for (i, unit), ax in zip(units.iterrows(), axes.flatten()):
#     l = trials.loc[trials.paw == "left"]
#     ax, cax = raster(ax, unit, l.onset / 60, t_before=1.0, t_after=1.0, dt=.05)

#     # get average firing rate before/after movement onset
#     spikes = unit.spikes_ms / 1000
#     firing_rate_before = np.median([len(spikes[(spikes > t-1) & (spikes < t-.25)]) / .75 for t in l.onset/60])
#     firing_rate_after = np.median([len(spikes[(spikes > t+.25) & (spikes < t+1)]) / .75 for t in l.onset/60])
#     cax.plot([-1, -.25], [firing_rate_before, firing_rate_before], lw=1, color="k")
#     cax.plot([.25, 1], [firing_rate_after, firing_rate_after], lw=1, color="k")

#     if i % 5 != 0:
#         ax.set(yticks=[], ylabel=None)
#         cax.set(ylabel=None)
#     # break

# clean_axes(f)

In [9]:


# f, axes = plt.subplots(np.ceil(len(units)/5).astype(int), 5, figsize=(25, 50))
# # f.suptitle(f"{REC} - locomotion onset")

# for (i, unit), ax in zip(units.iterrows(), axes.flatten()):

#     # get average firing rate before/after movement onset
#     spikes = unit.spikes_ms / 1000
#     firing_rate_before = [len(spikes[(spikes > t-1) & (spikes < t-.25)]) / .75 for t in trials.onset/60]
#     firing_rate_after = [len(spikes[(spikes > t+.25) & (spikes < t+1)]) / .75 for t in trials.onset/60]


#     firing_rate_befor_l = [len(spikes[(spikes > t-1) & (spikes < t-.25)]) / .75 for t in trials.loc[trials.paw == "left"].onset/60]
#     firing_rate_after_l = [len(spikes[(spikes > t+.25) & (spikes < t+1)]) / .75 for t in trials.loc[trials.paw == "left"].onset/60]

#     firing_rate_befor_r = [len(spikes[(spikes > t-1) & (spikes < t-.25)]) / .75 for t in trials.loc[trials.paw == "right"].onset/60]
#     firing_rate_after_r = [len(spikes[(spikes > t+.25) & (spikes < t+1)]) / .75 for t in trials.loc[trials.paw == "right"].onset/60]

#     _, p = ttest(firing_rate_before, firing_rate_after)

#     plot_mean_and_error(np.array([np.mean(firing_rate_before), np.mean(firing_rate_after)]), np.array([np.std(firing_rate_before), np.std(firing_rate_after)]), ax=ax, color=unit.color, alpha=.5)
#     plot_mean_and_error(np.array([np.mean(firing_rate_befor_l), np.mean(firing_rate_after_l)]), np.array([np.std(firing_rate_befor_l), np.std(firing_rate_after_l)]), ax=ax, color="red", alpha=.5)
#     plot_mean_and_error(np.array([np.mean(firing_rate_befor_r), np.mean(firing_rate_after_r)]), np.array([np.std(firing_rate_befor_r), np.std(firing_rate_after_r)]), ax=ax, color="black", alpha=.5)

#     if p < 0.05:
#         sign = "decrease" if np.mean(firing_rate_before) > np.mean(firing_rate_after) else "increase"
#         ax.set(title=f"{unit.unit_id} SIGNIFICANT {sign}", ylabel=None)
        

#     # break

# clean_axes(f)

# Run on all recordings
Now that we are happy with the code, get for each recording all the units that are +/- modulated by movement onset.

In [10]:
def get_unit_firign_rate(spikes, trials):
    """
        Given spike times in seconds and trials onsets in frames, get firing rate before/after each trial onset.
    """
    firing_rate_before = [len(spikes[(spikes > t-1) & (spikes < t-.25)]) / .75 for t in trials.onset/60]
    firing_rate_after = [len(spikes[(spikes > t+.25) & (spikes < t+1)]) / .75 for t in trials.onset/60]

    return firing_rate_before, firing_rate_after

In [11]:
MIN_WAKING_DURATION = .25  # when the mouse walks < than this we ignore it (seconds)
MIN_PAUSE_DURATION = .25  # when the mouse pauses < before a walking bout than this we ignore it (seconds)
MIN_WALKING_SPEED = 20  # mouse must reach this speed during a locomotion bout
SPEED_TH = 12


recordings = (Recording - "mouse_id='BAA1101192'").fetch("name")

print(recordings)

['FC_210715_AAA1110750_r5_hairpin' 'FC_210716_AAA1110750_r6_hairpin'
 'FC_210720_AAA1110750_hairpin' 'FC_210721_AAA1110750_hairpin'
 'FC_210722_AAA1110750_hairpin' 'FC_211022_BAA110516_hairpin'
 'FC_211027_BAA110516_hairpin' 'FC_211214_BAA110517_hairpin'
 'FC_220114_BAA110517_hairpin' 'FC_220117_BAA110517_hairpin'
 'FC_220119_BAA110517_hairpin' 'FC_220120_BAA110517_hairpin'
 'FC_210917_BAA1110279_hairpin' 'FC_210820_BAA1110281_hairpin'
 'FC_210829_BAA1110281_hairpin' 'FC_210830_BAA1110281_hairpin'
 'FC_210831_BAA1110281_hairpin' 'FC_210901_BAA1110281_hairpin'
 'FC_210906_BAA1110281_hairpin' 'FC_210917_BAA1110281_hairpin']


In [24]:
all_units = dict(
    recording = [],
    mouse_id = [],
    unit_id = [],
    brain_region = [],
    frate_before = [],
    frate_after = [],
    significant = [],
    pvalue = [],
    change_sign = [],
)

tot_units = 0
for rec in recordings:
    print(f"Processing {rec}")
    units, left_fl, right_fl, left_hl, right_hl, body = get_data(rec)
    if not len(units):
        continue

    units = units.loc[units.brain_region.isin(["PRNr", "PRNc"])]
    tot_units += len(units)

    stationary_off, stationary_on, left_walking, right_walking, stationary = get_walking_and_stationary(left_fl, right_fl, left_hl, right_hl, body, SPEED_TH)
    trials = get_trials(left_walking, right_walking, stationary_on, stationary_off)
    trials = trials.loc[trials.paw == "right"]

    for (i, unit) in units.iterrows():
        spikes = unit.spikes_ms / 1000
        firing_rate_before, firing_rate_after = get_unit_firign_rate(spikes, trials)

        _, p = ttest(firing_rate_before, firing_rate_after)
        sign = "decrease" if np.mean(firing_rate_before) > np.mean(firing_rate_after) else "increase"

        all_units["recording"].append(rec)
        all_units["mouse_id"].append(unit.mouse_id)
        all_units["unit_id"].append(unit.unit_id)
        all_units["brain_region"].append(unit.brain_region)
        all_units["frate_before"].append(firing_rate_before)
        all_units["frate_after"].append(firing_rate_after)
        all_units["significant"].append(p < 0.05)
        all_units["pvalue"].append(p)
        all_units["change_sign"].append(sign)

all_units = pd.DataFrame(all_units)
print(all_units.head())

Processing FC_210715_AAA1110750_r5_hairpin
Processing FC_210716_AAA1110750_r6_hairpin
Processing FC_210720_AAA1110750_hairpin
Processing FC_210721_AAA1110750_hairpin
Processing FC_210722_AAA1110750_hairpin
Processing FC_211022_BAA110516_hairpin
Processing FC_211027_BAA110516_hairpin
Processing FC_211214_BAA110517_hairpin
Processing FC_220114_BAA110517_hairpin
Processing FC_220117_BAA110517_hairpin
Processing FC_220119_BAA110517_hairpin
Processing FC_220120_BAA110517_hairpin
Processing FC_210917_BAA1110279_hairpin
Processing FC_210820_BAA1110281_hairpin


  out=out, **kwargs)


Processing FC_210829_BAA1110281_hairpin
Processing FC_210830_BAA1110281_hairpin
Processing FC_210831_BAA1110281_hairpin
Processing FC_210901_BAA1110281_hairpin
Processing FC_210906_BAA1110281_hairpin
Processing FC_210917_BAA1110281_hairpin
                         recording    mouse_id  unit_id brain_region  \
0  FC_210715_AAA1110750_r5_hairpin  AAA1110750      152         PRNc   
1  FC_210715_AAA1110750_r5_hairpin  AAA1110750      172         PRNc   
2  FC_210715_AAA1110750_r5_hairpin  AAA1110750      215         PRNc   
3  FC_210716_AAA1110750_r6_hairpin  AAA1110750      126         PRNc   
4  FC_210716_AAA1110750_r6_hairpin  AAA1110750      246         PRNc   

                                        frate_before  \
0  [20.0, 1.3333333333333333, 12.0, 4.0, 14.66666...   
1  [9.333333333333334, 1.3333333333333333, 2.6666...   
2  [0.0, 0.0, 0.0, 5.333333333333333, 0.0, 0.0, 1...   
3  [0.0, 9.333333333333334, 5.333333333333333, 8....   
4  [13.333333333333334, 6.666666666666667, 9.33

In [25]:
print(f"Total units: {tot_units}")

Total units: 410


In [27]:
up = all_units.loc[(all_units.change_sign == "increase") & (all_units.significant == True)]
down = all_units.loc[(all_units.change_sign == "decrease") & (all_units.significant == True)]

print(f"Tot units: {len(all_units)} - +-ve {len(up)} - -ve {len(down)}")

Tot units: 410 - +-ve 49 - -ve 33
