In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
import time

from aind_vr_foraging_analysis.utils import parse, plotting_utils as plotting, AddExtraColumns

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
import datetime

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:/scratch/vr-foraging/data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 4 - manipulating cost of travelling and global statistics\results'
data_path = r'../../../data/'

from scipy.interpolate import griddata
from matplotlib.colors import TwoSlopeNorm
import plotting_friction_experiment as f

from statsmodels.formula.api import glm
from statsmodels.genmod.generalized_linear_model import GLM
from statsmodels.genmod.families import Binomial
from statsmodels.genmod.families.links import logit
from sklearn.preprocessing import StandardScaler

## **One shot evaluation of time and speed**

In [66]:
torque_data = {}
date = datetime.date.today()
date_string = "12/03/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()
mouse = '754571'

In [8]:
params_df = pd.read_csv(data_path + 'torque_calibration.csv')

In [None]:
session_n = 0
cum_active_site = pd.DataFrame()
cum_velocity = pd.DataFrame()
cum_torque = pd.DataFrame()
within_session_number = 0
control_experiment = 0
previous_experiment = None

directory = os.path.join(base_path, mouse)
files = os.listdir(os.path.join(base_path, mouse))

sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)

# All this segment is to find the correct session without having the specific path
for file_name in sorted_files:
    # Find specific session sorted by date
    session = file_name[-15:-7]
    if datetime.datetime.strptime(session, "%Y%m%d").date() != date:
        continue

        
    # Recover data streams
    session_path = os.path.join(base_path, mouse, file_name)
    session_path = Path(session_path)
    data = parse.load_session_data(session_path)
    
    # Parse data into a dataframe with the main features
    try:
        reward_sites, active_site, config = parse.parse_dataframe(data)
    except:
        continue
    # -- At this step you can save the data into a csv file
    
    rig_name = data['config'].streams.rig_input.data['rig_name']
    # Expand with extra columns
    reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
    active_site = AddExtraColumns(reward_sites, active_site, run_on_init=False).add_time_previous_intersite_interpatch()

    # Load the encoder data separately
    stream_data = parse.ContinuousData(data)
    encoder_data = stream_data.encoder_data
    odor_triggers = stream_data.odor_triggers
    software_tone = data['software_events'].streams['ChoiceFeedback'].data.index
    choice_tone = stream_data.choice_feedback.index

    experiment = data['config'].streams.tasklogic_input.data['stage_name']
    
    # Recover color palette
    color_dict_label = {}
    dict_odor = {}
    list_patches = parse.TaskSchemaProperties(data).patches
    for i, patches in enumerate(list_patches):
        color_dict_label[patches['label']] = odor_list_color[i]
        dict_odor[i] = patches['label']
    
    if active_site.loc[active_site.label == 'InterPatch'].length.min() == 50:
        section = 'PostPatch'
    else:
        print(experiment)
        section = 'InterPatch'

    if section == 'PostPatch':
        active_site['active_patch'] = active_site['active_patch'].shift(-1)
        
    active_site['end_epoch'] = active_site.index.to_series().shift(-1)
    active_site['epoch_duration'] = active_site['end_epoch'] - active_site.index
    friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']
    new_active_site = active_site[active_site['label'] == section]
    
    # What was the friction applied if we have the friction of the schema? (We have the friction in the schema, we want the reality)
    wheel = rig_name
    resolved_torque = f.quadratic_model(65535 * friction, params_df.loc[params_df.wheel == wheel].a.values[0], params_df.loc[params_df.wheel == wheel].b.values[0], params_df.loc[params_df.wheel == wheel].c.values[0])    
    actual_friction = (params_df.loc[params_df.wheel == wheel].c.values[0] - resolved_torque)/params_df.loc[params_df.wheel == wheel].c.values[0]    
    actual_friction *=100    
    
    session_n += 1
    new_active_site['session_n'] = session_n
    new_active_site['experiment'] = experiment
    
    experiment = data['config'].streams.tasklogic_input.data['stage_name']
    if previous_experiment != experiment:
        within_session_number = 0
        previous_experiment = experiment
    else:
        within_session_number += 1

    if experiment == 'control':
        control_experiment += 1
        within_session_number = control_experiment
            
    new_active_site['within_session_number'] = within_session_number   
        
    cum_active_site = pd.concat([cum_active_site, new_active_site])
    
    velocity = plotting.trial_collection(new_active_site, 
                                                    encoder_data, 
                                                    mouse, 
                                                    session, 
                                                    window=[-1,10],  
                                                    cropped_to_length='epoch',
                                                    taken_col='filtered_velocity')

    velocity['cropped'] = velocity.times < min(velocity.groupby('active_patch').times.max())
    cum_velocity = pd.concat([cum_velocity, velocity])

    torque_data = stream_data.torque_data
    brake_data = stream_data.brake_data
    
    velocity = plotting.trial_collection(new_active_site, 
                                                    torque_data, 
                                                    mouse, 
                                                    session, 
                                                    window=[-1,10],  
                                                    cropped_to_length='epoch',
                                                    taken_col=['Torque'])


    velocity_end = plotting.trial_collection(new_active_site, 
                                                    torque_data, 
                                                    mouse, 
                                                    session, 
                                                    aligned='end_epoch',
                                                    window=[-5,2],  
                                                    taken_col=['Torque'])
    
    velocity['align'] = 'onset'
    velocity_end['align'] = 'offset'
    cum_torque = pd.concat([cum_torque, velocity])
    cum_torque = pd.concat([cum_torque, velocity_end])
    
    data[cum_torque.mouse.unique()[0]] = (cum_torque.loc[(cum_torque['align'] =='onset')&(cum_torque.times>0)&(cum_torque.times<1)&(cum_torque.experiment == 'control')].Torque.mean() - 
    cum_torque.loc[(cum_torque['align'] =='onset')&(cum_torque.times>0)&(cum_torque.times<1)&(cum_torque.experiment == 'friction')].Torque.mean())
    # plt.ylim(0, 40)
    
    fig, axes = plt.subplots(1,2, figsize=(12,6))
    ax = axes[0]
    sns.lineplot(data=new_active_site, x='active_patch', y='epoch_duration', hue='active_patch',  marker='o', ax=ax, legend=False)

    ax = axes[1]
    sns.lineplot(data=cum_velocity.loc[(cum_velocity.cropped==True)&(cum_velocity.experiment==experiment)], x='times', y='speed', 
                hue='active_patch',  errorbar=None, alpha=0.8, ax=ax)
    plt.xlim(-1, max(cum_velocity.loc[cum_velocity.cropped==True].times))
    plt.ylim(-15, 60)
    plt.fill_betweenx([-15, 60], -1, 0, color=color1, alpha=0.2)
    plt.fill_betweenx([-15, 60],0, 15, color='grey', alpha=0.2)
    plt.xlabel('Time from inter-patch start (s)')
    plt.ylabel('Velocity (cm/s)')
    plt.title(f'{experiment} { friction} {np.around(actual_friction,2)}')
    plt.legend(bbox_to_anchor=(1,0.9), title='Patch #')
    plt.suptitle(mouse)
    sns.despine()
    plt.tight_layout()

In [None]:
for mouse in ['754570','754579','754567','754580','754559','754560','754577','754566','754570','754571','754574','754575', '754582','745302','745305','745301']:
    session_n = 0
    cum_active_site = pd.DataFrame()
    cum_velocity = pd.DataFrame()
    cum_torque = pd.DataFrame()
    within_session_number = 0
    control_experiment = 0
    previous_experiment = None

    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))

    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)

    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        # Find specific session sorted by date
        session = file_name[-15:-7]
        if datetime.datetime.strptime(session, "%Y%m%d").date() != date:
            continue

            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        data = parse.load_session_data(session_path)
        
        # Parse data into a dataframe with the main features
        try:
            reward_sites, active_site, config = parse.parse_dataframe(data)
        except:
            continue
        # -- At this step you can save the data into a csv file
        
        rig_name = data['config'].streams.rig_input.data['rig_name']

        # Expand with extra columns
        reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
        active_site = AddExtraColumns(reward_sites, active_site, run_on_init=False).add_time_previous_intersite_interpatch()

        # Load the encoder data separately
        stream_data = parse.ContinuousData(data)
        encoder_data = stream_data.encoder_data
        odor_triggers = stream_data.odor_triggers
        software_tone = data['software_events'].streams['ChoiceFeedback'].data.index
        choice_tone = stream_data.choice_feedback.index

        friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']
        new_active_site = active_site[active_site['label'] == section]
        
        # What was the friction applied if we have the friction of the schema? (We have the friction in the schema, we want the reality)
        wheel = rig_name
        resolved_torque = f.quadratic_model(65535 * friction, params_df.loc[params_df.wheel == wheel].a.values[0], params_df.loc[params_df.wheel == wheel].b.values[0], params_df.loc[params_df.wheel == wheel].c.values[0])    
        actual_friction = (params_df.loc[params_df.wheel == wheel].c.values[0] - resolved_torque)/params_df.loc[params_df.wheel == wheel].c.values[0]    
        actual_friction *=100    
        
        experiment = data['config'].streams.tasklogic_input.data['stage_name']
        
        if experiment == 'friction' or experiment == 'friction_15' or experiment == 'friction_optimized':
            if actual_friction < 8:
                experiment = 'friction_low'
            elif actual_friction > 8 and actual_friction < 16:
                experiment = 'friction_med'
            else:
                experiment = 'friction_high'
            
        # Recover color palette
        color_dict_label = {}
        dict_odor = {}
        list_patches = parse.TaskSchemaProperties(data).patches
        for i, patches in enumerate(list_patches):
            color_dict_label[patches['label']] = odor_list_color[i]
            dict_odor[i] = patches['label']
        
        if active_site.loc[active_site.label == 'InterPatch'].length.min() == 50:
            section = 'PostPatch'
        else:
            print(experiment)
            section = 'InterPatch'

        if section == 'PostPatch':
            active_site['active_patch'] = active_site['active_patch'].shift(-1)
            
        active_site['end_epoch'] = active_site.index.to_series().shift(-1)
        active_site['epoch_duration'] = active_site['end_epoch'] - active_site.index
        friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']
        new_active_site = active_site[active_site['label'] == section]
        
        session_n += 1
        new_active_site['session_n'] = session_n
        new_active_site['experiment'] = experiment
        
        experiment = data['config'].streams.tasklogic_input.data['stage_name']
        if previous_experiment != experiment:
            within_session_number = 0
            previous_experiment = experiment
        else:
            within_session_number += 1

        if experiment == 'control':
            control_experiment += 1
            within_session_number = control_experiment
                
        new_active_site['within_session_number'] = within_session_number   
        new_active_site['actual_friction'] = actual_friction
        new_active_site['friction'] = friction
        
        cum_active_site = pd.concat([cum_active_site, new_active_site])
        
        velocity = plotting.trial_collection(new_active_site, 
                                                        encoder_data, 
                                                        mouse, 
                                                        session, 
                                                        window=[-1,10],  
                                                        cropped_to_length='epoch',
                                                        taken_col='filtered_velocity')

        velocity['cropped'] = velocity.times < min(velocity.groupby('active_patch').times.max())
        cum_velocity = pd.concat([cum_velocity, velocity])

        torque_data = stream_data.torque_data
        brake_data = stream_data.brake_data
        
        velocity = plotting.trial_collection(new_active_site, 
                                                        torque_data, 
                                                        mouse, 
                                                        session, 
                                                        window=[-1,10],  
                                                        cropped_to_length='epoch',
                                                        taken_col=['Torque'])


        velocity_end = plotting.trial_collection(new_active_site, 
                                                        torque_data, 
                                                        mouse, 
                                                        session, 
                                                        aligned='end_epoch',
                                                        window=[-5,2],  
                                                        taken_col=['Torque'])
        
        velocity['align'] = 'onset'
        velocity_end['align'] = 'offset'
        cum_torque = pd.concat([cum_torque, velocity])
        cum_torque = pd.concat([cum_torque, velocity_end])
        
        data[cum_torque.mouse.unique()[0]] = (cum_torque.loc[(cum_torque['align'] =='onset')&(cum_torque.times>0)&(cum_torque.times<1)&(cum_torque.experiment == 'control')].Torque.mean() - 
        cum_torque.loc[(cum_torque['align'] =='onset')&(cum_torque.times>0)&(cum_torque.times<1)&(cum_torque.experiment == 'friction')].Torque.mean())
        # plt.ylim(0, 40)
        
        fig, axes = plt.subplots(1,2, figsize=(12,6))
        ax = axes[0]
        sns.lineplot(data=new_active_site, x='active_patch', y='epoch_duration', hue='active_patch',  marker='o', ax=ax, legend=False)

        ax = axes[1]
        sns.lineplot(data=cum_velocity.loc[(cum_velocity.cropped==True)&(cum_velocity.experiment==experiment)], x='times', y='speed', 
                    hue='active_patch',  errorbar=None, alpha=0.8, ax=ax)
        plt.xlim(-1, max(cum_velocity.loc[cum_velocity.cropped==True].times))
        plt.ylim(-15, 60)
        plt.fill_betweenx([-15, 60], -1, 0, color=color1, alpha=0.2)
        plt.fill_betweenx([-15, 60],0, 15, color='grey', alpha=0.2)
        plt.xlabel('Time from inter-patch start (s)')
        plt.ylabel('Velocity (cm/s)')
        plt.title(f'{experiment} { friction} {np.around(actual_friction,2)}')
        plt.legend(bbox_to_anchor=(1,0.9), title='Patch #')
        plt.suptitle(mouse)
        sns.despine()
        plt.tight_layout()

In [None]:
f.torque_plots(cum_torque, limits=[min(cum_torque.Torque), max(cum_torque.Torque)])

In [None]:
f.plot_velocity_across_sessions(cum_velocity)

In [None]:
f.epoch_duration_plot(cum_active_site, mouse)

## **Parse velocity, time and speed for different sessions and animals**

In [34]:
date = datetime.date.today()
date_string = "08/28/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()
params_df = pd.read_csv(data_path + 'torque_calibration.csv')

#### **Plot the data but don't save**

In [None]:
sum_df = pd.DataFrame()
list_experiments = ['control', 'friction', 'friction_15', 'friction_optimized', 'distance_short', 'distance_long', 'distance_extra_short', 'distance_extra_long']
for mouse in ['754570','754579','754567','754580','754559','754560','754577','754566','754570','754571','754574','754575', '754582','745302','745305','745301']:
    print(mouse)
    session_n = 0
    active_site_list = []
    velocity_list = []
    velocity_list_end = []
    torque_list = []
    within_session_number = 0
    control_experiment = 0
    previous_experiment = None

    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))

    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)

    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        start_time = time.time()
        # Find specific session sorted by date
        session = file_name[-15:-7]
        if datetime.datetime.strptime(session, "%Y%m%d").date() < date:
            continue
        else:
            pass
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        try:
            data = parse.load_session_data(session_path)
        except:
            continue
        
        experiment = data['config'].streams.tasklogic_input.data['stage_name']
        if experiment not in list_experiments:
            print(experiment)
            continue
        try:
            friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']
        except TypeError:
            friction = 0
        rig_name = data['config'].streams.rig_input.data['rig_name']

        # What was the friction applied if we have the friction of the schema? (We have the friction in the schema, we want the reality)
        wheel = rig_name
        resolved_torque = f.quadratic_model(65535 * friction, params_df.loc[params_df.wheel == wheel].a.values[0], params_df.loc[params_df.wheel == wheel].b.values[0], params_df.loc[params_df.wheel == wheel].c.values[0])    
        actual_friction = (params_df.loc[params_df.wheel == wheel].c.values[0] - resolved_torque)/params_df.loc[params_df.wheel == wheel].c.values[0]    
        actual_friction *=100    
        torque_friction = params_df.loc[params_df.wheel == wheel].c.values[0] - resolved_torque
        
        if experiment == 'friction' or experiment == 'friction_15' or experiment == 'friction_optimized':
            if actual_friction < 8:
                experiment = 'friction_low'
            elif actual_friction > 8 and actual_friction < 16:
                experiment = 'friction_med'
            else:
                experiment = 'friction_high'
        print(experiment, actual_friction)
        
        # Parse data into a dataframe with the main features
        try:
            reward_sites, active_site, config = parse.parse_dataframe(data)
        except:
            continue
        # -- At this step you can save the data into a csv file
        
        if reward_sites.empty:
            continue
        
        # Expand with extra columns
        reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
        active_site = AddExtraColumns(reward_sites, active_site, run_on_init=False).add_time_previous_intersite_interpatch()

        # Load the encoder data separately
        stream_data = parse.ContinuousData(data)
        encoder_data = stream_data.encoder_data
        
        # Recover color palette
        color_dict_label = {}
        dict_odor = {}
        list_patches = parse.TaskSchemaProperties(data).patches
        for i, patches in enumerate(list_patches):
            color_dict_label[patches['label']] = odor_list_color[i]
            dict_odor[i] = patches['label']
        
        if active_site.loc[active_site.label == 'InterPatch'].length.min() == 50:
            section = 'PostPatch'
        else:
            section = 'InterPatch'

        if section == 'PostPatch':
            active_site['active_patch'] = active_site['active_patch'].shift(-1)
            
        active_site['end_epoch'] = active_site.index.to_series().shift(-1)
        active_site['epoch_duration'] = active_site['end_epoch'] - active_site.index

        new_active_site = active_site[active_site['label'] == section]
        
        session_n += 1
        new_active_site['session_n'] = session_n
        new_active_site['experiment'] = experiment
        try:
            friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']
        except:
            friction = 0
            
        experiment = data['config'].streams.tasklogic_input.data['stage_name']
        if previous_experiment != experiment:
            within_session_number = 0
            previous_experiment = experiment
        else:
            within_session_number += 1

        if experiment == 'control':
            control_experiment += 1
            within_session_number = control_experiment
                
        new_active_site['within_session_number'] = within_session_number   
        actual_friction = np.around(actual_friction,2)

        new_active_site['actual_friction'] = actual_friction
        new_active_site['torque_friction'] = torque_friction

        new_active_site['friction'] = friction
        new_active_site['mouse'] = mouse
        new_active_site['session'] = session
        new_active_site['wheel'] = wheel
        
        # cum_active_site = pd.concat([cum_active_site, new_active_site])
        active_site_list.append(new_active_site)
        
        velocity = plotting.trial_collection(new_active_site, 
                                                        encoder_data, 
                                                        mouse, 
                                                        session, 
                                                        window=[-1,2],  
                                                        cropped_to_length='epoch',
                                                        taken_col='filtered_velocity')

        velocity_end = plotting.trial_collection(new_active_site, 
                                                        encoder_data, 
                                                        mouse, 
                                                        session, 
                                                        window=[-1,2],  
                                                        aligned='end_epoch',
                                                        cropped_to_length='epoch',
                                                        taken_col='filtered_velocity')
        
        if velocity.empty:
            continue
        
        velocity['cropped'] = velocity.times < min(velocity.groupby('active_patch').times.max())
        velocity_end['cropped'] = velocity_end.times < min(velocity_end.groupby('active_patch').times.max())

        velocity_list.append(velocity)
        velocity_list_end.append(velocity_end)

        torque_data = stream_data.torque_data
        brake_data = stream_data.brake_data
        
        torque = plotting.trial_collection(new_active_site, 
                                                        torque_data, 
                                                        mouse, 
                                                        session, 
                                                        window=[-1,10],  
                                                        cropped_to_length='epoch',
                                                        taken_col=['Torque'])

        
        torque_end = plotting.trial_collection(new_active_site, 
                                                        torque_data, 
                                                        mouse, 
                                                        session, 
                                                        aligned='end_epoch',
                                                        window=[-2,1],  
                                                        taken_col=['Torque'])
        
        # velocity['align'] = 'onset'
        torque_end['align'] = 'offset'
        torque_end['friction'] = actual_friction
        torque['align'] = 'onset'
        torque['friction'] = actual_friction
        torque_list.append(torque)
        torque_list.append(torque_end)

    cum_active_site = pd.concat(active_site_list)
    cum_velocity = pd.concat(velocity_list)
    cum_velocity_end = pd.concat(velocity_list_end)
    cum_torque = pd.concat(torque_list)
    
    with PdfPages(os.path.join(results_path, f'{mouse}_torque_velocity_across_sessions_experiments.pdf')) as pdf:
        f.epoch_duration_plot(cum_active_site, mouse, save=pdf)
        f.plot_velocity_across_sessions(cum_velocity, save=pdf)
        f.plot_velocity_across_sessions(cum_velocity_end, save=pdf, xlim = [-2,2])
        f.torque_plots(cum_torque, limits=[min(cum_torque.Torque), max(cum_torque.Torque)], save=pdf)
        

### **Save the data but don't plot**

In [None]:
sum_df = pd.DataFrame()
# list_experiments = ['control', 'friction', 'friction_15', 'friction_optimized', 'distance_short', 'distance_long', 'distance_extra_short', 'distance_extra_long']
for mouse in ['754570','754579','754567','754580','754559','754560','754577','754566','754570','754571','754574','754575', '754582','745302','745305','745301']:
    print(mouse)
    session_n = 0
    active_site_list = []
    velocity_list = []
    velocity_list_end = []
    torque_list = []
    within_session_number = 0
    control_experiment = 0
    previous_experiment = None

    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))

    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)

    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        start_time = time.time()
        # Find specific session sorted by date
        session = file_name[-15:-7]
        if datetime.datetime.strptime(session, "%Y%m%d").date() < date:
            continue
        else:
            pass
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        try:
            data = parse.load_session_data(session_path)
        except:
            continue
        
        experiment = data['config'].streams.tasklogic_input.data['stage_name']
        
        try:
            friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']
        except TypeError:
            friction = 0
        rig_name = data['config'].streams.rig_input.data['rig_name']

        # What was the friction applied if we have the friction of the schema? (We have the friction in the schema, we want the reality)
        wheel = rig_name
        resolved_torque = f.quadratic_model(65535 * friction, params_df.loc[params_df.wheel == wheel].a.values[0], params_df.loc[params_df.wheel == wheel].b.values[0], params_df.loc[params_df.wheel == wheel].c.values[0])    
        torque_friction = params_df.loc[params_df.wheel == wheel].c.values[0] - resolved_torque
    
        if experiment == 'friction' or experiment == 'friction_15' or experiment == 'friction_optimized':
            if torque_friction < 120:
                experiment = 'friction_low'
            elif torque_friction > 120 and torque_friction < 240:
                experiment = 'friction_med'
            else:
                experiment = 'friction_high'
        print(experiment, torque_friction)
        
        # Parse data into a dataframe with the main features
        try:
            reward_sites, active_site, config = parse.parse_dataframe(data)
        except:
            continue
        # -- At this step you can save the data into a csv file
        
        if reward_sites.empty:
            continue
        
        # Expand with extra columns
        reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
        active_site = AddExtraColumns(reward_sites, active_site, run_on_init=False).add_time_previous_intersite_interpatch()

        # Load the encoder data separately
        stream_data = parse.ContinuousData(data)
        encoder_data = stream_data.encoder_data
        
        if active_site.loc[active_site.label == 'InterPatch'].length.unique()[0] == 50:
            section = 'PostPatch'
        else:
            section = 'InterPatch'

        if section == 'PostPatch':
            active_site['active_patch'] = active_site['active_patch'].shift(-1)

        try:
            friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']
        except:
            friction = 0
            
        active_site['end_epoch'] = active_site.index.to_series().shift(-1)
        active_site['epoch_duration'] = active_site['end_epoch'] - active_site.index

        new_active_site = active_site[active_site['label'] == section]
        new_active_site['experiment'] = experiment
        new_active_site['torque_friction'] = torque_friction
        new_active_site['friction'] = friction
        new_active_site['mouse'] = mouse
        new_active_site['session'] = session
        new_active_site['wheel'] = wheel
        
        # cum_active_site = pd.concat([cum_active_site, new_active_site])
        active_site_list.append(new_active_site)
        
        velocity = plotting.trial_collection(new_active_site, 
                                                        encoder_data, 
                                                        mouse, 
                                                        session, 
                                                        window=[-1,2],  
                                                        cropped_to_length='epoch',
                                                        taken_col='filtered_velocity')
        
        if velocity.empty:
            continue
        
        velocity_list.append(velocity)

        torque_data = stream_data.torque_data
        brake_data = stream_data.brake_data
        
        torque = plotting.trial_collection(new_active_site, 
                                                        torque_data, 
                                                        mouse, 
                                                        session, 
                                                        window=[-2,10],  
                                                        cropped_to_length='epoch',
                                                        taken_col=['Torque'])

        
        torque_list.append(torque)

    cum_active_site = pd.concat(active_site_list)
    cum_velocity = pd.concat(velocity_list)
    cum_torque = pd.concat(torque_list)
        
    group_list = ['mouse','session', 'experiment', 'friction', 'torque_friction', 'active_patch', 'wheel']
    acc_df = pd.DataFrame()
    
    temp_df = cum_torque.loc[(cum_torque['times']> 0)].groupby(group_list).Torque.mean().reset_index()
    temp_df.rename(columns={'Torque':'torque_interpatch'}, inplace=True)
    acc_df = temp_df
    
    temp_df = cum_torque.loc[(cum_torque['times']< 0)&(cum_torque['times'] > -2)].groupby(group_list).Torque.mean().reset_index()
    temp_df.rename(columns={'Torque':'torque_baseline'}, inplace=True)
    acc_df = acc_df.merge(temp_df, on=group_list)

    temp_df = cum_velocity.loc[(cum_velocity['times'] > 0)].groupby(group_list).speed.mean().reset_index()
    temp_df.rename(columns={'speed':'speed_average'}, inplace=True)
    acc_df = acc_df.merge(temp_df, on=group_list)

    temp_df = cum_active_site.groupby(group_list).agg({"epoch_duration":"mean", "length":"mean"}).reset_index()
    acc_df = acc_df.merge(temp_df, on=group_list)

    sum_df = pd.concat([acc_df, sum_df])

In [39]:
sum_df = sum_df.sort_values(by=['mouse', 'session']).reset_index(drop=True)
sum_df['session_n'] = sum_df.groupby('mouse')['session'].rank(method='dense').astype(int)
sum_df.to_csv(os.path.join(results_path, 'batch4_velocity_torque_duration_summary.csv'))

In [28]:
# Define the distances in your dataset
torques = sum_df['torque_baseline'].unique()
torques.sort()

# Create a custom palette using tab20
custom_palette = sns.color_palette("tab20", len(torques))

# Create a dictionary to map distances to colors
distance_color_map = {torques: color for torques, color in zip(torques, custom_palette)}

In [None]:
fig, axes = plt.subplots(2,3, figsize=(16,10))
for wheel, ax in zip(sum_df.wheel.unique(), axes.flatten()):
    test = sum_df.loc[sum_df.wheel == wheel].groupby(['session_enum', 'experiment']).torque_baseline.mean().reset_index()
    sns.boxplot(data=test, x='session_enum', y='torque_baseline', hue='experiment', ax=ax, palette='tab20', legend=False)
    ax.set_ylim(0, 2500)
    ax.set_title(wheel)
# # Manually create the legend
# sum_df.groupby(['mouse', 'session_enum', 'experiment']).torque_baseline.mean().reset_index()
# handles = []
# for regressor, color in zip(distance_color_map, sns.color_palette('tab20', sum_df.torque_baseline.nunique())):
#     handles.append(mpatches.Patch(color=color, label=regressor))
# fig.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=1, borderaxespad=0., title='Features', prop={'size': 8})
plt.tight_layout()
sns.despine()
plt.show()

In [None]:
fig, axes = plt.subplots(2,3, figsize=(16,10))
for wheel, ax in zip(sum_df.wheel.unique(), axes.flatten()):
    test = sum_df.loc[sum_df.wheel == wheel].groupby(['session_enum', 'mouse']).torque_baseline.mean().reset_index()
    sns.boxplot(data=test, x='session_enum', y='torque_baseline', ax=ax, palette='tab20', legend=False)
    ax.set_ylim(1500, 2500)
    ax.set_title(wheel)
    ax.set_xticks(np.arange(0, len(test.session_enum.unique()),10))
# # Manually create the legend
# sum_df.groupby(['mouse', 'session_enum', 'experiment']).torque_baseline.mean().reset_index()
# handles = []
# for regressor, color in zip(distance_color_map, sns.color_palette('tab20', sum_df.torque_baseline.nunique())):
#     handles.append(mpatches.Patch(color=color, label=regressor))
# fig.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=1, borderaxespad=0., title='Features', prop={'size': 8})
plt.tight_layout()
sns.despine()
plt.show()

## **Retrieve and plot results**

In [307]:
sum_df = pd.read_csv(os.path.join(results_path, 'batch4_velocity_torque_duration_summary.csv'), index_col = 0)

list_experiments = ['control', 'friction_med', 'friction_low', 'friction_high', 'distance_short', 'distance_long', 'distance_extra_short', 'distance_extra_long']
sum_df = sum_df.loc[sum_df.experiment.isin(list_experiments)]

sum_df['torque_friction'] = sum_df['torque_friction'].round(2)
sum_df['mouse'] = sum_df['mouse'].astype(int)

sum_df['session_n'] = sum_df.groupby('mouse')['session_n'].transform(lambda x: x - x.min())

In [308]:
# Define the distances in your dataset
distances = sum_df['length'].unique()
distances.sort()

# Create a custom palette using tab20
custom_palette = sns.color_palette("tab20", len(distances))

# Create a dictionary to map distances to colors
distance_color_map = {distance: color for distance, color in zip(distances, custom_palette)}

In [None]:
fig, axes = plt.subplots(4,4, figsize=(26,20))

for mouse, ax in zip(sum_df.mouse.unique(), axes.flatten()):
    test_df = sum_df.loc[sum_df.mouse == mouse].groupby(['torque_friction', 'session', 'session_n']).speed_average.mean().reset_index()

    sns.scatterplot(data=test_df, x='session_n', y='speed_average', hue='torque_friction', palette='magma', ax=ax, legend=False, zorder=5)
    sns.lineplot(data=test_df, x='session_n', y='speed_average', color='k', ax=ax, alpha=0.5, legend=False)
    ax.set_ylim(0,60)
    ax.set_title(mouse)
    ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    sns.despine()
plt.tight_layout()
plt.show()
fig.savefig(os.path.join(results_path, 'batch4_velocity_across_sessions.pdf'))

In [309]:
session_df = pd.read_csv(data_path + 'batch_4_session_df.csv', index_col=0)
#Normalize the session number
session_df = session_df.loc[session_df.experiment.isin(list_experiments)]
session_df['session_n'] = session_df.groupby('mouse')['session_n'].transform(lambda x: x - x.min())

mouse_df = pd.read_csv(data_path + 'batch_4_mouse_df.csv', index_col=0)
mouse_df = mouse_df.loc[mouse_df.experiment.isin(list_experiments)]
#Normalize the session number
mouse_df.drop(columns=['session_n', 'experiment', 'friction'], inplace=True)

In [284]:
group_list = ['mouse', 'session', 'active_patch']
sum_df = sum_df.merge(mouse_df, on=group_list, how='inner')

### **Summary torque and distance sessions**

In [None]:
fig, axes = plt.subplots(4,4, figsize=(26,20))
for mouse, ax in zip(sum_df.mouse.unique(), axes.flatten()):
    sns.scatterplot(data=sum_df.loc[sum_df.mouse == mouse], x='epoch_duration', y='length', hue='torque_friction', palette='viridis', ax=ax, alpha=0.8)
    ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    sns.despine()
plt.tight_layout()
plt.show()

### **Distribution of interpatch durations per experiment**

In [None]:
## Distribution of epoch durations across conditions
with PdfPages(os.path.join(results_path, f'distribution of epoch_durations.pdf')) as pdf:
    for mouse in sum_df.mouse.unique():
        test_df = sum_df.loc[sum_df.mouse == mouse].groupby(['experiment', 'session', 'active_patch']).epoch_duration.median().reset_index()
        fig, axes =  plt.subplots(2,3, figsize=(12,8))
        for experiment, ax in zip(test_df.experiment.unique(), axes.flatten()):
            adjust = sns.histplot(data=test_df.loc[test_df.experiment == experiment], x='epoch_duration', bins=np.arange(0,100,3), ax=ax, legend=False)
            # Get the maximum count from the histogram
            max_count = max([patch.get_height() for patch in adjust.patches])
            ax.vlines(test_df.loc[(test_df.experiment == experiment)].epoch_duration.median(), 0, max_count, color='red')
            ax.set_title(experiment)
        sns.despine()
        plt.suptitle(mouse)
        plt.tight_layout()
        plt.show()
        fig.savefig(pdf, format='pdf')

In [None]:
with PdfPages(os.path.join(results_path, 'batch4_velocity_torque_duration_summary.pdf')) as pdf:
    for mouse in sum_df.mouse.unique():
        fig = plt.figure(figsize=(16,8))
        fig.add_subplot(2,1,1)
        sns.barplot(data=sum_df.loc[sum_df.mouse == mouse], x='session_n', y='epoch_duration', estimator='median', hue='length', palette=distance_color_map)
        plot_df = sum_df.loc[sum_df.mouse == mouse].groupby(['session_n', 'experiment', 'torque_friction']).agg({'epoch_duration':'mean'}).reset_index()
        sns.scatterplot(data=plot_df, x='session_n', y='epoch_duration', style='torque_friction', color='grey', zorder=5)
        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=2)
        plt.xticks(ticks=plt.xticks()[0][::5])
        plt.title(f'{mouse}')
        plt.xlabel('')
        plt.xlim(-1, session_df.loc[session_df.mouse == mouse, 'session_n'].max()+2)
        plt.ylabel('Epoch \n duration (s)')
        for i in range(0, session_df.loc[session_df.mouse == mouse, 'session_n'].max(), 5):
            plt.axvline(x=i, color='black', linestyle='--', alpha=0.5)
            
        fig.add_subplot(2,1,2)
        experiments = session_df['experiment'].unique()
        palette = sns.color_palette("tab10", len(experiments))
        color_dict_experiment = dict(zip(experiments, palette))
        variable = 'reward_probability'
        
        # Create a style dictionary for each odor label
        odor_labels = session_df['odor_label'].unique()
        styles = ['o', 's', 'D', '^', 'v', '<', '>', 'p', '*', 'h']
        style_dict_odor_label = dict(zip(odor_labels, styles))
        
        min_value = session_df[variable].min()
        max_value = session_df[variable].max()
        ax = sns.scatterplot(session_df.loc[(session_df.mouse == mouse)], x='session_n', size="visit_number", hue='experiment', style='odor_label', sizes=(30, 500), y=variable, 
                palette=color_dict_experiment,  alpha=0.7,
                markers=style_dict_odor_label)
        handles, labels = ax.get_legend_handles_labels()
        plt.legend(handles=handles[:len(color_dict_experiment)], labels=labels[:len(color_dict_experiment)], bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1, title='Experiment')
        for i in range(0, session_df.loc[session_df.mouse == mouse, 'session_n'].max(), 5):
            plt.axvline(x=i, color='black', linestyle='--', alpha=0.5)
        plt.ylim(min_value, max_value)
        plt.xlabel('Session number')
        plt.xlim(session_df.loc[session_df.mouse == mouse, 'session_n'].min()-1, session_df.loc[session_df.mouse == mouse, 'session_n'].max()+2)
        plt.ylabel('Reward probability')
        plt.tight_layout()
        sns.despine()
        pdf.savefig(fig)
        plt.show()

In [None]:
fig, axes = plt.subplots(4,4, figsize=(26,20), sharey=True, sharex=True)
for mouse, ax in zip(sum_df.mouse.unique(), axes.flatten()):
    df_results = sum_df.loc[sum_df.mouse == mouse].groupby(['session', 'experiment']).agg({'epoch_duration':'mean'}).reset_index()

    sns.boxplot(data=df_results, x='experiment', y='epoch_duration', palette='viridis', hue='experiment', ax=ax)
    ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    sns.despine()

### **Explore relationship between torque, distance and time**

In [261]:
# Calculate mean of epoch duration for control sessions
sum_df['normalized_epoch_duration'] = sum_df['epoch_duration']
for mouse in sum_df['mouse'].unique():
    control_mean = sum_df.loc[(sum_df['mouse'] == mouse) & (sum_df['experiment'] == 'control')].groupby('session_n')['epoch_duration'].median()
    mean = np.mean(control_mean)
    
    # Normalize the epoch duration values
    sum_df['normalized_epoch_duration'] = sum_df.apply(
        lambda row: (row['epoch_duration'] / mean) if row['mouse'] == mouse else row['normalized_epoch_duration'],
        axis=1
    )                                                                                                                                                                                                                              

**How does the velocity change depending on the inserted torque and distance in the sessiuon**

In [None]:
fig, axes = plt.subplots(4, 3, figsize=(16, 16))

with PdfPages(os.path.join(results_path, 'batch4_heatmap_distance_torque_velocity.pdf')) as pdf:
    for mouse, ax in zip(sum_df.mouse.unique(), axes.flatten()):
        loop_df = sum_df.loc[sum_df.mouse == mouse].groupby('session_n').agg({'length':'mean', 'torque_friction':'mean', 'speed_average':'mean'}).reset_index()
        control_speed = np.mean(sum_df.loc[(sum_df.mouse == mouse)&(sum_df.experiment == 'control')].groupby('session_n').agg({'length':'mean', 'torque_friction':'mean', 'speed_average':'mean'})['speed_average'])
        # Define the range for distance and torque
        distance = loop_df['length'].values  # Distance values from the 'length' column
        torque = loop_df['torque_friction'].values  # Torque values from the 'torque_friction' column
        duration = loop_df['speed_average'].values  # Duration values from the 'epoch_duration' column

        # # Plot the scatter plot
        # fig, axes = plt.subplots(1, 2, figsize=(14, 6))
        # scatter = sns.scatterplot(x=distance, y=torque, hue=duration, palette='viridis', s=100, edgecolor='w', alpha=0.7, ax=axes[0])
        # # cbar = plt.colorbar(scatter.collections[0])

        # # Add labels and title

        # Create a grid of X (distance) and Y (torque)
        X, Y = np.meshgrid(np.linspace(distance.min(), distance.max(), 50), np.linspace(torque.min(), torque.max(), 50))

        # Interpolate Z as a function of distance and torque using epoch_duration
        # Z = loop_df['epoch_duration'].values
        Z = griddata((distance, torque), loop_df['speed_average'].values, (X, Y), method='linear')

        # Plot the heatmap
        vmin = np.nanmin(Z)
        vmax = np.nanmax(Z)
        vcenter = control_speed # You can change this value as needed

        norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)

        heatmap = ax.contourf(X, Y, Z, levels=100, cmap='coolwarm', norm=norm)  # Adjust 'coolwarm' as needed
        cbar = plt.colorbar(heatmap, ax=ax)
        
        # Get available levels from the heatmap
        levels = heatmap.levels

        # Select specific levels: first, center, and last
        selected_ticks = [levels[0], levels[len(levels) // 2], levels[-1]]

        # Set the colorbar ticks to the selected values
        cbar.set_ticks(selected_ticks)
        cbar.set_ticklabels([f"{tick:.2f}" for tick in selected_ticks])

        ax.set_xlabel("Distance (cm)")
        ax.set_ylabel("Torque (a.u.)")
        ax.set_title(mouse)
        cbar.set_label("Velocity (cm/s)")
        
        # Show the plot
    sns.despine()
    plt.tight_layout()
    plt.show()
    pdf.savefig(fig)


In [None]:
with PdfPages(os.path.join(results_path, 'batch4_heatmap_distance_torque_velocity_all.pdf')) as pdf:        
        test_df = sum_df.groupby(['mouse', 'session_n']).agg({'epoch_duration':'mean', 'length':'mean', 'torque_friction':'mean', 'speed_average':'mean'}).reset_index()
        control_speed = np.mean(sum_df.loc[(sum_df.experiment == 'control')].groupby('session_n').agg({'length':'mean', 'torque_friction':'mean', 'speed_average':'mean'})['speed_average'])
        
        # Define the range for distance and torque
        distance = test_df['length'].values  # Distance values from the 'length' column
        torque = test_df['torque_friction'].values  # Torque values from the 'torque_friction' column
        duration = test_df['speed_average'].values  # Duration values from the 'epoch_duration' column

        # Plot the scatter plot
        fig, axes = plt.subplots(1, 1, figsize=(6, 5))
        # scatter = sns.scatterplot(x=distance, y=torque, hue=duration, palette='coolwarm', s=100, edgecolor='w', alpha=0.7, ax=axes[0])
        # # cbar = plt.colorbar(scatter.collections[0])

        # # Add labels and title
        # axes[0].set_xlabel("Distance (meters)")
        # axes[0].set_ylabel("Torque (Nm)")
        # axes[0].set_title("Heatmap of Duration by Distance and Torque")
        # cbar.set_label("Duration (seconds)")

        # Create a grid of X (distance) and Y (torque)
        X, Y = np.meshgrid(np.linspace(distance.min(), distance.max(), 50), np.linspace(torque.min(), torque.max(), 50))

        # Interpolate Z as a function of distance and torque using epoch_duration
        # Z = loop_df['epoch_duration'].values
        Z = griddata((distance, torque), test_df['speed_average'].values, (X, Y), method='linear')

        # Plot the heatmap
        vmin = np.nanmin(Z)
        vmax = np.nanmax(Z)
        vcenter = control_speed  # You can change this value as needed

        norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
        heatmap = axes.contourf(X, Y, Z, levels=50, cmap='coolwarm', norm=norm)  # Adjust 'coolwarm' as needed
        cbar = plt.colorbar(heatmap)
        
        # Get available levels from the heatmap
        levels = heatmap.levels

        # Select specific levels: first, center, and last
        # selected_ticks = [levels[0], levels[len(levels) // 2], levels[-1]]
        selected_ticks = np.arange(np.around(levels[0], 1), 3.05, 0.4)
        # Set the colorbar ticks to the selected values
        cbar.set_ticks(selected_ticks)
        cbar.set_ticklabels([f"{tick:.2f}" for tick in selected_ticks])
        
        plt.xlabel("Distance (cm)")
        plt.ylabel("Torque (a.u.)")
        cbar.set_label("Velocity (cm/s)")
        
        sns.despine()
        # Show the plot
        plt.show()
        pdf.savefig(fig)


**How does the time it takes to travel change depending on the torque and the distance manipuation**

In [None]:
fig, axes = plt.subplots(4, 3, figsize=(16, 16))

with PdfPages(os.path.join(results_path, 'batch4_heatmap_distance_torque_time.pdf')) as pdf:
    for mouse, ax in zip(sum_df.mouse.unique(), axes.flatten()):
        loop_df = sum_df.loc[sum_df.mouse == mouse].groupby('session_n').agg({'length':'mean', 'torque_friction':'mean', 'normalized_epoch_duration':'mean'}).reset_index()
        
        # Define the range for distance and torque
        distance = loop_df['length'].values  # Distance values from the 'length' column
        torque = loop_df['torque_friction'].values  # Torque values from the 'torque_friction' column
        duration = loop_df['normalized_epoch_duration'].values  # Duration values from the 'epoch_duration' column

        # # Plot the scatter plot
        # fig, axes = plt.subplots(1, 2, figsize=(14, 6))
        # scatter = sns.scatterplot(x=distance, y=torque, hue=duration, palette='viridis', s=100, edgecolor='w', alpha=0.7, ax=axes[0])
        # # cbar = plt.colorbar(scatter.collections[0])

        # # Add labels and title

        # Create a grid of X (distance) and Y (torque)
        X, Y = np.meshgrid(np.linspace(distance.min(), distance.max(), 50), np.linspace(torque.min(), torque.max(), 50))

        # Interpolate Z as a function of distance and torque using epoch_duration
        # Z = loop_df['epoch_duration'].values
        Z = griddata((distance, torque), loop_df['normalized_epoch_duration'].values, (X, Y), method='linear')

        # Plot the heatmap
        vmin = np.nanmin(Z)
        vmax = np.nanmax(Z)
        vcenter = 1  # You can change this value as needed

        norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)

        heatmap = ax.contourf(X, Y, Z, levels=100, cmap='coolwarm', norm=norm)  # Adjust 'coolwarm' as needed
        cbar = plt.colorbar(heatmap, ax=ax)
        
        # Get available levels from the heatmap
        levels = heatmap.levels

        # Select specific levels: first, center, and last
        selected_ticks = [levels[0], levels[len(levels) // 2], levels[-1]]

        # Set the colorbar ticks to the selected values
        cbar.set_ticks(selected_ticks)
        cbar.set_ticklabels([f"{tick:.2f}" for tick in selected_ticks])

        ax.set_xlabel("Distance (cm)")
        ax.set_ylabel("Torque (a.u.)")
        ax.set_title(mouse)
        cbar.set_label("Duration (seconds)")
        
        # Show the plot
    sns.despine()
    plt.tight_layout()
    plt.show()
    pdf.savefig(fig)


In [None]:
with PdfPages(os.path.join(results_path, 'batch4_heatmap_distance_torque_time_all.pdf')) as pdf:        
        test_df = sum_df.groupby(['mouse', 'session_n']).agg({'epoch_duration':'mean', 'length':'mean', 'torque_friction':'mean', 'normalized_epoch_duration':'median'}).reset_index()
        
        # Define the range for distance and torque
        distance = test_df['length'].values  # Distance values from the 'length' column
        torque = test_df['torque_friction'].values  # Torque values from the 'torque_friction' column
        duration = test_df['normalized_epoch_duration'].values  # Duration values from the 'epoch_duration' column

        # Plot the scatter plot
        fig, axes = plt.subplots(1, 1, figsize=(6, 5))
        # scatter = sns.scatterplot(x=distance, y=torque, hue=duration, palette='coolwarm', s=100, edgecolor='w', alpha=0.7, ax=axes[0])
        # # cbar = plt.colorbar(scatter.collections[0])

        # # Add labels and title
        # axes[0].set_xlabel("Distance (meters)")
        # axes[0].set_ylabel("Torque (Nm)")
        # axes[0].set_title("Heatmap of Duration by Distance and Torque")
        # cbar.set_label("Duration (seconds)")

        # Create a grid of X (distance) and Y (torque)
        X, Y = np.meshgrid(np.linspace(distance.min(), distance.max(), 50), np.linspace(torque.min(), torque.max(), 50))

        # Interpolate Z as a function of distance and torque using epoch_duration
        # Z = loop_df['epoch_duration'].values
        Z = griddata((distance, torque), test_df['normalized_epoch_duration'].values, (X, Y), method='linear')

        # Plot the heatmap
        vmin = np.nanmin(Z)
        vmax = np.nanmax(Z)
        vcenter = 1  # You can change this value as needed

        norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
        heatmap = axes.contourf(X, Y, Z, levels=50, cmap='coolwarm', norm=norm)  # Adjust 'coolwarm' as needed
        cbar = plt.colorbar(heatmap)
        
        # Get available levels from the heatmap
        levels = heatmap.levels

        # Select specific levels: first, center, and last
        selected_ticks = [levels[0], levels[len(levels) // 2], levels[-1]]
        # selected_ticks = np.arange(np.around(levels[0], 1), 3.05, 0.4)
        # Set the colorbar ticks to the selected values
        cbar.set_ticks(selected_ticks)
        cbar.set_ticklabels([f"{tick:.2f}" for tick in selected_ticks])
        
        plt.xlabel("Distance (cm)")
        plt.ylabel("Torque (a.u.)")
        cbar.set_label("Duration (seconds)")
        
        sns.despine()
        plt.tight_layout()
        # Show the plot
        plt.show()
        pdf.savefig(fig)


**How does the preward when leaving change depending on the torque and the distance manipuation**

In [None]:
fig, axes = plt.subplots(4, 3, figsize=(16, 16))

with PdfPages(os.path.join(results_path, 'batch4_heatmap_distance_torque_preward.pdf')) as pdf:
    for mouse, ax in zip(sum_df.mouse.unique(), axes.flatten()):
        loop_df = sum_df.loc[sum_df.mouse == mouse].groupby('session_n').agg({'length':'mean', 'torque_friction':'mean', 'reward_probability':'mean'}).reset_index()
        control_preward = np.mean(sum_df.loc[(sum_df.mouse == mouse)&(sum_df.experiment == 'control')].groupby('session_n').agg({'length':'mean', 'torque_friction':'mean', 'reward_probability':'mean'})['reward_probability'])
        # Define the range for distance and torque
        distance = loop_df['length'].values  # Distance values from the 'length' column
        torque = loop_df['torque_friction'].values  # Torque values from the 'torque_friction' column
        duration = loop_df['reward_probability'].values  # Duration values from the 'epoch_duration' column

        # # Plot the scatter plot
        # fig, axes = plt.subplots(1, 2, figsize=(14, 6))
        # scatter = sns.scatterplot(x=distance, y=torque, hue=duration, palette='viridis', s=100, edgecolor='w', alpha=0.7, ax=axes[0])
        # # cbar = plt.colorbar(scatter.collections[0])

        # # Add labels and title

        # Create a grid of X (distance) and Y (torque)
        X, Y = np.meshgrid(np.linspace(distance.min(), distance.max(), 50), np.linspace(torque.min(), torque.max(), 50))

        # Interpolate Z as a function of distance and torque using epoch_duration
        # Z = loop_df['epoch_duration'].values
        Z = griddata((distance, torque), loop_df['reward_probability'].values, (X, Y), method='linear')

        # Plot the heatmap
        vmin = np.nanmin(Z)
        vmax = np.nanmax(Z)
        vcenter =  control_preward # You can change this value as needed

        norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)

        heatmap = ax.contourf(X, Y, Z, levels=100, cmap='coolwarm', norm=norm)  # Adjust 'coolwarm' as needed
        cbar = plt.colorbar(heatmap, ax=ax)
        
        # Get available levels from the heatmap
        levels = heatmap.levels

        # Select specific levels: first, center, and last
        selected_ticks = [levels[0], levels[len(levels) // 2], levels[-1]]

        # Set the colorbar ticks to the selected values
        cbar.set_ticks(selected_ticks)
        cbar.set_ticklabels([f"{tick:.2f}" for tick in selected_ticks])

        ax.set_xlabel("Distance (cm)")
        ax.set_ylabel("Torque (a.u.)")
        ax.set_title(mouse)
        cbar.set_label("P(reward)")
        
        # Show the plot
    sns.despine()
    plt.tight_layout()
    plt.show()
    pdf.savefig(fig)


In [None]:
with PdfPages(os.path.join(results_path, 'batch4_heatmap_distance_torque_preward_all.pdf')) as pdf:        
        test_df = sum_df.groupby(['mouse', 'session_n']).agg({'epoch_duration':'mean', 'length':'mean', 'torque_friction':'mean', 'reward_probability':'median'}).reset_index()
        control_speed = np.mean(sum_df.loc[(sum_df.experiment == 'control')].groupby('session_n').agg({'length':'mean', 'torque_friction':'mean', 'reward_probability':'median'})['reward_probability'])
        # Define the range for distance and torque
        distance = test_df['length'].values  # Distance values from the 'length' column
        torque = test_df['torque_friction'].values  # Torque values from the 'torque_friction' column
        duration = test_df['reward_probability'].values  # Duration values from the 'epoch_duration' column

        # Plot the scatter plot
        fig, axes = plt.subplots(1, 1, figsize=(6, 5))
        # scatter = sns.scatterplot(x=distance, y=torque, hue=duration, palette='coolwarm', s=100, edgecolor='w', alpha=0.7, ax=axes[0])
        # # cbar = plt.colorbar(scatter.collections[0])

        # # Add labels and title
        # axes[0].set_xlabel("Distance (meters)")
        # axes[0].set_ylabel("Torque (Nm)")
        # axes[0].set_title("Heatmap of Duration by Distance and Torque")
        # cbar.set_label("Duration (seconds)")

        # Create a grid of X (distance) and Y (torque)
        X, Y = np.meshgrid(np.linspace(distance.min(), distance.max(), 50), np.linspace(torque.min(), torque.max(), 50))

        # Interpolate Z as a function of distance and torque using epoch_duration
        # Z = loop_df['epoch_duration'].values
        Z = griddata((distance, torque), test_df['reward_probability'].values, (X, Y), method='linear')

        # Plot the heatmap
        vmin = np.nanmin(Z)
        vmax = np.nanmax(Z)
        vcenter = control_speed  # You can change this value as needed

        norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
        heatmap = axes.contourf(X, Y, Z, levels=50, cmap='coolwarm', norm=norm)  # Adjust 'coolwarm' as needed
        cbar = plt.colorbar(heatmap)
        
        # Get available levels from the heatmap
        levels = heatmap.levels

        # Select specific levels: first, center, and last
        selected_ticks = [levels[0], levels[len(levels) // 2], levels[-1]]
        # selected_ticks = np.arange(np.around(levels[0], 1), 3.05, 0.4)
        # Set the colorbar ticks to the selected values
        cbar.set_ticks(selected_ticks)
        cbar.set_ticklabels([f"{tick:.2f}" for tick in selected_ticks])
        
        plt.xlabel("Distance (cm)")
        plt.ylabel("Torque (a.u.)")
        cbar.set_label("P(reward)")
        
        sns.despine()
        plt.tight_layout()
        # Show the plot
        plt.show()
        pdf.savefig(fig)

## **Logitic models: what predicts the moment your will leave?**

In [None]:
features = ["torque_friction", "length", "epoch_duration"]
test_df = merge_df.copy()
# Initialize dataframes to store weights and cross-validation results
weights_df = pd.DataFrame(columns=['regressors', 'weights', 'mouse', 'session'])
cv_results_df = pd.DataFrame(columns=['mouse', 'cv_score'])

for (mouse, session), mouse_df in merged_df.groupby(['mouse', 'session']):
    print(f"Mouse: {mouse}, Session: {session}")
    
    # Select features and target variable
    X_mouse = mouse_df[features]
    y_mouse = mouse_df['reward_probabilities'].astype(int)
    
    # Define the pipeline
    poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
    scaler = StandardScaler()
    log_reg = LogisticRegression(C=1, max_iter=1000)
    
    pipeline = make_pipeline(poly, scaler, log_reg)
    
    # Perform 5-fold cross-validation
    cv_scores = cross_val_score(pipeline, X_mouse, y_mouse, cv=5)
    
    # Fit the pipeline
    pipeline.fit(X_mouse, y_mouse)
    
    # Extract the feature names after applying PolynomialFeatures
    poly_features = poly.fit(X_mouse).get_feature_names_out(features)
    
    # Get the weights for each feature
    log_reg_model = pipeline.named_steps['logisticregression']
    feature_weights = pd.Series(log_reg_model.coef_[0], index=poly_features)
    feature_weights = feature_weights.reset_index()
    feature_weights.rename(columns={'index': 'regressors', 0: 'weights'}, inplace=True)
    feature_weights['mouse'] = mouse
    feature_weights['session'] = session

    # Append the weights and cv scores to the respective dataframes
    weights_df = pd.concat([weights_df, feature_weights], ignore_index=True)
    cv_results_df = pd.concat([cv_results_df, pd.DataFrame({'session': [session], 'mouse': [mouse], 'cv_score': [cv_scores.mean()]})], ignore_index=True)

    # Print the cross-validation scores and their mean
    print(f"Mean cross-validation score: {cv_scores.mean():.2f}")
    print('\n')

weights_df['mouse'] = weights_df['mouse'].round(0).astype(str)

In [None]:
# Convert to pandas DataFrame
y = merge_df['reward_probability']
features = ["torque_friction", "length", "epoch_duration"]
design_matrix = merge_df.copy()

# Standardize predictors
scaler = StandardScaler()
design_matrix[features] = scaler.fit_transform(design_matrix[features])
# Fit the logistic regression model with standardized predictors
formula = "reward_probability ~ torque_friction * length + epoch_duration"
model = glm(formula=formula, data=design_matrix, family=Binomial(link=logit())).fit()

# Print model summary
print(model.summary())

# Predicted values
design_matrix["predicted_p_reward"] = model.predict(design_matrix)

# Evaluate predictions
print("\nFirst few predictions:")
print(design_matrix[["reward_probability", "predicted_p_reward"]].head())
