In [None]:
"""Notebook settings and imports"""

%load_ext autoreload
%autoreload 2
# %flow mode reactive

from collections import defaultdict

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots

import datajoint as dj
from aeon.dj_pipeline.analysis.block_analysis import *
from aeon.dj_pipeline import acquisition, streams
import colorsys
import plotly.io as pio
import re
import os
from aeon.analysis.block_plotting import gen_hex_grad, conv2d
from datetime import datetime, timedelta

## Set variables and fetch data

In [2]:
fig_save_directory = '/ceph/aeon/aeon/code/scratchpad/Orsi/overall_plots'

In [3]:
"""Set plot styles."""

subject_colors = plotly.colors.qualitative.Plotly
# @NOTE: Really, we shouldn't have to explicitly set colors for each subject.
subject_colors_dict = {
    "BAA-1104568": subject_colors[0],
    "BAA-1104569": subject_colors[1],
    "BAA-1104516": subject_colors[2],
    "BAA-1104519": subject_colors[3],

}
patch_colors = plotly.colors.qualitative.Dark2
patch_markers = [
    "circle",
    "bowtie",
    "square",
    "hourglass",
    "diamond",
    "cross",
    "x",
    "triangle",
    "star",
]
patch_markers_symbols = ["●", "⧓", "■", "⧗", "♦", "✖", "×", "▲", "★"]
patch_markers_dict = {marker: symbol for marker, symbol in zip(patch_markers, patch_markers_symbols)}
patch_markers_linestyles = ["solid", "dash", "dot", "dashdot", "longdashdot"]

In [4]:
"""Set experiment key, stages"""

key = {"experiment_name": "social0.3-aeon4"}

# how to get this from Dj? This i only valid for aeon4
pre_social_start = datetime.strptime('2024-06-08 00:00:00', '%Y-%m-%d %H:%M:%S')
pre_social_end = datetime.strptime('2024-06-17 00:00:00', '%Y-%m-%d %H:%M:%S')
social_start =  datetime.strptime('2024-06-19 00:00:00', '%Y-%m-%d %H:%M:%S') 
social_end = datetime.strptime('2024-07-04 00:00:00', '%Y-%m-%d %H:%M:%S')
post_social_start = datetime.strptime('2024-07-04 00:00:00', '%Y-%m-%d %H:%M:%S')
post_social_end = datetime.strptime('2024-07-13 00:00:00', '%Y-%m-%d %H:%M:%S')

exp_start = (Block & key & f"block_start >= '{pre_social_start}'").fetch("block_start", limit=1)[0]
exp_end = (Block & key & f"block_start >= '{post_social_end}'").fetch("block_end", limit=1)[0]


In [5]:
"""Define light cycle."""
# Define light cycle periods
twilight_start = timedelta(hours=18)
twilight_end = timedelta(hours=19)
night_start = timedelta(hours=19)
night_end = timedelta(hours=7)
dawn_start = timedelta(hours=7)
dawn_end = timedelta(hours=8)
day_start = timedelta(hours=8)
day_end = timedelta(hours=18)

In [None]:
"""Print experiment patch names and subject names"""
one_day_into_social = "2024-06-20 10:00:00"
block_start = (Block & key & f"block_start >= '{one_day_into_social}'").fetch("block_start", limit=1)[0]
block_end = (Block & key & f"block_start >= '{one_day_into_social}'").fetch("block_end", limit=1)[0]
one_block_key = key | {"block_start": str(pd.Timestamp(block_start))} | {"block_end": str(pd.Timestamp(block_end))}
chunk_restriction = acquisition.create_chunk_restriction(key["experiment_name"], block_start, block_end)
subj_patch_info_one_block = (BlockSubjectAnalysis.Patch & one_block_key).fetch(format="frame")
subject_names = list(subj_patch_info_one_block.index.get_level_values("subject_name").unique())
patch_names = list(subj_patch_info_one_block.index.get_level_values("patch_name").unique())
print(patch_names, subject_names)

In [None]:
"""Get foraging blocks."""
#NOTE: pre solo data seems to be missing?
#NOTE; got rid of this, now getting all blocks
pellet_threshold = 0
all_blocks = (
    BlockAnalysis.Patch
    & key
    & f'block_start >= "{exp_start}"'
    & f'block_start <= "{exp_end}"'
).fetch('block_start', 'pellet_count', as_dict=True)

pellet_sums = defaultdict(int)
for entry in all_blocks:
    pellet_sums[entry['block_start']] += entry['pellet_count']
foraging_block_starts = {block_start for block_start, total in pellet_sums.items() if total >= pellet_threshold}
foraging_block_starts = [{'block_start': block_start} for block_start in foraging_block_starts] #list of dicts
print(len(foraging_block_starts))

patch_info = (
    BlockAnalysis.Patch
    & key
    & foraging_block_starts
).fetch('block_start', "patch_name", "patch_rate", "patch_offset", as_dict=True)

subj_patch_info = (
    BlockSubjectAnalysis.Patch 
    & key
    & foraging_block_starts
).fetch(format="frame")
display(subj_patch_info, patch_info)

In [None]:
"""Convert `subj_patch_info` into a form amenable to plotting."""

reset_subj_patch_info = subj_patch_info.reset_index()  # reset to turn MultiIndex into columns
min_subj_patch_info = reset_subj_patch_info[  # select only relevant columns
    ['block_start', "patch_name", "subject_name", "pellet_timestamps", "patch_threshold"]
]
print(min_subj_patch_info.shape)
min_subj_patch_info
min_subj_patch_info = min_subj_patch_info.explode(
    ["pellet_timestamps", "patch_threshold"], ignore_index=True
).dropna().reset_index(drop=True)
print(min_subj_patch_info.shape)
# Rename and reindex columns and sort by time
min_subj_patch_info.columns = ['block_start', "patch", "subject", "time", "threshold"]
min_subj_patch_info = min_subj_patch_info.reindex(columns=["time",'block_start', "patch", "threshold", "subject"])
min_subj_patch_info = min_subj_patch_info.sort_values(by="time").reset_index(drop=True)

print(min_subj_patch_info.shape)
display(min_subj_patch_info)

In [None]:
"""Add patch mean values and block-normalized delivery times to pellet info."""

n_patches = len(patch_info)
patch_mean_info = pd.DataFrame(index=np.arange(n_patches), columns=min_subj_patch_info.columns)
patch_mean_info["subject"] = "mean"
patch_mean_info["patch"] = [d["patch_name"] for d in patch_info]
patch_mean_info["threshold"] = [((1 / d["patch_rate"]) + d["patch_offset"]) for d in patch_info]
patch_mean_info["time"] = subj_patch_info.index.get_level_values("block_start")[0]
patch_mean_info["block_start"] = subj_patch_info.index.get_level_values("block_start")[0]

min_subj_patch_info_plus = pd.concat((patch_mean_info, min_subj_patch_info)).reset_index(drop=True)
min_subj_patch_info_plus["norm_time"] = (
    (min_subj_patch_info_plus["time"] - min_subj_patch_info_plus["time"].iloc[0])
    / (min_subj_patch_info_plus["time"].iloc[-1] - min_subj_patch_info_plus["time"].iloc[0])
).round(3)

display(min_subj_patch_info_plus)

In [None]:
"""Add patch mean values and block-normalized delivery times to pellet info."""

all_patch_mean_info = pd.DataFrame()

# Group by block_start
grouped = min_subj_patch_info.groupby('block_start')
n_patches = len(patch_names)

for block_start, group in grouped:
    patch_mean_info = pd.DataFrame(index=np.arange(n_patches), columns=group.columns)
    current_patch_info = [d for d in patch_info if d['block_start'] == block_start]
    patch_mean_info["subject"] = "mean"
    patch_mean_info["patch"] = [d["patch_name"] for d in current_patch_info]
    patch_mean_info["threshold"] = [((1 / d["patch_rate"]) + d["patch_offset"]) for d in current_patch_info]
    patch_mean_info["time"] = block_start
    patch_mean_info["block_start"] = block_start
    patch_mean_info["subject"] = "mean"

    # Append the patch mean info to the all_patch_mean_info DataFrame
    all_patch_mean_info = pd.concat((all_patch_mean_info, patch_mean_info))

# Reset index of the final DataFrame
all_patch_mean_info = all_patch_mean_info.reset_index(drop=True)

min_subj_patch_info_plus = pd.concat([min_subj_patch_info, all_patch_mean_info], ignore_index=True)

grouped = min_subj_patch_info_plus.groupby('block_start')
for block_start, group in grouped:
    # Calculate normalized time for the patch mean info within each block
    norm_time = (group["time"] - group["time"].min()) / (group["time"].max() - group["time"].min())
    norm_time = norm_time.round(3)
    min_subj_patch_info_plus.loc[group.index, "norm_time"] = norm_time
min_subj_patch_info_plus = min_subj_patch_info_plus.sort_values(by="time").reset_index(drop=True)
# Display the final DataFrame
display(min_subj_patch_info_plus)

In [None]:
"""Create a cumulative pellet count (by subject) dataframe with additional pellet info."""

cum_pel_ct = min_subj_patch_info_plus.sort_values("time").copy().reset_index(drop=True)

def cumsum_helper(group):
    group["counter"] = np.arange(len(group)) + 1
    return group

# Calculate patch means for each block
patch_means = cum_pel_ct[cum_pel_ct["subject"] == "mean"][["patch", "threshold", "block_start"]]
patch_means = patch_means.rename(columns={"threshold": "mean_thresh"})
patch_means["mean_thresh"] = patch_means["mean_thresh"].astype(float).round(1)

# Merge patch means back into the main DataFrame
cum_pel_ct = cum_pel_ct.merge(patch_means, on=["patch", "block_start"], how="left")
cum_pel_ct["subject"] = cum_pel_ct["subject"].astype(str)

# Remove mean rows and reset index
cum_pel_ct = cum_pel_ct[~cum_pel_ct["subject"].str.contains("mean")].reset_index(drop=True)

# Apply cumulative sum helper function
cum_pel_ct = cum_pel_ct.groupby(["subject", "block_start"], group_keys=False).apply(cumsum_helper).reset_index(drop=True)

# Convert columns to float
make_float_cols = ["threshold", "mean_thresh", "norm_time"]
cum_pel_ct[make_float_cols] = cum_pel_ct[make_float_cols].astype(float)

# Create patch label
cum_pel_ct["patch_label"] = (
    cum_pel_ct["patch"] + " μ: " + cum_pel_ct["mean_thresh"].astype(float).round(1).astype(str)
)

# Calculate normalized threshold value
cum_pel_ct["norm_thresh_val"] = (
    (cum_pel_ct["threshold"] - cum_pel_ct["threshold"].min())
    / (cum_pel_ct["threshold"].max() - cum_pel_ct["threshold"].min())
).round(3)

# Sort by 'time' column
cum_pel_ct = cum_pel_ct.sort_values("time")

display(cum_pel_ct)

In [None]:
"""Get wheel timestamps for each patch."""

wheel_ts = (BlockAnalysis.Patch() & key & foraging_block_starts).fetch("block_start", "patch_name", "wheel_timestamps", as_dict=True)
wheel_ts_dict = {}
for d in wheel_ts:
    block_start = d["block_start"]
    patch_name = d["patch_name"]
    if block_start not in wheel_ts_dict:
        wheel_ts_dict[block_start] = {}
    wheel_ts_dict[block_start][patch_name] = d["wheel_timestamps"]
display(wheel_ts)

In [None]:
"""Get subject patch data."""

subject_patch_data = (BlockSubjectAnalysis.Patch() & key & foraging_block_starts).fetch(format="frame")
subject_patch_data.reset_index(level=["experiment_name"], drop=True, inplace=True)
subject_patch_data.reset_index(inplace=True)
display(subject_patch_data)

In [None]:
"""Add experiment phase labels to data ad number blocks."""
#NOTE: we need to add labels for experiment phases, here or already in DJ table?
# make new colunm experiment_phase in min_subj_patch_info_plus and ill in base on manually defined dates
def determine_phase(date):
    if pre_social_start <= date <= pre_social_end:
        return 'pre_social'
    elif social_start <= date <= social_end:
        return 'social'
    elif post_social_start <= date <= post_social_end:
        return 'post_social'
    else:
        print(f"Unknown phase for date {date}")
        return 'unknown'
min_subj_patch_info_plus['experiment_phase'] = min_subj_patch_info_plus['time'].apply(determine_phase)
display(min_subj_patch_info_plus)       

# Create block numbering for solo phases, grouped by subject and experiment_phase
solo_phases = min_subj_patch_info_plus[min_subj_patch_info_plus['experiment_phase'] != 'social']
solo_numbering = solo_phases[['subject', 'experiment_phase', 'block_start']].drop_duplicates()
solo_numbering['block_number'] = solo_numbering.groupby(['subject', 'experiment_phase']).cumcount() + 1

# Create block numbering for the social phase, grouped only by block_start (making sure same block_start has the same number for both subjects)
social_phase = min_subj_patch_info_plus[min_subj_patch_info_plus['experiment_phase'] == 'social']
social_numbering = social_phase[['experiment_phase', 'block_start']].drop_duplicates()
social_numbering['block_number'] = social_numbering.groupby(['experiment_phase']).cumcount() + 1

# Merge the numbering back
solo_numbering = solo_numbering[['subject', 'experiment_phase', 'block_start', 'block_number']]
social_numbering = social_numbering[['experiment_phase', 'block_start', 'block_number']]
min_subj_patch_info_plus = min_subj_patch_info_plus.merge(solo_numbering, on=['subject', 'experiment_phase', 'block_start'], how='left')
min_subj_patch_info_plus = min_subj_patch_info_plus.merge(social_numbering, on=['experiment_phase', 'block_start'], how='left', suffixes=('', '_social'))
min_subj_patch_info_plus['block_number'] = min_subj_patch_info_plus['block_number'].fillna(min_subj_patch_info_plus['block_number_social'])
min_subj_patch_info_plus.drop(columns=['block_number_social'], inplace=True)
min_subj_patch_info_plus['block_number'] = min_subj_patch_info_plus['block_number'].astype(int).astype(str)
display(min_subj_patch_info_plus)


In [None]:
"""Add pellet count to df."""

subject_patch_data_copy = subject_patch_data.copy()
subject_patch_data_copy = subject_patch_data_copy.rename(columns={
    'patch_name': 'patch',
    'subject_name': 'subject'
})

min_subj_patch_info_plus = min_subj_patch_info_plus.merge(
    subject_patch_data_copy[['block_start', 'patch', 'subject', 'pellet_count']],
    on=['block_start', 'patch', 'subject'],
    how='left'
)
min_subj_patch_info_plus

In [16]:
"""Assign patch types (easy, meadim, hard) to blocks."""
#NOTE: how to treat equal blocks?

def assign_patch_type(group):
    # Assign 'dummy' to patch_type where patch is 'PatchDummy1'
    group.loc[group['patch'] == 'PatchDummy1', 'patch_type'] = 'dummy'
    
    # Separate dummy and non-dummy rows
    dummy_rows = group[group['patch_type'] == 'dummy']
    non_dummy_rows = group[group['patch_type'] != 'dummy']
    
    # Sort non-dummy rows by mean_thresh
    sorted_group = non_dummy_rows.sort_values(by='mean_thresh')
    
    # Check if mean_thresh values are all different
    if sorted_group['mean_thresh'].nunique() < len(sorted_group):
        #print(f'Warning: Equal mean_thresh values in block {group["block_start"].iloc[0]}')
        #mean_thresh_value = sorted_group['mean_thresh'].iloc[0]
        #sorted_group['patch_type'] = f'equal_{mean_thresh_value}'
        sorted_group['patch_type'] = 'equal'
    else:
        # Create conditions for assignment
        conditions = [
            (sorted_group.index == sorted_group.index[0]),
            (sorted_group.index == sorted_group.index[1]),
            (sorted_group.index == sorted_group.index[2])
        ]
        
        # Corresponding choices
        choices = ['easy', 'medium', 'hard']
        
        # Apply conditions and choices
        sorted_group['patch_type'] = np.select(conditions, choices, default=sorted_group['patch_type'])
    
    # Combine dummy and non-dummy rows back together
    combined_group = pd.concat([sorted_group, dummy_rows]).sort_index()
    
    return combined_group

# Apply the function to each group
patch_means = patch_means.groupby('block_start').apply(assign_patch_type).reset_index(drop=True)

In [None]:
"""Assign env types (easy, hard) to block per day."""

# Convert block_start to datetime and extract the day
patch_means['block_start_day'] = pd.to_datetime(patch_means['block_start']).dt.date

# Filter the DataFrame where patch_type is not equal and not dummy
filtered_df = patch_means[~patch_means['patch_type'].isin(['equal', 'dummy'])]

# Group by the day of block_start and get unique values of mean_threshold
grouped_df = filtered_df.groupby('block_start_day')['mean_thresh'].apply(lambda x: sorted(x.unique())).reset_index()

# define easy and hard envs
easy_env_thresholds = [175.0, 378.0, 575.0]
hard_env_thresholds = [275.0, 673.8, 1075.0]

# Function to assign env type
def assign_env_type(thresholds):
    if any(thresh in easy_env_thresholds for thresh in thresholds):
        return 'easy'
    elif any(thresh in hard_env_thresholds for thresh in thresholds):
        return 'hard'
    else:
        return 'unknown'

# Apply the function to assign env_type
grouped_df['env_type'] = grouped_df['mean_thresh'].apply(assign_env_type)


#assing env types back to patch_means
# Ensure patch_means has block_start_day
patch_means['block_start_day'] = pd.to_datetime(patch_means['block_start']).dt.date

# Merge the env_type back to the patch_means DataFrame based on day
patch_means = patch_means.merge(grouped_df[['block_start_day', 'env_type']], on='block_start_day', how='left')

patch_means

In [None]:
"""Get subject patch preference data."""

patch_pref = (BlockSubjectAnalysis.Preference() & key & foraging_block_starts).fetch(format="frame")
patch_pref.reset_index(level=["experiment_name"], drop=True, inplace=True)
patch_pref.reset_index(inplace=True)
# Replace small vals with 0
patch_pref["cumulative_preference_by_wheel"] = patch_pref["cumulative_preference_by_wheel"].apply(
    lambda arr: np.where(np.array(arr) < 1e-3, 0, np.array(arr))
)
#NOTE: is this final preference just preference at end? not second half?
patch_pref.rename(columns={"patch_name": "patch", "subject_name": "subject"}, inplace=True)

#add the phase data to htis df
patch_pref['experiment_phase'] = patch_pref['block_start'].apply(determine_phase)

#add patch means and patch types to this
merged_df = patch_pref.merge(patch_means[['block_start', 'patch', 'mean_thresh', 'patch_type', 'env_type']], on=['block_start', 'patch'], how='left')
patch_pref['mean_threshold'] = merged_df['mean_thresh']
patch_pref['patch_type'] = merged_df['patch_type']
patch_pref['env_type'] = merged_df['env_type']
patch_pref.set_index(['block_start', 'patch', 'subject'], inplace=True)
display(patch_pref)

In [None]:
"""Calculate late and early preference by wheel as avg of preference in first/second half of block."""

# Calculate running preference by wheel
# NOTE: sometimes the lengths of the arrays for the different patches are off by 1, so had to add padding. Why can this mismatch happen?
# Function to pad arrays to the same length
def pad_array(arr, length):
    return np.pad(arr, (0, max(0, length - len(arr))), 'constant', constant_values=np.nan)

# Function to calculate running preference
def calculate_running_preference(group, pref_col, out_col):
    lengths = group[pref_col].apply(len)
    max_length = lengths.max()
    
    try:
        padded_arrays = group[pref_col].apply(lambda x: pad_array(x, max_length))  # Pad arrays to the same length
        total_pref = np.nansum(np.vstack(padded_arrays.values), axis=0)  # Sum pref at each ts
        group[out_col] = padded_arrays.apply(lambda x: np.nan_to_num(x / total_pref, 0.))  # Running pref
    except ValueError as e:
        print(f"ValueError for group with indices: {group.index}")
        raise e
    
    return group

# Group by subject and block_start to calculate running preference by wheel
patch_pref = patch_pref.groupby(["subject", 'block_start']).apply(lambda group: 
    calculate_running_preference(group, "cumulative_preference_by_wheel", "running_preference_by_wheel")
).droplevel(0)

# Group by subject and block_start to calculate running preference by time
patch_pref = patch_pref.groupby(["subject", 'block_start']).apply(lambda group: 
    calculate_running_preference(group, "cumulative_preference_by_time", "running_preference_by_time")
).droplevel(0)
patch_pref = patch_pref.droplevel([0, 1])

# Calculate early and late preference by wheel
# Function to calculate average of first and second half of an array
def calculate_half_averages(arr):
    half = len(arr) // 2
    first_half_avg = np.nanmean(arr[:half])
    second_half_avg = np.nanmean(arr[half:])
    return first_half_avg, second_half_avg

# Calculate averages for running preferences by wheel
patch_pref[['early_preference_by_wheel', 'late_preference_by_wheel']] = patch_pref['running_preference_by_wheel'].apply(
    lambda x: pd.Series(calculate_half_averages(x))
)

# Calculate averages for running preferences by time
patch_pref[['early_preference_by_time', 'late_preference_by_time']] = patch_pref['running_preference_by_time'].apply(
    lambda x: pd.Series(calculate_half_averages(x))
)
patch_pref = patch_pref.reset_index()
display(patch_pref)

In [20]:
"""Calculate difference betwen late-early preference by wheel."""
patch_pref['preference_change_by_wheel'] = patch_pref['late_preference_by_wheel'] - patch_pref['early_preference_by_wheel']

In [None]:
# Merge the DataFrames on 'patch' and 'block_start'
min_subj_patch_info_plus = min_subj_patch_info_plus.merge(
    patch_means[['patch', 'block_start', 'patch_type', 'env_type']],
    on=['patch', 'block_start'],
    how='left'
)
min_subj_patch_info_plus

In [None]:
# NOTE: Add dummy blocks in place of 0 pellet blocks that we have no data for
# TODO: check this part might be buggy??
# we have no data from 0 pellet blocks so here I'm adding dummy data with made up 0 pellet block for the gaps in the data to make visulalisation nicer

# 1. add dummy data to min_subj_patch_info_plus
min_subj_patch_info_plus_copy = min_subj_patch_info_plus.copy()
print(min_subj_patch_info_plus_copy['block_start'].nunique())
print(min_subj_patch_info_plus_copy.shape)

# Remove rows where subject is "mean"
mean_rows = min_subj_patch_info_plus_copy[min_subj_patch_info_plus_copy['subject'] == 'mean']
min_subj_patch_info_plus_copy = min_subj_patch_info_plus_copy[min_subj_patch_info_plus_copy['subject'] != 'mean']

# Initialize a list to hold new rows
new_rows = []

# Group by patch, subject, and experiment_phase
grouped = min_subj_patch_info_plus_copy.groupby(['patch', 'subject', 'experiment_phase'])

# Iterate over each group
for name, group in grouped:
    # Sort the group by block_start
    group = group.sort_values(by='block_start').reset_index(drop=True)
    
    # Identify gaps greater than 3 hours
    gaps = group['block_start'].diff() > pd.Timedelta(hours=5)
    
    # Iterate over the gaps
    for i in range(1, len(group)):
        if gaps.iloc[i]:
            start_time = group['block_start'].iloc[i-1]
            end_time = group['block_start'].iloc[i]
            
            # Generate new block_start entries with 2-hour intervals
            current_time = start_time + pd.Timedelta(hours=2)
            while current_time < end_time:
                new_row = {
                    'block_start': current_time,
                    'patch': name[0],
                    'subject': name[1],
                    'experiment_phase': name[2]
                }
                new_rows.append(new_row)
                current_time += pd.Timedelta(hours=2)

# Create a DataFrame for the new rows
new_rows_df = pd.DataFrame(new_rows)

# Add NaN columns for all other columns in the original DataFrame
for col in min_subj_patch_info_plus_copy.columns:
    if col not in ['block_start', 'patch', 'subject', 'experiment_phase']:
        if col in ['pellet_count', 'norm_time']:
            new_rows_df[col] = 0
        elif col == 'patch_type':
            #NOTE: find better way to assign this?
            new_rows_df[col] = np.where(new_rows_df['patch'] == 'Patch1', 'easy',
                                        np.where(new_rows_df['patch'] == 'Patch2', 'medium',
                                                 np.where(new_rows_df['patch'] == 'Patch3', 'hard', np.nan)))
        else:
            new_rows_df[col] = np.nan

# Append the new rows to the original DataFrame
min_subj_patch_info_plus_extended = pd.concat([min_subj_patch_info_plus_copy, new_rows_df], ignore_index=True)

# Reattach the rows where subject is "mean"
min_subj_patch_info_plus_extended = pd.concat([min_subj_patch_info_plus_extended, mean_rows], ignore_index=True)

# Sort by block_start
min_subj_patch_info_plus_extended = min_subj_patch_info_plus_extended.sort_values(by='block_start').reset_index(drop=True)

# Display the extended DataFrame
display(min_subj_patch_info_plus_extended)
print(min_subj_patch_info_plus_extended['block_start'].nunique())
print(min_subj_patch_info_plus_extended.shape)


# 2. Add dummy data to patch_pref
patch_pref_copy = patch_pref.copy()
print(patch_pref_copy['block_start'].nunique())
print(patch_pref_copy.shape)

# Initialize a list to hold new rows
new_rows = []

# Group by patch, subject, and experiment_phase
grouped = patch_pref_copy.groupby(['patch', 'subject', 'experiment_phase'])

# Iterate over each group
for name, group in grouped:
    # Sort the group by block_start
    group = group.sort_values(by='block_start').reset_index(drop=True)
    
    # Identify gaps greater than 3 hours
    gaps = group['block_start'].diff() > pd.Timedelta(hours=5)
    
    # Iterate over the gaps
    for i in range(1, len(group)):
        if gaps.iloc[i]:
            start_time = group['block_start'].iloc[i-1]
            end_time = group['block_start'].iloc[i]
            
            # Generate new block_start entries with 2-hour intervals
            current_time = start_time + pd.Timedelta(hours=2)
            while current_time < end_time:
                new_row = {
                    'block_start': current_time,
                    'patch': name[0],
                    'subject': name[1],
                    'experiment_phase': name[2]
                }
                new_rows.append(new_row)
                current_time += pd.Timedelta(hours=2)

# Get env types for each day as disct to later assign these to dummy blocks
# Create a DataFrame for the new rows
new_rows_df = pd.DataFrame(new_rows)

# Convert block_start to date
patch_pref_copy['block_start_date'] = patch_pref_copy['block_start'].dt.date

# Group by the day of block_start and experiment_phase, and get unique values of env_type
grouped = patch_pref_copy.groupby(['block_start_date', 'experiment_phase'])['env_type'].unique().reset_index()

# Remove NaNs from the unique values and ensure only one unique env_type per date
grouped['env_type'] = grouped['env_type'].apply(lambda x: [i for i in x if pd.notna(i)][0] if len([i for i in x if pd.notna(i)]) > 0 else np.nan)

# Convert the result to a dictionary
env_type_dict = grouped.set_index(['block_start_date', 'experiment_phase'])['env_type'].to_dict()

# Add NaN columns for all other columns in the original DataFrame
for col in patch_pref_copy.columns:
    if col not in ['block_start', 'patch', 'subject', 'experiment_phase']:
        if col in ['late_preference_by_wheel','final_preference_by_wheel','final_preference_by_time',  'early_preference_by_wheel', 'late_preference_by_time', 'early_preference_by_time', 'preference_change_by_wheel']:
            new_rows_df[col] = 0
        elif col == 'patch_type':
            new_rows_df[col] = np.where(new_rows_df['patch'] == 'Patch1', 'easy',
                                        np.where(new_rows_df['patch'] == 'Patch2', 'medium',
                                                 np.where(new_rows_df['patch'] == 'Patch3', 'hard', np.nan)))
        elif col == 'env_type':
            new_rows_df[col] = new_rows_df.apply(
                lambda row: env_type_dict.get((row['block_start'].date(), row['experiment_phase']), np.nan),
                axis=1
            )
        else:
            new_rows_df[col] = np.nan

# Append the new rows to the original DataFrame
patch_pref_extended = pd.concat([patch_pref_copy, new_rows_df], ignore_index=True)

# Sort by block_start
patch_pref_extended = patch_pref_extended.sort_values(by='block_start').reset_index(drop=True)

# Display the extended DataFrame
display(patch_pref_extended)
print(patch_pref_extended['block_start'].nunique())
print(patch_pref_extended.shape)

## Create overall plots

### 1. Pellet thresholds and count

In [None]:
""" Per patch ID """
# Filter the DataFrame for each experiment phase
pre_social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'pre_social']
social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'social']
post_social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'post_social']
box_colors = subject_colors[0 : len(subject_names)]  # subject colors + mean color

def create_phase_figure(df, phase_title):
    df = df[df['patch'] != "PatchDummy1"]
    patches = df['patch'].unique()

    if len(patches) == 0:
        raise ValueError(f"No patches found for phase: {phase_title}")
    
    desired_order = ["Patch1", "Patch2", "Patch3"]
    patches = sorted(patches, key=lambda x: desired_order.index(x) if x in desired_order else len(desired_order))

    if phase_title == "Social":
        # Calculate the minimum block_start
        min_block_start = df['block_start'].min()
        
        fig = make_subplots(
            rows=len(patches) * 3 - 1, cols=1, shared_xaxes=True,
            subplot_titles=[f"{patches[i//3]}" if i % 3 == 0 else "" for i in range(len(patches) * 3 - 1)],
            vertical_spacing=0.01, 
            row_heights=[0.1 if i % 3 == 0 else 0.2 if i % 3 == 1 else 0.02 for i in range(len(patches) * 3 - 1)]  # Adjust heights to increase space between pairs
        )
        
        # Calculate global min and max values for pellet_count and threshold
        global_pellet_max = df['pellet_count'].max()
        global_threshold_max = df['threshold'].max()
        
        for i, patch in enumerate(patches):
            patch_df = df[df['patch'] == patch]
            
            # Filter out rows where subject is 'mean'
            patch_df_no_mean = patch_df[patch_df['subject'] != 'mean']
            
            # Add pellet count plot
            for subject in patch_df_no_mean['subject'].unique():
                subject_df = patch_df_no_mean[patch_df_no_mean['subject'] == subject]
                fig.add_trace(
                    go.Scatter(
                        x=subject_df['block_start'],
                        y=subject_df['pellet_count'],
                        mode='lines',
                        name=f'Pellet count {subject}',
                        line=dict(color=subject_colors_dict[subject]),
                        fill='tozeroy',
                        showlegend=(i == len(patches) - 1)
                    ),
                    row=i*3+1, col=1
                )
        
            # Add threshold plot
            patch_df_no_mean_sorted = patch_df_no_mean.sort_values("block_start")
            patch_fig = px.box(
                patch_df_no_mean_sorted,
                x="block_start",
                y="threshold",
                color="subject",
                hover_data=["norm_time", "block_number"],
                color_discrete_sequence=box_colors,
                points="all"
            )
            for trace in patch_fig['data']:
                trace.showlegend = (i == len(patches) - 3)
                fig.add_trace(trace, row=i*3+2, col=1)
            
            # Add smoothed line for mean threshold - didnt smooth in the end, but can be changed
            mean_df = patch_df[patch_df['subject'] == 'mean']
            mean_df = mean_df.sort_values("block_start")
            mean_df['smoothed_threshold'] = mean_df['threshold'].rolling(window=1, min_periods=1).mean()
            fig.add_trace(
                go.Scatter(
                    x=mean_df['block_start'],
                    y=mean_df['smoothed_threshold'],
                    mode='lines',
                    name='Mean Threshold',
                    line=dict(color='black'),
                    showlegend=(i == len(patches) - 3)
                ),
                row=i*3+2, col=1
            )
        
        # Add shading for day and night periods
        current_day = social_start
        while current_day < social_end:
            night_start_time = current_day + night_start
            night_end_time = current_day + timedelta(days=1) + night_end if night_end < night_start else current_day + night_end
            
            day_start_time = current_day + day_start
            day_end_time = current_day + day_end
            
            twilight_start_time = current_day + twilight_start
            twilight_end_time = current_day + twilight_end
            
            dawn_start_time = current_day + dawn_start
            dawn_end_time = current_day + dawn_end
            
            fig.add_shape(
                type="rect",
                x0=night_start_time,
                x1=night_end_time,
                y0=0,
                y1=1,
                xref="x",
                yref='paper',
                fillcolor="darkgrey",
                opacity=0.5,
                layer="below",
                line_width=0,
            )
            fig.add_shape(
                type="rect",
                x0=day_start_time,
                x1=day_end_time,
                y0=0,
                y1=1,
                xref="x",
                yref="paper",
                fillcolor="white",
                opacity=0.5,
                layer="below",
                line_width=0,
            )
            fig.add_shape(
                type="rect",
                x0=twilight_start_time,
                x1=twilight_end_time,
                y0=0,
                y1=1,
                xref="x",
                yref="paper",
                fillcolor="lightgrey",
                opacity=0.5,
                layer="below",
                line_width=0,
            )
            fig.add_shape(
                type="rect",
                x0=dawn_start_time,
                x1=dawn_end_time,
                y0=0,
                y1=1,
                xref="x",
                yref="paper",
                fillcolor="lightgrey",
                opacity=0.5,
                layer="below",
                line_width=0,
            )
            current_day += timedelta(days=1)
        
        fig.update_layout(
            title=f"{phase_title} Phase: Pellet counts and patch thresholds",
            height=900,
            boxmode='group'
        )
                
        # Update y-axes for all pellet count and threshold subplots
        for i in range(len(patches)):
            fig.update_yaxes(title_text="Pellet Count", row=i*3+1, col=1, range=[0, global_pellet_max])
            fig.update_yaxes(title_text="Threshold (cm)", row=i*3+2, col=1, range=[0, global_threshold_max])
        
        fig.update_xaxes(title_text="Block start time", row=len(patches)*3-1, col=1, range=[min_block_start - timedelta(hours=1), None])
        return fig
    else:
        figs = []
        for subject in subject_names:
            # Filter the subject's data
            subject_data = df[df['subject'] == subject]
            
            # Determine the block_start range for the subject
            min_block_start = subject_data['block_start'].min()
            max_block_start = subject_data['block_start'].max()
            
            # Filter the means within the block_start range
            mean_data = df[(df['subject'] == 'mean') & (df['block_start'] >= min_block_start) & (df['block_start'] <= max_block_start)]

            # Calculate global min and max values for pellet_count and threshold
            global_pellet_max = subject_data['pellet_count'].max()
            global_threshold_max = subject_data['threshold'].max()
            
            # Combine the subject's data with the filtered means
            subject_df = pd.concat([subject_data, mean_data])
            
            fig = make_subplots(
                rows=len(patches) * 3 - 1, cols=1, shared_xaxes=True,
                subplot_titles=[f"{patches[i//3]}" if i % 3 == 0 else "" for i in range(len(patches) * 3 - 1)],
                vertical_spacing=0.01, 
                row_heights=[0.1 if i % 3 == 0 else 0.2 if i % 3 == 1 else 0.02 for i in range(len(patches) * 3 - 1)]  # Adjust heights to increase space between pairs
            )
            
            for i, patch in enumerate(patches):
                patch_df = subject_df[subject_df['patch'] == patch]
                
                # Filter out rows where subject is 'mean'
                patch_df_no_mean = patch_df[patch_df['subject'] != 'mean']
                # Add pellet count plot
                fig.add_trace(
                    go.Scatter(
                        x=patch_df_no_mean['block_start'],
                        y=patch_df_no_mean['pellet_count'],
                        mode='lines',
                        name=f'Pellet count {subject}',
                        line=dict(color=subject_colors_dict[subject]),
                        fill='tozeroy',
                        showlegend=(i == len(patches) - 1)
                    ),
                    row=i*3+1, col=1
                )
            
                # Add threshold plot
                patch_df_no_mean_sorted = patch_df_no_mean.sort_values("block_start")
                patch_fig = px.box(
                    patch_df_no_mean_sorted,
                    x="block_start",
                    y="threshold",
                    color="subject",
                    hover_data=["norm_time", "block_number"],
                    color_discrete_map={ 'mean': 'black',**subject_colors_dict},
                    points="all"
                )
                for trace in patch_fig['data']:
                    trace.showlegend = (i == len(patches) - 3)
                    fig.add_trace(trace, row=i*3+2, col=1)
                
                # Add smoothed line for mean threshold - didnt smooth in the end, but can be changed
                mean_df = patch_df[patch_df['subject'] == 'mean']
                mean_df = mean_df.sort_values("block_start")
                mean_df['smoothed_threshold'] = mean_df['threshold'].rolling(window=1, min_periods=1).mean()
                fig.add_trace(
                    go.Scatter(
                        x=mean_df['block_start'],
                        y=mean_df['smoothed_threshold'],
                        mode='lines',
                        name='Mean Threshold',
                        line=dict(color='black'),
                        showlegend=(i == len(patches) - 3)
                    ),
                    row=i*3+2, col=1
                )
                
            fig.update_layout(
                title=f"{phase_title} Phase: Pellet counts and patch thresholds",
                height=900,
                boxmode='group'
            )
                    
            # Add shading for day and night periods
            current_day = pre_social_start if phase_title == "Pre social" else post_social_start
            end_day = pre_social_end if phase_title == "Pre social" else post_social_end
            while current_day <= end_day:
                night_start_time = current_day + night_start
                night_end_time = current_day + timedelta(days=1) + night_end if night_end < night_start else current_day + night_end
                
                day_start_time = current_day + day_start
                day_end_time = current_day + day_end
                
                twilight_start_time = current_day + twilight_start
                twilight_end_time = current_day + twilight_end
                
                dawn_start_time = current_day + dawn_start
                dawn_end_time = current_day + dawn_end
                
                fig.add_shape(
                    type="rect",
                    x0=night_start_time,
                    x1=night_end_time,
                    y0=0,
                    y1=1,
                    xref="x",
                    yref="paper",
                    fillcolor="darkgrey",
                    opacity=0.5,
                    layer="below",
                    line_width=0,
                )
                fig.add_shape(
                    type="rect",
                    x0=day_start_time,
                    x1=day_end_time,
                    y0=0,
                    y1=1,
                    xref="x",
                    yref="paper",
                    fillcolor="white",
                    opacity=0.5,
                    layer="below",
                    line_width=0,
                )
                fig.add_shape(
                    type="rect",
                    x0=twilight_start_time,
                    x1=twilight_end_time,
                    y0=0,
                    y1=1,
                    xref="x",
                    yref="paper",
                    fillcolor="lightgrey",
                    opacity=0.5,
                    layer="below",
                    line_width=0,
                )
                fig.add_shape(
                    type="rect",
                    x0=dawn_start_time,
                    x1=dawn_end_time,
                    y0=0,
                    y1=1,
                    xref="x",
                    yref="paper",
                    fillcolor="lightgrey",
                    opacity=0.5,
                    layer="below",
                    line_width=0,
                )
                current_day += timedelta(days=1)
            
            fig.update_layout(
                title=f"{phase_title} Phase: Pellet counts and patch thresholds for {subject}",
                height=900,
                boxmode='group'
            )
                    
            # Update y-axes for all pellet count and threshold subplots
            for i in range(len(patches)):
                fig.update_yaxes(title_text="Pellet Count", row=i*3+1, col=1, range=[0, global_pellet_max])
                fig.update_yaxes(title_text="Threshold (cm)", row=i*3+2, col=1, range=[0, global_threshold_max])
            
            fig.update_xaxes(title_text="Block start time", row=len(patches)*3-1, col=1, range=[min_block_start - timedelta(hours=1), None])
        
            figs.append(fig)
        
        return figs

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre social")
social_fig = create_phase_figure(social_df, "Social")
post_social_figs = create_phase_figure(post_social_df, "Post social")

# Show the figures and write them to HTML files
#pre_social_fig.show()
#pio.write_html(pre_social_fig, file=os.path.join(fig_save_directory, "pre_social_threshold_fig.html"))

social_fig.show()
pio.write_html(social_fig, file=os.path.join(fig_save_directory, "social_threshold_fig.html"))

for i, fig in enumerate(post_social_figs):
    fig.show()
    pio.write_html(fig, file=os.path.join(fig_save_directory, f"post_social_threshold_fig_{i}.html"))

In [None]:
""" Per patch type """
# Filter the DataFrame for each experiment phase
pre_social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'pre_social']
social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'social']
post_social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'post_social']
box_colors =  subject_colors[0 : len(subject_names)]  # subject colors + mean color

def create_phase_figure(df, phase_title):
    df = df[df['patch'] != "PatchDummy1"]
    df = df[df['patch_type'] != "equal"]
    patches = df['patch_type'].unique()

    if len(patches) == 0:
        raise ValueError(f"No patches found for phase: {phase_title}")
    
    desired_order = ["easy", "medium", "hard"]
    patches = sorted(patches, key=lambda x: desired_order.index(x) if x in desired_order else len(desired_order))

    if phase_title == "Social":
        # Calculate the minimum block_start
        min_block_start = df['block_start'].min()
        
        fig = make_subplots(
            rows=len(patches) * 3 - 1, cols=1, shared_xaxes=True,
            subplot_titles=[f"{patches[i//3]}" if i % 3 == 0 else "" for i in range(len(patches) * 3 - 1)],
            vertical_spacing=0.01, 
            row_heights=[0.1 if i % 3 == 0 else 0.2 if i % 3 == 1 else 0.02 for i in range(len(patches) * 3 - 1)]  # Adjust heights to increase space between pairs
        )
        
        # Calculate global min and max values for pellet_count and threshold
        global_pellet_max = df['pellet_count'].max()
        global_threshold_max = df['threshold'].max()
        
        for i, patch in enumerate(patches):
            patch_df = df[df['patch_type'] == patch]
            
            # Filter out rows where subject is 'mean'
            patch_df_no_mean = patch_df[patch_df['subject'] != 'mean']
            
            # Add pellet count plot
            for subject in patch_df_no_mean['subject'].unique():
                subject_df = patch_df_no_mean[patch_df_no_mean['subject'] == subject]
                fig.add_trace(
                    go.Scatter(
                        x=subject_df['block_start'],
                        y=subject_df['pellet_count'],
                        mode='lines',
                        name=f'Pellet count {subject}',
                        line=dict(color=subject_colors_dict[subject]),
                        fill='tozeroy',
                        showlegend=(i == len(patches) - 1)
                    ),
                    row=i*3+1, col=1
                )
        
            # Add threshold plot for non-mean subjects
            patch_df_no_mean_sorted = patch_df_no_mean.sort_values("block_start")
            patch_fig = px.box(
                patch_df_no_mean_sorted,
                x="block_start",
                y="threshold",
                color="subject",
                hover_data=["norm_time", "block_number"],
                color_discrete_sequence=box_colors,
                points="all"
            )
            for trace in patch_fig['data']:
                trace.showlegend = (i == len(patches) - 3)
                fig.add_trace(trace, row=i*3+2, col=1)
            
            # Add smoothed line for mean threshold - didnt smooth in the end, but can be changed
            mean_df = patch_df[patch_df['subject'] == 'mean']
            mean_df = mean_df.sort_values("block_start")
            mean_df['smoothed_threshold'] = mean_df['threshold'].rolling(window=1, min_periods=1).mean()
            fig.add_trace(
                go.Scatter(
                    x=mean_df['block_start'],
                    y=mean_df['smoothed_threshold'],
                    mode='lines',
                    name='Mean Threshold',
                    line=dict(color='black'),
                    showlegend=(i == len(patches) - 3)
                ),
                row=i*3+2, col=1
            )
        
        # Add shading for day and night periods
        current_day = social_start
        while current_day < social_end:
            night_start_time = current_day + night_start
            night_end_time = current_day + timedelta(days=1) + night_end if night_end < night_start else current_day + night_end
            
            day_start_time = current_day + day_start
            day_end_time = current_day + day_end
            
            twilight_start_time = current_day + twilight_start
            twilight_end_time = current_day + twilight_end
            
            dawn_start_time = current_day + dawn_start
            dawn_end_time = current_day + dawn_end
            
            fig.add_shape(
                type="rect",
                x0=night_start_time,
                x1=night_end_time,
                y0=0,
                y1=1,
                xref="x",
                yref='paper',
                fillcolor="darkgrey",
                opacity=0.5,
                layer="below",
                line_width=0,
            )
            fig.add_shape(
                type="rect",
                x0=day_start_time,
                x1=day_end_time,
                y0=0,
                y1=1,
                xref="x",
                yref="paper",
                fillcolor="white",
                opacity=0.5,
                layer="below",
                line_width=0,
            )
            fig.add_shape(
                type="rect",
                x0=twilight_start_time,
                x1=twilight_end_time,
                y0=0,
                y1=1,
                xref="x",
                yref="paper",
                fillcolor="lightgrey",
                opacity=0.5,
                layer="below",
                line_width=0,
            )
            fig.add_shape(
                type="rect",
                x0=dawn_start_time,
                x1=dawn_end_time,
                y0=0,
                y1=1,
                xref="x",
                yref="paper",
                fillcolor="lightgrey",
                opacity=0.5,
                layer="below",
                line_width=0,
            )
            current_day += timedelta(days=1)
        
        fig.update_layout(
            title=f"{phase_title} Phase: Pellet counts and patch thresholds",
            height=900,
            boxmode='group'
        )
                
        # Update y-axes for all pellet count and threshold subplots
        for i in range(len(patches)):
            fig.update_yaxes(title_text="Pellet Count", row=i*3+1, col=1, range=[0, global_pellet_max])
            fig.update_yaxes(title_text="Threshold (cm)", row=i*3+2, col=1, range=[0, global_threshold_max])
        
        fig.update_xaxes(title_text="Block start time", row=len(patches)*3-1, col=1, range=[min_block_start - timedelta(hours=1), None])
        return fig
    else:
        figs = []
        for subject in subject_names:
            # Filter the subject's data
            subject_data = df[df['subject'] == subject]
            
            # Determine the block_start range for the subject
            min_block_start = subject_data['block_start'].min()
            max_block_start = subject_data['block_start'].max()
            
            # Filter the means within the block_start range
            mean_data = df[(df['subject'] == 'mean') & (df['block_start'] >= min_block_start) & (df['block_start'] <= max_block_start)]

            # Calculate global min and max values for pellet_count and threshold
            global_pellet_max = subject_data['pellet_count'].max()
            global_threshold_max = subject_data['threshold'].max()
            
            # Combine the subject's data with the filtered means
            subject_df = pd.concat([subject_data, mean_data])
            
            fig = make_subplots(
                rows=len(patches) * 3 - 1, cols=1, shared_xaxes=True,
                subplot_titles=[f"{patches[i//3]}" if i % 3 == 0 else "" for i in range(len(patches) * 3 - 1)],
                vertical_spacing=0.01, 
                row_heights=[0.1 if i % 3 == 0 else 0.2 if i % 3 == 1 else 0.02 for i in range(len(patches) * 3 - 1)]  # Adjust heights to increase space between pairs
            )
            
            for i, patch in enumerate(patches):
                patch_df = subject_df[subject_df['patch_type'] == patch]
                
                # Filter out rows where subject is 'mean'
                patch_df_no_mean = patch_df[patch_df['subject'] != 'mean']
                # Add pellet count plot
                fig.add_trace(
                    go.Scatter(
                        x=patch_df_no_mean['block_start'],
                        y=patch_df_no_mean['pellet_count'],
                        mode='lines',
                        name=f'Pellet count {subject}',
                        line=dict(color=subject_colors_dict[subject]),
                        fill='tozeroy',
                        showlegend=(i == len(patches) - 1)
                    ),
                    row=i*3+1, col=1
                )
            
                # Add threshold plot for non-mean subjects
                patch_df_no_mean_sorted = patch_df_no_mean.sort_values("block_start")
                patch_fig = px.box(
                    patch_df_no_mean_sorted,
                    x="block_start",
                    y="threshold",
                    color="subject",
                    hover_data=["norm_time", "block_number"],
                    color_discrete_map={ 'mean': 'black',**subject_colors_dict},
                    points="all"
                )
                for trace in patch_fig['data']:
                    trace.showlegend = (i == len(patches) - 3)
                    fig.add_trace(trace, row=i*3+2, col=1)
                
                # Add smoothed line for mean threshold - didnt smooth in the end, but can be changed
                mean_df = patch_df[patch_df['subject'] == 'mean']
                mean_df = mean_df.sort_values("block_start")
                mean_df['smoothed_threshold'] = mean_df['threshold'].rolling(window=1, min_periods=1).mean()
                fig.add_trace(
                    go.Scatter(
                        x=mean_df['block_start'],
                        y=mean_df['smoothed_threshold'],
                        mode='lines',
                        name='Mean Threshold',
                        line=dict(color='black'),
                        showlegend=(i == len(patches) - 3)
                    ),
                    row=i*3+2, col=1
                )
                
            fig.update_layout(
                title=f"{phase_title} Phase: Pellet counts and patch thresholds",
                height=900,
                boxmode='group'
            )
                    
            # Add shading for day and night periods
            current_day = pre_social_start if phase_title == "Pre social" else post_social_start
            end_day = pre_social_end if phase_title == "Pre social" else post_social_end
            while current_day <= end_day:
                night_start_time = current_day + night_start
                night_end_time = current_day + timedelta(days=1) + night_end if night_end < night_start else current_day + night_end
                
                day_start_time = current_day + day_start
                day_end_time = current_day + day_end
                
                twilight_start_time = current_day + twilight_start
                twilight_end_time = current_day + twilight_end
                
                dawn_start_time = current_day + dawn_start
                dawn_end_time = current_day + dawn_end
                
                fig.add_shape(
                    type="rect",
                    x0=night_start_time,
                    x1=night_end_time,
                    y0=0,
                    y1=1,
                    xref="x",
                    yref="paper",
                    fillcolor="darkgrey",
                    opacity=0.5,
                    layer="below",
                    line_width=0,
                )
                fig.add_shape(
                    type="rect",
                    x0=day_start_time,
                    x1=day_end_time,
                    y0=0,
                    y1=1,
                    xref="x",
                    yref="paper",
                    fillcolor="white",
                    opacity=0.5,
                    layer="below",
                    line_width=0,
                )
                fig.add_shape(
                    type="rect",
                    x0=twilight_start_time,
                    x1=twilight_end_time,
                    y0=0,
                    y1=1,
                    xref="x",
                    yref="paper",
                    fillcolor="lightgrey",
                    opacity=0.5,
                    layer="below",
                    line_width=0,
                )
                fig.add_shape(
                    type="rect",
                    x0=dawn_start_time,
                    x1=dawn_end_time,
                    y0=0,
                    y1=1,
                    xref="x",
                    yref="paper",
                    fillcolor="lightgrey",
                    opacity=0.5,
                    layer="below",
                    line_width=0,
                )
                current_day += timedelta(days=1)
            
            fig.update_layout(
                title=f"{phase_title} Phase: Pellet counts and patch thresholds for {subject}",
                height=900,
                boxmode='group'
            )
                    
            # Update y-axes for all pellet count and threshold subplots
            for i in range(len(patches)):
                fig.update_yaxes(title_text="Pellet Count", row=i*3+1, col=1, range=[0, global_pellet_max])
                fig.update_yaxes(title_text="Threshold (cm)", row=i*3+2, col=1, range=[0, global_threshold_max])
            
            fig.update_xaxes(title_text="Block start time", row=len(patches)*3-1, col=1, range=[min_block_start - timedelta(hours=1), None])
        
            figs.append(fig)
        
        return figs

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre social")
social_fig = create_phase_figure(social_df, "Social")
post_social_figs = create_phase_figure(post_social_df, "Post social")

# Show the figures and write them to HTML files
#pre_social_fig.show()
#pio.write_html(pre_social_fig, file=os.path.join(fig_save_directory, "pre_social_threshold_fig.html"))

social_fig.show()
pio.write_html(social_fig, file=os.path.join(fig_save_directory, "social_threshold_fig_type.html"))

for i, fig in enumerate(post_social_figs):
    fig.show()
    pio.write_html(fig, file=os.path.join(fig_save_directory, f"post_social_threshold_fig_type_{i}.html"))

Variations/improvements:
- can use block number instead of start for more equal spacing
- can remove low pellet blocks, currently no forgaging threshold
- should we leave mean threshold line connected through empty block or not?
- could add easy/herd env info here too
- in this experiment obvious pellet assignemnt issues...

In [None]:
""" Pellet count only, per patch ID. """
# Filter the DataFrame for each experiment phase
pre_social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'pre_social']
social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'social']
post_social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'post_social']
box_colors = subject_colors[0 : len(subject_names)]  # subject colors + mean color

# Create a color map for subjects
color_map = {subject: color for subject, color in zip(subject_names, box_colors)}

# Function to create a figure with subplots for each patch within a phase
def create_phase_figure(df, phase_title):
    df = df[df['patch'] != "PatchDummy1"]
    df = df[df['subject'] != "mean"]
    patches = df['patch'].unique()
    
    if len(patches) == 0:
        raise ValueError(f"No patches found for phase: {phase_title}")
    
    desired_order = ["Patch1", "Patch2", "Patch3"]
    patches = sorted(patches, key=lambda x: desired_order.index(x) if x in desired_order else len(desired_order))
    
    fig = make_subplots(
        rows=len(patches), cols=1,  shared_yaxes=True, shared_xaxes=True,
        subplot_titles=[f"{patch}" for patch in patches],
        vertical_spacing=0.05
    )
    
    for i, patch in enumerate(patches):
        patch_df = df[df['patch'] == patch]
        for subject in patch_df['subject'].unique():
            subject_df = patch_df[patch_df['subject'] == subject]
            fig.add_trace(
                go.Scatter(
                    x=subject_df['block_start'],
                    y=subject_df['pellet_count'],
                    mode='lines',
                    name=subject,
                    line=dict(color=color_map[subject]),
                    fill='tozeroy',
                    showlegend=(i == len(patches) - 3)
                ),
                row=i+1, col=1
            )
    
    fig.update_layout(
        title=f"{phase_title} Phase: Pellet Counts",
        legend_title_text='Subject',
        height=900
    )
    fig.update_yaxes(title_text="Pellet Count", matches='y')
    fig.update_xaxes(title_text="Block start time", row=len(patches), col=1)
    
    return fig

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre social")
social_fig = create_phase_figure(social_df, "Social")
post_social_fig = create_phase_figure(post_social_df, "Post social")

# Show the figures and write them to HTML files
#pre_social_fig.show()
# pio.write_html(pre_social_fig, file=os.path.join(fig_save_directory, "pre_social_pellet_count_fig.html"))

social_fig.show()
pio.write_html(social_fig, file=os.path.join(fig_save_directory, "social_pellet_count_fig.html"))

post_social_fig.show()
pio.write_html(post_social_fig, file=os.path.join(fig_save_directory, "post_social_pellet_count_fig.html"))

In [None]:
""" Pellet count only, per patch type. """
# Filter the DataFrame for each experiment phase
pre_social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'pre_social']
social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'social']
post_social_df = min_subj_patch_info_plus_extended[min_subj_patch_info_plus_extended['experiment_phase'] == 'post_social']
box_colors = subject_colors[0 : len(subject_names)]  # subject colors + mean color

# Create a color map for subjects
color_map = {subject: color for subject, color in zip(subject_names, box_colors)}

# Function to create a figure with subplots for each patch within a phase
def create_phase_figure(df, phase_title):
    df = df[df['patch'] != "PatchDummy1"]
    df = df[df['patch_type'] != "equal"]
    df = df[df['subject'] != "mean"]
    patches = df['patch_type'].unique()
    
    if len(patches) == 0:
        raise ValueError(f"No patches found for phase: {phase_title}")
    
    desired_order = ["easy", "medium", "hard"]
    patches = sorted(patches, key=lambda x: desired_order.index(x) if x in desired_order else len(desired_order))
    
    fig = make_subplots(
        rows=len(patches), cols=1,  shared_yaxes=True, shared_xaxes=True,
        subplot_titles=[f"{patch}" for patch in patches],
        vertical_spacing=0.05
    )
    
    for i, patch in enumerate(patches):
        patch_df = df[df['patch_type'] == patch]
        for subject in patch_df['subject'].unique():
            subject_df = patch_df[patch_df['subject'] == subject]
            fig.add_trace(
                go.Scatter(
                    x=subject_df['block_start'],
                    y=subject_df['pellet_count'],
                    mode='lines',
                    name=subject,
                    line=dict(color=color_map[subject]),
                    fill='tozeroy',
                    showlegend=(i == len(patches) - 3)
                ),
                row=i+1, col=1
            )
    
    fig.update_layout(
        title=f"{phase_title} Phase: Pellet Counts",
        legend_title_text='Subject',
        height=900
    )
    fig.update_yaxes(title_text="Pellet Count", matches='y')
    fig.update_xaxes(title_text="Block start time", row=len(patches), col=1)
    
    return fig

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre social")
social_fig = create_phase_figure(social_df, "Social")
post_social_fig = create_phase_figure(post_social_df, "Post social")

# Show the figures and write them to HTML files
#pre_social_fig.show()
# pio.write_html(pre_social_fig, file=os.path.join(fig_save_directory, "pre_social_pellet_count_fig.html"))

social_fig.show()
pio.write_html(social_fig, file=os.path.join(fig_save_directory, "social_pellet_count__type_fig.html"))

post_social_fig.show()
pio.write_html(post_social_fig, file=os.path.join(fig_save_directory, "post_social_pellet_count_type_fig.html"))

### 2. Final patch preference

In [None]:
"""Per patch ID"""
# Filter the DataFrame for each experiment phase
pre_social_df = patch_pref[patch_pref['experiment_phase'] == 'pre_social']
social_df = patch_pref[patch_pref['experiment_phase'] == 'social']
post_social_df = patch_pref[patch_pref['experiment_phase'] == 'post_social']
box_colors = patch_colors[0 : len(patch_names[:3])]  # subject colors

# Function to create subplots for a phase
def create_phase_subplots(df, phase_title):
    separation_strength = 700  # Adjust this value to control the amount of separation
    jitter_strength = 300  # Adjust this value to control the amount of jitter
    df = df[df['patch'] != "PatchDummy1"]
    df = df[df['env_type'].isin(['easy','hard'])]
    df = df[df['env_type'] != "nan"]
    unique_subjects = df['subject'].unique()
    unique_env_types = df['env_type'].unique()
    subject_offsets = {subject: i * separation_strength for i, subject in enumerate(unique_subjects)}
    
    # Apply offset and jitter to block_start based on subject
    df['block_start_separated'] = df.apply(
        lambda row: row['block_start'] + pd.Timedelta(seconds=subject_offsets[row['subject']]) + pd.Timedelta(seconds=np.random.uniform(-jitter_strength, jitter_strength)),
        axis=1
    )
    shared_xaxes = True if phase_title == "Social" else False
    
    # Create subplots with shared x-axes and conditionally shared y-axes
    fig = make_subplots(
        rows=len(unique_env_types), 
        cols=len(unique_subjects), 
        subplot_titles=[f"{env_type} environment - {subject}" for env_type in unique_env_types for subject in unique_subjects],
        shared_xaxes=shared_xaxes,
        shared_yaxes=True,
        vertical_spacing=0.1,
        horizontal_spacing=0.05
        
    )

    for i, env_type in enumerate(unique_env_types):
        for j, subject in enumerate(unique_subjects):
            filtered_df = df[(df['subject'] == subject) & (df['env_type'] == env_type)]
            scatter = px.scatter(
                filtered_df.sort_values("block_start"),
                x="block_start_separated",
                y="late_preference_by_wheel",
                color="patch",
                symbol="patch",
                color_discrete_sequence=box_colors,
                labels={"final_preference_by_wheel": "Late preference by wheel", "block_start_separated": "Block Start"}
            )
            for trace in scatter.data:
                trace.showlegend = (i == len(unique_env_types) - 1 and j == len(unique_subjects) - 1)
                fig.add_trace(trace, row=i+1, col=j+1)
            
            # Add smoothed bars with shading underneath
            for patch in filtered_df['patch'].unique():
                patch_df = filtered_df[filtered_df['patch'] == patch].sort_values("block_start_separated")
                patch_df['smoothed_preference'] = patch_df['late_preference_by_wheel'].rolling(window=5, min_periods=1, center=True).mean()
                fig.add_trace(
                    go.Bar(
                        x=patch_df['block_start_separated'],
                        y=patch_df['smoothed_preference'],
                        name=f'Smoothed {patch}',
                        marker=dict(
                            color=box_colors[list(filtered_df['patch'].unique()).index(patch)],
                            line=dict(color=box_colors[list(filtered_df['patch'].unique()).index(patch)], width=6)
                        ),
                        width = 540000, #3 hours
                        opacity=0.5,
                        showlegend=False
                    ),
                    row=i+1, col=j+1
                )
    
    # Update only scatter traces
    for trace in fig.data:
        if isinstance(trace, go.Scatter):
            trace.update(marker=dict(size=6))
    
    fig.update_layout(
        height=300 * len(unique_env_types),
        title_text=f"{phase_title} Phase: Late preference by wheel",
        legend_title_text='Subject and Patch'
    )
    
    return fig

# Create subplots for each phase
#pre_social_fig = create_phase_subplots(pre_social_df, "Pre social")
social_fig = create_phase_subplots(social_df, "Social")
post_social_fig = create_phase_subplots(post_social_df, "Post social")

# Show the subplots
#pre_social_fig.show()
social_fig.show()
#pio.write_html(social_fig, file=os.path.join(fig_save_directory, "social_pref_patch_fig.html"))

post_social_fig.show()
#pio.write_html(post_social_fig, file=os.path.join(fig_save_directory, "post_social_pref_patch_fig.html"))

In [None]:
"""Per patch ID"""
# Filter the DataFrame for each experiment phase
pre_social_df = patch_pref[patch_pref['experiment_phase'] == 'pre_social']
social_df = patch_pref[patch_pref['experiment_phase'] == 'social']
post_social_df = patch_pref[patch_pref['experiment_phase'] == 'post_social']
box_colors = patch_colors[0 : len(patch_names[:3])]  # subject colors

# Function to create subplots for a phase
def create_phase_subplots(df, phase_title):
    separation_strength = 700  # Adjust this value to control the amount of separation
    jitter_strength = 300  # Adjust this value to control the amount of jitter
    df = df[df['patch'] != "PatchDummy1"]
    df = df[df['env_type'].isin(['easy','hard'])]
    df = df[df['env_type'] != "nan"]
    unique_subjects = df['subject'].unique()
    unique_env_types = df['env_type'].unique()
    subject_offsets = {subject: i * separation_strength for i, subject in enumerate(unique_subjects)}
    
    # Apply offset and jitter to block_start based on subject
    df['block_start_separated'] = df.apply(
        lambda row: row['block_start'] + pd.Timedelta(seconds=subject_offsets[row['subject']]) + pd.Timedelta(seconds=np.random.uniform(-jitter_strength, jitter_strength)),
        axis=1
    )
    shared_xaxes = True if phase_title == "Social" else False
    
    # Create subplots with shared x-axes and conditionally shared y-axes
    fig = make_subplots(
        rows=len(unique_env_types), 
        cols=len(unique_subjects), 
        subplot_titles=[f"{env_type} - {subject} environment" for env_type in unique_env_types for subject in unique_subjects],
        shared_xaxes=shared_xaxes,
        shared_yaxes=True,
        vertical_spacing=0.1,
        horizontal_spacing=0.05
    )

    for i, env_type in enumerate(unique_env_types):
        for j, subject in enumerate(unique_subjects):
            filtered_df = df[(df['subject'] == subject) & (df['env_type'] == env_type)]
            scatter = px.scatter(
                filtered_df.sort_values("block_start"),
                x="block_start_separated",
                y="late_preference_by_wheel",
                color="patch",
                symbol="patch",
                color_discrete_sequence=box_colors,
                labels={"final_preference_by_wheel": "Late preference by wheel", "block_start_separated": "Block Start"}
            )
            for trace in scatter.data:
                trace.showlegend = (i == len(unique_env_types) - 1 and j == len(unique_subjects) - 1)
                fig.add_trace(trace, row=i+1, col=j+1)
            
            # Add smoothed line with shading underneath
            for patch in filtered_df['patch'].unique():
                patch_df = filtered_df[filtered_df['patch'] == patch].sort_values("block_start_separated")
                patch_df['smoothed_preference'] = patch_df['late_preference_by_wheel'].rolling(window=5, min_periods=1, center=True).mean()
                
                # Split data into segments where the time difference is less than 24 hours
                segments = []
                current_segment = []
                for k in range(len(patch_df)):
                    if k == 0 or (patch_df['block_start_separated'].iloc[k] - patch_df['block_start_separated'].iloc[k-1]).total_seconds() < 86400:
                        current_segment.append(patch_df.iloc[k])
                    else:
                        segments.append(pd.DataFrame(current_segment))
                        current_segment = [patch_df.iloc[k]]
                if current_segment:
                    segments.append(pd.DataFrame(current_segment))
                
                for segment in segments:
                    fig.add_trace(
                        go.Scatter(
                            x=segment['block_start_separated'],
                            y=segment['smoothed_preference'],
                            mode='lines',
                            name=f'Smoothed {patch}',
                            line=dict(color=box_colors[list(filtered_df['patch'].unique()).index(patch)]),
                            fill='tozeroy',
                            showlegend=False
                        ),
                        row=i+1, col=j+1
                    )
    
    fig.update_traces(marker=dict(size=6)) 
    fig.update_layout(
        height=300 * len(unique_env_types),
        title_text=f"{phase_title} Phase: Late preference by wheel",
        legend_title_text='Subject and Patch'
    )
    
    return fig

# Create subplots for each phase
#pre_social_fig = create_phase_subplots(pre_social_df, "Pre social")
social_fig = create_phase_subplots(social_df, "Social")
post_social_fig = create_phase_subplots(post_social_df, "Post social")

# Show the subplots
#pre_social_fig.show()
social_fig.show()
pio.write_html(social_fig, file=os.path.join(fig_save_directory, "social_pref_patch_fig.html"))

post_social_fig.show()
pio.write_html(post_social_fig, file=os.path.join(fig_save_directory, "post_social_pref_patch_fig.html"))

In [None]:
""" Per patch type. """
# Filter the DataFrame for each experiment phase
pre_social_df = patch_pref[patch_pref['experiment_phase'] == 'pre_social']
social_df = patch_pref[patch_pref['experiment_phase'] == 'social']
post_social_df = patch_pref[patch_pref['experiment_phase'] == 'post_social']
box_colors = patch_colors[0 : len(patch_names[:3])]  # subject colors

# Function to create subplots for a phase
def create_phase_subplots(df, phase_title):
    separation_strength = 700  # Adjust this value to control the amount of separation
    jitter_strength = 300  # Adjust this value to control the amount of jitter
    df = df[df['patch_type'] != "equal"]
    df = df[df['patch'] != "PatchDummy1"]
    df = df[df['env_type'].isin(['easy','hard'])]
    df = df[df['env_type'] != "nan"]
    unique_subjects = df['subject'].unique()
    unique_env_types = df['env_type'].unique()
    subject_offsets = {subject: i * separation_strength for i, subject in enumerate(unique_subjects)}
    
    # Apply offset and jitter to block_start based on subject
    df['block_start_separated'] = df.apply(
        lambda row: row['block_start'] + pd.Timedelta(seconds=subject_offsets[row['subject']]) + pd.Timedelta(seconds=np.random.uniform(-jitter_strength, jitter_strength)),
        axis=1
    )
    shared_xaxes = True if phase_title == "Social" else False
    
    # Create subplots with shared x-axes and conditionally shared y-axes
    fig = make_subplots(
        rows=len(unique_env_types), 
        cols=len(unique_subjects), 
        subplot_titles=[f"{env_type} - {subject} environment" for env_type in unique_env_types for subject in unique_subjects],
        shared_xaxes=shared_xaxes,
        shared_yaxes=True,
        vertical_spacing=0.1,
        horizontal_spacing=0.05
    )

    for i, env_type in enumerate(unique_env_types):
        for j, subject in enumerate(unique_subjects):
            filtered_df = df[(df['subject'] == subject) & (df['env_type'] == env_type)]
            scatter = px.scatter(
                filtered_df.sort_values("block_start"),
                x="block_start_separated",
                y="late_preference_by_wheel",
                color="patch_type",
                symbol="patch_type",
                color_discrete_sequence=box_colors,
                labels={"final_preference_by_wheel": "Late preference by wheel", "block_start_separated": "Block Start"}
            )
            for trace in scatter.data:
                trace.showlegend = (i == len(unique_env_types) - 1 and j == len(unique_subjects) - 1)
                fig.add_trace(trace, row=i+1, col=j+1)
            
            # Add smoothed line with shading underneath
            for patch_type in filtered_df['patch_type'].unique():
                patch_df = filtered_df[filtered_df['patch_type'] == patch_type].sort_values("block_start_separated")
                patch_df['smoothed_preference'] = patch_df['late_preference_by_wheel'].rolling(window=5, min_periods=1, center=True).mean()
                
                # Split data into segments where the time difference is less than 24 hours
                segments = []
                current_segment = []
                for k in range(len(patch_df)):
                    if k == 0 or (patch_df['block_start_separated'].iloc[k] - patch_df['block_start_separated'].iloc[k-1]).total_seconds() < 86400:
                        current_segment.append(patch_df.iloc[k])
                    else:
                        segments.append(pd.DataFrame(current_segment))
                        current_segment = [patch_df.iloc[k]]
                if current_segment:
                    segments.append(pd.DataFrame(current_segment))
                
                for segment in segments:
                    fig.add_trace(
                        go.Scatter(
                            x=segment['block_start_separated'],
                            y=segment['smoothed_preference'],
                            mode='lines',
                            name=f'Smoothed {patch_type}',
                            line=dict(color=box_colors[list(filtered_df['patch_type'].unique()).index(patch_type)]),
                            fill='tozeroy',
                            showlegend=False
                        ),
                        row=i+1, col=j+1
                    )
    
    fig.update_traces(marker=dict(size=6)) 
    fig.update_layout(
        height=300 * len(unique_env_types),
        title_text=f"{phase_title} Phase: Late preference by wheel",
        legend_title_text='Subject and Patch'
    )
    
    return fig

# Create subplots for each phase
#pre_social_fig = create_phase_subplots(pre_social_df, "Pre social")
social_fig = create_phase_subplots(social_df, "Social")
post_social_fig = create_phase_subplots(post_social_df, "Post social")

# Show the subplots
#pre_social_fig.show()
social_fig.show()
pio.write_html(social_fig, file=os.path.join(fig_save_directory, "social_pref_type_fig.html"))

post_social_fig.show()
pio.write_html(post_social_fig, file=os.path.join(fig_save_directory, "post_social_type_patch_fig.html"))

In [None]:
""" Same but easy and hard environmentson same plot. """
# Filter the DataFrame for each experiment phase
pre_social_df = patch_pref[patch_pref['experiment_phase'] == 'pre_social']
social_df = patch_pref[patch_pref['experiment_phase'] == 'social']
post_social_df = patch_pref[patch_pref['experiment_phase'] == 'post_social']
box_colors = patch_colors[0 : len(patch_names[:3])]  # subject colors


# Function to convert rgb to hex
def rgb_to_hex(rgb):
    rgb = re.match(r'rgb\((\d+),\s*(\d+),\s*(\d+)\)', rgb).groups()
    return f'#{int(rgb[0]):02x}{int(rgb[1]):02x}{int(rgb[2]):02x}'

# Function to adjust color brightness
def adjust_color_brightness(color, factor):
    if color.startswith('rgb'):
        color = rgb_to_hex(color)
    if not color.startswith('#') or len(color) != 7:
        raise ValueError(f"Invalid color format: {color}")
    r, g, b = [int(color[i:i+2], 16) for i in (1, 3, 5)]
    h, l, s = colorsys.rgb_to_hls(r / 255.0, g / 255.0, b / 255.0)
    l = max(0, min(1, l * factor))
    r, g, b = colorsys.hls_to_rgb(h, l, s)
    return f'#{int(r * 255):02x}{int(g * 255):02x}{int(b * 255):02x}'

# Create lighter and darker hues for 'easy' and 'hard' environments
box_colors_light = [adjust_color_brightness(color, 1.1) for color in box_colors]
box_colors_dark = [adjust_color_brightness(color, 0.85) for color in box_colors]

# Function to create subplots for a phase
def create_phase_subplots(df, phase_title):
    separation_strength = 700  # Adjust this value to control the amount of separation
    jitter_strength = 300  # Adjust this value to control the amount of jitter
    df = df[df['patch'] != "PatchDummy1"]
    df = df[df['env_type'].isin(['easy', 'hard'])]
    df = df[df['env_type'] != "nan"]
    unique_subjects = df['subject'].unique()
    subject_offsets = {subject: i * separation_strength for i, subject in enumerate(unique_subjects)}

    # Apply offset and jitter to block_start based on subject
    df['block_start_separated'] = df.apply(
        lambda row: row['block_start'] + pd.Timedelta(seconds=subject_offsets[row['subject']]) + pd.Timedelta(seconds=np.random.uniform(-jitter_strength, jitter_strength)),
        axis=1
    )
    shared_xaxes = True if phase_title == "Social" else False
    
    # Create subplots with shared x-axes and conditionally shared y-axes
    fig = make_subplots(
        rows=len(unique_subjects), 
        cols=1, 
        subplot_titles=[f"{subject}" for subject in unique_subjects],
        shared_xaxes=shared_xaxes,
        shared_yaxes=True,
        vertical_spacing=0.1
    )

    for i, subject in enumerate(unique_subjects):
        for env_type, color_set in zip(['easy', 'hard'], [box_colors_light, box_colors_dark]):
            filtered_df = df[(df['subject'] == subject) & (df['env_type'] == env_type)]
            scatter = px.scatter(
                filtered_df.sort_values("block_start"),
                x="block_start_separated",
                y="late_preference_by_wheel",
                color="patch",
                symbol="patch",
                color_discrete_sequence=color_set,
                labels={"final_preference_by_wheel": "Late preference by wheel", "block_start_separated": "Block Start"}
            )
            for trace in scatter.data:
                trace.showlegend = (i == len(unique_subjects) - 1 and env_type == 'hard')
                fig.add_trace(trace, row=i+1, col=1)
    
    fig.update_traces(marker=dict(size=6)) 
    fig.update_layout(
        height=400 * len(unique_subjects),
        title_text=f"{phase_title} Phase: Late preference by wheel",
        legend_title_text='Subject and Patch'
    )
    fig.update_yaxes(title_text="Late preference by wheel")
    return fig

# Create subplots for each phase
#pre_social_fig = create_phase_subplots(pre_social_df, "Pre social")
social_fig = create_phase_subplots(social_df, "Social")
post_social_fig = create_phase_subplots(post_social_df, "Post social")

# Show and save the subplots
#pre_social_fig.show()
social_fig.show()
#pio.write_html(social_fig, file=os.path.join(fig_save_directory, "social_pref_patch_fig.html"))

post_social_fig.show()
#pio.write_html(post_social_fig, file=os.path.join(fig_save_directory, "post_social_pref_patch_fig.html"))

### 3. Patch preference change

In [None]:
"""Plot preference change by patch type (easy, medium, hard), per block start time, grouped by subject."""
# NOTE: do we need this plot? is it informative or not?
# TODO: make this better - separate subjects and add shadd lines
# Filter the DataFrame for each experiment phase adn remove dummy
pre_social_df = patch_pref[(patch_pref['experiment_phase'] == 'pre_social') & (patch_pref['patch'] != 'PatchDummy1')]
social_df = patch_pref[(patch_pref['experiment_phase'] == 'social') & (patch_pref['patch'] != 'PatchDummy1')]
post_social_df = patch_pref[(patch_pref['experiment_phase'] == 'post_social') & (patch_pref['patch'] != 'PatchDummy1')]
box_colors = subject_colors[0 : len(subject_names)]  # subject colors

# Function to create a single figure for a phase
def create_phase_figure(df, phase_title):
    separation_strength = 700  # Adjust this value to control the amount of separation
    jitter_strength = 300  # Adjust this value to control the amount of jitter
    unique_subjects = df['subject'].unique()
    subject_offsets = {subject: i * separation_strength for i, subject in enumerate(unique_subjects)}
    
    # Apply offset and jitter to block_start based on subject
    df['block_start_separated'] = df.apply(
        lambda row: row['block_start'] + pd.Timedelta(seconds=subject_offsets[row['subject']]) + pd.Timedelta(seconds=np.random.uniform(-jitter_strength, jitter_strength)),
        axis=1
    )
    
    fig = px.scatter(
        df.sort_values("block_start"),
        x="block_start_separated",
        y="preference_change_by_wheel",
        color="subject",
        symbol="patch_type",
        color_discrete_sequence=box_colors,
        title=f"{phase_title} Phase: Late-early preference by wheel",
        labels={"preference_change_by_wheel": "Late-early preference by wheel", "block_start_separated": "Block Start"}
    )
    fig.update_traces(marker=dict(size=6)) 
    fig.update_layout(
        height=400,
        legend_title_text='Subject and Patch Type',
        yaxis=dict(range=[-1, 1])
    )
    
    return fig

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre Solo")
social_fig = create_phase_figure(social_df, "Social")
post_social_fig = create_phase_figure(post_social_df, "Post Solo")

# Show the figures
#pre_social_fig.show()
social_fig.show()
post_social_fig.show()

In [32]:
# TODO: could add subjective value plots as well as prference here