In [None]:
"""Notebook settings and imports"""

%load_ext autoreload
%autoreload 2
# %flow mode reactive

from collections import defaultdict

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots

import datajoint as dj
from aeon.dj_pipeline.analysis.block_analysis import *
from aeon.dj_pipeline import acquisition, streams
from aeon.analysis.block_plotting import gen_hex_grad, conv2d

## Set vars and fetch data that will be shared across multiple plots

In [None]:
"""Set plot styles."""

subject_colors = plotly.colors.qualitative.Plotly
# @NOTE: Really, we shouldn't have to explicitly set colors for each subject.
subject_colors_dict = {
    "BAA-1104516": subject_colors[0],
    "BAA-1104519": subject_colors[1],
    "BAA-1104568": subject_colors[2],
    "BAA-1104569": subject_colors[3],
}
patch_colors = plotly.colors.qualitative.Dark2
patch_markers = [
    "circle",
    "bowtie",
    "square",
    "hourglass",
    "diamond",
    "cross",
    "x",
    "triangle",
    "star",
]
patch_markers_symbols = ["●", "⧓", "■", "⧗", "♦", "✖", "×", "▲", "★"]
patch_markers_dict = {marker: symbol for marker, symbol in zip(patch_markers, patch_markers_symbols)}
patch_markers_linestyles = ["solid", "dash", "dot", "dashdot", "longdashdot"]

In [None]:
"""Set experiment key, stages"""

key = {"experiment_name": "social0.3-aeon4"}

# how to get this from Dj? This i only valid for aeon4
pre_social_start = datetime.strptime('2024-06-08 00:00:00', '%Y-%m-%d %H:%M:%S')
pre_social_end = datetime.strptime('2024-06-17 00:00:00', '%Y-%m-%d %H:%M:%S')
social_start =  datetime.strptime('2024-06-19 00:00:00', '%Y-%m-%d %H:%M:%S') 
social_end = datetime.strptime('2024-07-04 00:00:00', '%Y-%m-%d %H:%M:%S')
post_social_start = datetime.strptime('2024-07-04 00:00:00', '%Y-%m-%d %H:%M:%S')
post_social_end = datetime.strptime('2024-07-13 00:00:00', '%Y-%m-%d %H:%M:%S')

exp_start = (Block & key & f"block_start >= '{pre_social_start}'").fetch("block_start", limit=1)[0]
exp_end = (Block & key & f"block_start >= '{post_social_end}'").fetch("block_end", limit=1)[0]


In [None]:
"""Print experiment patch names and subject names"""
one_day_into_social = "2024-06-20 10:00:00"
block_start = (Block & key & f"block_start >= '{one_day_into_social}'").fetch("block_start", limit=1)[0]
block_end = (Block & key & f"block_start >= '{one_day_into_social}'").fetch("block_end", limit=1)[0]
one_block_key = key | {"block_start": str(pd.Timestamp(block_start))} | {"block_end": str(pd.Timestamp(block_end))}
chunk_restriction = acquisition.create_chunk_restriction(key["experiment_name"], block_start, block_end)
subj_patch_info_one_block = (BlockSubjectAnalysis.Patch & one_block_key).fetch(format="frame")
subject_names = list(subj_patch_info_one_block.index.get_level_values("subject_name").unique())
patch_names = list(subj_patch_info_one_block.index.get_level_values("patch_name").unique())
print(patch_names, subject_names)

In [None]:
"""Get foraging blocks."""
#NOTE: pre solo data seems to be missing?
all_blocks = (
    BlockAnalysis.Patch
    & key
    & f'block_start >= "{exp_start}"'
    & f'block_start <= "{exp_end}"'
).fetch('block_start', 'pellet_count', as_dict=True)

pellet_sums = defaultdict(int)
for entry in all_blocks:
    pellet_sums[entry['block_start']] += entry['pellet_count']
foraging_block_starts = {block_start for block_start, total in pellet_sums.items() if total >= 10}
foraging_block_starts = [{'block_start': block_start} for block_start in foraging_block_starts] #list of dicts
print(len(foraging_block_starts))

patch_info = (
    BlockAnalysis.Patch
    & key
    & foraging_block_starts
).fetch('block_start', "patch_name", "patch_rate", "patch_offset", as_dict=True)

subj_patch_info = (
    BlockSubjectAnalysis.Patch 
    & key
    & foraging_block_starts
).fetch(format="frame")
display(subj_patch_info, patch_info)

In [None]:
"""Convert `subj_patch_info` into a form amenable to plotting."""

reset_subj_patch_info = subj_patch_info.reset_index()  # reset to turn MultiIndex into columns
min_subj_patch_info = reset_subj_patch_info[  # select only relevant columns
    ['block_start', "patch_name", "subject_name", "pellet_timestamps", "patch_threshold"]
]
min_subj_patch_info = min_subj_patch_info.explode(
    ["pellet_timestamps", "patch_threshold"], ignore_index=True
).dropna().reset_index(drop=True)
# Rename and reindex columns and sort by time
min_subj_patch_info.columns = ['block_start', "patch", "subject", "time", "threshold"]
min_subj_patch_info = min_subj_patch_info.reindex(columns=["time",'block_start', "patch", "threshold", "subject"])
min_subj_patch_info = min_subj_patch_info.sort_values(by="time").reset_index(drop=True)

display(min_subj_patch_info)

In [None]:
"""Add patch mean values and block-normalized delivery times to pellet info."""

n_patches = len(patch_info)
patch_mean_info = pd.DataFrame(index=np.arange(n_patches), columns=min_subj_patch_info.columns)
patch_mean_info["subject"] = "mean"
patch_mean_info["patch"] = [d["patch_name"] for d in patch_info]
patch_mean_info["threshold"] = [((1 / d["patch_rate"]) + d["patch_offset"]) for d in patch_info]
patch_mean_info["time"] = subj_patch_info.index.get_level_values("block_start")[0]
patch_mean_info["block_start"] = subj_patch_info.index.get_level_values("block_start")[0]

min_subj_patch_info_plus = pd.concat((patch_mean_info, min_subj_patch_info)).reset_index(drop=True)
min_subj_patch_info_plus["norm_time"] = (
    (min_subj_patch_info_plus["time"] - min_subj_patch_info_plus["time"].iloc[0])
    / (min_subj_patch_info_plus["time"].iloc[-1] - min_subj_patch_info_plus["time"].iloc[0])
).round(3)

display(min_subj_patch_info_plus)

In [None]:
"""Add patch mean values and block-normalized delivery times to pellet info."""

all_patch_mean_info = pd.DataFrame()

# Group by block_start
grouped = min_subj_patch_info.groupby('block_start')
n_patches = len(patch_names)

for block_start, group in grouped:
    patch_mean_info = pd.DataFrame(index=np.arange(n_patches), columns=group.columns)
    current_patch_info = [d for d in patch_info if d['block_start'] == block_start]
    patch_mean_info["subject"] = "mean"
    patch_mean_info["patch"] = [d["patch_name"] for d in current_patch_info]
    patch_mean_info["threshold"] = [((1 / d["patch_rate"]) + d["patch_offset"]) for d in current_patch_info]
    patch_mean_info["time"] = block_start
    patch_mean_info["block_start"] = block_start

    # Append the patch mean info to the all_patch_mean_info DataFrame
    all_patch_mean_info = pd.concat((all_patch_mean_info, patch_mean_info))

# Reset index of the final DataFrame
all_patch_mean_info = all_patch_mean_info.reset_index(drop=True)

min_subj_patch_info_plus = pd.concat([min_subj_patch_info, all_patch_mean_info], ignore_index=True)

grouped = min_subj_patch_info_plus.groupby('block_start')
for block_start, group in grouped:
    # Calculate normalized time for the patch mean info within each block
    norm_time = (group["time"] - group["time"].min()) / (group["time"].max() - group["time"].min())
    norm_time = norm_time.round(3)
    min_subj_patch_info_plus.loc[group.index, "norm_time"] = norm_time
min_subj_patch_info_plus = min_subj_patch_info_plus.sort_values(by="time").reset_index(drop=True)
# Display the final DataFrame
display(min_subj_patch_info_plus)

In [None]:
"""Create a cumulative pellet count (by subject) dataframe with additional pellet info."""

cum_pel_ct = min_subj_patch_info_plus.sort_values("time").copy().reset_index(drop=True)

def cumsum_helper(group):
    group["counter"] = np.arange(len(group)) + 1
    return group

# Calculate patch means for each block
patch_means = cum_pel_ct[cum_pel_ct["subject"] == "mean"][["patch", "threshold", "block_start"]]
patch_means = patch_means.rename(columns={"threshold": "mean_thresh"})
patch_means["mean_thresh"] = patch_means["mean_thresh"].astype(float).round(1)

# Merge patch means back into the main DataFrame
cum_pel_ct = cum_pel_ct.merge(patch_means, on=["patch", "block_start"], how="left")

# Remove mean rows and reset index
cum_pel_ct = cum_pel_ct[~cum_pel_ct["subject"].str.contains("mean")].reset_index(drop=True)

# Apply cumulative sum helper function
cum_pel_ct = cum_pel_ct.groupby(["subject", "block_start"], group_keys=False).apply(cumsum_helper).reset_index(drop=True)

# Convert columns to float
make_float_cols = ["threshold", "mean_thresh", "norm_time"]
cum_pel_ct[make_float_cols] = cum_pel_ct[make_float_cols].astype(float)

# Create patch label
cum_pel_ct["patch_label"] = (
    cum_pel_ct["patch"] + " μ: " + cum_pel_ct["mean_thresh"].astype(float).round(1).astype(str)
)

# Calculate normalized threshold value
cum_pel_ct["norm_thresh_val"] = (
    (cum_pel_ct["threshold"] - cum_pel_ct["threshold"].min())
    / (cum_pel_ct["threshold"].max() - cum_pel_ct["threshold"].min())
).round(3)

# Sort by 'time' column
cum_pel_ct = cum_pel_ct.sort_values("time")

display(cum_pel_ct)

In [None]:
"""Get wheel timestamps for each patch."""

wheel_ts = (BlockAnalysis.Patch() & key & foraging_block_starts).fetch("block_start", "patch_name", "wheel_timestamps", as_dict=True)
wheel_ts_dict = {}
for d in wheel_ts:
    block_start = d["block_start"]
    patch_name = d["patch_name"]
    if block_start not in wheel_ts_dict:
        wheel_ts_dict[block_start] = {}
    wheel_ts_dict[block_start][patch_name] = d["wheel_timestamps"]
display(wheel_ts)

In [None]:
"""Get subject patch data."""

subject_patch_data = (BlockSubjectAnalysis.Patch() & key & foraging_block_starts).fetch(format="frame")
subject_patch_data.reset_index(level=["experiment_name"], drop=True, inplace=True)
subject_patch_data.reset_index(inplace=True)
display(subject_patch_data)

In [None]:
"""Add experiment phase labels to data ad number blocks."""
#NOTE: we need to add lables for experiment phases, were to do this?
# make new colunm experiment_phase in min_subj_patch_info_plus and ill in base on manually defined dates
def determine_phase(date):
    if pre_social_start <= date <= pre_social_end:
        return 'pre_social'
    elif social_start <= date <= social_end:
        return 'social'
    elif post_social_start <= date <= post_social_end:
        return 'post_social'
    else:
        print(f"Unknown phase for date {date}")
        return 'unknown'
min_subj_patch_info_plus['experiment_phase'] = min_subj_patch_info_plus['time'].apply(determine_phase)
display(min_subj_patch_info_plus)       

# Create block numbering for solo phases, grouped by subject and experiment_phase
solo_phases = min_subj_patch_info_plus[min_subj_patch_info_plus['experiment_phase'] != 'social']
solo_numbering = solo_phases[['subject', 'experiment_phase', 'block_start']].drop_duplicates()
solo_numbering['block_number'] = solo_numbering.groupby(['subject', 'experiment_phase']).cumcount() + 1

# Create block numbering for the social phase, grouped only by block_start (making sure same block_start has the same number for both subjects)
social_phase = min_subj_patch_info_plus[min_subj_patch_info_plus['experiment_phase'] == 'social']
social_numbering = social_phase[['experiment_phase', 'block_start']].drop_duplicates()
social_numbering['block_number'] = social_numbering.groupby(['experiment_phase']).cumcount() + 1

# Merge the numbering back
solo_numbering = solo_numbering[['subject', 'experiment_phase', 'block_start', 'block_number']]
social_numbering = social_numbering[['experiment_phase', 'block_start', 'block_number']]
min_subj_patch_info_plus = min_subj_patch_info_plus.merge(solo_numbering, on=['subject', 'experiment_phase', 'block_start'], how='left')
min_subj_patch_info_plus = min_subj_patch_info_plus.merge(social_numbering, on=['experiment_phase', 'block_start'], how='left', suffixes=('', '_social'))
min_subj_patch_info_plus['block_number'] = min_subj_patch_info_plus['block_number'].fillna(min_subj_patch_info_plus['block_number_social'])
min_subj_patch_info_plus.drop(columns=['block_number_social'], inplace=True)
min_subj_patch_info_plus['block_number'] = min_subj_patch_info_plus['block_number'].astype(int).astype(str)
display(min_subj_patch_info_plus)


## Create overall plots

### 1. Pellet thresholds

In [None]:
"""Plot threshold distribution (boxplot) per block start time, grouped by subject."""

# Filter the DataFrame for each experiment phase
pre_social_df = min_subj_patch_info_plus[min_subj_patch_info_plus['experiment_phase'] == 'pre_social']
social_df = min_subj_patch_info_plus[min_subj_patch_info_plus['experiment_phase'] == 'social']
post_social_df = min_subj_patch_info_plus[min_subj_patch_info_plus['experiment_phase'] == 'post_social']
box_colors = ["#0A0A0A"] + subject_colors[0 : len(subject_names)]  # subject colors + mean color

# Function to create a figure with subplots for each patch within a phase
def create_phase_figure(df, phase_title):
    patches = df['patch'].unique()
    
    if len(patches) == 0:
        raise ValueError(f"No patches found for phase: {phase_title}")
    
    desired_order = ["Patch1", "Patch2", "Patch3", "PatchDummy1"]
    patches = sorted(patches, key=lambda x: desired_order.index(x) if x in desired_order else len(desired_order))
    
    fig = make_subplots(
        rows=len(patches), cols=1,  shared_yaxes=True, shared_xaxes=True,
        subplot_titles=[f"{patch}" for patch in patches],
        vertical_spacing=0.05
    )
    
    for i, patch in enumerate(patches):
        patch_df = df[df['patch'] == patch]
        patch_fig = px.box(
            patch_df.sort_values("block_start"),
            x="block_start",
            y="threshold",
            color="subject",
            hover_data=["norm_time", "block_number"],
            color_discrete_sequence=box_colors,
            points="all"
        )
        for trace in patch_fig['data']:
            trace.showlegend = (i == len(patches) - 2)
            fig.add_trace(trace, row=i+1, col=1)
    
    fig.update_layout(
        title=f"{phase_title} Phase: Patch Means and Sampled Threshold Values",
        yaxis_title="Threshold (cm)",
        height=900,
        boxmode='group'
    )
    fig.update_yaxes(matches='y')
    fig.update_xaxes(title_text="Block start time", row=len(patches), col=1)

    return fig

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre social")
social_fig = create_phase_figure(social_df, "Social")
post_social_fig = create_phase_figure(post_social_df, "Post social")

# Show the figures
#pre_social_fig.show()
social_fig.show()
post_social_fig.show()

In [None]:
"""Plot threshold distribution (boxplot) per block number, grouped by subject."""
# Filter the DataFrame for each experiment phase
pre_social_df = min_subj_patch_info_plus[min_subj_patch_info_plus['experiment_phase'] == 'pre_social']
social_df = min_subj_patch_info_plus[min_subj_patch_info_plus['experiment_phase'] == 'social']
post_social_df = min_subj_patch_info_plus[min_subj_patch_info_plus['experiment_phase'] == 'post_social']
box_colors = ["#0A0A0A"] + subject_colors[0 : len(subject_names)]  # subject colors + mean color

# Function to create a figure with subplots for each patch within a phase
def create_phase_figure(df, phase_title):
    patches = df['patch'].unique()
    
    if len(patches) == 0:
        raise ValueError(f"No patches found for phase: {phase_title}")
    
    desired_order = ["Patch1", "Patch2", "Patch3", "PatchDummy1"]
    patches = sorted(patches, key=lambda x: desired_order.index(x) if x in desired_order else len(desired_order))
    
    fig = make_subplots(
        rows=len(patches), cols=1,  shared_yaxes=True, shared_xaxes=True,
        subplot_titles=[f"{patch}" for patch in patches],
        vertical_spacing=0.05
    )
    
    for i, patch in enumerate(patches):
        patch_df = df[df['patch'] == patch]
        patch_df['block_number'] = pd.Categorical(patch_df['block_number'], ordered=True, categories=sorted(patch_df['block_number'].unique(), key=int))
        patch_fig = px.box(
            patch_df.sort_values("block_number"),
            x="block_number",
            y="threshold",
            color="subject",
            hover_data=["norm_time", "block_start"],
            color_discrete_sequence=box_colors,
            points="all"
        )
        for trace in patch_fig['data']:
            trace.showlegend = (i == len(patches) - 1)
            fig.add_trace(trace, row=i+1, col=1)
    
    fig.update_layout(
        title=f"{phase_title} Phase: Patch Means and Sampled Threshold Values",
        yaxis_title="Threshold (cm)",
        height=900,
        boxmode='group'
    )
    fig.update_yaxes(matches='y')
    fig.update_xaxes(title_text="Block Number", row=len(patches), col=1)

    return fig

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre social")
social_fig = create_phase_figure(social_df, "Social")
post_social_fig = create_phase_figure(post_social_df, "Post social")

# Show the figures
#pre_social_fig.show()
social_fig.show()
post_social_fig.show()

### 2. Final patch preference

In [None]:
"""Assign patch types (easy, meadim, hard) to blocks."""
#NOTE: is htis correct? how ot treat equal blocks?
#NOTE: also how to assign if easy or hard env?

def assign_patch_type(group):
    # Assign 'dummy' to patch_type where patch is 'PatchDummy1'
    group.loc[group['patch'] == 'PatchDummy1', 'patch_type'] = 'dummy'
    
    # Separate dummy and non-dummy rows
    dummy_rows = group[group['patch_type'] == 'dummy']
    non_dummy_rows = group[group['patch_type'] != 'dummy']
    
    # Sort non-dummy rows by mean_thresh
    sorted_group = non_dummy_rows.sort_values(by='mean_thresh')
    
    # Check if mean_thresh values are all different
    if sorted_group['mean_thresh'].nunique() < len(sorted_group):
        #print(f'Warning: Equal mean_thresh values in block {group["block_start"].iloc[0]}')
        #mean_thresh_value = sorted_group['mean_thresh'].iloc[0]
        #sorted_group['patch_type'] = f'equal_{mean_thresh_value}'
        sorted_group['patch_type'] = 'equal'
    else:
        # Create conditions for assignment
        conditions = [
            (sorted_group.index == sorted_group.index[0]),
            (sorted_group.index == sorted_group.index[1]),
            (sorted_group.index == sorted_group.index[2])
        ]
        
        # Corresponding choices
        choices = ['easy', 'medium', 'hard']
        
        # Apply conditions and choices
        sorted_group['patch_type'] = np.select(conditions, choices, default=sorted_group['patch_type'])
    
    # Combine dummy and non-dummy rows back together
    combined_group = pd.concat([sorted_group, dummy_rows]).sort_index()
    
    return combined_group

# Apply the function to each group
patch_means = patch_means.groupby('block_start').apply(assign_patch_type).reset_index(drop=True)

In [None]:
"""Get subject patch preference data."""

patch_pref = (BlockSubjectAnalysis.Preference() & key & foraging_block_starts).fetch(format="frame")
patch_pref.reset_index(level=["experiment_name"], drop=True, inplace=True)
patch_pref.reset_index(inplace=True)
# Replace small vals with 0
patch_pref["cumulative_preference_by_wheel"] = patch_pref["cumulative_preference_by_wheel"].apply(
    lambda arr: np.where(np.array(arr) < 1e-3, 0, np.array(arr))
)
#NOTE: is this the final preference we want? chech how it is caluclated
patch_pref.rename(columns={"patch_name": "patch", "subject_name": "subject"}, inplace=True)

#add the phase data to htis df
patch_pref['experiment_phase'] = patch_pref['block_start'].apply(determine_phase)

#add patch means and patch types to this
merged_df = patch_pref.merge(patch_means[['block_start', 'patch', 'mean_thresh', 'patch_type']], on=['block_start', 'patch'], how='left')
patch_pref['mean_threshold'] = merged_df['mean_thresh']
patch_pref['patch_type'] = merged_df['patch_type']
patch_pref.set_index(['block_start', 'patch', 'subject'], inplace=True)
display(patch_pref)

In [None]:
"""Calculate late and early preference by wheel as avg of preference in first/second half of block."""

# Calculate running preference by wheel
# NOTE: sometimes the lengths of the arrays for the different patches are off by 1, so had to add padding. Why can this mismatch happen?
# Function to pad arrays to the same length
def pad_array(arr, length):
    return np.pad(arr, (0, max(0, length - len(arr))), 'constant', constant_values=np.nan)

# Function to calculate running preference
def calculate_running_preference(group, pref_col, out_col):
    lengths = group[pref_col].apply(len)
    max_length = lengths.max()
    
    try:
        padded_arrays = group[pref_col].apply(lambda x: pad_array(x, max_length))  # Pad arrays to the same length
        total_pref = np.nansum(np.vstack(padded_arrays.values), axis=0)  # Sum pref at each ts
        group[out_col] = padded_arrays.apply(lambda x: np.nan_to_num(x / total_pref, 0.))  # Running pref
    except ValueError as e:
        print(f"ValueError for group with indices: {group.index}")
        raise e
    
    return group

# Group by subject and block_start to calculate running preference by wheel
patch_pref = patch_pref.groupby(["subject", 'block_start']).apply(lambda group: 
    calculate_running_preference(group, "cumulative_preference_by_wheel", "running_preference_by_wheel")
).droplevel(0)

# Group by subject and block_start to calculate running preference by time
patch_pref = patch_pref.groupby(["subject", 'block_start']).apply(lambda group: 
    calculate_running_preference(group, "cumulative_preference_by_time", "running_preference_by_time")
).droplevel(0)
patch_pref = patch_pref.droplevel([0, 1])

# Calculate early and late preference by wheel
# Function to calculate average of first and second half of an array
def calculate_half_averages(arr):
    half = len(arr) // 2
    first_half_avg = np.nanmean(arr[:half])
    second_half_avg = np.nanmean(arr[half:])
    return first_half_avg, second_half_avg

# Calculate averages for running preferences by wheel
patch_pref[['early_preference_by_wheel', 'late_preference_by_wheel']] = patch_pref['running_preference_by_wheel'].apply(
    lambda x: pd.Series(calculate_half_averages(x))
)

# Calculate averages for running preferences by time
patch_pref[['early_preference_by_time', 'late_preference_by_time']] = patch_pref['running_preference_by_time'].apply(
    lambda x: pd.Series(calculate_half_averages(x))
)
patch_pref = patch_pref.reset_index()
display(patch_pref)

In [None]:
"""Plot late preference by patch ID, per block start time, grouped by subject."""

# Filter the DataFrame for each experiment phase
pre_social_df = patch_pref[patch_pref['experiment_phase'] == 'pre_social']
social_df = patch_pref[patch_pref['experiment_phase'] == 'social']
post_social_df = patch_pref[patch_pref['experiment_phase'] == 'post_social']
box_colors = subject_colors[0 : len(subject_names)]  # subject colors

# Function to create a single figure for a phase
def create_phase_figure(df, phase_title):
    separation_strength = 700  # Adjust this value to control the amount of separation
    jitter_strength = 300  # Adjust this value to control the amount of jitter
    unique_subjects = df['subject'].unique()
    subject_offsets = {subject: i * separation_strength for i, subject in enumerate(unique_subjects)}
    
    # Apply offset and jitter to block_start based on subject
    df['block_start_separated'] = df.apply(
        lambda row: row['block_start'] + pd.Timedelta(seconds=subject_offsets[row['subject']]) + pd.Timedelta(seconds=np.random.uniform(-jitter_strength, jitter_strength)),
        axis=1
    )
    
    fig = px.scatter(
        df.sort_values("block_start"),
        x="block_start_separated",
        y="late_preference_by_wheel",
        color="subject",
        symbol="patch",
        color_discrete_sequence=box_colors,
        title=f"{phase_title} Phase: Late preference by wheel",
        labels={"final_preference_by_wheel": "Late preference by wheel", "block_start_separated": "Block Start"}
    )
    fig.update_traces(marker=dict(size=6)) 
    fig.update_layout(
        height=400,
        legend_title_text='Subject and Patch'
    )
    
    return fig

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre social")
social_fig = create_phase_figure(social_df, "Social")
post_social_fig = create_phase_figure(post_social_df, "Post social")

# Show the figures
#pre_social_fig.show()
social_fig.show()
post_social_fig.show()

In [None]:
"""Plot late preference by patch mean threshold, per block start time, grouped by subject."""

# Filter the DataFrame for each experiment phase adn remove dummy
pre_social_df = patch_pref[(patch_pref['experiment_phase'] == 'pre_social') & (patch_pref['patch'] != 'PatchDummy1')]
social_df = patch_pref[(patch_pref['experiment_phase'] == 'social') & (patch_pref['patch'] != 'PatchDummy1')]
post_social_df = patch_pref[(patch_pref['experiment_phase'] == 'post_social') & (patch_pref['patch'] != 'PatchDummy1')]
box_colors = subject_colors[0 : len(subject_names)]  # subject colors

# Function to create a single figure for a phase
def create_phase_figure(df, phase_title):
    separation_strength = 700  # Adjust this value to control the amount of separation
    jitter_strength = 300  # Adjust this value to control the amount of jitter
    unique_subjects = df['subject'].unique()
    subject_offsets = {subject: i * separation_strength for i, subject in enumerate(unique_subjects)}
    
    # Apply offset and jitter to block_start based on subject
    df['block_start_separated'] = df.apply(
        lambda row: row['block_start'] + pd.Timedelta(seconds=subject_offsets[row['subject']]) + pd.Timedelta(seconds=np.random.uniform(-jitter_strength, jitter_strength)),
        axis=1
    )
    
    fig = px.scatter(
        df.sort_values("block_start"),
        x="block_start_separated",
        y="late_preference_by_wheel",
        color="subject",
        symbol="mean_threshold",
        color_discrete_sequence=box_colors,
        title=f"{phase_title} Phase: Late preference by wheel",
        labels={"late_preference_by_wheel": "Late preference by wheel", "block_start_separated": "Block Start"}
    )
    fig.update_traces(marker=dict(size=6)) 
    fig.update_layout(
        height=400,
        legend_title_text='Subject and Patch Mean Threshold'
    )
    
    return fig

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre social")
social_fig = create_phase_figure(social_df, "Social")
post_social_fig = create_phase_figure(post_social_df, "Post social")

# Show the figures
#pre_social_fig.show()
social_fig.show()
post_social_fig.show()

In [None]:
"""Plot late preference by patch type (easy, medium, hard), per block start time, grouped by subject."""

# Filter the DataFrame for each experiment phase adn remove dummy
pre_social_df = patch_pref[(patch_pref['experiment_phase'] == 'pre_social') & (patch_pref['patch'] != 'PatchDummy1')]
social_df = patch_pref[(patch_pref['experiment_phase'] == 'social') & (patch_pref['patch'] != 'PatchDummy1')]
post_social_df = patch_pref[(patch_pref['experiment_phase'] == 'post_social') & (patch_pref['patch'] != 'PatchDummy1')]
box_colors = subject_colors[0 : len(subject_names)]  # subject colors

# Function to create a single figure for a phase
def create_phase_figure(df, phase_title):
    separation_strength = 700  # Adjust this value to control the amount of separation
    jitter_strength = 300  # Adjust this value to control the amount of jitter
    unique_subjects = df['subject'].unique()
    subject_offsets = {subject: i * separation_strength for i, subject in enumerate(unique_subjects)}
    
    # Apply offset and jitter to block_start based on subject
    df['block_start_separated'] = df.apply(
        lambda row: row['block_start'] + pd.Timedelta(seconds=subject_offsets[row['subject']]) + pd.Timedelta(seconds=np.random.uniform(-jitter_strength, jitter_strength)),
        axis=1
    )
    
    fig = px.scatter(
        df.sort_values("block_start"),
        x="block_start_separated",
        y="late_preference_by_wheel",
        color="subject",
        symbol="patch_type",
        color_discrete_sequence=box_colors,
        title=f"{phase_title} Phase: Late preference by wheel",
        labels={"late_preference_by_wheel": "Late preference by wheel", "block_start_separated": "Block Start"}
    )
    fig.update_traces(marker=dict(size=6)) 
    fig.update_layout(
        height=400,
        legend_title_text='Subject and Patch Type'
    )
    
    return fig

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre social")
social_fig = create_phase_figure(social_df, "Social")
post_social_fig = create_phase_figure(post_social_df, "Post social")

# Show the figures
#pre_social_fig.show()
social_fig.show()
post_social_fig.show()

In [None]:
"""Same but with lineplot."""

# Define the patch names and their corresponding linestyles
patch_names = ['Patch1', 'Patch2', 'Patch3', 'Patch4']
patch_markers_linestyles = ['solid', 'dash', 'dot', 'dashdot']

# Define the patch types and their corresponding linestyles
patch_types = ['easy', 'medium', 'hard', 'equal']
patch_types_linestyles = ['solid', 'dash', 'dot', 'dashdot']

# Create the patch_type_linestyles dictionary
patch_type_linestyles = {patch: linestyle for patch, linestyle in zip(patch_types, patch_types_linestyles)}

# Filter the DataFrame for each experiment phase and remove dummy
pre_social_df = patch_pref[(patch_pref['experiment_phase'] == 'pre_social') & (patch_pref['patch'] != 'PatchDummy1')]
social_df = patch_pref[(patch_pref['experiment_phase'] == 'social') & (patch_pref['patch'] != 'PatchDummy1')]
post_social_df = patch_pref[(patch_pref['experiment_phase'] == 'post_social') & (patch_pref['patch'] != 'PatchDummy1')]
box_colors = subject_colors[0 : len(subject_names)]  # subject colors

# Function to create a single figure for a phase
def create_phase_figure(df, phase_title):
    if df.empty:
        return None
    
    separation_strength = 700  # Adjust this value to control the amount of separation
    jitter_strength = 300  # Adjust this value to control the amount of jitter
    unique_subjects = df['subject'].unique()
    subject_offsets = {subject: i * separation_strength for i, subject in enumerate(unique_subjects)}
    
    df['line_dash'] = df['patch_type'].map(patch_type_linestyles)
    
    # Apply offset and jitter to block_start based on subject
    df['block_start_separated'] = df.apply(
        lambda row: row['block_start'] + pd.Timedelta(seconds=subject_offsets[row['subject']]) + pd.Timedelta(seconds=np.random.uniform(-jitter_strength, jitter_strength)),
        axis=1
    )
    
    fig = px.line(
        df.sort_values("block_start"),
        x="block_start_separated",
        y="late_preference_by_wheel",
        color="subject",
        symbol="patch_type",
        line_dash="line_dash",
        color_discrete_sequence=box_colors,
        title=f"{phase_title} Phase: Late preference by wheel",
        labels={"late_preference_by_wheel": "Late preference by wheel", "block_start_separated": "Block Start"}
    )
    fig.update_traces(marker=dict(size=6)) 
    fig.update_layout(
        height=400,
        legend_title_text='Subject and Patch Type'
    )
    
    return fig

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre Solo")
social_fig = create_phase_figure(social_df, "Social")
post_social_fig = create_phase_figure(post_social_df, "Post Solo")

# Show the figures
#pre_social_fig.show()
if social_fig:
    social_fig.show()
if post_social_fig:
    post_social_fig.show()

### 3. Patch preference change

In [None]:
"""Calculate difference betwen late-early preference by wheel."""
patch_pref['preference_change_by_wheel'] = patch_pref['late_preference_by_wheel'] - patch_pref['early_preference_by_wheel']

In [None]:
"""Plot preference change by patch type (easy, medium, hard), per block start time, grouped by subject."""

# Filter the DataFrame for each experiment phase adn remove dummy
pre_social_df = patch_pref[(patch_pref['experiment_phase'] == 'pre_social') & (patch_pref['patch'] != 'PatchDummy1')]
social_df = patch_pref[(patch_pref['experiment_phase'] == 'social') & (patch_pref['patch'] != 'PatchDummy1')]
post_social_df = patch_pref[(patch_pref['experiment_phase'] == 'post_social') & (patch_pref['patch'] != 'PatchDummy1')]
box_colors = subject_colors[0 : len(subject_names)]  # subject colors

# Function to create a single figure for a phase
def create_phase_figure(df, phase_title):
    separation_strength = 700  # Adjust this value to control the amount of separation
    jitter_strength = 300  # Adjust this value to control the amount of jitter
    unique_subjects = df['subject'].unique()
    subject_offsets = {subject: i * separation_strength for i, subject in enumerate(unique_subjects)}
    
    # Apply offset and jitter to block_start based on subject
    df['block_start_separated'] = df.apply(
        lambda row: row['block_start'] + pd.Timedelta(seconds=subject_offsets[row['subject']]) + pd.Timedelta(seconds=np.random.uniform(-jitter_strength, jitter_strength)),
        axis=1
    )
    
    fig = px.scatter(
        df.sort_values("block_start"),
        x="block_start_separated",
        y="preference_change_by_wheel",
        color="subject",
        symbol="patch_type",
        color_discrete_sequence=box_colors,
        title=f"{phase_title} Phase: Late-early preference by wheel",
        labels={"preference_change_by_wheel": "Late-early preference by wheel", "block_start_separated": "Block Start"}
    )
    fig.update_traces(marker=dict(size=6)) 
    fig.update_layout(
        height=400,
        legend_title_text='Subject and Patch Type',
        yaxis=dict(range=[-1, 1])
    )
    
    return fig

# Create figures for each phase
#pre_social_fig = create_phase_figure(pre_social_df, "Pre Solo")
social_fig = create_phase_figure(social_df, "Social")
post_social_fig = create_phase_figure(post_social_df, "Post Solo")

# Show the figures
#pre_social_fig.show()
social_fig.show()
post_social_fig.show()