In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from typing import Dict
from os import PathLike
from pathlib import Path

from aind_vr_foraging_analysis.utils import parse, processing, plotting_utils as plotting, AddExtraColumns

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np
import datetime

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

import ipywidgets as widgets
from IPython.display import display
from matplotlib.patches import Rectangle

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:/scratch/vr-foraging/data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 4 - manipulating cost of travelling and global statistics\results'

from scipy.optimize import curve_fit

**To-do**
-  Accumulate the velocity, torque and brake for multiple animals

In [None]:
def plot_velocity_across_sessions(cum_velocity, save=False):
    fig = plt.figure(figsize=(12,22))

    fig.add_subplot(5,2,1)
    sns.lineplot(data=cum_velocity.loc[cum_velocity.cropped==True], x='times', y='speed', hue='experiment',  errorbar=None, legend=True)
    plt.xlim(-1, max(cum_velocity.loc[cum_velocity.cropped==True].times))
    plt.ylim(0, 50)
    plt.fill_betweenx([-5, 50], -1, 0, color=color1, alpha=0.2)
    plt.fill_betweenx([-5, 50],0, 15, color='grey', alpha=0.2)
    plt.xlabel('Time from inter-patch start (s)')

    i=0
    for experiment, colors in zip(cum_velocity.experiment.unique(), ['Blues', 'Oranges', 'Greens', 'Reds', 'Purples', 'Purples']):
        i+=1
        fig.add_subplot(5,2,1+i)
        sns.lineplot(data=cum_velocity.loc[(cum_velocity.cropped==True)&(cum_velocity.experiment==experiment)], x='times', y='speed', 
                    hue='within_session_number', palette=colors, errorbar=None, alpha=0.8)
        plt.xlim(-1, max(cum_velocity.loc[cum_velocity.cropped==True].times))
        plt.ylim(0, 50)
        plt.fill_betweenx([-5, 50], -1, 0, color=color1, alpha=0.2)
        plt.fill_betweenx([-5, 50],0, 15, color='grey', alpha=0.2)
        plt.xlabel('Time from inter-patch start (s)')
        plt.ylabel('Velocity (cm/s)')
        plt.title(experiment)
    
        plt.legend(borderaxespad=0., title='Session')
        
    plt.tight_layout()
    sns.despine()
    plt.show()
    if save:
        save.savefig(fig)

In [None]:
def torque_plots(cum_torque, limits: list = [1500, 2400], save= False):
    fig = plt.figure(figsize=(12,4))
    fig.add_subplot(121)

    sns.lineplot(data=cum_torque.loc[cum_torque['align'] =='onset'], x='times', y='Torque', hue='experiment', errorbar=None, legend=False, alpha=0.7)
    plt.xlim(-1, 15)
    plt.ylim(limits)
    sns.despine()
    plt.fill_betweenx(limits, -1, 0, color=color1, alpha=0.2)
    plt.fill_betweenx(limits,0, 15, color='grey', alpha=0.2)
    plt.xlabel('Time from inter-patch start (s)')

    fig.add_subplot(122)
    sns.lineplot(data=cum_torque.loc[cum_torque['align'] =='offset'], x='times', y='Torque',  hue='experiment', errorbar=None, alpha=0.7)
    plt.xlim(-5, 2)
    plt.ylim(limits)
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.fill_betweenx(limits, -15, 0, color='grey', alpha=0.2)
    plt.fill_betweenx(limits, 0, 2, color=color1, alpha=0.2)
    plt.xlabel('Time from interpatch end (s)')
    sns.despine()
    plt.tight_layout()
    plt.show()
    if save:
        save.savefig(fig)

In [None]:
def epoch_duration_plot(cum_active_site, mouse, save=False):
    fig, ax= plt.subplots(1,2,figsize=(14,6), sharey=True, gridspec_kw={'width_ratios': [2, 1]})
    sns.boxplot(data=cum_active_site, x='session_n', y='epoch_duration', hue='experiment', showfliers=False, legend=False, ax=ax[0])
    # plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    ax[0].set_ylabel('Epoch \n duration (s)')
    ax[0].set_ylim(0, 45)
    
    sns.boxplot(data=cum_active_site, x='experiment', y='epoch_duration', hue='experiment', showfliers=False, ax = ax[1])
    plt.xticks(rotation=45)
    ax[1].set_ylabel('Epoch \n duration (s)')
    ax[1].set_ylim(0, 45)
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    sns.despine()
    plt.tight_layout()
    plt.show()
    if save:
        save.savefig(fig)

## How do animals progress their behavior with subsequent sessions?

In [None]:
torque_data = {}
date = datetime.date.today()
date_string = "12/05/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()
mouse = '754580'

In [None]:
session_n = 0
cum_active_site = pd.DataFrame()
cum_velocity = pd.DataFrame()
cum_torque = pd.DataFrame()
within_session_number = 0
control_experiment = 0
previous_experiment = None

directory = os.path.join(base_path, mouse)
files = os.listdir(os.path.join(base_path, mouse))

sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)

# All this segment is to find the correct session without having the specific path
for file_name in sorted_files:
    # Find specific session sorted by date
    session = file_name[-15:-7]
    if datetime.datetime.strptime(session, "%Y%m%d").date() != date:
        continue

        
    # Recover data streams
    session_path = os.path.join(base_path, mouse, file_name)
    session_path = Path(session_path)
    data = parse.load_session_data(session_path)
    
    # Parse data into a dataframe with the main features
    try:
        reward_sites, active_site, config = parse.parse_dataframe(data)
    except:
        continue
    # -- At this step you can save the data into a csv file
    
    # Expand with extra columns
    reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
    active_site = AddExtraColumns(reward_sites, active_site, run_on_init=False).add_time_previous_intersite_interpatch()

    # Load the encoder data separately
    stream_data = parse.ContinuousData(data)
    encoder_data = stream_data.encoder_data
    odor_triggers = stream_data.odor_triggers
    software_tone = data['software_events'].streams['ChoiceFeedback'].data.index
    choice_tone = stream_data.choice_feedback.index

    experiment = data['config'].streams.tasklogic_input.data['stage_name']
    
    # Recover color palette
    color_dict_label = {}
    dict_odor = {}
    list_patches = parse.TaskSchemaProperties(data).patches
    for i, patches in enumerate(list_patches):
        color_dict_label[patches['label']] = odor_list_color[i]
        dict_odor[i] = patches['label']
    
    if active_site.loc[active_site.label == 'InterPatch'].length.min() == 50:
        section = 'PostPatch'
    else:
        print(experiment)
        section = 'InterPatch'

    if section == 'PostPatch':
        active_site['active_patch'] = active_site['active_patch'].shift(-1)
        
    active_site['end_epoch'] = active_site.index.to_series().shift(-1)
    active_site['epoch_duration'] = active_site['end_epoch'] - active_site.index
    friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']
    new_active_site = active_site[active_site['label'] == section]
    
    session_n += 1
    new_active_site['session_n'] = session_n
    new_active_site['experiment'] = experiment
    
    experiment = data['config'].streams.tasklogic_input.data['stage_name']
    if previous_experiment != experiment:
        within_session_number = 0
        previous_experiment = experiment
    else:
        within_session_number += 1

    if experiment == 'control':
        control_experiment += 1
        within_session_number = control_experiment
            
    new_active_site['within_session_number'] = within_session_number   
        
    cum_active_site = pd.concat([cum_active_site, new_active_site])
    
    velocity = plotting.trial_collection(new_active_site, 
                                                    encoder_data, 
                                                    mouse, 
                                                    session, 
                                                    window=[-1,10],  
                                                    cropped_to_length='epoch',
                                                    taken_col='filtered_velocity')

    velocity['cropped'] = velocity.times < min(velocity.groupby('active_patch').times.max())
    cum_velocity = pd.concat([cum_velocity, velocity])

    torque_data = stream_data.torque_data
    brake_data = stream_data.brake_data
    
    velocity = plotting.trial_collection(new_active_site, 
                                                    torque_data, 
                                                    mouse, 
                                                    session, 
                                                    window=[-1,10],  
                                                    cropped_to_length='epoch',
                                                    taken_col=['Torque'])


    velocity_end = plotting.trial_collection(new_active_site, 
                                                    torque_data, 
                                                    mouse, 
                                                    session, 
                                                    aligned='end_epoch',
                                                    window=[-5,2],  
                                                    taken_col=['Torque'])
    
    velocity['align'] = 'onset'
    velocity_end['align'] = 'offset'
    cum_torque = pd.concat([cum_torque, velocity])
    cum_torque = pd.concat([cum_torque, velocity_end])
    
    data[cum_torque.mouse.unique()[0]] = (cum_torque.loc[(cum_torque['align'] =='onset')&(cum_torque.times>0)&(cum_torque.times<1)&(cum_torque.experiment == 'control')].Torque.mean() - 
    cum_torque.loc[(cum_torque['align'] =='onset')&(cum_torque.times>0)&(cum_torque.times<1)&(cum_torque.experiment == 'friction')].Torque.mean())
    # plt.ylim(0, 40)

In [None]:
friction

In [None]:
175/20

In [None]:
reward_sites.reward_delivered.sum() * 5

In [None]:
fig, axes = plt.subplots(1,2, figsize=(12,6))
ax = axes[0]
sns.lineplot(data=new_active_site, x='active_patch', y='epoch_duration', palette='viridis', marker='o', ax=ax)

ax = axes[1]
sns.lineplot(data=cum_velocity.loc[(cum_velocity.cropped==True)&(cum_velocity.experiment==experiment)], x='times', y='speed', 
            hue='active_patch',  errorbar=None, alpha=0.8, ax=ax)
plt.xlim(-1, max(cum_velocity.loc[cum_velocity.cropped==True].times))
plt.ylim(-15, 60)
plt.fill_betweenx([-15, 60], -1, 0, color=color1, alpha=0.2)
plt.fill_betweenx([-15, 60],0, 15, color='grey', alpha=0.2)
plt.xlabel('Time from inter-patch start (s)')
plt.ylabel('Velocity (cm/s)')
plt.title(experiment)
plt.legend(bbox_to_anchor=(1,0.9), title='Session')
plt.suptitle(mouse)
sns.despine()
plt.tight_layout()

In [None]:
friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']

In [None]:
torque_plots(cum_torque, limits=[min(cum_torque.Torque), max(cum_torque.Torque)])

In [None]:
plot_velocity_across_sessions(cum_velocity)

In [None]:
epoch_duration_plot(cum_active_site, mouse)

## **Several animals loops**

In [None]:
date = datetime.date.today()
date_string = "10/20/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()

In [None]:
import time

In [None]:
torque_dict = {}
across_mice_velocity = []
across_mice_torque = []
for mouse in ['754570','754579','754567','754580','754559','754560','754577','754566','754570','754571','754574','754575', '754582','745302','745305','745301']:
    print(mouse)
    session_n = 0
    active_site_list = []
    velocity_list = []
    torque_list = []
    within_session_number = 0
    control_experiment = 0
    previous_experiment = None

    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))

    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)

    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        start_time = time.time()
        # Find specific session sorted by date
        session = file_name[-15:-7]
        if datetime.datetime.strptime(session, "%Y%m%d").date() < date:
            continue
        else:
            session_n += 1
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        try:
            data = parse.load_session_data(session_path)
        except:
            continue
        
        experiment = data['config'].streams.tasklogic_input.data['stage_name']
        # if experiment != 'friction_optimized':
        #     continue
        
        # Parse data into a dataframe with the main features
        try:
            reward_sites, active_site, config = parse.parse_dataframe(data)
        except:
            continue
        # -- At this step you can save the data into a csv file
        end_time = time.time()  # End timer
        print(f"load and parse time: {end_time - start_time:.6f} seconds")
        
        if reward_sites.empty:
            continue
        
        # Expand with extra columns
        reward_sites = AddExtraColumns(reward_sites, active_site, run_on_init=True).reward_sites
        active_site = AddExtraColumns(reward_sites, active_site, run_on_init=False).add_time_previous_intersite_interpatch()

        # Load the encoder data separately
        stream_data = parse.ContinuousData(data)
        encoder_data = stream_data.encoder_data
        
        end_time = time.time()  # End timer
        print(f"extra and stream time: {end_time - start_time:.6f} seconds")
        
        # Recover color palette
        color_dict_label = {}
        dict_odor = {}
        list_patches = parse.TaskSchemaProperties(data).patches
        for i, patches in enumerate(list_patches):
            color_dict_label[patches['label']] = odor_list_color[i]
            dict_odor[i] = patches['label']
        
        if active_site.loc[active_site.label == 'InterPatch'].length.min() == 50:
            section = 'PostPatch'
        else:
            section = 'InterPatch'

        if section == 'PostPatch':
            active_site['active_patch'] = active_site['active_patch'].shift(-1)
            
        active_site['end_epoch'] = active_site.index.to_series().shift(-1)
        active_site['epoch_duration'] = active_site['end_epoch'] - active_site.index

        new_active_site = active_site[active_site['label'] == section]
        
        new_active_site['session_n'] = session_n
        new_active_site['experiment'] = experiment
        
        friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']
        new_active_site['friction'] = friction
        
        experiment = data['config'].streams.tasklogic_input.data['stage_name']
        if previous_experiment != experiment:
            within_session_number = 0
            previous_experiment = experiment
        else:
            within_session_number += 1

        if experiment == 'control':
            control_experiment += 1
            within_session_number = control_experiment
                
        new_active_site['within_session_number'] = within_session_number   
            
        # cum_active_site = pd.concat([cum_active_site, new_active_site])
        active_site_list.append(new_active_site)
        
        velocity = plotting.trial_collection(new_active_site, 
                                                        encoder_data, 
                                                        mouse, 
                                                        session, 
                                                        window=[-1,2],  
                                                        cropped_to_length='epoch',
                                                        taken_col='filtered_velocity')

        velocity['cropped'] = velocity.times < min(velocity.groupby('active_patch').times.max())
        
        end_time = time.time()  # End timer
        print(f"crop: {end_time - start_time:.6f} seconds")
        print(time)
        velocity_list.append(velocity)
        end_time = time.time()  # End timer
        print(f"append: {end_time - start_time:.6f} seconds")

        torque_data = stream_data.torque_data
        brake_data = stream_data.brake_data
        
        # velocity = plotting.trial_collection(new_active_site, 
        #                                                 torque_data, 
        #                                                 mouse, 
        #                                                 session, 
        #                                                 window=[-1,10],  
        #                                                 cropped_to_length='epoch',
        #                                                 taken_col=['Torque'])

        
        # velocity_end = plotting.trial_collection(new_active_site, 
        #                                                 torque_data, 
        #                                                 mouse, 
        #                                                 session, 
        #                                                 aligned='end_epoch',
        #                                                 window=[-5,2],  
        #                                                 taken_col=['Torque'])
        
        # # velocity['align'] = 'onset'
        # velocity_end['align'] = 'offset'
        # velocity_end['friction'] = friction
        # torque_list.append(velocity_end)

        # torque_dict[cum_torque.mouse.unique()[0]] = {'friction_10' : (cum_torque.loc[(cum_torque['align'] =='onset')&(cum_torque.times>0)&(cum_torque.times<1)&(cum_torque.experiment == 'control')].Torque.mean() - cum_torque.loc[(cum_torque['align'] =='onset')&(cum_torque.times>0)&(cum_torque.times<1)&(cum_torque.experiment == 'friction')].Torque.mean()),
        #                                             'friction_15': (cum_torque.loc[(cum_torque['align'] =='onset')&(cum_torque.times>0)&(cum_torque.times<1)&(cum_torque.experiment == 'control')].Torque.mean() - cum_torque.loc[(cum_torque['align'] =='onset')&(cum_torque.times>0)&(cum_torque.times<1)&(cum_torque.experiment == 'friction_15')].Torque.mean())}
        # plt.ylim(0, 40)

    cum_active_site = pd.concat(active_site_list)
    cum_velocity = pd.concat(velocity_list)
    cum_torque = pd.concat(torque_list)
    
    with PdfPages(os.path.join(results_path, f'{mouse}_torque_velocity_across_sessions_experiments.pdf')) as pdf:
        print('epoch')
        epoch_duration_plot(cum_active_site, mouse, save=pdf)
        print('velocity')
        plot_velocity_across_sessions(cum_velocity, save=pdf)
        # print('torque')
        # torque_plots(cum_torque, limits=[min(cum_torque.Torque), max(cum_torque.Torque)], save=pdf)
        
    # cum_velocity = cum_velocity.loc[(cum_velocity.times > 0.5)&(cum_velocity.times < 1.5)].groupby('friction').speed.mean().reset_index()
    # cum_velocity['mouse'] = mouse
    # across_mice_velocity.append(cum_velocity)

        
    # cum_torque = cum_torque.loc[(cum_torque.times <0)&(cum_torque.times > -1.5)].groupby('friction').Torque.mean().reset_index()
    # cum_torque['mouse'] = mouse
    # across_mice_torque.append(cum_torque)


In [None]:
cum_torque = pd.concat(across_mice_torque)
cum_velocity = pd.concat(across_mice_velocity)

In [None]:
# Compute the mean Torque for rows with friction == 0
new_col = cum_torque.loc[cum_torque.friction == 0].groupby('mouse').Torque.mean()

# Map the computed mean Torque back to the original DataFrame for matching mouse IDs
cum_torque['new_torque'] = cum_torque['mouse'].map(new_col)
cum_torque['norm_torque'] = cum_torque.new_torque - cum_torque.Torque

In [None]:
yellow_palette = sns.color_palette([
    "#FFFFCC",  # Light pastel yellow
    "#FFCC00",  # Golden yellow
    "#FF9900",  # Amber

])

In [None]:
# cum_torque['friction'] = cum_torque['friction']*100
cum_velocity['friction'] = cum_velocity['friction']*100

In [None]:
path = r'Z:\scratch\vr-foraging\brake_calibratrion'
df_cum = pd.DataFrame()
for calibrations in os.listdir(path):
    df = pd.read_csv(os.path.join(path, calibrations), names=['input_torque', 'output_torque'])
    df['wheel'] = calibrations[:2]
    df_cum = pd.concat([df_cum, df], axis=0)

In [None]:
df_cum = df_cum.loc[df_cum.wheel == '4B']
df_cum['input_torque'] = (df_cum['input_torque']/65535)*100

In [None]:
fig = plt.figure(figsize=(6,5))

sns.lineplot(data=df_cum, x='input_torque', y='output_torque',  color=yellow_palette[2])

sns.despine()
plt.ylabel('Torque')
plt.xlabel('Friction (%)')
plt.tight_layout()
fig.savefig(os.path.join(results_path, 'torque.svg'))

In [None]:
fig = plt.figure(figsize=(10,4))

sns.swarmplot(data=cum_velocity, x='friction', y='speed', dodge=3, color='black')
sns.boxplot(data=cum_velocity, x='friction', y='speed', hue='friction', palette=yellow_palette, legend=False, width=0.5, showfliers=False)

sns.despine()
plt.ylabel('Velocity (cm/s)')
plt.xlabel('Friction (%)')
plt.tight_layout()
# fig.savefig(os.path.join(results_path, 'velocity_across_mice.svg'))

In [None]:
cum_velocity.to_csv(os.path.join(results_path, 'velocity_across_mice.csv'))

In [None]:
fig = plt.figure(figsize=(4,5))

sns.swarmplot(data=cum_torque, x='friction', y='Torque', dodge=3, color='black')
sns.boxplot(data=cum_torque, x='friction', y='Torque', hue='friction', palette=yellow_palette, legend=False, width=0.5)
sns.despine()

In [None]:
    cum_velocity = pd.concat(across_mice_velocity)

In [None]:
sns.lineplot(data=new_active_site, x='active_patch', y='epoch_duration', palette='viridis', marker='o')
sns.despine()
plt.ylim(5,40)

In [None]:
friction_df.to_csv(os.path.join(results_path, 'friction_df.csv'))

In [None]:
friction_df = pd.read_csv(os.path.join(results_path, 'friction_df.csv'))
fig = plt.figure(figsize=(12,6))
sns.barplot(data=friction_df, x='Mouse', y='Torque', hue='Friction')
plt.xticks(rotation=45)