In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

from aind_vr_foraging_analysis.utils.parsing import data_access
from aind_vr_foraging_analysis.utils.plotting import general_plotting_utils as plotting, plotting_friction_experiment as f

# Plotting libraries
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.backends.backend_pdf import PdfPages

import seaborn as sns
import pandas as pd
import numpy as np

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../../data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 4 - manipulating cost of travelling and global statistics\results'

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#1b9e77', 
    'PatchC': '#7570b3', 
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
    'patch_single': color1,
    'patch_delayed': color2,
    'patch_no_reward': color3,
    '90': color1,
    '60': color2,
    '0': color3,
    'slow': color1,
    'fast': color2,
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}

## Theory of the decay rates and reward rate

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Parameters
n_trials = 20
n_simulations = 1000
travel_time = 5  # in trials

# Define patch types
conditions = [
    {"label": "Decay 0.12, p0=0.9", "p0": 0.9, "decay_rate": 0.12, "color": "orange"},
    {"label": "Decay 0.12, p0=0.6", "p0": 0.6, "decay_rate": 0.12, "color": "green"},
]

def simulate_patch(p0, decay_rate, n_trials, n_simulations):
    reward_rates = np.zeros((n_simulations, n_trials))
    p_values = np.zeros((n_simulations, n_trials))  # track true p(reward)

    for sim in range(n_simulations):
        p = p0
        rewards = []
        for trial in range(n_trials):
            reward = np.random.rand() < p
            rewards.append(reward)
            if reward:
                p *= np.exp(-decay_rate)
            cumulative_rewards = np.cumsum(rewards)
            reward_rates[sim, trial] = cumulative_rewards[-1] / (trial + 1)
            p_values[sim, trial] = p

    mean = reward_rates.mean(axis=0)
    low = np.percentile(reward_rates, 2.5, axis=0)
    high = np.percentile(reward_rates, 97.5, axis=0)

    return reward_rates, p_values, mean, low, high

# Step 1: Simulate each patch
all_means = []
reward_data = []

for cond in conditions:
    rates, p_values, mean, low, high = simulate_patch(cond["p0"], cond["decay_rate"], n_trials, n_simulations)
    all_means.append(mean)
    reward_data.append({
        "rates": rates,
        "p_values": p_values,
        "mean": mean,
        "low": low,
        "high": high,
        **cond
    })

# Step 2: Compute environmental average reward rate
env_avg = np.mean(all_means, axis=0)

# Step 3 & 4: Apply MVT and record leave info
def apply_mvt_and_get_stats(rates, p_values, env_avg):
    leave_times = []
    leave_reward_rates = []
    leave_p_values = []

    for sim_rate, sim_p in zip(rates, p_values):
        for t in range(len(sim_rate)):
            if sim_rate[t] < env_avg[t]:
                leave_times.append(t + 1)
                leave_reward_rates.append(sim_rate[t])
                leave_p_values.append(sim_p[t])
                break
        else:
            leave_times.append(len(sim_rate))
            leave_reward_rates.append(sim_rate[-1])
            leave_p_values.append(sim_p[-1])

    return leave_times, leave_reward_rates, leave_p_values

# Apply to each condition
for data in reward_data:
    leave_times, leave_rates, leave_ps = apply_mvt_and_get_stats(
        data["rates"], data["p_values"], env_avg
    )
    data["leave_times"] = leave_times
    data["mean_leave_time"] = np.median(leave_times)
    data["leave_reward_rates"] = leave_rates
    data["mean_leave_rate"] = np.mean(leave_rates)
    data["leave_p_values"] = leave_ps
    data["mean_leave_p"] = np.mean(leave_ps)

# Step 5: Print summary
for data in reward_data:
    print(f"{data['label']}:")
    print(f"  Median leave trial: {data['mean_leave_time']:.1f}")
    print(f"  Mean reward rate at leaving: {data['mean_leave_rate']:.3f}")
    print(f"  Mean p(reward) at leaving: {data['mean_leave_p']:.3f}")
    print()

# Plot 1: Reward curves and MVT threshold
plt.figure(figsize=(12, 6))
t = np.arange(1, n_trials + 1)

for data in reward_data:
    plt.plot(t, data["mean"], label=data["label"], color=data["color"])
    plt.fill_between(t, data["low"], data["high"], color=data["color"], alpha=0.2)

plt.plot(t, env_avg, label="MVT Threshold (Env Avg)", color="black", linestyle="--")
plt.title("Reward Curves and MVT Threshold")
plt.xlabel("Trial")
plt.ylabel("Reward Rate")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot 2: Histogram of leave times
plt.figure(figsize=(10, 5))
for data in reward_data:
    plt.hist(data["leave_times"], bins=np.arange(0, 55, 1), alpha=0.6, color=data["color"],
             label=f"{data['label']} (med leave={data['mean_leave_time']:.1f})")

plt.axvline(np.mean(env_avg), color="black", linestyle="--", label="Mean env avg (visual guide)")
plt.title("Distribution of Patch Leaving Times (MVT)")
plt.xlabel("Trial at Leaving")
plt.ylabel("Number of Simulations")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot 3: Histogram of reward rates at leaving
plt.figure(figsize=(10, 5))
for data in reward_data:
    plt.hist(data["leave_reward_rates"], bins=30, alpha=0.6, color=data["color"],
             label=f"{data['label']} (mean rate={data['mean_leave_rate']:.3f})")

plt.axvline(np.mean(env_avg), color="black", linestyle="--", label="Env Avg Threshold")
plt.title("Distribution of Reward Rates at Leaving")
plt.xlabel("Reward Rate at Leave")
plt.ylabel("Number of Simulations")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot 4: Histogram of p(reward) at leaving
plt.figure(figsize=(10, 5))
for data in reward_data:
    plt.hist(data["leave_p_values"], bins=30, alpha=0.6, color=data["color"],
             label=f"{data['label']} (mean p={data['mean_leave_p']:.3f})")

plt.title("Distribution of p(Reward) at Leaving")
plt.xlabel("True p(Reward) at Leave")
plt.ylabel("Number of Simulations")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Parameters
n_trials = 50
n_simulations = 1000
travel_time = 5  # in trials

# Define patch types
conditions = [
    {"label": "Decay 0.12, p0=0.9", "p0": 0.9, "decay_rate": 0.12, "color": "orange"},
    {"label": "Decay 0.12, p0=0.6", "p0": 0.6, "decay_rate": 0.12, "color": "green"},
]

def simulate_patch_with_p(p0, decay_rate, n_trials, n_simulations):
    reward_rates = np.zeros((n_simulations, n_trials))
    p_values = np.zeros((n_simulations, n_trials))
    
    for sim in range(n_simulations):
        p = p0
        rewards = []
        for trial in range(n_trials):
            reward = np.random.rand() < p
            rewards.append(reward)
            p_values[sim, trial] = p
            if reward:
                p *= np.exp(-decay_rate)
            cumulative_rewards = np.cumsum(rewards)
            reward_rates[sim, trial] = cumulative_rewards[-1] / (trial + 1)
    
    mean = reward_rates.mean(axis=0)
    low = np.percentile(reward_rates, 2.5, axis=0)
    high = np.percentile(reward_rates, 97.5, axis=0)
    
    return reward_rates, p_values, mean, low, high

# Step 1: Simulate each patch
all_means = []
reward_data = []

for cond in conditions:
    rates, p_values, mean, low, high = simulate_patch_with_p(cond["p0"], cond["decay_rate"], n_trials, n_simulations)
    all_means.append(mean)
    reward_data.append({
        "rates": rates, "p_values": p_values, "mean": mean, "low": low, "high": high, **cond
    })

# Step 2: Compute environmental average reward rate
env_avg = np.mean(all_means, axis=0)

# Step 3: Apply MVT and collect leave data
def apply_mvt_and_get_details(rates, p_values, env_avg):
    leave_times = []
    leave_reward_rates = []
    leave_p_values = []
    leave_env_rates = []
    
    for sim_rate, sim_p in zip(rates, p_values):
        for t in range(len(sim_rate)):
            marginal = sim_rate[t]
            if marginal < env_avg[t]:
                leave_times.append(t + 1)
                leave_reward_rates.append(sim_rate[t])
                leave_p_values.append(sim_p[t])
                leave_env_rates.append(env_avg[t])
                break
        else:
            leave_times.append(len(sim_rate))
            leave_reward_rates.append(sim_rate[-1])
            leave_p_values.append(sim_p[-1])
            leave_env_rates.append(env_avg[-1])
    
    return leave_times, leave_reward_rates, leave_p_values, leave_env_rates

# Apply to each condition
for data in reward_data:
    lt, rr, pvals, envs = apply_mvt_and_get_details(data["rates"], data["p_values"], env_avg)
    data["leave_times"] = lt
    data["leave_reward_rates"] = rr
    data["leave_p_values"] = pvals
    data["leave_env_avg"] = envs

# Plot: Comparison of p(reward) vs env_avg at time of leaving
plt.figure(figsize=(8, 6))
for data in reward_data:
    plt.scatter(data["leave_env_avg"], data["leave_p_values"], alpha=0.4, s=10, color=data["color"], label=data["label"])

plt.plot([0, 1], [0, 1], 'k--', label="Identity Line")
plt.xlabel("Environmental Avg Reward Rate (at leave)")
plt.ylabel("p(Reward) at Leave")
plt.title("p(Reward) at Leaving vs MVT Threshold")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


## Explore global rate experiments using other metrics for rate

In [None]:
# Recover and clean batch 4 dataset
# batch4 = pd.read_csv(data_path + 'batch_4.csv') # if you want the original dataset
batch4 = pd.read_csv(data_path + 'batch_4_fixed_interpatch.csv')

# These mice are in the dataset but didn't perform the manipulation
batch4 = batch4[(batch4['mouse'] != 754573)&(batch4['mouse'] != 754572)&(batch4['mouse'] != 745300)&(batch4['mouse'] != 745306)&(batch4['mouse'] != 745307)]

## Micr with weird behavior
# batch4 = batch4.loc[(batch4.mouse != 754577)&(batch4.mouse != 754575)]

# Import data from batch3
batch3 = pd.read_csv(data_path + 'batch_3.csv')
batch3 = batch3.loc[(batch3.mouse != 715866)]

# Merge both datasets
df = pd.concat([batch3, batch4], ignore_index=True)

df= df.loc[~df.patch_label.isin(['patch_delayed', 'patch_no_reward', 'patch_single', 'delayed', 'single', 'no_reward', 'PatchZB'])]

In [None]:
df['patch_label'] = df['patch_label'].replace({'Alpha pinene': '60','Alpha-pinene': '60', 'Methyl Butyrate': '90', 'Ethyl Butyrate': '90', 'Amyl Acetate': '0', 
                                               '2,3-Butanedione': 'slow', '2-Heptanone': 'slow',  'Methyl Acetate':'fast', 'Fenchone':'0'})
df['experiment'] = df['experiment'].replace({'base': 'control'})

In [None]:
# df = pd.read_csv(data_path + 'batch_3_4_window5.csv')
# df = df.loc[(df.experiment == 'control')]

In [None]:
df['rate_since_entry'] = df['cumulative_rewards'] / df['time_since_entry']
df['stops'] = df['site_number'] + 1
df['rate_stops'] = df['cumulative_rewards'] / df['stops']

In [None]:
# Set desired rolling window size (e.g., 3 trials)
window_size = 8

In [None]:
df_cum = pd.DataFrame()
for mouse in df.mouse.unique():
    for sn in df.loc[(df.mouse == mouse)].session_n.unique():
        print(mouse)
        df_session = df.loc[(df.mouse == mouse)&(df.session_n == sn)].copy()

        df_sorted = df_session.sort_values(['patch_number', 'site_number'])

        # Step 2: Within each patch_number, forward-fill the last reward time
        df_sorted['last_reward_time'] = (
            df_sorted
            .groupby('patch_number')['reward_onset_time']
            .transform(lambda x: x.ffill())
        )

        # If needed, restore original order
        df_final = df_sorted.sort_values('start_time')

        df_final['fixed_last_reward'] = df_final['last_reward_time'].shift(1)
        df_final['local_average_rate'] = 5 / (df_final['start_time']-df_final['fixed_last_reward'])

        # Make sure the dataframe is sorted correctly
        df_final = df_final.sort_values(['patch_number', 'site_number'])

        # Compute rolling average within patch, allowing fewer points at the start
        df_final['running_avg_rate'] = (
            df_final
            .groupby('patch_number')['is_reward']
            .rolling(window=window_size, min_periods=1)
            .mean()
            .reset_index(level=0, drop=True)
        )

        print(df_final.loc[(df_final.last_site == 1)&(df_final.site_number > 0)].groupby('patch_label').running_avg_rate.mean())
        df_cum  = pd.concat([df_final, df_cum], ignore_index=True)
        
df = df_cum.copy()

In [None]:
import csv
df.to_csv(data_path + 'batch_3_4_window5.csv', index=False, quoting=csv.QUOTE_ALL )

In [None]:
df.index = df.index - df.index[0]  # Set the index to start at 0
df = df[(df['engaged'] == True)&(df['patch_number'] <= 20)]
df = df.loc[(df['label'] == 'OdorSite')]
df = df.loc[(df['experiment'] == 'control')]

In [None]:
variable = 'running_avg_rate' # reward_probability, 'rate_stops', 'rate_since_entry', perceived_reward_probability, running_avg_rate, local_average_rate

In [None]:
# Group the data
odor_labels = ['90', '60']
df = df[df['patch_label'] != '0']

# Group the data as you described
patch_df = (
    # df[((df.odor_label != 'Ethyl Butyrate') & (df.site_number > 0))|((df.site_number > 1)&(df.odor_label == 'Ethyl Butyrate'))&(df['last_site'] == 1)]
    df.loc[(df['last_visit'] == 1)&(df.site_number > 0)]
    .groupby(['mouse', 'session_n', 'patch_label', 'patch_number'])
    .agg({'rate_since_entry': 'mean', 
          'perceived_reward_probability': 'mean',
          'reward_probability': 'min', 
          'rate_stops': 'mean', 
          'running_avg_rate': 'mean', 
          'local_average_rate': 'mean'})
    .reset_index()
)
session_df = (
    patch_df
    .groupby(['mouse', 'session_n', 'patch_label'])
    .agg({'rate_since_entry': 'mean',
          'perceived_reward_probability': 'mean',
          'reward_probability': 'mean', 
          'rate_stops': 'mean', 
          'running_avg_rate': 'mean',
          'local_average_rate': 'mean'
          }
         )
    .reset_index()
)

# These df summarizes metrics for each mouse (averages all sessions and all patches withing that session)
mouse_df = ( 
        patch_df
        .groupby(['mouse','patch_label'])
        .agg({'rate_since_entry': 'mean',
          'perceived_reward_probability': 'mean',
          'reward_probability': 'mean', 
          'rate_stops': 'mean', 
          'running_avg_rate': 'mean',
          'local_average_rate': 'mean'
              })
        .reset_index()
)

In [None]:
# Within session progression
with PdfPages(results_path + f'/within_session_x_odor_site_y_{variable}.pdf') as pdf:
    for mouse in df.mouse.unique():
            for sn in df.loc[df.mouse == mouse].session_n.unique():
                # Filter the DataFrame for the current mouse and patch
                filtered_df = df[(df['mouse'] == mouse) & (df['session_n'] == sn)]

                fig, ax = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
                axes = ax[0]
                sns.scatterplot(data=filtered_df, x='odor_sites', y=variable, 
                                hue='patch_label', palette=color_dict_label, style='last_visit', markers={0: '.', 1: 'o'},
                                s=50, zorder=2, edgecolor=None, ax=axes, legend=False,
                                )
                sns.lineplot(data=filtered_df, x='odor_sites', y=variable, color='grey', zorder=1, ax=axes, linewidth=0.5)
                # axes.set_ylim(-0.05, 1)

                axes = ax[1]
                sns.scatterplot(data=filtered_df, x='odor_sites', y='patch_label', hue='patch_label', palette=color_dict_label, s=50, zorder=2, edgecolor=None, ax=axes)
                sns.lineplot(data=filtered_df, x='odor_sites', y='patch_label', color='grey', zorder=1, ax=axes)

                axes.set_ylabel('Resident Patch')
                axes.set_xlabel('Odor Site #')
                plt.suptitle(f'{mouse}-{sn}')
                plt.legend(title='Odor Site', bbox_to_anchor=(1.05, 1), loc='upper left')
                sns.despine()
                plt.tight_layout()
                pdf.savefig(fig, bbox_inches='tight')
                plt.close(fig)

In [None]:
## Difference between the two patches per mouse
mice = df['mouse'].unique()
n_mice = len(mice)
n_cols = 5  # adjust number of columns as needed
n_rows = int(np.ceil(n_mice / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows), squeeze=False, sharey=True)

for idx, mouse in enumerate(mice):
    row = idx // n_cols
    col = idx % n_cols
    ax = axes[row, col]

    mouse_group = session_df[session_df['mouse'] == mouse]

    # Plot boxplot
    sns.boxplot(
        x='patch_label',
        y=variable,
        palette=color_dict_label,
        data=mouse_group,
        order=odor_labels,
        zorder=10,
        width=0.7,
        ax=ax,
        fliersize=0
    )

    # Plot session lines
    f.plot_lines(
        data=mouse_group,
        ax=ax,
        variable=variable,
        one_line='session_n',
        order=odor_labels
    )

    ax.set_title(f'Mouse {mouse}')
    ax.set_xlabel('')
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Odor 1', 'Odor 2'])
    # ax.set_ylabel('Rate since entry')

    sns.despine(ax=ax)

# Turn off empty axes
for idx in range(n_mice, n_rows * n_cols):
    row = idx // n_cols
    col = idx % n_cols
    axes[row, col].axis('off')

fig.tight_layout()
plt.show()
fig.savefig(results_path + f'/grid_mouse_y_{variable}_x_patch_label.pdf', bbox_inches='tight')


In [None]:
# Summary of the two patches all mice together
fig, ax = plt.subplots(figsize=(5, 5))
sns.boxplot(data=mouse_df, x='patch_label', y='running_avg_rate', palette=color_dict_label, order=odor_labels, zorder=10, width=0.7, fliersize=0, ax=ax)
# Plot session lines
f.plot_lines(
    data=mouse_df,
    ax=ax,
    variable=variable,
    one_line='mouse',
    order=odor_labels
)

ax.set_xlabel('')
ax.set_xticks([0, 1])
ax.set_xticklabels(['Odor 1', 'Odor 2'])
ax.set_ylim(0, 1)
from scipy import stats
stats.ttest_rel(
    mouse_df.loc[mouse_df['patch_label'] == '90', 'running_avg_rate'].values,
    mouse_df.loc[mouse_df['patch_label'] == '60', 'running_avg_rate'].values
)

In [None]:
sns.regplot(data=mouse_df.loc[mouse_df.patch_label == '90'], x='running_avg_rate', y='reward_probability', color=color1)
sns.regplot(data=mouse_df.loc[mouse_df.patch_label == '60'], x='running_avg_rate', y='reward_probability', color=color2)
plt.plot([0, 1], [0, 1], linestyle='--', color='black')  # Dashed gray line

In [None]:
mouse_df = mouse_df.loc[mouse_df.mouse != 715870]
mouse_df.loc[mouse_df['patch_label'] == '90', 'running_avg_rate'].shape[0]

In [None]:
# Grid per mouse across all sessions
with PdfPages(results_path + f'/grid_session_y_{variable}_x_patch_label.pdf') as pdf:
    for mouse in df['mouse'].unique():
        test_df = df[df['mouse'] == mouse]
        session_ns = mouse_df['session_n'].unique()

        n_sessions = len(session_ns)
        n_cols = 5  # Adjust number of columns
        n_rows = int(np.ceil(n_sessions / n_cols))

        fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 4*n_rows), squeeze=False, sharey=True)

        for idx, session_n in enumerate(session_ns):
            row = idx // n_cols
            col = idx % n_cols
            ax = axes[row, col]

            session_df = test_df[test_df['session_n'] == session_n]
            # plot_df = session_df[(session_df['last_site'] == 1)&((session_df.odor_label != 'Ethyl Butyrate') & (session_df.site_number > 0))|((session_df.site_number > 1)&(session_df.odor_label == 'Ethyl Butyrate'))]
            plot_df = session_df[(session_df['site_number']>1)&(session_df['last_visit'] == 1)]
            
            if plot_df.empty:
                ax.axis('off')
                continue

            sns.boxplot(
                data=plot_df,
                x='patch_label',
                y=variable,
                palette=color_dict_label,
                order=odor_labels,
                ax=ax
            )
            ax.set_title(f"Session {session_n}")
            sns.despine(ax=ax)

        # Turn off any unused axes
        for idx in range(n_sessions, n_rows * n_cols):
            row = idx // n_cols
            col = idx % n_cols
            axes[row, col].axis('off')

        fig.suptitle(f'Mouse {mouse}', fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()
        pdf.savefig(fig, bbox_inches='tight')

## Look for the rate in experiment2

In [None]:
# Import data from batch3
df = pd.read_csv(data_path + 'batch_3_4_window5.csv')
df = df.loc[(df.mouse != 716455)]

In [None]:
df.index = df.index - df.index[0]  # Set the index to start at 0
df = df[(df['engaged'] == True)&(df['patch_number'] <= 20)]

df['rate_since_entry'] = df['cumulative_rewards'] / df['time_since_entry']

df['stops'] = df['site_number'] + 1
df['rate_stops'] = df['cumulative_rewards'] / df['stops']

df = df.loc[(df['label'] == 'OdorSite')]

df = df.loc[df.experiment == 'experiment2']

In [None]:
odor_labels = ['slow', 'fast']

In [None]:
variable = 'running_avg_rate' # reward_probability, 'rate_stops', 'rate_since_entry', perceived_reward_probability

In [None]:
window = 20

In [None]:
# Group the data
df = df[df['patch_label'] != '0']

# Group the data as you described
patch_df = (
    # df[((df.odor_label != 'Ethyl Butyrate') & (df.site_number > 0))|((df.site_number > 1)&(df.odor_label == 'Ethyl Butyrate'))&(df['last_site'] == 1)]
    df.loc[(df['last_site'] == 1)&(df.site_number > 0)]
    .groupby(['mouse', 'session_n', 'patch_label', 'patch_number'])
    .agg({'rate_since_entry': 'mean', 
          'perceived_reward_probability': 'mean',
          'reward_probability': 'min', 
          'rate_stops': 'mean', 
          'running_avg_rate': 'mean'})
    .reset_index()
)
session_df = (
    patch_df
    .groupby(['mouse', 'session_n', 'patch_label'])
    .agg({'rate_since_entry': 'mean',
          'perceived_reward_probability': 'mean',
          'reward_probability': 'mean', 
          'rate_stops': 'mean', 
          'running_avg_rate': 'mean'
          }
         )
    .reset_index()
)

# These df summarizes metrics for each mouse (averages all sessions and all patches withing that session)
mouse_df = ( 
        session_df
        .groupby(['mouse','patch_label'])
        .agg({'rate_since_entry': 'mean',
          'perceived_reward_probability': 'mean',
          'reward_probability': 'mean', 
          'rate_stops': 'mean', 
          'running_avg_rate': 'mean'
              })
        .reset_index()
)

In [None]:
sns.regplot(data=mouse_df.loc[mouse_df.patch_label == 'slow'], x='running_avg_rate', y='reward_probability', color=color1)
sns.regplot(data=mouse_df.loc[mouse_df.patch_label == 'fast'], x='running_avg_rate', y='reward_probability', color=color2)

plt.plot([0, 1], [0, 1], linestyle='--', color='black')  # Dashed gray line

In [None]:
with PdfPages(results_path + '\\' + 'patches_exp2.pdf') as pdf:
    for mouse in df.mouse.unique():
        reward_sites = df.loc[(df.mouse == mouse)&(df.label == 'OdorSite')]
        for i, session_df in reward_sites.groupby('session_n'):
            print(i, mouse)
            plotting.segmented_raster_vertical(session_df, color_dict_label=color_dict_label, save=pdf)

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
sns.boxplot(data=mouse_df, x='patch_label', y=variable, palette=color_dict_label, order=odor_labels, zorder=10, width=0.7, fliersize=0, ax=ax)
# Plot session lines
f.plot_lines(
    data=mouse_df,
    ax=ax,
    variable=variable,
    one_line='mouse',
    order=odor_labels
)

ax.set_xlabel('')
ax.set_xticks([0, 1])
ax.set_xticklabels(['Odor 1', 'Odor 2'])

from scipy import stats
stats.ttest_rel(
    mouse_df.loc[mouse_df['patch_label'] == 'slow', 'running_avg_rate'].values,
    mouse_df.loc[mouse_df['patch_label'] == 'fast', 'running_avg_rate'].values
)

In [None]:
## Difference between the two patches per mouse
mice = df['mouse'].unique()
n_mice = len(mice)
n_cols = 5  # adjust number of columns as needed
n_rows = int(np.ceil(n_mice / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows), squeeze=False, sharey=True)

for idx, mouse in enumerate(mice):
    row = idx // n_cols
    col = idx % n_cols
    ax = axes[row, col]

    mouse_group = session_df[session_df['mouse'] == mouse]

    # Plot boxplot
    sns.boxplot(
        x='patch_label',
        y=variable,
        palette=color_dict_label,
        data=mouse_group,
        order=odor_labels,
        zorder=10,
        width=0.7,
        ax=ax,
        fliersize=0
    )

    # Plot session lines
    f.plot_lines(
        data=mouse_group,
        ax=ax,
        variable=variable,
        one_line='session_n',
        order=odor_labels
    )

    ax.set_title(f'Mouse {mouse}')
    ax.set_xlabel('')
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Odor 1', 'Odor 2'])
    # ax.set_ylabel('Rate since entry')

    sns.despine(ax=ax)

# Turn off empty axes
for idx in range(n_mice, n_rows * n_cols):
    row = idx // n_cols
    col = idx % n_cols
    axes[row, col].axis('off')

fig.tight_layout()
plt.show()
fig.savefig(results_path + f'/grid_mouse_y_{variable}_x_patch_label.pdf', bbox_inches='tight')

In [None]:
with PdfPages(results_path + f'/grid_session_y_{variable}_x_patch_label_exp2.pdf') as pdf:
    for mouse in df['mouse'].unique():
        mouse_df = df[df['mouse'] == mouse]
        session_ns = mouse_df['session_n'].unique()

        n_sessions = len(session_ns)
        n_cols = 5  # Adjust number of columns
        n_rows = int(np.ceil(n_sessions / n_cols))
        try:
            fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 4*n_rows), squeeze=False, sharey=True)
        except:
            print(f"Error creating subplots for mouse {mouse}")
            continue
        
        for idx, session_n in enumerate(session_ns):
            row = idx // n_cols
            col = idx % n_cols
            ax = axes[row, col]

            session_df = mouse_df[mouse_df['session_n'] == session_n]
            # plot_df = session_df[(session_df['last_site'] == 1)&((session_df.odor_label != 'Ethyl Butyrate') & (session_df.site_number > 0))|((session_df.site_number > 1)&(session_df.odor_label == 'Ethyl Butyrate'))]
            plot_df = session_df[(session_df['site_number']>0)&(session_df['last_visit'] == 1)]
            
            if plot_df.empty:
                ax.axis('off')
                continue

            sns.boxplot(
                data=plot_df,
                x='patch_label',
                y=variable,
                palette=color_dict_label,
                order=odor_labels,
                ax=ax
            )
            ax.set_title(f"Session {session_n}")
            sns.despine(ax=ax)

        # Turn off any unused axes
        for idx in range(n_sessions, n_rows * n_cols):
            row = idx // n_cols
            col = idx % n_cols
            axes[row, col].axis('off')

        fig.suptitle(f'Mouse {mouse}', fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()
        pdf.savefig(fig, bbox_inches='tight')