In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
import time

import seaborn as sns
import pandas as pd

import datetime
from aind_vr_foraging_analysis.utils.plotting import plotting_friction_experiment as f
from aind_vr_foraging_analysis.utils.parsing import  parse, AddExtraColumns
import aind_vr_foraging_analysis.utils.plotting as plotting

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:/scratch/vr-foraging/data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 4 - manipulating cost of travelling and global statistics\results'
data_path = r'../../../data/'

## **Parse velocity, time and speed for different sessions and animals**

In [None]:
date = datetime.date.today()
date_string = "08/28/2024"
date = datetime.datetime.strptime(date_string, "%m/%d/%Y").date()
params_df = pd.read_csv(data_path + 'torque_calibration.csv')

### **Save the data**

In [None]:
sum_df = pd.DataFrame()
# list_experiments = ['control', 'friction', 'friction_15', 'friction_optimized', 'distance_short', 'distance_long', 'distance_extra_short', 'distance_extra_long']
for mouse in ['754570','754579','754567','754580','754559','754560','754577','754566','754570','754571','754574','754575', '754582','745302','745305','745301']:
    print(mouse)
    session_n = 0
    active_site_list = []
    velocity_list = []
    velocity_list_end = []
    torque_list = []
    within_session_number = 0
    control_experiment = 0
    previous_experiment = None

    directory = os.path.join(base_path, mouse)
    files = os.listdir(os.path.join(base_path, mouse))

    sorted_files = sorted(files, key=lambda x: os.path.getctime(os.path.join(directory, x)), reverse=False)

    # All this segment is to find the correct session without having the specific path
    for file_name in sorted_files:
        start_time = time.time()
        # Find specific session sorted by date
        session = file_name[-15:-7]
        if datetime.datetime.strptime(session, "%Y%m%d").date() < date:
            continue
        else:
            pass
            
        # Recover data streams
        session_path = os.path.join(base_path, mouse, file_name)
        session_path = Path(session_path)
        try:
            data = parse.load_session_data(session_path)
        except:
            continue
        
        experiment = data['config'].streams.tasklogic_input.data['stage_name']
        
        try:
            friction = data['config'].streams.tasklogic_input.data['task_parameters']['environment_statistics']['patches'][0]['virtual_site_generation']['post_patch']['treadmill_specification']['friction']['distribution_parameters']['value']
        except TypeError:
            friction = 0
        rig_name = data['config'].streams.rig_input.data['rig_name']

        # What was the friction applied if we have the friction of the schema? (We have the friction in the schema, we want the reality)
        wheel = rig_name
        resolved_torque = f.quadratic_model(65535 * friction, params_df.loc[params_df.wheel == wheel].a.values[0], params_df.loc[params_df.wheel == wheel].b.values[0], params_df.loc[params_df.wheel == wheel].c.values[0])    
        torque_friction = params_df.loc[params_df.wheel == wheel].c.values[0] - resolved_torque
    
        if experiment == 'friction' or experiment == 'friction_15' or experiment == 'friction_optimized':
            if torque_friction < 120:
                experiment = 'friction_low'
            elif torque_friction > 120 and torque_friction < 240:
                experiment = 'friction_med'
            else:
                experiment = 'friction_high'
        print(experiment, torque_friction)
        
        # Parse data into a dataframe with the main features
        try:
            all_epochs = parse.parse_dataframe(data)
        except:
            continue
        # -- At this step you can save the data into a csv file
        
        if reward_sites.empty:
            continue
        
        # Expand with extra columns
        active_site = AddExtraColumns(all_epochs).get_all_epochs

        # Load the encoder data separately
        stream_data = parse.ContinuousData(data)
        encoder_data = stream_data.encoder_data
        
        if active_site.loc[active_site.label == 'InterPatch'].length.unique()[0] == 50:
            section = 'PostPatch'
        else:
            section = 'InterPatch'

        if section == 'PostPatch':
            active_site['patch_number'] = active_site['patch_number'].shift(-1)

        active_site['end_epoch'] = active_site.index.to_series().shift(-1)
        active_site['epoch_duration'] = active_site['end_epoch'] - active_site.index

        new_active_site = active_site[active_site['label'] == section]
        new_active_site['experiment'] = experiment
        new_active_site['torque_friction'] = torque_friction
        new_active_site['friction'] = friction
        new_active_site['mouse'] = mouse
        new_active_site['session'] = session
        new_active_site['wheel'] = wheel
        
        # cum_active_site = pd.concat([cum_active_site, new_active_site])
        active_site_list.append(new_active_site)
        
        velocity = plotting.trial_collection(new_active_site, 
                                                        encoder_data, 
                                                        mouse, 
                                                        session, 
                                                        window=[-1,2],  
                                                        cropped_to_length='epoch',
                                                        taken_col='filtered_velocity')
        
        if velocity.empty:
            continue
        
        velocity_list.append(velocity)

        torque_data = stream_data.torque_data
        brake_data = stream_data.brake_data
        
        torque = plotting.trial_collection(new_active_site, 
                                                        torque_data, 
                                                        mouse, 
                                                        session, 
                                                        window=[-2,10],  
                                                        cropped_to_length='epoch',
                                                        taken_col=['Torque'])

        
        torque_list.append(torque)

    cum_active_site = pd.concat(active_site_list)
    cum_velocity = pd.concat(velocity_list)
    cum_torque = pd.concat(torque_list)
        
    group_list = ['mouse','session', 'experiment', 'friction', 'torque_friction', 'patch_number', 'wheel']
    acc_df = pd.DataFrame()
    
    temp_df = cum_torque.loc[(cum_torque['times']> 0)].groupby(group_list).Torque.mean().reset_index()
    temp_df.rename(columns={'Torque':'torque_interpatch'}, inplace=True)
    acc_df = temp_df
    
    temp_df = cum_torque.loc[(cum_torque['times']< 0)&(cum_torque['times'] > -2)].groupby(group_list).Torque.mean().reset_index()
    temp_df.rename(columns={'Torque':'torque_baseline'}, inplace=True)
    acc_df = acc_df.merge(temp_df, on=group_list)

    temp_df = cum_velocity.loc[(cum_velocity['times'] > 0)].groupby(group_list).speed.mean().reset_index()
    temp_df.rename(columns={'speed':'speed_average'}, inplace=True)
    acc_df = acc_df.merge(temp_df, on=group_list)

    temp_df = cum_active_site.groupby(group_list).agg({"epoch_duration":"mean", "length":"mean"}).reset_index()
    acc_df = acc_df.merge(temp_df, on=group_list)

    sum_df = pd.concat([acc_df, sum_df])

In [None]:
sum_df = sum_df.sort_values(by=['mouse', 'session']).reset_index(drop=True)
sum_df['session_n'] = sum_df.groupby('mouse')['session'].rank(method='dense').astype(int)
sum_df.to_csv(os.path.join(results_path, 'batch4_velocity_torque_duration_summary.csv'))