Notebook to plot a rasterplot with the trial-averaged traces per animal, sorted by conditions.

It uses the `features.csv` file generated by `features-from-dlc`.

In [26]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

%matplotlib widget

In [27]:
# Parameters
# full path to the features.csv file
filepath = "/path/to/features.csv"

# exact name of the features to plot, as they appear in the csv file
features = ["speed", "theta_body", "theta_neck"]

# map feature name to display name
clabels = {
    "speed": "speed (cm/s)",
    "theta_body": "body angle (°)",
    "theta_neck": "neck angle (°)",
}

# whether to sort by "condition" or by "animal"
# "animal" can be used to order by rostro-caudal injection level
# (conditions will always be splitted with a red line)
sortby = "condition"
# order (list of animal ids if sort by animals; list of conditions if sort by conditions)
sortby_order = ["cond1", "cond2", "cond3", "cond4"]

# colormap : choose from https://matplotlib.org/stable/users/explain/colors/colormaps.html#sequential
cmap = "viridis"
# colormap : set highest value to this quantile in the data (remove outliers)
quantile = 0.99

In [None]:
# Load file
df_ffd = pd.read_csv(filepath)
display(df_ffd.head())

In [None]:
# Group by animal, condition and time to extract individual time series,
# and average them for each animal, keeping the animal ID and condition information
df = df_ffd.groupby(["animal", "time", "condition"])[features].mean().reset_index()
# sort by
df = df.set_index(sortby).loc[sortby_order].reset_index()
display(df.head())

In [30]:
# get mapping between animals and condition
map_animal_cond = (
    df[["animal", "condition"]].drop_duplicates(ignore_index=True).set_index("animal")
)
map_animal_cond = map_animal_cond.to_dict()["condition"]

In [31]:
# get delimiters location index
ser = (
    df.groupby(["animal"])["condition"]
    .unique()
    .str[0]
    .reset_index()
    .set_index(sortby)
    .loc[sortby_order]
    .reset_index()["condition"]
)
ind_delimiter = np.where(ser.ne(ser.shift().bfill()))[0]

In [None]:
# get shapes
nfeatures = len(features)
time = df["time"].unique()
ntimes = len(time)
animals = df["animal"].unique()
nanimals = len(animals)

for feature in features:
    # get values range
    crange = [df[feature].quantile(1 - quantile), df[feature].quantile(quantile)]

    # prepare figure
    fig, ax = plt.subplots(figsize=(8, 6))
    fig.subplots_adjust(right=0.9)

    # prepare data
    data = np.reshape(df[feature].values, (nanimals, ntimes))

    # plot raster
    p = ax.pcolormesh(
        time,
        animals,
        data,
        vmin=crange[0],
        vmax=crange[1],
    )

    # plot delimiters between conditions
    _ = [ax.axhline(loc - 0.5, linewidth=1.5, color="#ec1b8a") for loc in ind_delimiter]

    # flip plot so that first row is first condition
    ax.invert_yaxis()

    # add colorbar
    plt.colorbar(p, label=clabels[feature])

    # change y axis labels to show conditions
    ylabels = [
        f"{item.get_text()}-{map_animal_cond[item.get_text()]}"
        for item in ax.get_yticklabels()
    ]
    ax.set_yticks(animals)
    ax.set_yticklabels(ylabels)

    # xlabel
    ax.set_xlabel("time (s)")

    # stim
    ax.axvline(0, color="#1b8aec")
