In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../../../src/')

import os

from aind_vr_foraging_analysis.utils.plotting import general_plotting_utils as plotting, plotting_friction_experiment as f
from aind_vr_foraging_analysis.utils.parsing import parse, AddExtraColumns, data_access

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import datetime
from pathlib import Path

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../../data/'

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='#e7298a'
odor_list_color = [color1, color2, color3]
color_dict = {0: color1, 1: color2, 2: color3}
color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Amyl Acetate': color3, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4,
                    'Methyl Butyrate': color1}
# color_dict_label = {'Ethyl Butyrate': '#d95f02', 'Alpha-pinene': '#1b9e77', 'Amyl Acetate': '#7570b3', 
#                     '2-Heptanone' : '#1b9e77', 'Methyl Acetate': '#d95f02', 'Fenchone': '#7570b3', '2,3-Butanedione': '#e7298a'}
dict_odor = {}
rate = -0.12
offset = 0.6
dict_odor['Ethyl Butyrate'] = {'rate':rate, 'offset':offset, 'color': '#d95f02'}
dict_odor['Methyl Butyrate'] = {'rate':rate, 'offset':0.9, 'color': '#d95f02'}
dict_odor['Alpha-pinene'] = {'rate':rate, 'offset':offset, 'color': '#1b9e77'}
dict_odor['Amyl Acetate'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}
dict_odor['Methyl Acetate'] = {'rate':rate, 'offset':offset, 'color': color1}
dict_odor['2,3-Butanedione'] = {'rate':rate, 'offset':offset, 'color': color4}
dict_odor['Fenchone'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}
dict_odor['2-Heptanone'] = {'rate':rate, 'offset':offset, 'color': '#7570b3'}

# Define exponential function
def exponential_func(x, a, b):
    return a * np.exp(b * x)

def format_func(value, tick_number):
    return f"{value:.0f}"

results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 4 - manipulating cost of travelling and global statistics\results'



### Generating the dataset

In [None]:
date = datetime.date.today()
date_string = "2024-8-25"

In [None]:
params_df = pd.read_csv(data_path + 'torque_calibration.csv')

In [None]:
mouse_list = ['754570','754579','754567','754580','754559','754560','754577',
              '754566','754571','754574','754575', 
              '754582','745302','745305','745301'
              ]

In [None]:
list_errors = []
summary_df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
    mouse=mouse,
    date_string=date_string,
    when='on_or_after',
    )
    session_n = 0
    print('Mouse: ', mouse)
    # All this segment is to find the correct session without having the specific path
    for session_path in session_paths:
        print(session_path)
        try:
            data = parse.load_session_data(session_path)
        except Exception as e:
            print('Error loading data: ', e)
            list_errors.append(session_path)
            continue
        # Parse data into a dataframe with the main features
        try:
            all_epochs = parse.parse_dataframe(data)
        except Exception as e:
            print('Error parsing data: ', e)
            list_errors.append(session_path)
            continue
        
        # -- At this step you can save the data into a csv file
        
        if  'OdorSite' not in all_epochs.label.unique():
            print('No odor site data')
            continue
        
        rig = data['config'].streams.rig_input.data['rig_name']
        experiment = data['config'].streams.tasklogic_input.data['stage_name']
        
        # What was the friction applied if we have the friction of the schema? (We have the friction in the schema, we want the reality)
        try:
            friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']
        except:
            friction = 0
            
        wheel = rig
        resolved_torque = f.quadratic_model(65535 * friction, params_df.loc[params_df.wheel == wheel].a.values[0], params_df.loc[params_df.wheel == wheel].b.values[0], params_df.loc[params_df.wheel == wheel].c.values[0])    
        torque_friction = params_df.loc[params_df.wheel == wheel].c.values[0] - resolved_torque
    
        if experiment == 'friction' or experiment == 'friction_15' or experiment == 'friction_optimized':
            if torque_friction < 120:
                experiment = 'friction_low'
            elif torque_friction > 120 and torque_friction < 240:
                experiment = 'friction_med'
            else:
                experiment = 'friction_high'
                
        all_epochs['duration_epoch'] = all_epochs.index.to_series().diff().shift(-1)
        all_epochs['mouse'] = mouse
        all_epochs['session'] = str(session_path)[-25:]   
        all_epochs['rig'] = rig
        all_epochs['torque_friction'] = torque_friction
        all_epochs['experiment'] = experiment
        
        session_n += 1
        all_epochs['session_n'] = session_n
        
        # Expand with extra columns
        extra_columns = AddExtraColumns(all_epochs, run_on_init=True)
        all_epochs = extra_columns.get_all_epochs()
        odor_sites = extra_columns.get_odor_sites()

        odor_sites['perceived_reward_probability'] = odor_sites['after_choice_cumulative_rewards'] / (odor_sites['site_number'] +1)
        
        # Load the encoder data separately
        stream_data = parse.ContinuousData(data)
        encoder_data = stream_data.encoder_data
        
        # Remove segments where the mouse was disengaged
        last_engaged_patch = odor_sites['patch_number'][odor_sites['skipped_count'] >= 3].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = odor_sites['patch_number'].max()
        odor_sites['engaged'] = np.where((odor_sites['patch_number'] <= last_engaged_patch)|(odor_sites['patch_number'] >= 20), 1, 0)
    
        all_epochs = pd.concat([all_epochs.loc[all_epochs.label !='OdorSite'], odor_sites]).sort_index()
        
        summary_df = pd.concat([summary_df, all_epochs])

In [None]:
# summary_df = summary_df.sort_values(by=['mouse', 'session']).reset_index(drop=True)
# summary_df['session_n'] = summary_df.groupby('mouse')['session'].rank(method='dense').astype(int)
# summary_df['within_session_n'] = summary_df.groupby(['mouse', 'experiment'])['session'].rank(method='dense').astype(int)

summary_df.to_csv(os.path.join(data_path, 'batch_4.csv'))

### Retrieving data from the calibration

In [None]:
params_df = pd.read_csv(data_path + 'torque_calibration.csv')

In [None]:
# What was the friction applied if we have the friction of the schema? (We have the friction in the schema, we want the reality)
for i, row in summary_df.iterrows():
    wheel = row['rig']
    resolved_torque = friction.quadratic_model(65535 * (row['friction']), params_df.loc[params_df.wheel == wheel].a.values[0], params_df.loc[params_df.wheel == wheel].b.values[0], params_df.loc[params_df.wheel == wheel].c.values[0])    
    actual_friction = (params_df.loc[params_df.wheel == wheel].c.values[0] - resolved_torque)/params_df.loc[params_df.wheel == wheel].c.values[0]    
    summary_df.loc[i, 'torque_difference'] = params_df.loc[params_df.wheel == wheel].c.values[0] - resolved_torque

In [None]:
new_df = summary_df.loc[summary_df.friction!=0].groupby(['mouse','session_number','experiment']).agg({'actual_friction':'mean'}).reset_index()
np.quantile(new_df.actual_friction, 0.33)

### Parsing the data

In [None]:
## Reclassifying friction sections depending on the actual friction. Adding a column for the average length
to_merge = (summary_df.loc[summary_df.label == 'PostPatch'].groupby(['mouse', 'session'])
                      .agg(length_average=('length', 'mean'),
                           interpatch_friction = ('actual_friction', 'max'),
                           torque_friction = ('torque_difference', 'max')).reset_index())
summary_df = pd.merge(summary_df, to_merge, on=['mouse', 'session'])

list_frictions = ['friction', 'friction_15', 'friction_optimized']
summary_df['experiment'] = np.where(summary_df['experiment'].isin(list_frictions), 'friction', summary_df['experiment'])
summary_df['experiment_torque'] = summary_df['experiment']

## Change the experiment name to match the real friction applied
# summary_df['experiment'] = np.where((summary_df['interpatch_friction'] < 8) & (summary_df['experiment'] == 'friction'), 'friction_low', summary_df['experiment'])
# summary_df['experiment'] = np.where((summary_df['interpatch_friction'] >= 13) & (summary_df['experiment'] == 'friction'), 'friction_high', summary_df['experiment'])
# summary_df['experiment'] = np.where((summary_df['interpatch_friction'] < 13) & (summary_df['experiment'] == 'friction'), 'friction_med', summary_df['experiment'])

summary_df['experiment_torque'] = np.where((summary_df['torque_friction'] < 120) & (summary_df['experiment_torque'] == 'friction'), 'friction_low', summary_df['experiment_torque'])
summary_df['experiment_torque'] = np.where((summary_df['torque_friction'] >= 240) & (summary_df['experiment_torque'] == 'friction'), 'friction_high', summary_df['experiment_torque'])
summary_df['experiment_torque'] = np.where((summary_df['torque_friction'] < 240) & (summary_df['experiment_torque'] == 'friction'), 'friction_med', summary_df['experiment_torque'])

In [None]:
lists = ['friction_low', 'friction_med', 'friction_high']
# Create a pivot table to prepare data for the heatmap
test_df = summary_df.loc[summary_df.experiment.isin(lists)].groupby(['mouse', 'experiment', 'session', 'experiment_torque']).agg({'torque_friction':'mean'}).reset_index()
heatmap_data = test_df.pivot_table(index='experiment', columns='experiment_torque', values='mouse', aggfunc='count', fill_value=0)

# Plot the heatmap
plt.figure(figsize=(5, 5))
sns.heatmap(heatmap_data, annot=True, fmt='d', cmap='viridis')
plt.title('Heatmap of Experiment and Experiment Torque Allocation')
plt.xlabel('Experiment Torque')
plt.ylabel('Experiment')
plt.show()