In [1]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

from aind_vr_foraging_analysis.utils import parse, plotting_utils as plotting, AddExtraColumns

# Plotting libraries
import matplotlib.pyplot as plt

import seaborn as sns
import pandas as pd
import datetime

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = 'Z:/scratch/vr-foraging/data/'
foraging_figures = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\results'



Dates with experiments
* 10/04 - 716458 for 4 forces and no postpatch
* 10/11 - 716458 (in reality is 866) with postpatch but reversed logic of the one intended
* 10/12 - 716455, 458, 715866 fixed
* 10/14 - 716458, 715866
* 10/15 - 716458, 715866

In [2]:
date = datetime.date.today()
date_string = "10/22/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()
mouse = '754574'

# date = datetime.date.today()
# date_string = "10/04/2024"
# date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()
# mouse = '716458'

In [None]:
session_found = False

directory = os.path.join(base_path, mouse)
files = os.listdir(os.path.join(base_path, mouse))

sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=True)

# All this segment is to find the correct session without having the specific path
for file_name in sorted_files:
    
    if session_found == True:
        break
    
    print(file_name)
    # Find specific session sorted by date
    session = file_name[-15:-7]
    if datetime.datetime.strptime(session, "%Y%m%d").date() != date:
        continue
    else:
        print('correct date found')
        session_found = True
        
    # Recover data streams
    session_path = os.path.join(base_path, mouse, file_name)
    session_path = Path(session_path)
    data = parse.load_session_data(session_path)
    
    # Parse data into a dataframe with the main features
    reward_sites, active_site, config = parse.parse_dataframe(data)
    # -- At this step you can save the data into a csv file
    
    # Expand with extra columns
    reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
    active_site = AddExtraColumns(reward_sites, active_site, run_on_init=False).add_time_previous_intersite_interpatch()

    # Load the encoder data separately
    stream_data = parse.ContinuousData(data)
    encoder_data = stream_data.encoder_data
    odor_triggers = stream_data.odor_triggers
    software_tone = data['software_events'].streams['ChoiceFeedback'].data.index
    choice_tone = stream_data.choice_feedback.index

    ## Remove the last segment of the session when the mouse is not engaged
    # last_engaged_patch = reward_sites['active_patch'][reward_sites['skipped_count'] >= 10].min()
    # if pd.isna(last_engaged_patch):
    #     last_engaged_patch = reward_sites['active_patch'].max()
    # reward_sites = reward_sites.loc[reward_sites['active_patch'] <= last_engaged_patch]

    # Recover color palette
    color_dict_label = {}
    dict_odor = {}
    list_patches = parse.TaskSchemaProperties(data).patches
    for i, patches in enumerate(list_patches):
        color_dict_label[patches['label']] = odor_list_color[i]
        dict_odor[i] = patches['label']
    

In [5]:
label_dict = {**{
    "InterSite": '#808080',
    "InterPatch": '#b3b3b3', 
    "PostPatch": '#d9d9d9',}, 
              **color_dict_label}

In [None]:
def update_plot(x_start):
    zero_index = active_site.index[0]

    fig, axs = plt.subplots(1,1, figsize=(20,4))
    
    _legend = {}
    for idx, site in enumerate(active_site.iloc[:-1].iterrows()):
        site_label = site[1]["label"]
        if site_label == "Reward":
            site_label = f"Odor {site[1]['odor']['index']+1}"
            facecolor = label_dict[site_label]
        elif site_label == "RewardSite":
            site_label = site[1]['odor_label']
            facecolor = label_dict[site_label]
        elif site_label == "InterPatch":
            facecolor = label_dict[site_label]
        elif site_label == "InterSite":
            facecolor = label_dict[site_label]
        else:
            site_label = "PostPatch"
            facecolor = label_dict[site_label]

        p = Rectangle(
            (active_site.index[idx] - zero_index, -2), active_site.index[idx+1] - active_site.index[idx], 8,
            linewidth = 0, facecolor = facecolor, alpha = .5)
        _legend[site_label] = p
        axs.add_patch(p)

    s, lw = 400, 2
    # Plotting raster
    y_idx = -0.4
    _legend["Choice Tone"] = axs.scatter(stream_data.choice_feedback.index - zero_index,
            stream_data.choice_feedback.index * 0 + y_idx,
            marker="s", s=100, lw=lw, c='darkblue',
            label="Choice Tone")
    software_tone = data['software_events'].streams['ChoiceFeedback'].data.index
    _legend["SoftTone"] = axs.scatter(software_tone - zero_index,
            software_tone*0 + y_idx,
            marker=".", s=s, lw=lw, c='green',
            label="SoftwareTone")
    
    y_idx += 1
    _legend["Lick"] = axs.scatter(stream_data.lick_onset.index - zero_index,
            stream_data.lick_onset.index * 0 + y_idx,
            marker="|", s=s, lw=lw, c='k',
            label="Lick")
    _legend["Reward"] = axs.scatter(stream_data.valve_output_pulse.index - zero_index,
            stream_data.valve_output_pulse.index*0 + y_idx,
            marker=".", s=s, lw=lw, c='deepskyblue',
            label="Reward")
    
    
    software_water = data['software_events'].streams['GiveReward'].data.index
    _legend["SoftReward"] = axs.scatter(software_water - zero_index,
            software_water*0 + y_idx,
            marker="x", s=s, lw=lw, c='red',
            label="SoftwareReward")
    

    # _legend["Waits"] = axs.scatter(stream_data.succesfull_wait.index - zero_index,
    #     succesfull_wait.index*0 + 1.2,
    #     marker=".", s=s, lw=lw, c='green',
    #     label="Reward")
    
    _legend["Odor_on"] = axs.scatter(odor_triggers.odor_onset - zero_index,
        odor_triggers.odor_onset*0 + 2.5,
        marker="|", s=s, lw=lw, c='pink',
        label="ON")
    
    _legend["Odor_off"] = axs.scatter(odor_triggers.odor_offset - zero_index,
        odor_triggers.odor_offset*0 + 2.5,
        marker="|", s=s, lw=lw, c='purple',
        label="ON")
    
    y_idx += 1

    #ax.set_xticks(np.arange(0, sites.index[-1] - zero_index, 10))
    axs.set_yticklabels([])
    axs.set_xlabel("Time(s)")
    axs.set_ylim(bottom=-1, top = 3)
    axs.grid(False)
    plt.gca().yaxis.set_visible(False)

    ax2 = axs.twinx()
    _legend["Velocity"] = ax2.plot(stream_data.encoder_data.index - zero_index, stream_data.encoder_data.filtered_velocity, c="k", label="Encoder", alpha = 0.8)[0]
    try:
        v_thr = config.streams.TaskLogic.data["operationControl"]["positionControl"]["stopResponseConfig"]["velocityThreshold"]
    except:
        v_thr = 8
    _legend["Stop Threshold"] = ax2.plot(ax2.get_xlim(), (v_thr, v_thr), c="k", label="Encoder", alpha = 0.5, lw = 2, ls = "--")[0]
    ax2.grid(False)
    ax2.set_ylim((-5, 70))
    ax2.set_ylabel("Velocity (cm/s)")
    ax2.hlines(0, 0, active_site.index[-1] - zero_index, lw=1)
    axs.legend(_legend.values(), _legend.keys(), bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0.)

    # axs[0].stairs(software_events.streams.RewardAvailableInPatch.data["data"].values[:-1],
    #           software_events.streams.RewardAvailableInPatch.data["data"].index.values -  zero_index,
    #           lw = 3, color = 'k', fill=0)
    
    axs.set_xlabel("Time(s)")
    axs.grid(False)
    axs.set_ylim(bottom=-1, top = 4)
    axs.set_yticks([0,3])
    axs.yaxis.tick_right()
    axs.set_xlim([x_start, x_start + 20])

# Define callback functions for the arrow buttons
def on_left_button_clicked(button):
    x_start_widget.value -= 20

def on_right_button_clicked(button):
    x_start_widget.value += 20

# Create arrow buttons
left_button = widgets.Button(description='◄')
right_button = widgets.Button(description='►')

# Define widget for the starting value of x-axis
x_start_widget = widgets.FloatText(value=00.0, description='X start:', continuous_update=False)

# Set button click event handlers
left_button.on_click(on_left_button_clicked)
right_button.on_click(on_right_button_clicked)

# Arrange the buttons and widget horizontally
button_box = widgets.HBox([left_button, right_button])
ui = widgets.VBox([button_box, x_start_widget])

# Create interactive plot
interactive_plot = widgets.interactive_output(update_plot, {'x_start': x_start_widget})

# Display the interactive plot and UI
display(ui, interactive_plot)

# if save_name is not None:
#     plt.savefig(janelia_figures + f"\{save_name}_time.svg", bbox_inches='tight', pad_inches=0.1, transparent=True)

In [None]:
torque_data = stream_data.torque_data
brake_data = stream_data.brake_data

In [8]:
section = 'PostPatch'

active_site['interpatch_friction'] = active_site['odor_label'].shift(-2).str.split(' ').str[-1]
if section == 'InterPatch':
    active_site['odor_friction'] = active_site['odor_label'].shift(-2)
else:
    active_site['odor_friction'] = active_site['odor_label'].shift(2)
    active_site['active_patch'] = active_site['active_patch'].shift(-1)
active_site['end_epoch'] = active_site.index.to_series().shift(-1)
active_site['epoch_duration'] = active_site['end_epoch'] - active_site.index

new_active_site = active_site[active_site['label'] == section]

In [9]:
velocity = plotting.trial_collection(new_active_site, 
                                                encoder_data, 
                                                mouse, 
                                                session, 
                                                window=[-1,10],  
                                                cropped_to_length='epoch',
                                                taken_col='filtered_velocity')


velocity_end = plotting.trial_collection(new_active_site, 
                                                encoder_data, 
                                                mouse, 
                                                session, 
                                                aligned='end_epoch',
                                                window=[-5,2],  
                                                taken_col='filtered_velocity')

In [None]:
if 'Methyl Butyrate 0.2' not in velocity['odor_friction'].unique():
       order_hue = ['Alpha-pinene 0', 'Methyl Butyrate 0.05',
              'Methyl Butyrate 0.1','Alpha-pinene 0.15']
else:
       order_hue = ['Alpha-pinene 0', 'Methyl Butyrate 0.05',
              'Alpha-pinene 0.15', 'Methyl Butyrate 0.2']
       
df_results = velocity.groupby(['odor_friction','active_patch']).agg({'speed': 'mean', 'epoch_duration' : 'mean'}).reset_index()

fig = plt.figure(figsize=(12,4))
fig.add_subplot(121)

sns.swarmplot(data=df_results, x='odor_friction', y='speed', hue='active_patch', order=order_hue, palette='viridis', legend=False)
sns.boxplot(data=df_results, x='odor_friction', y='speed',  order=order_hue, palette=color_dict_label, legend=False, fliersize=0)
plt.xticks(rotation=45)
sns.despine(    )

fig.add_subplot(122)
sns.swarmplot(data=df_results, x='odor_friction', y='epoch_duration', hue='active_patch', order=order_hue, palette='viridis', legend=False)
sns.boxplot(data=df_results, x='odor_friction', y='epoch_duration',  order=order_hue, palette=color_dict_label,legend=False, fliersize=0)
plt.xticks(rotation=45)
sns.despine()
plt.ylim(0, 40)

In [None]:
fig = plt.figure(figsize=(12,4))
fig.add_subplot(121)

sns.lineplot(data=velocity, x='times', y='speed', hue='odor_friction', palette=color_dict_label, errorbar=None, legend=False)
plt.xlim(-1, 15)
plt.ylim(0, 45)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
sns.despine()
plt.fill_betweenx([-5, 50], -1, 0, color=color1, alpha=0.2)
plt.fill_betweenx([-5, 50],0, 15, color='grey', alpha=0.2)
plt.xlabel('Time from inter-patch start (s)')

fig.add_subplot(122)
sns.lineplot(data=velocity_end, x='times', y='speed', hue='odor_friction', palette=color_dict_label, errorbar=None)
plt.xlim(-5, 2)
plt.ylim(0, 45)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
sns.despine()
plt.fill_betweenx([-5, 50], -15, 0, color='grey', alpha=0.2)
plt.fill_betweenx([-5, 50], 0, 2, color=color1, alpha=0.2)

plt.xlabel('Time from interpatch end (s)')

In [None]:
plotting.segmented_raster_vertical(reward_sites, 
                                data['config'].streams['tasklogic_input'].data, 
                                color_dict_label=color_dict_label)

### **Look at the torque and the brake force during the epochs**

In [609]:
velocity = plotting.trial_collection(new_active_site, 
                                                torque_data, 
                                                mouse, 
                                                session, 
                                                window=[-1,10],  
                                                cropped_to_length='epoch',
                                                taken_col=['Torque'])


velocity_end = plotting.trial_collection(new_active_site, 
                                                torque_data, 
                                                mouse, 
                                                session, 
                                                aligned='end_epoch',
                                                window=[-5,2],  
                                                taken_col=['Torque'])

brake_plot = plotting.trial_collection(new_active_site, 
                                                brake_data, 
                                                mouse, 
                                                session, 
                                                window=[-1,10],  
                                                cropped_to_length='epoch',
                                                taken_col=['BrakeCurrentSetPoint'])

In [None]:
if 'Methyl Butyrate 0.2' not in velocity['odor_friction'].unique():
       order_hue = ['Alpha-pinene 0', 'Methyl Butyrate 0.05',
              'Methyl Butyrate 0.1','Alpha-pinene 0.15']
else:
       order_hue = ['Alpha-pinene 0', 'Methyl Butyrate 0.05',
              'Alpha-pinene 0.15', 'Methyl Butyrate 0.2']
       
df_results = velocity.groupby(['odor_friction','active_patch']).agg({'Torque': 'mean', 'epoch_duration' : 'mean'}).reset_index()

fig = plt.figure(figsize=(12,4))
fig.add_subplot(121)

sns.swarmplot(data=df_results, x='odor_friction', y='Torque', hue='active_patch', order=order_hue, palette='viridis', legend=False)
sns.boxplot(data=df_results, x='odor_friction', y='Torque',  order=order_hue, palette=color_dict_label, legend=False, fliersize=0)
plt.xticks(rotation=45)
sns.despine(    )

df_results = brake_plot.groupby(['odor_friction','active_patch']).agg({'BrakeCurrentSetPoint': 'mean', 'epoch_duration' : 'mean'}).reset_index()
fig.add_subplot(122)
sns.swarmplot(data=df_results, x='odor_friction', y='BrakeCurrentSetPoint', hue='active_patch', order=order_hue, palette='viridis', legend=False)
sns.boxplot(data=df_results, x='odor_friction', y='BrakeCurrentSetPoint',  order=order_hue, palette=color_dict_label,legend=False, fliersize=0)
plt.xticks(rotation=45)
sns.despine()
# plt.ylim(0, 40)

In [None]:
fig = plt.figure(figsize=(16,4))
fig.add_subplot(121)

sns.lineplot(data=velocity, x='times', y='Torque', hue='odor_friction', palette=color_dict_label, errorbar=None, legend=False)
plt.xlim(-1, 15)
plt.ylim(1500, 2200)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
sns.despine()
plt.fill_betweenx([1500, 2200], -1, 0, color=color1, alpha=0.2)
plt.fill_betweenx([1500, 2200],0, 15, color='grey', alpha=0.2)
plt.xlabel('Time from inter-patch start (s)')

fig.add_subplot(122)
sns.lineplot(data=velocity_end, x='times', y='Torque', hue='odor_friction', palette=color_dict_label, errorbar=None)
plt.xlim(-5, 2)
plt.ylim(1500, 2200)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
sns.despine()
plt.fill_betweenx([1500, 2200], -15, 0, color='grey', alpha=0.2)
plt.fill_betweenx([1500, 2200], 0, 2, color=color1, alpha=0.2)

plt.xlabel('Time from interpatch end (s)')
plt.tight_layout()

## How do animals progress their behavior with subsequent sessions?

In [13]:

date = datetime.date.today()
date_string = "10/10/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()
mouse = '716458'
mouse = '713866'

In [None]:
session_n = 0
cum_active_site = pd.DataFrame()
cum_velocity = pd.DataFrame()

directory = os.path.join(base_path, mouse)
files = os.listdir(os.path.join(base_path, mouse))

sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)

# All this segment is to find the correct session without having the specific path
for file_name in sorted_files:

    
    print(file_name)
    # Find specific session sorted by date
    session = file_name[-15:-7]
    if datetime.datetime.strptime(session, "%Y%m%d").date() < date:
        continue
    else:
        session_n += 1
        print('correct date found')
        
    # Recover data streams
    session_path = os.path.join(base_path, mouse, file_name)
    session_path = Path(session_path)
    data = parse.load_session_data(session_path)
    
    # Parse data into a dataframe with the main features
    reward_sites, active_site, config = parse.parse_dataframe(data)
    # -- At this step you can save the data into a csv file
    
    # Expand with extra columns
    reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
    active_site = AddExtraColumns(reward_sites, active_site, run_on_init=False).add_time_previous_intersite_interpatch()

    # Load the encoder data separately
    stream_data = parse.ContinuousData(data)
    encoder_data = stream_data.encoder_data
    odor_triggers = stream_data.odor_triggers
    software_tone = data['software_events'].streams['ChoiceFeedback'].data.index
    choice_tone = stream_data.choice_feedback.index

    ## Remove the last segment of the session when the mouse is not engaged
    # last_engaged_patch = reward_sites['active_patch'][reward_sites['skipped_count'] >= 10].min()
    # if pd.isna(last_engaged_patch):
    #     last_engaged_patch = reward_sites['active_patch'].max()
    # reward_sites = reward_sites.loc[reward_sites['active_patch'] <= last_engaged_patch]

    # Recover color palette
    color_dict_label = {}
    dict_odor = {}
    list_patches = parse.TaskSchemaProperties(data).patches
    for i, patches in enumerate(list_patches):
        color_dict_label[patches['label']] = odor_list_color[i]
        dict_odor[i] = patches['label']
    
    section = 'PostPatch'

    active_site['interpatch_friction'] = active_site['odor_label'].shift(-2).str.split(' ').str[-1]
    if section == 'InterPatch':
        active_site['odor_friction'] = active_site['odor_label'].shift(-2)
    else:
        active_site['odor_friction'] = active_site['odor_label'].shift(2)
        active_site['active_patch'] = active_site['active_patch'].shift(-1)
    active_site['end_epoch'] = active_site.index.to_series().shift(-1)
    active_site['epoch_duration'] = active_site['end_epoch'] - active_site.index

    new_active_site = active_site[active_site['label'] == section]
    
    new_active_site['session_n'] = session_n
    cum_active_site = pd.concat([cum_active_site, new_active_site])
    
    # velocity = plotting.trial_collection(new_active_site, 
    #                                                 encoder_data, 
    #                                                 mouse, 
    #                                                 session, 
    #                                                 window=[-1,10],  
    #                                                 cropped_to_length='epoch',
    #                                                 taken_col='filtered_velocity')
    
    # velocity['session_n'] = session_n
    # cum_velocity = pd.concat([cum_velocity, velocity])


In [None]:
fig = plt.figure(figsize=(12,8))
for i, friction in enumerate(['Alpha-pinene 0', 'Methyl Butyrate 0.05',
              'Alpha-pinene 0.15', 'Methyl Butyrate 0.2']):
    fig.add_subplot(2,2,i+1)
    sns.lineplot(data=cum_velocity.loc[cum_velocity.odor_friction==friction], x='times', y='speed', hue='session_n',  errorbar=None, legend=True)
    plt.xlim(-1, 15)
    plt.ylim(0, 45)
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.title(friction)
    plt.tight_layout()
    sns.despine()
    plt.fill_betweenx([-5, 50], -1, 0, color=color1, alpha=0.2)
    plt.fill_betweenx([-5, 50],0, 15, color='grey', alpha=0.2)
    plt.xlabel('Time from inter-patch start (s)')


In [None]:
if 'Methyl Butyrate 0.2' not in cum_active_site['odor_friction'].unique():
       order_hue = ['Alpha-pinene 0', 'Methyl Butyrate 0.05',
              'Methyl Butyrate 0.1','Alpha-pinene 0.15']
else:
       order_hue = ['Alpha-pinene 0', 'Methyl Butyrate 0.05',
              'Alpha-pinene 0.15', 'Methyl Butyrate 0.2']
       
df_results = cum_active_site.groupby(['odor_friction','active_patch', 'session_n']).agg({'epoch_duration' : 'mean'}).reset_index()

fig = plt.figure(figsize=(12,4))

fig.add_subplot(122)
sns.boxplot(data=df_results, x='odor_friction', y='epoch_duration', hue='session_n', order=order_hue, palette='viridis')
plt.xticks(rotation=45)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
sns.despine()
plt.ylim(0, 40)

In [None]:
cum_active_site.groupby('session_n').epoch_duration.median()

In [None]:
cum_active_site.epoch_duration.median()

In [None]:
cum_active_site.loc[cum_active_site['session_n'] != 1].groupby('odor_friction').length.mean()

## Torque measurements before running experiment

In [164]:
df = pd.DataFrame(columns=['rig_name', 'torque', 'distance_m'])
session_n = -1

In [None]:
date = datetime.date.today()
date_string = "10/1/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()

mouse_list = ['754579','754567','754580','754559','754560','754577','754566','754570','754571','754572','754573','754574','754575', '754582','745302','745305','745301']
# Iterate through the folders in the target directory
for mouse in mouse_list:
    internal_session = 0
    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))

    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)

    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        # Find specific session sorted by date
        session = file_name[-15:-7]
        if datetime.datetime.strptime(session, "%Y%m%d").date() < date:
            continue
        else:
            print('correct date found')
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        data = parse.load_session_data(session_path)

        # Parse data into a dataframe with the main features
        try:
            reward_sites, active_site, config = parse.parse_dataframe(data)
        except:
            continue
        # -- At this step you can save the data into a csv file
        
        # Expand with extra columns
        reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
        active_site = AddExtraColumns(reward_sites, active_site, run_on_init=False).add_time_previous_intersite_interpatch()

        # Load the encoder data separately
        stream_data = parse.ContinuousData(data)
        encoder_data = stream_data.encoder_data
        odor_triggers = stream_data.odor_triggers
        software_tone = data['software_events'].streams['ChoiceFeedback'].data.index
        choice_tone = stream_data.choice_feedback.index
    
        torque_data = stream_data.torque_data
        brake_data = stream_data.brake_data
    
        torque_data.plot()
        plt.title(data['config'].streams.rig_input.data['rig_name'])
        plt.show() 
        
        internal_session += 1
        session_n += 1
        df.at[session_n,'rig_name'] = data['config'].streams.rig_input.data['rig_name']
        df.at[session_n,'torque'] = torque_data['Torque'].mode()[0]
        df.at[session_n,'internal_session'] = internal_session
        df.at[session_n,'distance_m'] = data['operation_control'].streams.CurrentPosition.data.max()[0]/100

In [None]:
sns.swarmplot(data=df, x='rig_name', y='torque', hue='internal_session', palette='magma')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
sns.despine()