In [None]:
# IPython magig  tools
%load_ext autoreload
%autoreload 2

from aind_vr_foraging_analysis.utils.parsing import data_access

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages

sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

color1='#d95f02'
color2='#1b9e77'
color3='#7570b3'
color4='yellow'
odor_list_color = [color1, color2, color3, color4]

pdf_path = r'Z:\scratch\vr-foraging\sessions'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 5 - learning\results'

color_dict_label = {'InterSite': '#808080',
    'InterPatch': '#b3b3b3', 
    'PatchZ': '#d95f02', 'PatchZB': '#d95f02', 
    'PatchB': '#d95f02','PatchA': '#7570b3', 
    'PatchC': '#1b9e77',
    'Alpha-pinene': '#1b9e77', 
    'Methyl Butyrate': '#7570b3', 
    'Amyl Acetate': '#d95f02', 
    'Fenchone': '#7570b3', 
     'S': color1,
    'D': color2,
    'N': color3,   
    }

label_dict = {**{
"InterSite": '#808080',
"InterPatch": '#b3b3b3'}, 
            **color_dict_label}
import os


In [None]:
# Function to assign codes
def get_condition_code(text):
    if 'delayed' in text:
        return 'D'
    elif 'single' in text:
        return 'S'
    elif 'no_reward' in text or 'noreward' in text:
        return 'N'
    else:
        return None

In [None]:
trainer_dict = {
                '789915': 'Katrina', 
                '789917' : 'Katrina', 
                '789925': 'Olivia',

}      
mouse_list = trainer_dict.keys()

In [None]:
date_string = "2025-9-07"

sum_df = pd.DataFrame()
summary_df = pd.DataFrame()
for mouse in mouse_list:
    session_paths = data_access.find_sessions_relative_to_date(
        mouse=mouse,
        date_string=date_string,
        when='between',
        end_date_string="2025-10-16"
    )
    session_n = 0
    for session_path in session_paths:
        print(mouse, session_path)
        try:
            all_epochs, stream_data, data = data_access.load_session(
                session_path
            )
        except:
            print(f"Error loading {session_path.name}")
            continue
        all_epochs['mouse'] = mouse
        all_epochs['session'] = session_path.name[7:17]
        all_epochs['session_n'] = session_n
        session_n += 1
        
        all_epochs['stage'] = data['config'].streams.tasklogic_input.data['stage_name']
        
        last_engaged_patch = all_epochs['patch_number'][all_epochs['skipped_count'] >= 10].min()
        if pd.isna(last_engaged_patch):
            last_engaged_patch = all_epochs['patch_number'].max()
        all_epochs['engaged'] = np.where(all_epochs['patch_number'] <= last_engaged_patch-10, 1, 0)

        # all_epochs = all_epochs[all_epochs['engaged'] == 1]
        all_epochs['time_session'] = all_epochs.index - all_epochs.index[0]
        try:
            all_epochs['block'] = all_epochs['patch_label'].str.extract(r'set(\d+)').astype(int)
        except ValueError: 
            all_epochs['block'] = 0
            
        # Compute total and visited patches in a single step
        all_epochs['patch_label'] = all_epochs['patch_label'].apply(get_condition_code)
        
        patch_total = all_epochs.groupby('patch_label')['patch_number'].nunique()

        visited_filter = (all_epochs.site_number == 0) & (all_epochs.is_choice == 1)
        patch_visited = all_epochs.loc[visited_filter].groupby('patch_label')['patch_number'].nunique()
        block = all_epochs.loc[visited_filter].groupby('patch_label')['block'].nunique()

        # Combine into one dataframe
        patch_df = pd.DataFrame({
            'patch_number': patch_total,
            'visited': patch_visited, 
            'block': block
        }).fillna(0)  # Fill NaNs for labels that were never visited

        patch_df['fraction_visited'] = patch_df['visited'] / patch_df['patch_number']
        patch_df['mouse'] = mouse
        patch_df['session'] = session_path.name[7:17]
    
        sum_df = pd.concat([patch_df.reset_index(), sum_df])
        summary_df = pd.concat([all_epochs, summary_df])


In [None]:
summary_df.reset_index(inplace=True)
summary_df.sort_values(by=['mouse', 'session', 'patch_number'], inplace=True)


In [None]:
summary = summary_df.loc[summary_df.mouse.isin(['789915', '789917', '789925'])]
summary = summary.loc[summary.session_n >5]

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
plot = summary.groupby(['mouse', 'session_n']).agg({'reward_amount': 'max'}).reset_index()
sns.lineplot(data=plot, x='session_n', y='reward_amount', marker='o', hue='mouse', palette='Set1')
plt.xticks(rotation=45, ha='right')
sns.despine()

In [None]:
summary['reward_max'] = (
    summary.groupby(['mouse', 'session_n'])['reward_amount']
    .transform('max')
)

In [None]:
# summary = summary.loc[summary.is_reward == 1]
plot = (
    summary
    .groupby(['mouse', 'session_n', 'reward_max'])
    .agg(stops=('is_choice', 'sum'))
    .reset_index()
)
fig, axes = plt.subplots(1, 3, figsize=(18, 4), sharey=True)
for ax, mouse in zip(axes, ['789915', '789917', '789925']):
    sns.scatterplot(data=plot[plot.mouse == mouse], x='session_n', y='stops', hue='reward_max', palette='Set1', ax=ax)
    ax.set_title(f'Mouse {mouse})')
    ax.set_ylim(0,500)
    ax.legend_.remove()
    sns.despine()

In [None]:
summary['time_session'] = summary['time_session']/60

In [None]:
# summary = summary.loc[summary.is_reward == 1]
plot = (
    summary
    .groupby(['mouse', 'session_n', 'reward_max'])
    .agg(time_session=('time_session', 'max'))
    .reset_index()
)
fig, axes = plt.subplots(1, 3, figsize=(18, 4), sharey=True)
for ax, mouse in zip(axes, ['789915', '789917', '789925']):
    sns.scatterplot(data=plot[plot.mouse == mouse], x='session_n', y='time_session', hue='reward_max', palette='Set1', ax=ax)
    ax.set_title(f'Mouse {mouse})')
    # ax.set_ylim(0,500)
    ax.legend_.remove()
    sns.despine()


In [None]:
# summary = summary.loc[summary.is_reward == 1]
plot = (
    summary
    .groupby(['mouse', 'session_n', 'reward_max'])
    .agg(stops=('is_choice', 'sum'))
    .reset_index()
)

sns.boxplot(data=plot, x='mouse', y='stops', hue='reward_max', palette='Set1')
sns.despine()
plt.legend(title='Reward Volume (uL)', bbox_to_anchor=(1.05, 1), loc='upper left')

In [None]:
# Simulate patch-foraging data for three patch types with same depletion rate
# but different initial reward probabilities (offsets). Assumed offsets: [0.6, 0.69, 0.6]
import numpy as np
import pandas as pd
np.random.seed(12345)

# Params
n_mice = 10
n_sessions = 20
patches_per_session = 30
s = 1.0  # reward size
p0_list = [0.6, 0.9, 0.0]  # offsets for three patch types
k = 0.2   # exponential decay rate per visit
tau_mean = 0.1  # mean travel time between patches (s)
tau_sd = 0.01
handle_mean = 0.5  # mean handling / inter-visit time within patch (s)
handle_sd = 0.1
noise_sigma = 0.12  # perceptual noise in threshold (fraction)

rows = []
for mouse in range(1, n_mice+1):
    for session in range(1, n_sessions+1):
        for patch_idx in range(1, patches_per_session+1):
            travel_time = max(0.1, np.random.normal(tau_mean, tau_sd))
            # assign a patch type (1,2,3) cyclically
            ptype = (patch_idx - 1) % 3  # 0,1,2
            p0 = p0_list[ptype]
            visit_idx = 0
            patch_time = 0.0
            # simulate visits up to cap
            while True:
                visit_idx += 1
                p = p0 * np.exp(-k * (visit_idx - 1))
                reward = np.random.rand() < p
                dt = max(0.05, np.random.normal(handle_mean, handle_sd))
                patch_time += dt
                rows.append({
                    "mouse": mouse,
                    "session": session,
                    "patch_number": patch_idx,
                    "patch_label": f"P{ptype+1}",
                    "site_number": visit_idx,
                    "is_reward": int(reward),
                    "p_true": p,
                    "dt": dt,
                    "travel_time": travel_time,
                    "patch_time": patch_time
                })
                if visit_idx >= 50:
                    break
                if p < 0.0002:
                    break

df = pd.DataFrame(rows)

# Compute pre-decision R_bar (using all generated visits)
patch_summary = df.groupby(['mouse','session','patch_number']).agg(
    rewards=('is_reward','sum'),
    patch_time=('dt','sum'),
    travel_time=('travel_time','first')
).reset_index()

total_rewards = patch_summary['rewards'].sum()
total_time = patch_summary['patch_time'].sum() + patch_summary['travel_time'].sum()
R_bar_true = total_rewards / total_time

# Decide leave per patch using noisy MVT as before
def decide_leave_for_patch(p_series, R_bar_true, sigma=noise_sigma, softness=6.0):
    R_est = R_bar_true * (1.0 + np.random.normal(0, sigma))
    below = (p_series * s) <= R_est
    if below.any():
        first_idx = below.idxmax()
        p_leave = 1 / (1 + np.exp(-softness * ((p_series[first_idx] * s) - R_est)))
        leave_prob = 1 - p_leave
        if np.random.rand() < leave_prob:
            return first_idx + 1
    return p_series.index[-1] + 1

leave_rows = []
for (mouse, session, patch_number), group in df.groupby(['mouse','session','patch_number']):
    p_series = group['p_true'].reset_index(drop=True)
    leave_idx = decide_leave_for_patch(p_series, R_bar_true)
    for i in range(len(p_series)):
        leave_rows.append({
            "mouse": mouse,
            "session": session,
            "patch_number": patch_number,
            "site_number": i+1,
            "is_leave": int((i+1) == leave_idx)
        })

leave_df = pd.DataFrame(leave_rows)
df2 = df.merge(leave_df, on=['mouse','session','patch_number','site_number'], how='left')

# Trim to visits up to leave
def trim_patch(g):
    if g['is_leave'].any():
        leave_idx = g.loc[g['is_leave']==1, 'site_number'].iloc[0]
        return g[g['site_number'] <= leave_idx]
    else:
        return g

df_trim = df2.groupby(['mouse','session','patch_number'], group_keys=False).apply(trim_patch).reset_index(drop=True)

# Recompute observed R_bar
patch_summary2 = df_trim.groupby(['mouse','session','patch_number']).agg(
    rewards=('is_reward','sum'),
    patch_time=('dt','sum'),
    travel_time=('travel_time','first'),
    leave_visit=('site_number','max'),
    patch_label=('patch_label','first')
).reset_index()

total_rewards2 = patch_summary2['rewards'].sum()
total_time2 = patch_summary2['patch_time'].sum() + patch_summary2['travel_time'].sum()
R_bar_observed = total_rewards2 / total_time2

# Estimate p_hat by visit index and by patch_type
p_hat_by_visit = df_trim.groupby('site_number')['is_reward'].mean().reindex(range(1, df_trim['site_number'].max()+1), fill_value=0)
p_hat_by_type_visit = df_trim.groupby(['patch_label','site_number'])['is_reward'].mean().unstack(level=0).reindex(range(1, df_trim['site_number'].max()+1), fill_value=0)

G_hat = (p_hat_by_visit * s).cumsum()
marginal_hat = (p_hat_by_visit * s).diff().fillna(p_hat_by_visit * s)

df_trim['p_hat_visit'] = df_trim['site_number'].map(p_hat_by_visit.to_dict())
df_trim['G_hat_visit'] = df_trim['site_number'].map(G_hat.to_dict())
df_trim['marginal_hat'] = df_trim['site_number'].map(marginal_hat.to_dict())

# # Save CSV
# out_path = "/mnt/data/simulated_foraging_three_patches.csv"
# df_trim.to_csv(out_path, index=False)


# print(f"Saved to: {out_path}")
print(f"Pre-trim R_bar = {R_bar_true:.4f} rewards/sec")
print(f"Observed R_bar = {R_bar_observed:.4f} rewards/sec")
print("Patch-type initial probabilities used:", p0_list)
print("Number of visits (rows):", len(df_trim))


In [None]:
plot = df_trim.groupby(['mouse', 'session', 'patch_label']).site_number.max().reset_index()
sns.boxplot(data=plot, x='patch_label', y='site_number', palette='Set2')

In [None]:
df = df_trim.loc[df_trim['is_leave']==1].groupby(['mouse','session','patch_label']).p_true.mean().reset_index()
sns.boxplot(data=df, x='patch_label', y='p_true', palette='Set2')