In [None]:
import sys
import os
from icecream import ic

from pathlib import Path

import utils_behavior

from utils_behavior import Ballpushing_utils
from utils_behavior import Utils
from utils_behavior import Processing
from utils_behavior import HoloviewsTemplates

import pandas as pd
import hvplot.pandas
import numpy as np

from scipy import stats
from statsmodels.stats.multitest import multipletests
from scipy.signal import find_peaks

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import pandas as pd
import pyarrow.feather as feather

import pandas as pd
import numpy as np
import holoviews as hv
import hvplot.pandas
from scipy.optimize import curve_fit
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score


import matplotlib.pyplot as plt
import seaborn as sns


import importlib

import holoviews as hv

import pwlf

hv.extension("bokeh")

In [None]:
# Get the path to the data

Datapath = Utils.get_data_path()

In [None]:
# Load already existing dataset: 

Subset = feather.read_feather("/mnt/upramdya_data/MD/MultiMazeRecorder/Datasets/Learning/240809_BallposData.feather")

In [None]:
# Find folders with "Learning or learning" in the name as a list

folders = [f for f in Datapath.glob("*Learning*")]

folders

In [None]:
#Generate the experiments

Exps = [Ballpushing_utils.Experiment(f) for f in folders]

In [None]:
Data = Ballpushing_utils.Dataset(Exps)

In [None]:
Data.generate_dataset(success_cutoff=False)

In [None]:
# Get the data columns
  
Data.data.columns

In [None]:
# Make a subset of the data keeping only yball, yball_smooth,

Subset = Data.data[["Frame","time", "yball", "yball_smooth", "yball_relative", "fly", "Peak", "Date"]]

# Save this subset to a feather file

#feather.write_feather(Subset, "/mnt/upramdya_data/MD/MultiMazeRecorder/Datasets/Learning/240809_BallposData.feather")

In [None]:
Subset

In [None]:
# Get the name of the first fly and store it to use it for the next steps

fly = Subset["fly"].iloc[0]

TestData = Subset[Subset["fly"] == fly]

# Now take one fly and plot the yball_relative

#Subset[Subset["fly"] == fly].hvplot(x="time", y="yball_relative", kind="scatter")

In [None]:
# Compute the derivative of the yball_relative

Subset["yball_relative_derivative"] = Subset["yball_relative"].diff()

In [None]:


# Plot the derivative

Subset[Subset["fly"] == fly].hvplot(x="time", y="yball_relative_derivative", kind="scatter")



In [None]:
# Do Negative peak detection on the derivative

Peaks = find_peaks(-TestData["yball_relative_derivative"], height=0.23, distance=500)

#Peaks

In [None]:
# Extract the "time" and "Frame" columns from the dataset
frame_to_time_mapping = dict(zip(TestData["Frame"], TestData["time"]))

# Convert Peaks data to lists
x_peaks = Peaks[0].tolist()
y_peaks = Peaks[1]["peak_heights"].tolist()

# Map frame indices to time values
x_peaks_time = [frame_to_time_mapping[frame] for frame in x_peaks]

# Plot the derivative and the peaks
scatter_plot = TestData.hvplot(x="time", y="yball_relative_derivative", kind="scatter")
peaks_plot = hv.Scatter(
    (x_peaks_time, y_peaks), kdims=["time"], vdims=["peak_heights"], label="Peaks"
)

# Combine the plots
combined_plot = scatter_plot * peaks_plot
#combined_plot

In [None]:
import pandas as pd
from scipy.signal import find_peaks

# Group by individual flies
grouped = Subset.groupby("fly")


# Function to process each group
def process_group(group):
    # Find peaks in the derivative
    peaks, _ = find_peaks(
        -group["yball_relative_derivative"], height=0.3, distance=500
    )

    # Debug: Print the peaks detected for each fly
    print(f"Fly: {group['fly'].iloc[0]}, Peaks: {peaks}")

    # Initialize the "Trial" column
    group["Trial"] = 0

    # Assign trial numbers based on peak positions
    trial_number = 1
    previous_peak = 0
    for peak in peaks:
        group.iloc[previous_peak : peak + 1, group.columns.get_loc("Trial")] = (
            trial_number
        )
        trial_number += 1
        previous_peak = peak + 1

    # Assign the last trial number to the remaining rows
    group.iloc[previous_peak:, group.columns.get_loc("Trial")] = trial_number

    # Debug: Print the trial numbers assigned for each fly
    print(group[["fly", "Frame", "Trial"]].head(10))

    return group


# Apply the function to each group and combine the results
Trials = grouped.apply(process_group).reset_index(drop=True)

In [None]:
# Find how many unique trials there are grouped by fly
Trials.groupby("fly")["Trial"].nunique()

In [None]:
# Now we can drop any fly that has less than 2 trials

Filtered = Trials.groupby("fly").filter(lambda x: x["Trial"].nunique() > 1)

Filtered

In [None]:
Filtered.groupby("fly")["Trial"].nunique()

In [None]:
# Group by Trial and time, then compute the mean of yball_relative
averaged_data = (
    Filtered.groupby(["Trial", "time"])["yball_relative"].mean().reset_index()
)

# Plot the averaged yball_relative values
# averaged_data.hvplot(
#     x="time", y="yball_relative", by="Trial", kind="line", legend="top_left"
# )

In [None]:
# Among the flies that have more than 1 trial, we need to drop the Frames that belong to a plateau, meaning the frames where the yball_relative is above a certain threshold


# Function to clean each trial
def clean_trial(group):
    # Drop the first 500 frames
    group = group.iloc[500:]

    # Find the index where yball_relative reaches its max value for the first time
    max_index = group["yball_relative"].idxmax()

    # Trim the trial data to end at this maximum value
    group = group.loc[:max_index]

    return group


# Apply the function to each group and combine the results
cleaned_data = (
    Filtered.groupby(["fly", "Trial"]).apply(clean_trial).reset_index(drop=True)
)

# Display the cleaned dataset
#cleaned_data

In [None]:
# Function to compute trial duration
def compute_trial_duration(group):
    duration = group["time"].max() - group["time"].min()
    return pd.Series({"duration": duration})


# Apply the function to each group and compute the trial durations
trial_durations = (
    cleaned_data.groupby(["fly", "Trial"]).apply(compute_trial_duration).reset_index()
)

In [None]:
Jitterbox = HoloviewsTemplates.jitter_boxplot(trial_durations, kdims = "Trial", metric="duration", plot_options=HoloviewsTemplates.hv_slides).opts(invert_axes=False, xlabel="Trial Number", ylabel="Duration (s)")

In [None]:
Jitterbox

In [None]:
# Save the plot as a PNG and a SVG file
hv.save(
    Jitterbox,
    "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240809_TrialDuration.png",
)

In [None]:
# Method to save as svg

from bokeh.io.export import export_svgs


def export_svg(obj, filename):
    plot_state = hv.renderer("bokeh").get_plot(obj).state
    plot_state.output_backend = "svg"
    export_svgs(plot_state, filename=filename)


export_svg(
    Jitterbox,
    "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240809_TrialDuration.svg",
)

In [None]:
# Assuming trial_durations is your DataFrame
# Ensure 'Trial' is treated as a categorical variable with ordered categories
trial_durations["Trial"] = pd.Categorical(trial_durations["Trial"], ordered=True)

# Create the jitterboxplot
jitterbox = HoloviewsTemplates.jitter_boxplot(
    trial_durations,
    kdims="Trial",
    metric="duration",
    plot_options=HoloviewsTemplates.hv_slides,
).opts(invert_axes=False, xlabel="Trial Number", ylabel="Duration (s)")

# Create a line plot to track each "fly" across trial numbers
line_plot = (
    hv.Curve(trial_durations, kdims=["Trial"], vdims=["duration", "fly"])
    .groupby("fly")
    .opts(color="gray", line_width=1, alpha=0.5)
    .overlay()
)

# Overlay the line plot on top of the jitterboxplot
combined_plot = (jitterbox * line_plot).opts(width = 1500,
                                             height = 1000,
                                             fontscale = 2.5,
                                             show_grid = True,
                                             )

combined_plot

In [None]:
# Save the plot as png

hv.save(combined_plot, "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240809_TrialDuration_Lines.png")

In [None]:
# Get how many individual flies we have for each trial number

trial_durations.groupby("Trial")["fly"].nunique()

In [None]:
# Assuming trial_durations is your DataFrame
# Ensure 'Trial' is treated as a categorical variable with ordered categories
trial_durations["Trial"] = pd.Categorical(trial_durations["Trial"], ordered=True)


# Function to calculate improvement
def calculate_improvement(df):
    df = df.sort_values(by="Trial")
    df["improved"] = df["duration"].diff().lt(0)
    return df


# Apply the function to each group of 'fly'
trial_durations = trial_durations.groupby("fly").apply(calculate_improvement)

# Calculate the proportion of flies that improved in each trial
improvement_proportions = trial_durations.groupby("Trial")["improved"].mean()

# Print the results
print(improvement_proportions)

# Plot the results
import holoviews as hv

hv.extension("bokeh")

improvement_plot = (
    hv.Curve(improvement_proportions)
    .opts(
        xlabel="Trial Number",
        ylabel="Proportion of Flies Improved",
        title="Proportion of Flies that Improved in Each Trial",
    )
    .options(**HoloviewsTemplates.hv_slides["plot"])
    .opts(invert_axes=False, show_legend=True)
)

improvement_plot

In [None]:
# Filter to only include the trials where at least 5 flies participated
improvement_proportions_filtered = improvement_proportions[
    trial_durations.groupby("Trial")["fly"].nunique() >= 5
]

# Plot the filtered results
improvement_plot_filtered = (
    hv.Curve(improvement_proportions_filtered)
    .opts(
        xlabel="Trial Number",
        ylabel="Proportion of Flies Improved",
        title="Proportion of Flies that Improved in Each Trial",
    )
    .options(**HoloviewsTemplates.hv_slides["plot"])
    .opts(invert_axes=False, show_legend=True)
)

improvement_plot_filtered

In [None]:
hv.save(
    improvement_plot_filtered,
    "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240819_Improvement.png",
)

hv.save(
    improvement_plot_filtered,
    "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240819_Improvement.html",
)


In [None]:
# Function to calculate duration ratio and invert it
def calculate_duration_ratio(df):
    df = df.sort_values(by="Trial")
    df["duration_ratio"] = df["duration"].shift(1) / df["duration"]
    return df


# Apply the function to each group of 'fly'
trial_durations = trial_durations.groupby("fly").apply(calculate_duration_ratio)

# Calculate the average inverted duration ratio for each trial
average_inverted_ratio = trial_durations.groupby("Trial")["duration_ratio"].mean()

# Print the results
print(average_inverted_ratio)

# Plot the results
improvement_magnitude_plot = hv.Curve(average_inverted_ratio).opts(
    xlabel="Trial Number",
    ylabel="Average Inverted Duration Ratio",
    title="Average Inverted Duration Ratio in Each Trial",
).options(**HoloviewsTemplates.hv_slides["plot"]).opts(invert_axes=False, show_legend=True)

improvement_magnitude_plot

In [None]:
# Filter the data for the 10 first trials
filtered_trial_durations = trial_durations[trial_durations["Trial"] <= 10]

# Calculate the average duration for each trial
average_duration = filtered_trial_durations.groupby("Trial")["duration"].mean()

# Compute the bootstrapped confidence intervals

# Number of bootstrap samples
n_samples = 1000

# Initialize an empty DataFrame to store the bootstrap samples
bootstrap_samples = pd.DataFrame()

# Perform bootstrapping
for trial in range(1, 11):
    # Get the duration values for the current trial
    trial_values = filtered_trial_durations[filtered_trial_durations["Trial"] == trial][
        "duration"
    ]

    # Initialize an empty list to store the bootstrap samples
    samples = []

    # Perform bootstrapping
    for _ in range(n_samples):
        # Sample with replacement from the trial values
        sample = np.random.choice(trial_values, len(trial_values), replace=True)

        # Compute the mean of the sample and store it
        samples.append(sample.mean())

    # Store the bootstrap samples in the DataFrame
    bootstrap_samples[trial] = samples

# Compute the 95% confidence intervals
confidence_intervals = bootstrap_samples.quantile([0.025, 0.975], axis=0).T

# Plot the average duration and confidence intervals
average_duration_plot = (
    hv.Curve(average_duration)
    .opts(
        xlabel="Trial Number",
        ylabel="Average Duration (s)",
        title="Average Duration in Each Trial",
        xticks=list(
            range(1, len(trial_durations["Trial"].unique()) + 1)
        ),  # Explicitly set x-ticks to include all trial numbers,
    )
    .options(**HoloviewsTemplates.hv_slides["plot"])
    .opts(invert_axes=False, show_legend=False)
)

confidence_interval_plot = hv.Area(
    (
        confidence_intervals.index,
        confidence_intervals[0.025],
        confidence_intervals[0.975],
    ),
    vdims=["lower", "upper"],
    label="95% Confidence Interval",
).opts(color="blue", alpha=0.2, xticks=list(range(1, len(trial_durations["Trial"].unique()) + 1))).opts(show_legend=False)

CombinedPlot = (average_duration_plot * confidence_interval_plot).opts(xlabel="Trial Number", ylabel="Average Duration (s)")

In [None]:
hv.save(CombinedPlot, "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240819_AverageDuration.png")
hv.save(CombinedPlot, "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240819_AverageDuration.html")

In [None]:
# Do the same for the whole dataset


# Calculate the average duration for each trial
average_duration = trial_durations.groupby("Trial")["duration"].mean()

# Compute the bootstrapped confidence intervals

# Number of bootstrap samples
n_samples = 1000

# Initialize an empty DataFrame to store the bootstrap samples
bootstrap_samples = pd.DataFrame()

# Perform bootstrapping
for trial in range(1, len(trial_durations["Trial"].unique())):
    # Get the duration values for the current trial
    trial_values = trial_durations[trial_durations["Trial"] == trial][
        "duration"
    ]

    # Initialize an empty list to store the bootstrap samples
    samples = []

    # Perform bootstrapping
    for _ in range(n_samples):
        # Sample with replacement from the trial values
        sample = np.random.choice(trial_values, len(trial_values), replace=True)

        # Compute the mean of the sample and store it
        samples.append(sample.mean())

    # Store the bootstrap samples in the DataFrame
    bootstrap_samples[trial] = samples

# Compute the 95% confidence intervals
confidence_intervals = bootstrap_samples.quantile([0.025, 0.975], axis=0).T

# Plot the average duration and confidence intervals
average_duration_plot = (
    hv.Curve(average_duration)
    .opts(
        xlabel="Trial Number",
        ylabel="Average Duration",
        title="Average Duration in Each Trial",
    )
    .options(**HoloviewsTemplates.hv_slides["plot"])
    .opts(invert_axes=False, show_legend=False)
)

confidence_interval_plot = hv.Area(
    (
        confidence_intervals.index,
        confidence_intervals[0.025],
        confidence_intervals[0.975],
    ),
    vdims=["lower", "upper"],
    label="95% Confidence Interval",
).opts(color="blue", alpha=0.3)

average_duration_plot * confidence_interval_plot

In [None]:
# Assuming trial_durations is your DataFrame
# Ensure 'Trial' is treated as a categorical variable with ordered categories
trial_durations["Trial"] = pd.Categorical(trial_durations["Trial"], ordered=True)


# Function to calculate improvement and worsening
def calculate_improvement_worsening(df):
    df = df.sort_values(by="Trial")
    df["improved"] = df["duration"].diff().lt(0)
    df["worsened"] = df["duration"].diff().gt(0)
    return df


# Apply the function to each group of 'fly'
trial_durations = trial_durations.groupby("fly").apply(calculate_improvement_worsening)

# Calculate the proportion of flies that improved and worsened in each trial
improvement_proportions = trial_durations.groupby("Trial")["improved"].mean()
worsening_proportions = trial_durations.groupby("Trial")["worsened"].mean()

# Calculate the difference between improvement and worsening proportions
difference_proportions = improvement_proportions - worsening_proportions

# Print the results
print(difference_proportions)

# Plot the results

difference_plot = hv.Curve(difference_proportions).opts(
    xlabel="Trial Number",
    ylabel="Proportion Difference (Improved - Worsened)",
    title="Difference in Proportion of Flies Improved vs Worsened in Each Trial",
).options(**HoloviewsTemplates.hv_slides["plot"]).opts(invert_axes=False, show_legend=True)

difference_plot

In [None]:
# Assuming trial_durations is your DataFrame
# Ensure 'Trial' is treated as a categorical variable with ordered categories
trial_durations["Trial"] = pd.Categorical(trial_durations["Trial"], ordered=True)


# Function to calculate improvement, worsening, and change in duration
def calculate_changes(df):
    df = df.sort_values(by="Trial")
    df["improved"] = df["duration"].diff().lt(0)
    df["worsened"] = df["duration"].diff().gt(0)
    df["change"] = df["duration"].diff()
    return df


# Apply the function to each group of 'fly'
trial_durations = trial_durations.groupby("fly").apply(calculate_changes)

# Calculate the proportion of flies that improved and worsened in each trial
improvement_proportions = trial_durations.groupby("Trial")["improved"].mean()
worsening_proportions = trial_durations.groupby("Trial")["worsened"].mean()

# Calculate the difference between improvement and worsening proportions
difference_proportions = improvement_proportions - worsening_proportions

# Calculate the average change in duration for each trial
average_change = trial_durations.groupby("Trial")["change"].mean()

# Combine the proportion difference with the average change
combined_metric = pd.DataFrame(
    {"Proportion Difference": difference_proportions, "Average Change": average_change}
)

# Print the results
print(combined_metric)

# Plot the results
proportion_plot = hv.Curve(
    combined_metric["Proportion Difference"], label="Proportion Difference"
).opts(
    xlabel="Trial Number",
    ylabel="Proportion Difference (Improved - Worsened)",
    width=800,
    height=400,
)

change_plot = hv.Curve(combined_metric["Average Change"], label="Average Change").opts(
    xlabel="Trial Number",
    ylabel="Average Change in Duration",
    width=800,
    height=400,
)

combined_plot = proportion_plot * change_plot.opts(ylabel="Combined Metric")

# Save the plot as png
# hv.save(
#     combined_plot,
#     "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240809_Combined_Metric.png",
# )

combined_plot

In [None]:
# Save the plot as png

hv.save(improvement_plot, "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240809_Improvement.png")

In [None]:
# For flies for which "Peak" is NaN, give them the value "morning"
Filtered["Peak"] = Filtered["Peak"].fillna("morning")
cleaned_data["Peak"] = cleaned_data["Peak"].fillna("morning")

# Ensure 'Trial' is treated as a categorical variable with ordered categories
cleaned_data["Trial"] = pd.Categorical(cleaned_data["Trial"], ordered=True)


# Apply the function to each group and compute the trial durations
trial_durations_peak = (
    cleaned_data.groupby(["Peak","fly", "Trial"]).apply(compute_trial_duration).reset_index()
)


Jitterbox_peak = HoloviewsTemplates.jitter_boxplot(
    trial_durations_peak,
    kdims="Trial",
    metric="duration",
    plot_options=HoloviewsTemplates.hv_slides,
    groupby="Peak",
    render="grouped"
).opts(invert_axes=False, xlabel="Trial Number", ylabel="Duration (s)")

Jitterbox_peak

In [None]:
# Save as html

hv.save(Jitterbox_peak, "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240809_TrialDuration_Peak.html")

In [None]:
# Redo the improvement but grouped by Peak

# Assuming trial_durations is your DataFrame
# Ensure 'Trial' is treated as a categorical variable with ordered categories
trial_durations_peak["Trial"] = pd.Categorical(
    trial_durations_peak["Trial"], ordered=True
)


# Function to calculate improvement
def calculate_improvement(df):
    df = df.sort_values(by="Trial")
    df["improved"] = df["duration"].diff().lt(0)
    return df

# Apply the function to each group of 'fly' and 'Peak'
trial_durations_peak = trial_durations_peak.groupby(["fly", "Peak"]).apply(
    calculate_improvement
)


# Calculate the proportion of flies that improved in each trial for each Peak
improvement_proportions = (
    trial_durations_peak.groupby(["Peak", "Trial"])["improved"].mean().reset_index()
)

# Print the results
print(improvement_proportions)

# Plot the results
import holoviews as hv

hv.extension("bokeh")

# Create a separate plot for each Peak
plots = []
for peak in improvement_proportions["Peak"].unique():
    peak_data = improvement_proportions[improvement_proportions["Peak"] == peak]
    plot = hv.Curve(peak_data, kdims="Trial", vdims="improved").opts(
        xlabel="Trial Number",
        ylabel="Proportion of Flies Improved",
        title=f"Proportion of Flies that Improved in Each Trial for Peak {peak}",
        width=800,
        height=400,
    )
    plots.append(plot)

# Combine all plots into a single layout
improvement_plot = hv.Layout(plots).cols(1)

improvement_plot

In [None]:
# Convert Trial back to integers
cleaned_data["Trial"] = cleaned_data["Trial"].astype(int)

# Function to compute the end time of each trial
def compute_trial_end_time(group):
    end_time = group["time"].max()
    return pd.Series({"end_time": end_time})


# Apply the function to each group and compute the end times
trial_end_times = (
    cleaned_data.groupby(["fly", "Trial"]).apply(compute_trial_end_time).reset_index()
)

# Sort by end_time
trial_end_times = trial_end_times.sort_values(by="end_time")

# Compute the cumulative count of solved trials
trial_end_times["cumulative_solved_trials"] = (
    trial_end_times["end_time"].rank(method="first").astype(int)
)

# Plot the cumulative count of solved trials over time
cumulcurve = trial_end_times.hvplot.step(
    x="end_time",
    y="cumulative_solved_trials",
    title="Cumulative Solved Trials Over Time",
    xlabel="Time",
    ylabel="Cumulative Solved Trials",
).opts(
        height=1000,
        width=1200,
        alpha=1,
        line_width=2,
        xlabel="Time(s)",
        ylabel="Cumulative Solved Trials",
        show_grid=True,
        fontscale=3,
        title="Cumulative Solved Trials Over Time",
    )

In [None]:
# Save it as a PNG

hv.save(cumulcurve, "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240809_CumulativeSolvedTrials.png")

In [None]:
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from scipy import stats
import holoviews as hv
import hvplot.pandas
from holoviews import opts  # Import opts from holoviews


# Define logistic function
def logistic(x, L, k, x0):
    return L / (1 + np.exp(-k * (x - x0)))


# Filter data to include only relevant subset
GroupData_morning = trial_end_times

# Calculate the number of unique 'flies' in the dataset
num_replicates = GroupData_morning["fly"].nunique()

# Initial guesses for L, k, x0
p0 = [
    max(GroupData_morning["cumulative_solved_trials"]),
    1,
    np.median(GroupData_morning["end_time"]),
]

# Fit logistic function to data
params, _ = curve_fit(
    logistic,
    GroupData_morning["end_time"],
    GroupData_morning["cumulative_solved_trials"],
    p0,
    maxfev=10000,
)

# params contains the fitted values for L, k, x0
L, k, x0 = params

# Create a new DataFrame for the logistic fit curve
logistic_fit = pd.DataFrame(
    {
        "end_time": np.linspace(
            GroupData_morning["end_time"].min(),
            GroupData_morning["end_time"].max(),
            100,
        ),
    }
)

# Calculate y values for the logistic fit curve
logistic_fit["cumulative_solved_trials"] = logistic(logistic_fit["end_time"], L, k, x0)

# Calculate R-squared for logistic fit curve
y_pred_logistic = logistic(GroupData_morning["end_time"], L, k, x0)
r_squared_logistic = r2_score(
    GroupData_morning["cumulative_solved_trials"], y_pred_logistic
)

# Calculate linear fit for the first half of the data
first_half = GroupData_morning.iloc[: len(GroupData_morning) // 2]
slope, intercept, _, _, _ = stats.linregress(
    first_half["end_time"], first_half["cumulative_solved_trials"]
)

# Create a new DataFrame for the linear fit line
linear_fit = pd.DataFrame(
    {
        "end_time": np.linspace(
            GroupData_morning["end_time"].min(),
            GroupData_morning["end_time"].max(),
            100,
        ),
    }
)

# Calculate y values for the linear fit line
linear_fit["cumulative_solved_trials"] = slope * linear_fit["end_time"] + intercept

# Calculate R-squared for linear fit curve
y_pred_linear = slope * GroupData_morning["end_time"] + intercept
r_squared_linear = r2_score(
    GroupData_morning["cumulative_solved_trials"], y_pred_linear
)

# Calculate center x value and min and maximum xy value
center_x = (
    GroupData_morning["end_time"].max() - GroupData_morning["end_time"].min()
) / 2
max_y = max(
    max(GroupData_morning["cumulative_solved_trials"]),
    max(logistic_fit["cumulative_solved_trials"]),
)

max_x = GroupData_morning["end_time"].max()
min_y = GroupData_morning["cumulative_solved_trials"].min()

# Create your plot using GroupData_morning and add the fits and annotations
cumulcurve_pool = (
    hv.Curve(
        data=GroupData_morning,
        kdims=["end_time"],
        vdims=["cumulative_solved_trials"],
    )
    .opts(
        height=1000,
        width=1200,
        alpha=1,
        line_width=2,
        xlabel="Time(s)",
        ylabel="Cumulative Solved Trials",
        show_grid=True,
        fontscale=3,
        title="Cumulative Solved Trials Over Time",
    )
    * hv.Curve(
        data=logistic_fit, kdims=["end_time"], vdims=["cumulative_solved_trials"]
    ).opts(color="green")
    * hv.Curve(
        data=linear_fit, kdims=["end_time"], vdims=["cumulative_solved_trials"]
    ).opts(color="red")
    # Uncomment the following lines to add text annotations
    * hv.Text(center_x, max_y - 7, f"Logistic fit R-squared: {r_squared_logistic:.2f}").opts(text_color="green")
    * hv.Text(center_x, max_y, f"Linear fit R-squared: {r_squared_linear:.2f}").opts(text_color="red")
    * hv.Text(max_x - 100, min_y * num_replicates, f"N = {num_replicates}")
).opts(
    opts.Area(
        show_legend=False,
        show_frame=False,
        fill_color="blue",
        line_color="black",
        line_width=0,
    )
)

cumulcurve_pool

In [None]:
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
import holoviews as hv
import hvplot.pandas
from holoviews import opts


# Define double sigmoid function
def double_sigmoid(x, L1, k1, x01, L2, k2, x02):
    return (L1 / (1 + np.exp(-k1 * (x - x01)))) + (L2 / (1 + np.exp(-k2 * (x - x02))))


# Improved initial guesses for L1, k1, x01, L2, k2, x02
p0 = [
    max(GroupData_morning["cumulative_solved_trials"]) / 2,
    0.1,
    np.median(GroupData_morning["end_time"]) / 2,
    max(GroupData_morning["cumulative_solved_trials"]) / 2,
    0.1,
    np.median(GroupData_morning["end_time"]) * 1.5,
]

# Increase maxfev
maxfev_value = 10000

# Fit double sigmoid function to data
params, _ = curve_fit(
    double_sigmoid,
    GroupData_morning["end_time"],
    GroupData_morning["cumulative_solved_trials"],
    p0,
    maxfev=maxfev_value,
)

# params contains the fitted values for L1, k1, x01, L2, k2, x02
L1, k1, x01, L2, k2, x02 = params

# Calculate fitted values
fitted_values = double_sigmoid(GroupData_morning["end_time"], L1, k1, x01, L2, k2, x02)

# Calculate residuals
residuals = GroupData_morning["cumulative_solved_trials"] - fitted_values

# Calculate SS_res and SS_tot
SS_res = np.sum(residuals**2)
SS_tot = np.sum(
    (
        GroupData_morning["cumulative_solved_trials"]
        - np.mean(GroupData_morning["cumulative_solved_trials"])
    )
    ** 2
)

# Calculate R-squared
r_squared = 1 - (SS_res / SS_tot)
print(f"R-squared: {r_squared}")

# Create a new DataFrame for the double sigmoid fit curve
double_sigmoid_fit = pd.DataFrame(
    {
        "end_time": np.linspace(
            GroupData_morning["end_time"].min(),
            GroupData_morning["end_time"].max(),
            100,
        ),
    }
)

# Calculate y values for the double sigmoid fit curve
double_sigmoid_fit["cumulative_solved_trials"] = double_sigmoid(
    double_sigmoid_fit["end_time"], L1, k1, x01, L2, k2, x02
)

# Plot the original data and the double sigmoid fit curve
cumulcurve_pool = hv.Curve(
    data=GroupData_morning,
    kdims=["end_time"],
    vdims=["cumulative_solved_trials"],
).opts(
    height=1000,
    width=1200,
    alpha=1,
    line_width=2,
    xlabel="Time(s)",
    ylabel="Cumulative Solved Trials",
    show_grid=True,
    fontscale=3,
    title="Cumulative Solved Trials Over Time",
) * hv.Curve(
    data=double_sigmoid_fit, kdims=["end_time"], vdims=["cumulative_solved_trials"]
).opts(
    color="green"
)

#cumulcurve_pool

# Double sigmoid + Linear fit

In [None]:
# Filter data to include only relevant subset
GroupData_morning = trial_end_times

# Define double sigmoid function
def double_sigmoid(x, L1, k1, x01, L2, k2, x02):
    return (L1 / (1 + np.exp(-k1 * (x - x01)))) + (L2 / (1 + np.exp(-k2 * (x - x02))))


# Improved initial guesses for L1, k1, x01, L2, k2, x02
p0 = [
    max(GroupData_morning["cumulative_solved_trials"]) / 2,
    0.1,
    np.median(GroupData_morning["end_time"]) / 2,
    max(GroupData_morning["cumulative_solved_trials"]) / 2,
    0.1,
    np.median(GroupData_morning["end_time"]) * 1.5,
]

# Increase maxfev
maxfev_value = 10000

# Fit double sigmoid function to data
params, _ = curve_fit(
    double_sigmoid,
    GroupData_morning["end_time"],
    GroupData_morning["cumulative_solved_trials"],
    p0,
    maxfev=maxfev_value,
)

# params contains the fitted values for L1, k1, x01, L2, k2, x02
L1, k1, x01, L2, k2, x02 = params

# Calculate fitted values
fitted_values = double_sigmoid(GroupData_morning["end_time"], L1, k1, x01, L2, k2, x02)

# Calculate residuals
residuals = GroupData_morning["cumulative_solved_trials"] - fitted_values

# Calculate SS_res and SS_tot
SS_res = np.sum(residuals**2)
SS_tot = np.sum(
    (
        GroupData_morning["cumulative_solved_trials"]
        - np.mean(GroupData_morning["cumulative_solved_trials"])
    )
    ** 2
)

# Calculate R-squared
r_squared = 1 - (SS_res / SS_tot)
print(f"R-squared: {r_squared}")

# Create a new DataFrame for the double sigmoid fit curve
double_sigmoid_fit = pd.DataFrame(
    {
        "end_time": np.linspace(
            GroupData_morning["end_time"].min(),
            GroupData_morning["end_time"].max(),
            100,
        ),
    }
)

# Calculate y values for the double sigmoid fit curve
double_sigmoid_fit["cumulative_solved_trials"] = double_sigmoid(
    double_sigmoid_fit["end_time"], L1, k1, x01, L2, k2, x02
)

# Fit a linear regression model
linear_model = LinearRegression()
linear_model.fit(
    GroupData_morning[["end_time"]], GroupData_morning["cumulative_solved_trials"]
)

# Calculate fitted values for the linear model
linear_fitted_values = linear_model.predict(GroupData_morning[["end_time"]])

# Calculate R-squared for the linear model
linear_r2 = r2_score(
    GroupData_morning["cumulative_solved_trials"], linear_fitted_values
)
print(f"Linear R-squared: {linear_r2}")

# Create a new DataFrame for the linear fit curve
linear_fit = pd.DataFrame(
    {
        "end_time": GroupData_morning["end_time"],
        "cumulative_solved_trials": linear_fitted_values,
    }
)

# Plot the original data, the double sigmoid fit curve, and the linear fit curve
cumulcurve_pool = (
    hv.Curve(
        data=GroupData_morning,
        kdims=["end_time"],
        vdims=["cumulative_solved_trials"],
    ).opts(
        height=1000,
        width=1200,
        alpha=1,
        line_width=2,
        xlabel="Time(s)",
        ylabel="Cumulative Solved Trials",
        show_grid=True,
        fontscale=3,
        title="Cumulative Solved Trials Over Time",
    )
    * hv.Curve(
        data=double_sigmoid_fit, kdims=["end_time"], vdims=["cumulative_solved_trials"]
    )
    .opts(
        color="green",
    )
    .relabel(f"Double Sigmoid Fit (R²={r_squared:.2f})")
    * hv.Curve(data=linear_fit, kdims=["end_time"], vdims=["cumulative_solved_trials"])
    .opts(
        color="red",
        line_dash="dashed",
    )
    .relabel(f"Linear Fit (R²={linear_r2:.2f})")
)

cumulcurve_pool = cumulcurve_pool.options(**HoloviewsTemplates.hv_slides["plot"]).opts(
    invert_axes=False, show_legend=True)

hv.save(
    cumulcurve_pool,
    "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240809_CumulativeSolvedTrials_sigmoid_linear.png",
)

In [None]:
# Save as PNG

hv.save(cumulcurve_pool, "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240809_CumulativeSolvedTrials.png")

In [None]:
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from scipy import stats
import holoviews as hv
import hvplot.pandas
from holoviews import opts  # Import opts from holoviews


# Define logistic (sigmoid) function
def sigmoid(x, L, k, x0):
    return L / (1 + np.exp(-k * (x - x0)))


# Filter data to include only relevant subset
GroupData_morning = trial_end_times

# Calculate the number of unique 'flies' in the dataset
num_replicates = GroupData_morning["fly"].nunique()

# Initial guesses for L, k, x0
p0 = [
    max(GroupData_morning["cumulative_solved_trials"]),
    1,
    np.median(GroupData_morning["end_time"]),
]

# Fit sigmoid function to data
params, _ = curve_fit(
    sigmoid,
    GroupData_morning["end_time"],
    GroupData_morning["cumulative_solved_trials"],
    p0,
    maxfev=5000,
)

# params contains the fitted values for L, k, x0
L, k, x0 = params

# Create a new DataFrame for the sigmoid fit curve
sigmoid_fit = pd.DataFrame(
    {
        "end_time": np.linspace(
            GroupData_morning["end_time"].min(),
            GroupData_morning["end_time"].max(),
            100,
        ),
    }
)

# Calculate y values for the sigmoid fit curve
sigmoid_fit["cumulative_solved_trials"] = sigmoid(sigmoid_fit["end_time"], L, k, x0)

# Calculate R-squared for sigmoid fit curve
y_pred_sigmoid = sigmoid(GroupData_morning["end_time"], L, k, x0)
r_squared_sigmoid = r2_score(
    GroupData_morning["cumulative_solved_trials"], y_pred_sigmoid
)

# Calculate linear fit for the first half of the data
first_half = GroupData_morning.iloc[: len(GroupData_morning) // 2]
slope, intercept, _, _, _ = stats.linregress(
    first_half["end_time"], first_half["cumulative_solved_trials"]
)

# Create a new DataFrame for the linear fit line
linear_fit = pd.DataFrame(
    {
        "end_time": np.linspace(
            GroupData_morning["end_time"].min(),
            GroupData_morning["end_time"].max(),
            100,
        ),
    }
)

# Calculate y values for the linear fit line
linear_fit["cumulative_solved_trials"] = slope * linear_fit["end_time"] + intercept

# Calculate R-squared for linear fit curve
y_pred_linear = slope * GroupData_morning["end_time"] + intercept
r_squared_linear = r2_score(
    GroupData_morning["cumulative_solved_trials"], y_pred_linear
)

# Calculate center x value and min and maximum xy value
center_x = (
    GroupData_morning["end_time"].max() - GroupData_morning["end_time"].min()
) / 2
max_y = max(
    max(GroupData_morning["cumulative_solved_trials"]),
    max(sigmoid_fit["cumulative_solved_trials"]),
)

max_x = GroupData_morning["end_time"].max()
min_y = GroupData_morning["cumulative_solved_trials"].min()

# Create your plot using GroupData_morning and add the fits and annotations
cumulcurve_pool = (
    hv.Curve(
        data=GroupData_morning,
        kdims=["end_time"],
        vdims=["cumulative_solved_trials"],
    )
    .opts(
        height=1000,
        width=1200,
        alpha=1,
        line_width=2,
        xlabel="Time(s)",
        ylabel="Cumulative Solved Trials",
        show_grid=True,
        fontscale=3,
        title="Cumulative Solved Trials Over Time",
    )
    * hv.Curve(
        data=sigmoid_fit, kdims=["end_time"], vdims=["cumulative_solved_trials"]
    ).opts(color="blue")
    * hv.Curve(
        data=linear_fit, kdims=["end_time"], vdims=["cumulative_solved_trials"]
    ).opts(color="red")
    # Uncomment the following lines to add text annotations
    * hv.Text(center_x, max_y - 10, f"Sigmoid fit R-squared: {r_squared_sigmoid:.2f}").opts(text_color="blue")
    * hv.Text(center_x, max_y, f"Linear fit R-squared: {r_squared_linear:.2f}").opts(text_color="red")
    * hv.Text(max_x - 100, min_y * num_replicates, f"N = {num_replicates}")
).opts(
    opts.Area(
        show_legend=False,
        show_frame=False,
        fill_color="blue",
        line_color="black",
        line_width=0,
    )
)

cumulcurve_pool

In [None]:
Filtered["solved_trials"] = Filtered["Trial"] - 1

# On clean data, get the average "solved_trial" grouped by time

average_solved = Filtered.groupby("time")["solved_trials"].mean()

# Plot the average "solved_trials" over time
average_solved.hvplot(
    title="Average Solved Trials Over Time",
    xlabel="Time",
    ylabel="Average Solved Trials",
    height=400,
    width=800,
)

In [None]:
import pandas as pd
import hvplot.pandas

# Convert Trial back to integers
cleaned_data["Trial"] = cleaned_data["Trial"].astype(int)


# Function to compute the end time of each trial
def compute_trial_end_time(group):
    end_time = group["time"].max()
    return pd.Series({"end_time": end_time})


# Apply the function to each group and compute the end times
trial_end_times = (
    cleaned_data.groupby(["fly", "Trial"]).apply(compute_trial_end_time).reset_index()
)

# Sort by end_time
trial_end_times = trial_end_times.sort_values(by="end_time")

# Compute the cumulative count of solved trials
trial_end_times["cumulative_solved_trials"] = (
    trial_end_times["end_time"].rank(method="first").astype(int)
)

# Compute the average number of trials solved over time
average_solved_trials = (
    trial_end_times.groupby("end_time")["cumulative_solved_trials"].mean().reset_index()
)

# Plot the average number of trials solved over time
average_plot = average_solved_trials.hvplot.line(
    x="end_time",
    y="cumulative_solved_trials",
    title="Average Number of Solved Trials Over Time",
    xlabel="Time",
    ylabel="Average Number of Solved Trials",
).opts(
    height=1000,
    width=1200,
    alpha=1,
    line_width=2,
    xlabel="Time(s)",
    ylabel="Average Number of Solved Trials",
    show_grid=True,
    fontscale=3,
    title="Average Number of Solved Trials Over Time",
)

average_plot

In [None]:
import pandas as pd
import hvplot.pandas
import numpy as np

# Convert Trial back to integers
cleaned_data["Trial"] = cleaned_data["Trial"].astype(int)


# Function to compute the end time of each trial
def compute_trial_end_time(group):
    end_time = group["time"].max()
    return pd.Series({"end_time": end_time})


# Apply the function to each group and compute the end times
trial_end_times = (
    cleaned_data.groupby(["fly", "Trial"]).apply(compute_trial_end_time).reset_index()
)

# Sort by end_time
trial_end_times = trial_end_times.sort_values(by="end_time")

# Compute the cumulative count of solved trials
trial_end_times["cumulative_solved_trials"] = (
    trial_end_times["end_time"].rank(method="first").astype(int)
)

# Compute the average number of trials solved over time
average_solved_trials = (
    trial_end_times.groupby("end_time")["cumulative_solved_trials"].mean().reset_index()
)

# Compute the standard deviation of the cumulative solved trials at each time point
std_solved_trials = (
    trial_end_times.groupby("end_time")["cumulative_solved_trials"].std().reset_index()
)

# Merge the mean and standard deviation data
average_solved_trials = average_solved_trials.merge(
    std_solved_trials, on="end_time", suffixes=("_mean", "_std")
)

# Compute the 95% confidence intervals
z_score = 1.96  # for 95% confidence interval
average_solved_trials["upper_bound"] = average_solved_trials[
    "cumulative_solved_trials_mean"
] + z_score * (average_solved_trials["cumulative_solved_trials_std"] / np.sqrt(80))
average_solved_trials["lower_bound"] = average_solved_trials[
    "cumulative_solved_trials_mean"
] - z_score * (average_solved_trials["cumulative_solved_trials_std"] / np.sqrt(80))

# Plot the average number of trials solved over time with confidence intervals
average_plot = average_solved_trials.hvplot.line(
    x="end_time",
    y="cumulative_solved_trials_mean",
    title="Average Number of Solved Trials Over Time",
    xlabel="Time",
    ylabel="Average Number of Solved Trials",
    color="blue",
    label="Mean",
).opts(
    height=1000,
    width=1200,
    alpha=1,
    line_width=2,
    xlabel="Time(s)",
    ylabel="Average Number of Solved Trials",
    show_grid=True,
    fontscale=3,
    title="Average Number of Solved Trials Over Time",
)

confidence_plot = average_solved_trials.hvplot.area(
    x="end_time",
    y="upper_bound",
    y2="lower_bound",
    color="blue",
    alpha=0.3,
    label="95% CI",
)

average_plot * confidence_plot

# Average trials solved per frame

In [None]:
# Generate a column "solved_trials" that is equal to "Trial" - 1

Filtered["solved_trials"] = Filtered["Trial"] - 1

# Group by time and fly, then compute the mean of "solved_trials"

average_solved = Filtered.groupby(["time"])["solved_trials"].mean().reset_index()

In [None]:
# Plot the curve of the average "solved_trials" over time

avgplot = average_solved.hvplot(
    x="time",
    y="solved_trials",
    kind="line",
    title="Average Solved Trials Over Time",
    xlabel="Time",
    ylabel="Average Solved Trials",
    height=400,
    width=800,
)

# Save the plot

hv.save(avgplot, "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240815_AverageSolvedTrials.png")

In [None]:
cleaned_data["solved_trials"] = cleaned_data["Trial"] - 1

# Group by time and fly, then compute the mean of "solved_trials"

clean_average_solved = cleaned_data.groupby(["time"])["solved_trials"].mean().reset_index()

In [None]:
clean_avgplot = clean_average_solved.hvplot(
    x="time",
    y="solved_trials",
    kind="line",
    title="Average Solved Trials Over Time",
    xlabel="Time",
    ylabel="Average Solved Trials",
    height=400,
    width=800,
)

hv.save(clean_avgplot, "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240815_CleanAverageSolvedTrials.png")

In [None]:
# Assuming df is your DataFrame with 'time' and 'solved_trials' columns
# Group by 'time' and calculate mean and SEM
grouped = cleaned_data.groupby("time")["solved_trials"].agg(["mean", "sem"]).reset_index()

# Keep only times between 0 and 7100

grouped = grouped[(grouped["time"] >= 0) & (grouped["time"] <= 7100)]

# Calculate the confidence intervals (95% confidence level)
confidence_level = 0.95
z_score = 1.96  # Z-score for 95% confidence
grouped["ci_lower"] = grouped["mean"] - z_score * grouped["sem"]
grouped["ci_upper"] = grouped["mean"] + z_score * grouped["sem"]

# Rename columns for clarity
grouped.rename(columns={"mean": "solved_trials"}, inplace=True)

# Plot the data with confidence intervals
clean_avgplot = grouped.hvplot(
    x="time",
    y="solved_trials",
    kind="line",
    title="Average Solved Trials Over Time",
    xlabel="Time",
    ylabel="Average Solved Trials",
    height=400,
    width=800,
)

# Create the area plot for the confidence intervals
ci_area = hv.Area(
    (grouped["time"], grouped["ci_lower"], grouped["ci_upper"]),
    vdims=["ci_lower", "ci_upper"],
).opts(alpha=0.3, color="blue")

# Overlay the confidence interval area plot with the line plot
combined_plot = ci_area * clean_avgplot

# Save the combined plot
hv.save(
    combined_plot,
    "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240815_CleanAverageSolvedTrials.png",
)

In [None]:
# Group by 'time' and calculate mean and SEM
grouped = (
    cleaned_data.groupby("time")["solved_trials"].agg(["mean", "sem"]).reset_index()
)

# Keep only times between 0 and 7100
grouped = grouped[(grouped["time"] >= 0) & (grouped["time"] <= 7100)]

# Calculate the confidence intervals (95% confidence level)
confidence_level = 0.95
z_score = 1.96  # Z-score for 95% confidence
grouped["ci_lower"] = grouped["mean"] - z_score * grouped["sem"]
grouped["ci_upper"] = grouped["mean"] + z_score * grouped["sem"]

# Rename columns for clarity
grouped.rename(columns={"mean": "solved_trials"}, inplace=True)

# Fit a linear regression model
linear_model = LinearRegression()
linear_model.fit(grouped[["time"]], grouped["solved_trials"])
grouped["linear_fit"] = linear_model.predict(grouped[["time"]])
linear_r2 = r2_score(grouped["solved_trials"], grouped["linear_fit"])


# Define logistic function
def logistic(x, L, x0, k):
    return L / (1 + np.exp(-k * (x - x0)))


# Initial guesses for L, x0, k
L_initial = max(grouped["solved_trials"])
x0_initial = np.median(grouped["time"])
k_initial = 1 / (max(grouped["time"]) - min(grouped["time"]))

# Fit a logistic regression model
popt, _ = curve_fit(
    logistic,
    grouped["time"],
    grouped["solved_trials"],
    p0=[L_initial, x0_initial, k_initial],
    maxfev=5000,
)
grouped["logistic_fit"] = logistic(grouped["time"], *popt)
logistic_r2 = r2_score(grouped["solved_trials"], grouped["logistic_fit"])

# Plot the data with confidence intervals
clean_avgplot = grouped.hvplot(
    x="time",
    y="solved_trials",
    kind="line",
    title="Average Solved Trials Over Time",
    xlabel="Time",
    ylabel="Average Solved Trials",
    height=400,
    width=800,
)

# Create the area plot for the confidence intervals
ci_area = hv.Area(
    (grouped["time"], grouped["ci_lower"], grouped["ci_upper"]),
    vdims=["ci_lower", "ci_upper"],
).opts(alpha=0.3, color="gray")

# Create the linear fit plot
linear_fit_plot = grouped.hvplot(
    x="time",
    y="linear_fit",
    kind="line",
    color="red",
    line_dash="dashed",
    label=f"Linear Fit (R²={linear_r2:.2f})",
)

# Create the logistic fit plot
logistic_fit_plot = grouped.hvplot(
    x="time",
    y="logistic_fit",
    kind="line",
    color="blue",
    line_dash="dotted",
    label=f"Logistic Fit (R²={logistic_r2:.2f})",
)

# Overlay the confidence interval area plot with the line plot and fits
combined_plot = (ci_area * clean_avgplot * linear_fit_plot * logistic_fit_plot)

# Save the combined plot
hv.save(
    combined_plot,
    "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240815_CleanAverageSolvedTrials.png",
)

In [None]:
# Group by 'time' and calculate mean and SEM
grouped = (
    cleaned_data.groupby("time")["solved_trials"].agg(["mean", "sem"]).reset_index()
)

# Keep only times between 0 and 7100
grouped = grouped[(grouped["time"] >= 0) & (grouped["time"] <= 7100)]

# Calculate the confidence intervals (95% confidence level)
confidence_level = 0.95
z_score = 1.96  # Z-score for 95% confidence
grouped["ci_lower"] = grouped["mean"] - z_score * grouped["sem"]
grouped["ci_upper"] = grouped["mean"] + z_score * grouped["sem"]

# Rename columns for clarity
grouped.rename(columns={"mean": "solved_trials"}, inplace=True)

# Fit a linear regression model
linear_model = LinearRegression()
linear_model.fit(grouped[["time"]], grouped["solved_trials"])
grouped["linear_fit"] = linear_model.predict(grouped[["time"]])
linear_r2 = r2_score(grouped["solved_trials"], grouped["linear_fit"])

# Fit a piecewise linear regression model
piecewise_model = pwlf.PiecewiseLinFit(grouped["time"], grouped["solved_trials"])

# Define the number of line segments
num_segments = 4  # Adjust this number based on your data

# Fit the model
breaks = piecewise_model.fit(num_segments)

# Predict the piecewise linear fit
grouped["piecewise_fit"] = piecewise_model.predict(grouped["time"])
piecewise_r2 = r2_score(grouped["solved_trials"], grouped["piecewise_fit"])

# Plot the data with confidence intervals
clean_avgplot = grouped.hvplot(
    x="time",
    y="solved_trials",
    kind="line",
    title="Average Solved Trials Over Time",
    xlabel="Time",
    ylabel="Average Solved Trials",
    line_width=3,
    height=400,
    width=800,
)

# Create the area plot for the confidence intervals
ci_area = hv.Area(
    (grouped["time"], grouped["ci_lower"], grouped["ci_upper"]),
    vdims=["ci_lower", "ci_upper"],
).opts(alpha=0.1, color="blue")

# Create the linear fit plot
linear_fit_plot = grouped.hvplot(
    x="time",
    y="linear_fit",
    kind="line",
    line_width=3,
    color="red",
    line_dash="dashed",
    label=f"Linear Fit (R²={linear_r2:.2f})",
)

# Create the piecewise linear fit plot
piecewise_fit_plot = grouped.hvplot(
    x="time",
    y="piecewise_fit",
    kind="line",
    line_width=3,
    color="green",
    line_dash="dotdash",
    label=f"Piecewise Linear Fit (R²={piecewise_r2:.2f})",
)

# Overlay the confidence interval area plot with the line plot and fits
combined_plot = ci_area * clean_avgplot * linear_fit_plot * piecewise_fit_plot

# Save the combined plot
hv.save(
    combined_plot,
    "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240815_CleanAverageSolvedTrials_Linear.png",
)

In [None]:
combined_plot = combined_plot.options(**HoloviewsTemplates.hv_slides["plot"]).opts(
    invert_axes=False, show_legend=True
)

hv.save(
    combined_plot,
    "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240815_CleanAverageSolvedTrials_Linear.png",
)

In [None]:
Subplot = (clean_avgplot * ci_area).options(**HoloviewsTemplates.hv_slides["plot"]).opts(
    invert_axes=False, show_legend=True)

hv.save(Subplot,
        "/mnt/upramdya_data/MD/MultiMazeRecorder/Plots/Learning/240815_CleanAverageSolvedTrials_Linear_Subplot.png")