In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

from matplotlib import patches as mpatches
import sys
sys.path.append('../../../src/')

import os

# Plotting libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
import pandas as pd
import numpy as np

from aind_vr_foraging_analysis.utils.plotting import plotting_friction_experiment as f

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'


from scipy.interpolate import griddata
from matplotlib.colors import TwoSlopeNorm

from statsmodels.formula.api import glm
from sklearn.preprocessing import StandardScaler

# Modelling libraries
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix, roc_curve, auc

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='#e7298a'
odor_list_color = [color1, color2, color3]
color_dict = {0: color1, 1: color2, 2: color3}
color_dict_label = {'Ethyl Butyrate': color1, 'Alpha-pinene': color2, 'Amyl Acetate': color3, 
                    '2-Heptanone' : color2, 'Methyl Acetate': color1, 'Fenchone': color3, '2,3-Butanedione': color4,
                    'Methyl Butyrate': color1, }

# Define exponential function
def exponential_func(x, a, b):
    return a * np.exp(b * x)

def format_func(value, tick_number):
    return f"{value:.0f}"

velocity_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 4 - manipulating cost of travelling and global statistics\results'
data_path = r'../../../data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\Meeting presentations\SAC\SAC2025-May\figures'

palette = {
    'control': 'grey',  # Red
    'friction_high': '#6a51a3',  # Purple
    'friction_med': '#807dba',  # Lighter Purple
    'friction_low': '#9e9ac8',  # Lightest Purple
    'distance_extra_short': 'crimson',  # Blue
    'distance_short': 'pink',  # Lighter Blue
    'distance_extra_long': '#fd8d3c',  # Yellow
    'distance_long': '#fdae6b'  # Lighter Yellow
}

sns.set_context('talk')

### **Explore relationship between torque, distance and time**

In [None]:
## This dataset needs to be obtained from ANALYSYS_velocity_traces
sum_df = pd.read_csv(os.path.join(velocity_path, 'batch4_velocity_torque_duration_summary.csv'), index_col = 0)

list_experiments = ['control', 'friction_med', 'friction_low', 'friction_high', 'distance_short', 'distance_long', 'distance_extra_short', 'distance_extra_long']
sum_df = sum_df.loc[sum_df.experiment.isin(list_experiments)]

sum_df['torque_friction'] = sum_df['torque_friction'].round(2)
sum_df['mouse'] = sum_df['mouse'].astype(int)
sum_df['session_n'] = sum_df.groupby('mouse')['session_n'].transform(lambda x: x - x.min())

In [None]:
# Define the distances in your dataset
distances = sum_df['length'].unique()
distances.sort()

# Create a custom palette using tab20
custom_palette = sns.color_palette("tab20", len(distances))

# Create a dictionary to map distances to colors
distance_color_map = {distance: color for distance, color in zip(distances, custom_palette)}

In [None]:
## This datasets have to be obtained from ANALYSIS_friction_and_distance
session_df = pd.read_csv(data_path + 'batch_4_session_df.csv', index_col=0)

#Normalize the session number
session_df = session_df.loc[session_df.experiment.isin(list_experiments)]
session_df['session_n'] = session_df.groupby('mouse')['session_n'].transform(lambda x: x - x.min())

mouse_df = pd.read_csv(data_path + 'batch_4_mouse_df.csv', index_col=0)
mouse_df = mouse_df.loc[mouse_df.experiment.isin(list_experiments)]
mouse_df['session_n'] = mouse_df.groupby('mouse')['session_n'].transform(lambda x: x - x.min())

#Normalize the session number
mouse_df.drop(columns=['session_n', 'experiment', 'friction'], inplace=True)

In [None]:
group_list = ['mouse', 'session', 'active_patch']
sum_df = sum_df.merge(mouse_df, on=group_list, how='inner')

In [None]:

# Calculate mean of epoch duration for control sessions
sum_df['normalized_epoch_duration'] = sum_df['epoch_duration']
for mouse in sum_df['mouse'].unique():
    control_mean = sum_df.loc[(sum_df['mouse'] == mouse) & (sum_df['experiment'] == 'control')].groupby('session_n')['epoch_duration'].median()
    mean = np.mean(control_mean)
    
    # Normalize the epoch duration values
    sum_df['normalized_epoch_duration'] = sum_df.apply(
        lambda row: (row['epoch_duration'] / mean) if row['mouse'] == mouse else row['normalized_epoch_duration'],
        axis=1
    )   
    
# Calculate mean of epoch duration for control sessions
sum_df['normalized_torque_friction'] = sum_df['torque_friction']
for mouse in sum_df['mouse'].unique():
    control_mean = sum_df.loc[(sum_df['mouse'] == mouse)&(sum_df['experiment'] == 'control')].groupby('session_n')['torque_friction'].max()
    mean = np.mean(control_mean)
    
    # Normalize the epoch duration values
    sum_df['normalized_torque_friction'] = sum_df.apply(
        lambda row: (row['torque_friction'] / mean) if row['mouse'] == mouse else row['normalized_torque_friction'],
        axis=1
    ) 

# Rank torque values per mouse

sum_df["torque_norm"] = (
    sum_df.groupby("mouse")["torque_friction"]
    .transform(lambda x: pd.factorize(np.sort(x.unique()))[0][np.searchsorted(np.sort(x.unique()), x)])
)

# Rank torque values per mouse
sum_df["duration_norm"] = (
    sum_df.groupby("mouse")["epoch_duration"]
    .transform(lambda x: pd.factorize(np.sort(x.unique()))[0][np.searchsorted(np.sort(x.unique()), x)])
)

In [None]:
sum_df.groupby(['session_n', 'experiment', 'mouse']).agg({'torque_norm' : 'mean', 'reward_probability':'median'}).reset_index()

In [None]:
duration_label = 'epoch_duration'
torque_label = 'torque_norm'

In [None]:
# sum_df = sum_df.loc[sum_df['mouse'] != 754579]

**How does the velocity change depending on the inserted torque and distance in the sessiuon**

In [None]:
fig, axes = plt.subplots(4, 3, figsize=(16, 16))

with PdfPages(os.path.join(results_path, 'batch4_heatmap_distance_torque_velocity.pdf')) as pdf:
    for mouse, ax in zip(sum_df.mouse.unique(), axes.flatten()):
        loop_df = sum_df.loc[sum_df.mouse == mouse].groupby('session_n').agg({'length':'mean', torque_label:'mean', 'speed_average':'mean'}).reset_index()
        control_speed = np.mean(sum_df.loc[(sum_df.mouse == mouse)&(sum_df.experiment == 'control')].groupby('session_n').agg({'length':'mean', torque_label:'mean', 'speed_average':'mean'})['speed_average'])
        # Define the range for distance and torque
        distance = loop_df['length'].values  # Distance values from the 'length' column
        torque = loop_df[torque_label].values  # Torque values from the 'torque_friction' column
        duration = loop_df['speed_average'].values  # Duration values from the 'epoch_duration' column

        # # Plot the scatter plot
        # fig, axes = plt.subplots(1, 2, figsize=(14, 6))
        # scatter = sns.scatterplot(x=distance, y=torque, hue=duration, palette='viridis', s=100, edgecolor='w', alpha=0.7, ax=axes[0])
        # # cbar = plt.colorbar(scatter.collections[0])

        # # Add labels and title

        # Create a grid of X (distance) and Y (torque)
        X, Y = np.meshgrid(np.linspace(distance.min(), distance.max(), 50), np.linspace(torque.min(), torque.max(), 50))

        # Interpolate Z as a function of distance and torque using epoch_duration
        # Z = loop_df['epoch_duration'].values
        Z = griddata((distance, torque), loop_df['speed_average'].values, (X, Y), method='linear')

        # Plot the heatmap
        vmin = np.nanmin(Z)
        vmax = np.nanmax(Z)
        vcenter = control_speed # You can change this value as needed

        norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)

        heatmap = ax.contourf(X, Y, Z, levels=100, cmap='coolwarm', norm=norm)  # Adjust 'coolwarm' as needed
        cbar = plt.colorbar(heatmap, ax=ax)
        
        # Get available levels from the heatmap
        levels = heatmap.levels

        # Select specific levels: first, center, and last
        selected_ticks = [levels[0], levels[len(levels) // 2], levels[-1]]

        # Set the colorbar ticks to the selected values
        cbar.set_ticks(selected_ticks)
        cbar.set_ticklabels([f"{tick:.2f}" for tick in selected_ticks])

        ax.set_xlabel("Distance (cm)")
        ax.set_ylabel("Torque (a.u.)")
        ax.set_title(mouse)
        cbar.set_label("Velocity (cm/s)")
        
        # Show the plot
    sns.despine()
    plt.tight_layout()
    plt.show()
    pdf.savefig(fig)


**How does the time it takes to travel change depending on the torque and the distance manipuation**

In [None]:
fig, axes = plt.subplots(4, 3, figsize=(16, 16))

with PdfPages(os.path.join(results_path, 'batch4_heatmap_distance_torque_time.pdf')) as pdf:
    for mouse, ax in zip(sum_df.mouse.unique(), axes.flatten()):
        loop_df = sum_df.loc[sum_df.mouse == mouse].groupby('session_n').agg({'length':'mean', torque_label:'mean', duration_label:'mean'}).reset_index()
        
        # Define the range for distance and torque
        distance = loop_df['length'].values  # Distance values from the 'length' column
        torque = loop_df[torque_label].values  # Torque values from the 'torque_friction' column
        duration = loop_df[duration_label].values  # Duration values from the 'epoch_duration' column

        # # Plot the scatter plot
        # fig, axes = plt.subplots(1, 2, figsize=(14, 6))
        # scatter = sns.scatterplot(x=distance, y=torque, hue=duration, palette='viridis', s=100, edgecolor='w', alpha=0.7, ax=axes[0])
        # # cbar = plt.colorbar(scatter.collections[0])

        # # Add labels and title

        # Create a grid of X (distance) and Y (torque)
        X, Y = np.meshgrid(np.linspace(distance.min(), distance.max(), 50), np.linspace(torque.min(), torque.max(), 50))

        # Interpolate Z as a function of distance and torque using epoch_duration
        # Z = loop_df['epoch_duration'].values
        Z = griddata((distance, torque), loop_df[duration_label].values, (X, Y), method='linear')

        # Plot the heatmap
        vmin = np.nanmin(Z)
        vmax = np.nanmax(Z)
        vcenter = np.mean(loop_df[duration_label].values) # You can change this value as needed

        norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)

        heatmap = ax.contourf(X, Y, Z, levels=100, cmap='coolwarm', norm=norm)  # Adjust 'coolwarm' as needed
        cbar = plt.colorbar(heatmap, ax=ax)
        
        # Get available levels from the heatmap
        levels = heatmap.levels

        # Select specific levels: first, center, and last
        selected_ticks = [levels[0], levels[len(levels) // 2], levels[-1]]

        # Set the colorbar ticks to the selected values
        cbar.set_ticks(selected_ticks)
        cbar.set_ticklabels([f"{tick:.2f}" for tick in selected_ticks])

        ax.set_xlabel("Distance (cm)")
        ax.set_ylabel("Torque (a.u.)")
        ax.set_title(mouse)
        cbar.set_label("Duration (seconds)")
        
        # Show the plot
    sns.despine()
    plt.tight_layout()
    plt.show()
    pdf.savefig(fig)


**How does the preward when leaving change depending on the torque and the distance manipuation**

In [None]:
loop_df.groupby('experiment').length.unique()

In [None]:
loop_df.groupby('experiment').torque_norm.unique()

In [None]:
fig, axes = plt.subplots(4, 3, figsize=(16, 16))

with PdfPages(os.path.join(results_path, 'batch4_heatmap_distance_torque_preward.pdf')) as pdf:
    for mouse, ax in zip(sum_df.mouse.unique(), axes.flatten()):
        loop_df = sum_df.loc[sum_df.mouse == mouse].copy()
        loop_df = loop_df.groupby(['session', 'experiment']).agg({'length':'mean', torque_label:'mean', 'reward_probability':'mean'}).reset_index()
        control_preward = np.mean(sum_df.loc[(sum_df.mouse == mouse)&(sum_df.experiment == 'control')].groupby('session').agg({'length':'mean', 'torque_friction':'mean', 'reward_probability':'mean'})['reward_probability'])
        
        loop_df['reward_probability_centered'] = loop_df['reward_probability'] - control_preward

        # Define the range for distance and torque
        distance = loop_df['length'].values  # Distance values from the 'length' column
        torque = loop_df[torque_label].values  # Torque values from the 'torque_friction' column
        duration = loop_df['reward_probability_centered'].values  # Duration values from the 'epoch_duration' column
        print(loop_df.loc[loop_df.experiment == 'control','reward_probability_centered'].mean())
        
        # Create a grid of X (distance) and Y (torque)
        X, Y = np.meshgrid(np.linspace(distance.min(), distance.max(), 50), np.linspace(torque.min(), torque.max(), 50))

        # Interpolate Z as a function of distance and torque using epoch_duration
        # Z = loop_df['epoch_duration'].values
        Z = griddata((distance, torque), loop_df['reward_probability_centered'].values, (X, Y))

        # # Plot the heatmap
        vmax = np.nanmax(np.abs(Z))
        vmin = -vmax
        vcenter = 0
        norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)

        heatmap = ax.contourf(X, Y, Z, levels=100, cmap='coolwarm', norm=norm)  # Adjust 'coolwarm' as needed
        cbar = plt.colorbar(heatmap, ax=ax)
        
        # Get available levels from the heatmap
        levels = heatmap.levels

        # Select specific levels: first, center, and last
        selected_ticks = [levels[0], levels[-1]]

        # Set the colorbar ticks to the selected values
        cbar.set_ticks(selected_ticks)
        cbar.set_ticklabels([f"{tick:.2f}" for tick in selected_ticks])
        ax.set_xlabel("Distance (cm)")
        ax.set_ylabel("Torque (a.u.)")
        ax.set_title(mouse)
        cbar.set_label("P(reward)")
        
        # Show the plot
    sns.despine()
    plt.tight_layout()
    plt.show()
    pdf.savefig(fig)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import TwoSlopeNorm
from scipy.interpolate import griddata

# === Settings ===
z_var = 'reward_probability'   # 🔄 Choose column: 'reward_probability', 'epoch_duration', etc.
if z_var == 'epoch_duration':
    metric = 'median'
else:
    metric = 'mean'
normalize = True               # Normalize each mouse's values to [0, 1]
subtract_control = True        # Subtract control mean (global or per mouse depending on normalization)

# Grid for interpolation
x_vals = np.linspace(sum_df['length'].min(), sum_df['length'].max(), 50)
y_vals = np.linspace(sum_df[torque_label].min(), sum_df[torque_label].max(), 50)
X, Y = np.meshgrid(x_vals, y_vals)

Z_list = []

# Global control mean for subtraction
for mouse_id, mouse_df in sum_df.groupby("mouse"):
    global_control_mean = mouse_df.loc[sum_df['experiment'] == 'control', z_var].mean()

    grouped = mouse_df.groupby("session_n").agg({
        'length': 'mean',
        torque_label: 'mean',
        z_var: metric
    }).reset_index()

    z = grouped[z_var].values

    if normalize:
        z_min, z_max = np.nanmin(z), np.nanmax(z)
        if z_max - z_min == 0 or np.isnan(z_min) or np.isnan(z_max):
            continue  # Skip invalid or constant
        z = (z - z_min) / (z_max - z_min)

        if subtract_control:
            control_mean = mouse_df.loc[mouse_df['experiment'] == 'control', z_var].mean()
            control_scaled = (control_mean - z_min) / (z_max - z_min)
            z = z - control_scaled  # Align control to 0 in normalized scale
    else:
        if np.any(np.isnan(z)):
            continue
        if subtract_control:
            z = z - global_control_mean  # Raw values shifted by global control

    x = grouped['length'].values
    y = grouped[torque_label].values

    Z_interp = griddata((x, y), z, (X, Y), method='linear')
    Z_list.append(Z_interp)

# Average heatmap across mice
Z_stack = np.stack(Z_list)
Z_avg = np.nanmean(Z_stack, axis=0)

# Plotting
fig, ax = plt.subplots(figsize=(6, 5))

vmin = np.nanmin(Z_avg)
vmax = np.nanmax(Z_avg)

if subtract_control:
    vcenter = 0
    norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
else:
    norm = plt.Normalize(vmin=vmin, vmax=vmax)

heatmap = ax.contourf(X, Y, Z_avg, levels=50, cmap='coolwarm', norm=norm)
cbar = plt.colorbar(heatmap)

# Label depends on toggles
label = z_var.replace('_', ' ').capitalize()
if normalize:
    label = "Normalized " + label
if subtract_control:
    label += "\n (centered on control)"

cbar.set_label(label)

# Reduce ticks by half
ticks = cbar.get_ticks()
cbar.set_ticks(ticks[::2])
cbar.set_ticklabels([f"{tick:.2f}" for tick in ticks[::2]])
plt.ylim(-0.1,6)
plt.xlabel("Distance (cm)")
plt.ylabel("Torque (a.u.)")
plt.title(f"Heatmap of {label}")
sns.despine()
plt.tight_layout()
plt.show()
fig.savefig(os.path.join(results_path, f'batch4_heatmap_distance_torque_{z_var}_norm{normalize}_subscontrol{subtract_control}.pdf'), bbox_inches='tight')
