In [None]:
from pathlib import Path
import re
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from lumicks import pylake
from lumicks.pylake.file import FdCurve
import seaborn as sns
from skimage.measure import label
from scipy.signal import find_peaks
from dataclasses import dataclass
import pandas as pd

In [None]:
@dataclass
class ForcePeak:
    distance: float
    force: float
    index: int


@dataclass
class Oscillation:
    increasing_force: npt.NDArray[np.float64]
    increasing_distance: npt.NDArray[np.float64]
    decreasing_force: npt.NDArray[np.float64]
    decreasing_distance: npt.NDArray[np.float64]
    force_peaks: list[ForcePeak] | None = None
    num_peaks: int | None = None

    def plot(self, title: str = "", plot_peaks=True):
        plt.figure(figsize=(10, 5))
        plt.plot(self.increasing_distance, self.increasing_force, label="Increasing Segment")
        plt.plot(self.decreasing_distance, self.decreasing_force, label="Decreasing Segment")
        if plot_peaks:
            if len(self.force_peaks) > 0:
                # plot a vertical line at each peak
                for peak in self.force_peaks:
                    plt.axvline(
                        x=peak.distance,
                        color="grey",
                        linestyle="--",
                        label="Detected Peak" if peak == self.force_peaks[0] else "",
                    )
            else:
                print("No peaks to plot.")
        plt.xlabel("Distance (um)")
        plt.ylabel("Force (pN)")
        plt.title("Force-Distance Curve")
        plt.legend()
        plt.title(title)
        plt.show()

    def find_peaks(
        self,
        peak_height: tuple[float, float] | None = None,
        peak_prominence: float | None = None,
        verbose: bool | None = None,
        plotting: bool | None = None,
        oscillation_index: int | None = None,
        curve_id: str | None = None,
    ):
        # find the peaks just in this oscillation
        # allow overrides but keep the original defaults in the sub function
        kwargs = {k: v for k, v in locals().items() if k != "self" and v is not None}
        return find_peaks_individual_oscillation(oscillation=self, **kwargs)


@dataclass
class ReducedFDCurve:
    filename: str
    curve_id: int
    all_forces: npt.NDArray[np.float64]
    all_distances: npt.NDArray[np.float64]
    oscillations: list[Oscillation]
    include_in_processing: bool = True


class ReducedMarker:
    def __init__(self, file_path: Path, verbose: bool = False, plotting: bool = False):
        self.file_path = file_path
        self.file_name = file_path.name
        self.file = pylake.File(self.file_path)
        self.metadata = self.get_file_metadata(self.file_name)
        self.telereps = self.metadata["telereps"]
        self.protein_name = self.metadata["protein_name"]
        self.fd_curves = self.load_curves(
            filename=self.file_name, fdcurves=self.file.fdcurves, verbose=verbose, plotting=plotting
        )
        self.include_in_processing = True

    @staticmethod
    def get_file_metadata(filename: str) -> dict[str, str]:
        """
        Obtain file metadata from the filename.

        Parameters
        ----------
        filename : str
            The name of the file to extract the metadata from.

        Returns
        -------
        dict[str, str]
            A dictionary containing the metadata extracted from the filename.
        """
        metadata = {}
        # grab tel_reps
        tel_reps = re.search(r"Tel(\d+)", filename)
        if tel_reps:
            tel_reps = int(tel_reps.group(1))
        else:
            raise ValueError(f"Could not find telereps regex: Tel(\\d+) in file name {filename}")
        metadata["telereps"] = tel_reps
        # grab protein name, assumed to be before the string "Marker X"
        protein_name = re.search(r" (\w+)(?= Marker \d+)", filename)
        if protein_name:
            protein_name = protein_name.group(1)
        else:
            raise ValueError(f"Could not find protein name in file name {filename}")
        metadata["protein_name"] = protein_name

        return metadata

    @staticmethod
    def load_curves(
        filename: str, fdcurves: dict[str, FdCurve], verbose: bool = False, plotting: bool = False
    ) -> dict[str, ReducedFDCurve]:
        """
        Load the force-distance curves from the file.

        Parameters
        ----------
        fdcurves : dict[str, pylake.FdCurve]
            A dictionary of force-distance curves.
        verbose : bool, optional
            If True, print additional information about the curves being loaded. Default is False.
        plotting : bool, optional
            If True, plot the intermediate plots. Default is False.

        Returns
        -------
        dict[str, ReducedFDCurve]
            A dictionary of reduced force-distance curves, where the keys are the curve IDs and the values are
            instances of ReducedFDCurve containing the force and distance data, as well as the oscillations found in
            the curves.
        """
        fd_curves = {}
        for curve_id, curve_data in fdcurves.items():
            if verbose:
                print(f"Loading curve {curve_id} with {len(curve_data.d.data)} data points.")
            force_data = curve_data.f.data
            distance_data = curve_data.d.data

            if plotting:
                plt.plot(distance_data, force_data, label=f"Curve {curve_id}")
                plt.title(f"Force-Distance Curve {curve_id}")
                plt.ylabel("Force (pN)")
                plt.xlabel("Distance (um)")
                plt.legend()
                plt.show()

            # determine the starting distance to be the first peak in frequency of the distance data
            bin_size_um = 0.1
            bin_edges = np.arange(np.min(distance_data), np.max(distance_data) + bin_size_um, bin_size_um)
            hist, _ = np.histogram(distance_data, bins=bin_edges)

            if False:
                # plot the histogram
                plt.hist(distance_data, bins=bin_edges, alpha=0.5, label=f"Curve {curve_id}")
                plt.title(f"Distance Histogram {curve_id}")
                plt.ylabel("Counts")
                plt.xlabel("Distance (um)")
                plt.legend()
                plt.show()

            # find the largest peak in the histogram
            peak_index = np.argmax(hist)
            # get the midpoint of the bin
            base_distance = (bin_edges[peak_index] + bin_edges[peak_index + 1]) / 2
            if verbose:
                print(f"Base distance: {base_distance} um")

            # check that the peak is strong, as in that the peak bin contains a lot more counts than the other bins
            # critera: peak should be at least 2x the next highest bin
            next_highest_bin = np.partition(hist, -2)[-2]
            peak_strength = hist[peak_index] / next_highest_bin if next_highest_bin > 0 else 0
            if verbose:
                print(f"Peak strength: {peak_strength:.2f} (peak is strength * next highest bin)")
            peak_strength_threshold = 2.0
            if peak_strength < peak_strength_threshold:
                print(
                    f"Peak at {base_distance} um is not strong enough (<{peak_strength_threshold})."
                    f" skipping curve {curve_id}"
                )
                continue

            flat_distance_um = base_distance
            flat_distance_tolerance_um = 0.1

            flat_regions_bool = np.abs(distance_data - flat_distance_um) < flat_distance_tolerance_um
            flat_regions: list[tuple[int, int]] = []
            current_flat_region_start = None
            for index, is_flat in enumerate(flat_regions_bool):
                if is_flat and current_flat_region_start is None:
                    current_flat_region_start = index
                elif not is_flat and current_flat_region_start is not None:
                    flat_regions.append((current_flat_region_start, index - 1))
                    current_flat_region_start = None
            if current_flat_region_start is not None:
                flat_regions.append((current_flat_region_start, len(flat_regions_bool) - 1))

            if verbose:
                print(f"Flat regions: {flat_regions}")

            # eliminate non-flat regions at the start and end of the curve
            if flat_regions:
                if flat_regions[0][0] > 0:
                    # cut the array to start at the first flat region
                    distance_data_trimmed = distance_data[flat_regions[0][0] :]
                    force_data_trimmed = force_data[flat_regions[0][0] :]
                    flat_regions_bool_trimmed = flat_regions_bool[flat_regions[0][0] :]
                else:
                    distance_data_trimmed = distance_data
                    force_data_trimmed = force_data
                    flat_regions_bool_trimmed = flat_regions_bool
                if flat_regions[-1][1] < len(distance_data_trimmed) - 1:
                    # cut the array to end at the last flat region
                    distance_data_trimmed = distance_data_trimmed[: flat_regions[-1][1] + 1]
                    force_data_trimmed = force_data_trimmed[: flat_regions[-1][1] + 1]
                    flat_regions_bool_trimmed = flat_regions_bool_trimmed[: flat_regions[-1][1] + 1]
            else:
                distance_data_trimmed = distance_data
                force_data_trimmed = force_data
                flat_regions_bool_trimmed = flat_regions_bool

            if plotting:
                plt.plot(distance_data, label="Distance Data")
                plt.plot(distance_data_trimmed, label="Trimmed Distance Data", linestyle="--")
                plt.title(f"Distance data retrieved via file.fdcurves[{curve_id}].d.data")
                plt.xlabel("Index")
                plt.ylabel("Distance (um)")
                plt.legend()
                plt.show()

            # get the non-flat regions
            non_flat_regions_bool_trimmed = ~flat_regions_bool_trimmed
            # label the non-flat regions
            labelled_non_flat_regions = label(non_flat_regions_bool_trimmed)
            # for each non-flat region, determine the maximum distance value and set the indexes to the left of it
            # as the increasing segment and indexes to the right as the decreasing segment
            oscillations: list[Oscillation] = []
            for label_index in range(1, labelled_non_flat_regions.max() + 1):
                # get the indexes of the current non-flat region
                non_flat_region_indexes = np.argwhere(labelled_non_flat_regions == label_index)
                non_flat_region_start = non_flat_region_indexes[0][0]
                non_flat_region_end = non_flat_region_indexes[-1][0]
                non_flat_region_distances = distance_data_trimmed[non_flat_region_start : non_flat_region_end + 1]
                non_flat_region_local_maximum_distance_index = np.argmax(non_flat_region_distances)
                non_flat_region_global_maximum_distance_index = (
                    non_flat_region_start + non_flat_region_local_maximum_distance_index
                )
                # set the increasing segment to the left of the maximum distance index, with the second index being
                # exclusive
                increasing_segment_start = non_flat_region_start
                # keep the largest distance value in the increasing segment
                increasing_segment_end = non_flat_region_global_maximum_distance_index + 1
                # set the decreasing segment to the right of the maximum distance index, with the second index
                # being exclusive
                decreasing_segment_start = non_flat_region_global_maximum_distance_index + 1
                decreasing_segment_end = non_flat_region_end + 1
                # get the force data for the increasing and decreasing segments
                increasing_force = force_data_trimmed[increasing_segment_start:increasing_segment_end]
                increasing_distance = distance_data_trimmed[increasing_segment_start:increasing_segment_end]
                decreasing_force = force_data_trimmed[decreasing_segment_start:decreasing_segment_end]
                decreasing_distance = distance_data_trimmed[decreasing_segment_start:decreasing_segment_end]

                # create an Oscillation object
                oscillation = Oscillation(
                    increasing_force=increasing_force,
                    increasing_distance=increasing_distance,
                    decreasing_force=decreasing_force,
                    decreasing_distance=decreasing_distance,
                )
                oscillations.append(oscillation)

            # Check the oscillations by fitting an ewlc model to the return curve

            for oscillation_index, oscillation in enumerate(oscillations):
                increasing_force = oscillation.increasing_force
                increasing_distance = oscillation.increasing_distance
                decreasing_force = oscillation.decreasing_force
                decreasing_distance = oscillation.decreasing_distance

                if plotting:
                    xs = np.arange(len(increasing_distance) + len(decreasing_distance))
                    # plot the distance data
                    plt.plot(xs[: len(increasing_distance)], increasing_distance, label="increasing")
                    plt.plot(xs[len(increasing_distance) :], decreasing_distance, label="decreasing")
                    plt.title(f"Oscillation {oscillation_index} distance segments for curve {curve_id}")
                    plt.xlabel("Index")
                    plt.ylabel("Distance (um)")
                    plt.legend()
                    plt.show()
                    # plot the force data
                    plt.plot(xs[: len(increasing_force)], increasing_force, label="increasing")
                    plt.plot(xs[len(increasing_force) :], decreasing_force, label="decreasing")
                    plt.title(f"Oscillation {oscillation_index} force segments for curve {curve_id}")
                    plt.xlabel("Index")
                    plt.ylabel("Force (pN)")
                    plt.legend()
                    plt.show()

                # fit a ewlc model to the decreasing segment
                model = pylake.ewlc_odijk_force(name="ewlc_return_fit") + pylake.force_offset(name="ewlc_return_fit")
                fit = pylake.FdFit(model)
                fit.add_data(name="return", f=decreasing_force, d=decreasing_distance)
                fit["ewlc_return_fit/Lp"].value = 50
                fit["ewlc_return_fit/Lp"].lower_bound = 39
                fit["ewlc_return_fit/Lp"].upper_bound = 80
                fit["ewlc_return_fit/Lc"].value = 27
                fit["ewlc_return_fit/f_offset"].lower_bound = 0
                fit["ewlc_return_fit/f_offset"].upper_bound = 1

                fit.fit()
                # note that the error for some reason is the same value repeated for each point. taking the mean
                # just in case this changes.
                error = np.mean(fit.sigma)

                if plotting:
                    fit.plot()
                    plt.title(
                        f"Fit for decreasing segment of oscillation {oscillation_index} for "
                        f"curve {curve_id} after fitting. Error: {error:.2f} pN"
                    )
                    plt.xlabel("Distance (um)")
                    plt.ylabel("Force (pN)")
                    plt.legend()
                    plt.show()

            if verbose:
                print(f"Found {len(oscillations)} oscillations in curve {curve_id}.")
            # create a FDCurve object
            fd_curve = ReducedFDCurve(
                filename=filename,
                curve_id=curve_id,
                all_forces=force_data_trimmed,
                all_distances=distance_data_trimmed,
                oscillations=oscillations,
            )
            fd_curves[curve_id] = fd_curve

            if verbose:
                print(f"Loaded curve {curve_id} from {filename}.")

        return fd_curves

    def find_peaks_marker(
        self,
        peak_height: tuple[float, float] = (0.5, 30),
        prominence=0.8,
        verbose: bool = False,
        plotting: bool = False,
    ):
        """Find peaks in the oscillations"""
        # iterate over the oscillations and find peaks, adding them to the oscillation objects
        for curve_id, fd_curve in self.fd_curves.items():
            if verbose:
                print(f"Finding peaks in curve {curve_id} with {len(fd_curve.oscillations)} oscillations.")
            for oscillation_index, oscillation in enumerate(fd_curve.oscillations):
                if verbose:
                    print(f"  Oscillation {oscillation_index}: ", end="")
                force_peaks = find_peaks_individual_oscillation(
                    oscillation=oscillation,
                    peak_height=peak_height,
                    peak_prominence=prominence,
                    verbose=verbose,
                    plotting=plotting,
                    oscillation_index=oscillation_index,
                    curve_id=curve_id,
                )

                if len(force_peaks) == 0:
                    if verbose:
                        print(f"No peaks found in oscillation {oscillation_index} for curve {curve_id}.")

                if verbose:
                    # print force peak distances formatted to .2f
                    force_peak_distances = [f"{peak.distance:.2f}" for peak in force_peaks]
                    print(
                        f"Found {len(force_peaks)} peaks in oscillation {oscillation_index} for curve {curve_id}: "
                        f"{', '.join(force_peak_distances)}"
                    )
                # add the force peaks to the oscillation object
                oscillation.force_peaks = force_peaks


@staticmethod
def find_peaks_individual_oscillation(
    oscillation: Oscillation,
    peak_height: tuple[float, float] = (0.5, 30),
    peak_prominence=0.8,
    verbose: bool = False,
    plotting: bool = False,
    oscillation_index: int | None = None,
    curve_id: str | None = None,
) -> list[ForcePeak]:
    """
    Find peaks in an individual oscillation's increasing distance-force curve.

    Note that I think we are making a huge assumption that the distance changes at a constant rate, since we don't
    fit the peaks on the 2d data, but rather only on the force data, ignoring the distance components.

    Parameters
    ----------
    oscillation : Oscillation
        The oscillation object containing force and distance data.
    peak_height : tuple[float, float], optional
        The minimum and maximum height of the peaks to be detected.
    peak_prominence : float, optional
        The prominence of the peaks to be detected.
    verbose : bool, optional
        If True, print additional information about the peaks found.
    plotting : bool, optional
        If True, plot the increasing segment with the detected peaks.

    Returns
    -------
    list[ForcePeak]
        A list of ForcePeak objects.
    """
    # find peaks in the increasing force-distance segment
    increasing_force = oscillation.increasing_force
    increasing_distance = oscillation.increasing_distance
    peak_indexes, _ = find_peaks(increasing_force, height=peak_height, prominence=peak_prominence)
    if verbose:
        print(
            f"Found {len(peak_indexes)} peaks in increasing segment of oscillation {oscillation_index} for curve {curve_id}."
        )

    # Check if the peaks meet the criteria for having at least a certain ratio of the data preceding it being
    # increasing. This is to hopefully avoid noise spikes being detected as peaks.
    vetted_peak_indexes = []
    for peak_number, peak_index in enumerate(peak_indexes):
        if peak_number == 0:
            # always keep the first
            vetted_peak_indexes.append(peak_index)
        else:
            previous_peak_index = peak_indexes[peak_number - 1]
            if verbose:
                print(f"vetting peak {peak_number} at index {peak_index}")
            if should_vet_peak_increasing_force_criteria(
                peak_index=peak_index,
                previous_peak_index=previous_peak_index,
                oscillation_force_data=increasing_force,
                oscillation_distance_data=increasing_distance,
            ):
                print(
                    f"Peak {peak_number} at index {peak_index} ({increasing_distance[peak_index]:.2f} um) in oscillation {oscillation_index} for curve {curve_id} "
                    f"did not meet the increasing force criteria and was removed."
                )
            else:
                vetted_peak_indexes.append(peak_index)

    if len(vetted_peak_indexes) == 0:
        return np.array([])
    if plotting:
        # plot the increasing segment with the peaks
        plt.plot(increasing_distance, increasing_force, label="increasing")
        # vlines for the peaks
        plt.vlines(
            increasing_distance[vetted_peak_indexes],
            ymin=np.min(increasing_force),
            ymax=np.max(increasing_force),
            color="grey",
            label="peaks",
            linestyle="--",
        )
        plt.title(f"Peaks in increasing segment of oscillation {oscillation_index} for curve {curve_id}")
        plt.xlabel("Distance (um)")
        plt.ylabel("Force (pN)")
        plt.legend()
        plt.show()
    force_peaks = []
    for peak_index in vetted_peak_indexes:
        force_peak = ForcePeak(
            distance=increasing_distance[peak_index],
            force=increasing_force[peak_index],
            index=peak_index,
        )
        force_peaks.append(force_peak)
    return force_peaks


def should_vet_peak_increasing_force_criteria(
    peak_index: int,
    previous_peak_index: int,
    oscillation_force_data: npt.NDArray[np.float64],
    oscillation_distance_data: npt.NDArray[np.float64],
    minimum_increasing_decreasing_ratio: float = 0.5,
    verbose: bool = False,
) -> bool:
    """
    Check if a peak meets the criteria for having at least a certain ratio of the data preceding it being increasing.

    This is to hopefully avoid noise spikes being detected as peaks.

    Parameters
    ----------
    peak_index : int
        The index of the peak to be checked.
    previous_peak_index : int
        The index of the previous peak.
    oscillation_force_data : npt.NDArray[np.float64]
        The force data of the oscillation.
    oscillation_distance_data : npt.NDArray[np.float64]
        The distance data of the oscillation.
    minimum_increasing_decreasing_ratio : float, optional
        The minimum ratio of the distance after the minimum force to the distance before the minimum force.
        Default is 0.5.

    Returns
    -------
    bool
        True if the peak does not meet the criteria and should be deleted, False otherwise.
    """
    # find the minimum force between the previous peak and this peak
    between_peak_minimum_force_index = (
        np.argmin(oscillation_force_data[previous_peak_index : peak_index + 1]) + previous_peak_index
    )
    if verbose:
        print(
            f"indexes: previous peak {previous_peak_index}, current peak {peak_index}, minimum force {between_peak_minimum_force_index}"
        )
    distance_at_previous_peak = oscillation_distance_data[previous_peak_index]
    distance_at_minimum_force = oscillation_distance_data[between_peak_minimum_force_index]
    distance_at_current_peak = oscillation_distance_data[peak_index]
    distance_before_minimum = distance_at_minimum_force - distance_at_previous_peak
    distance_after_minimum = distance_at_current_peak - distance_at_minimum_force
    if verbose:
        print(
            f"distance at: previous peak: {distance_at_previous_peak}, minimum force: {distance_at_minimum_force}, current peak: {distance_at_current_peak}"
        )
        print(f"Distance before minimum: {distance_before_minimum}")
        print(f"Distance after minimum: {distance_after_minimum}")
    distance_increasing_decreasing_ratio = distance_after_minimum / distance_before_minimum
    if verbose:
        print(f"Distance increasing/decreasing ratio: {distance_increasing_decreasing_ratio:.2f}")
    if distance_increasing_decreasing_ratio < minimum_increasing_decreasing_ratio:
        return True
    return False

In [None]:
def rupture_force_stats(
    markers: dict[str, ReducedMarker],
    verbose: bool = False,
    plotting: bool = False,
) -> pd.DataFrame:
    """
    Calculate rupture force stats for a list of markers.
    """
    all_peak_stats = []
    for marker_filename, marker in markers.items():
        if verbose:
            print(f"Calculating rupture force stats for marker {marker.file_name}")
        if not marker.include_in_processing:
            print(f"Ignoring non-included marker: {marker.file_name}.")
            continue
        for curve_id, fd_curve in marker.fd_curves.items():
            if verbose:
                print(f"  Curve ID: {curve_id}")
            if not fd_curve.include_in_processing:
                print(f"  Ignoring non-included curve: {curve_id}.")
                continue
            for oscillation_index, oscillation in enumerate(fd_curve.oscillations):
                if oscillation.force_peaks is None:
                    if verbose:
                        print(f"    Oscillation {oscillation_index}: Hasn't been processed for peaks yet.")
                    continue
                # add the force peaks to the list
                for force_peak in oscillation.force_peaks:
                    all_peak_stats.append(
                        {
                            "marker": marker_filename,
                            "curve_id": curve_id,
                            "oscillation_index": oscillation_index,
                            "distance": force_peak.distance,
                            "force": force_peak.force,
                            "tel_reps": marker.telereps,
                            "protein_name": marker.protein_name,
                        }
                    )
    if verbose:
        print(f"Found {len(all_peak_stats)} force peaks in total across all markers.")
    # create a DataFrame from the list of dictionaries
    df = pd.DataFrame(all_peak_stats)
    if plotting:
        # plot the force peaks by protein name
        sns.scatterplot(data=df, x="distance", y="force", hue="protein_name")
        plt.title("Force Peaks")
        plt.xlabel("Distance (um)")
        plt.ylabel("Force (pN)")
        plt.legend(title="protein")
        plt.show()
        # plot the force peaks by tel reps
        sns.scatterplot(data=df, x="distance", y="force", hue="tel_reps")
        plt.title("Force Peaks")
        plt.xlabel("Distance (um)")
        plt.ylabel("Force (pN)")
        plt.legend(title="tel reps")
        plt.show()
    return df

In [None]:
data_dirs = [Path("/Users/sylvi/optical_data/loading_markers/data")]
for data_dir in data_dirs:
    assert data_dir.exists(), f"Data directory {data_dir} does not exist."
output_folder = Path("/Users/sylvi/optical_data/loading_markers/processed/")
markers: dict[str, ReducedMarker] = {}
for data_dir in data_dirs:
    for file in data_dir.glob("*.h5"):
        if file.is_file() and "Marker" in file.name:
            print(f"Loading marker from file: {file}")
            marker = ReducedMarker(file_path=file, verbose=False, plotting=False)
            markers[file.name] = marker

            # find peaks
            marker.find_peaks_marker(
                peak_height=(0.5, 30),
                prominence=0.8,
                verbose=False,
                plotting=False,
            )

print("-" * 100)

# print all peaks found
for marker_name, marker in markers.items():
    print(f"Marker: {marker_name}")
    for curve_id, fd_curve in marker.fd_curves.items():
        print(f"  Curve ID: {curve_id}")
        for oscillation_index, oscillation in enumerate(fd_curve.oscillations):
            print(
                f"    Oscillation {oscillation_index}: "
                f"{len(oscillation.force_peaks)} peaks found at "
                f"{', '.join([f'{peak.distance:.2f} um' for peak in oscillation.force_peaks])}"
            )

# Calculate and print rupture force stats
print("-" * 100)
rupture_force_df = rupture_force_stats(
    markers=markers,
    verbose=False,
    plotting=True,
)

In [None]:
# Just analyse one curve
plot_individual_rupture_region_subtractions = False
plot_rupture_regions = False
print(f"markers: {markers.keys()}")
# marker_filename = "20250611-151042  Tel5 1nM Trf2dTRFH Marker 1.h5"
for marker_filename in markers.keys():
    print(f"Analysing marker {marker_filename}")
    marker = markers[marker_filename]
    marker_fdcurves = marker.fd_curves
    print(f"fdcurves: {marker_fdcurves.keys()}")
    for curve_id in marker_fdcurves.keys():
        print(f"curve {curve_id} in marker {marker_filename}")
        marker_fdcurve = marker_fdcurves[curve_id]
        fdcurve_oscillations = marker_fdcurve.oscillations
        print(f"oscillations list: {np.arange(len(fdcurve_oscillations))}")
        for oscillation_index, oscillation in enumerate(fdcurve_oscillations):

            # fit the return curve
            return_curve_model = pylake.ewlc_odijk_force(name="ewlc_return_fit") + pylake.force_offset(
                name="ewlc_return_fit"
            )
            return_curve_fit = pylake.FdFit(return_curve_model)
            return_curve_fit.add_data(name="return", f=oscillation.decreasing_force, d=oscillation.decreasing_distance)
            return_curve_fit["ewlc_return_fit/Lp"].value = 50
            return_curve_fit["ewlc_return_fit/Lp"].lower_bound = 39
            return_curve_fit["ewlc_return_fit/Lp"].upper_bound = 80
            return_curve_fit["ewlc_return_fit/Lc"].value = 27
            return_curve_fit["ewlc_return_fit/f_offset"].lower_bound = 0
            return_curve_fit["ewlc_return_fit/f_offset"].upper_bound = 1

            return_curve_fit.fit()
            # note that the error for some reason is the same value repeated for each point. taking the mean
            # just in case this changes.
            return_curve_fit_error = np.mean(return_curve_fit.sigma)

            # subtract the increasing curve by the fitted model decreasing curve, assuming the two are similar
            modelled_increasing_force = return_curve_model(
                independent=oscillation.increasing_distance, params=return_curve_fit.params
            )

            # find the peaks in the oscillation
            oscillation.plot(title=f"Oscillation {oscillation_index} for curve {curve_id} in marker {marker_filename}")

            if len(oscillation.force_peaks) < 2:
                print(
                    f"Oscillation {oscillation_index} for curve {curve_id} in marker {marker_filename} has less than 2 peaks, skipping."
                )
                continue

            # Grab the rupture region preceding the main curve, ie the data between the penultimate and last force peak
            penultimate_peak_position = oscillation.force_peaks[-2].index
            last_peak_position = oscillation.force_peaks[-1].index
            last_rupture_region_force = oscillation.increasing_force[
                penultimate_peak_position : last_peak_position + 1
            ]
            last_rupture_region_distance = oscillation.increasing_distance[
                penultimate_peak_position : last_peak_position + 1
            ]
            lowest_force_index = np.argmin(last_rupture_region_force)
            last_rupture_region_force_trimmed = last_rupture_region_force[lowest_force_index:]
            last_rupture_region_distance_trimmed = last_rupture_region_distance[lowest_force_index:]

            # iteratively grab all the rupture regions
            rupture_regions = []
            if plot_rupture_regions:
                fig, ax = plt.subplots(figsize=(20, 8))
            # plot the force-distance until the first rupture region)
            first_rupture_start_position_index = oscillation.force_peaks[0].index
            if plot_rupture_regions:
                ax.plot(
                    oscillation.increasing_distance[: first_rupture_start_position_index + 1],
                    oscillation.increasing_force[: first_rupture_start_position_index + 1],
                    label="non rupture region data",
                    alpha=0.5,
                )
            for i in range(len(oscillation.force_peaks) - 1):
                rupture_start_position = oscillation.force_peaks[i].index
                rupture_end_position = oscillation.force_peaks[i + 1].index
                rupture_region_force = oscillation.increasing_force[rupture_start_position : rupture_end_position + 1]
                rupture_region_distance = oscillation.increasing_distance[
                    rupture_start_position : rupture_end_position + 1
                ]
                lowest_force_index = np.argmin(rupture_region_force)
                rupture_region_force_increase_region = rupture_region_force[lowest_force_index:]
                rupture_region_force_decrease_region = rupture_region_force[:lowest_force_index]
                rupture_region_distance_increase_region = rupture_region_distance[lowest_force_index:]
                rupture_region_distance_decrease_region = rupture_region_distance[:lowest_force_index]
                rupture_regions.append(
                    {
                        "decrease_region_distance": rupture_region_distance_decrease_region,
                        "decrease_region_force": rupture_region_force_decrease_region,
                        "increase_region_distance": rupture_region_distance_increase_region,
                        "increase_region_force": rupture_region_force_increase_region,
                        "decrease_region_start_index": rupture_start_position,
                        "decrease_region_end_index": rupture_start_position + lowest_force_index,
                        "increase_region_start_index": rupture_start_position + lowest_force_index,
                        "increase_region_end_index": rupture_end_position,
                        "model": None,
                        "fit_params": None,
                        "fitted_curve": None,
                        "fit_error": None,
                    }
                )

                if plot_rupture_regions:
                    # plot each rupture region
                    ax.plot(rupture_region_distance, rupture_region_force, label=f"rupture region {i}")
                    ax.plot(
                        rupture_region_distance_increase_region,
                        rupture_region_force_increase_region,
                        label=f"increasing rupture region {i}",
                    )

                    # add a vertical line at the peak
                    current_peak_distance = oscillation.force_peaks[i].distance
                    ax.axvline(x=current_peak_distance, color="grey", linestyle="--", label="peak")

                    ax.set_xlabel("Distance (um)")
                    ax.set_ylabel("Force (pN)")

            if plot_rupture_regions:
                ax.legend()
                ax.set_title(
                    f"Rupture regions for oscillation {oscillation_index} for curve {curve_id} in marker {marker_filename}"
                )
                plt.show()

            # For each region, subtract the previous fit's curve, then fit a new curve to the region
            # iterate backwards through the rupture regions, starting with the last one
            if plot_rupture_regions:
                fig, ax = plt.subplots(figsize=(20, 8))
            for rupture_region_index in range(len(rupture_regions) - 1, -1, -1):
                print(f"Fitting rupture region {rupture_region_index}")
                rupture_region = rupture_regions[rupture_region_index]
                if plot_individual_rupture_region_subtractions:
                    fig_individual, ax_individual = plt.subplots(figsize=(10, 5))

                if plot_rupture_regions:
                    ax.plot(
                        rupture_region["increase_region_distance"],
                        rupture_region["increase_region_force"],
                        label=f"rupture region {rupture_region_index} data",
                        alpha=0.5,
                    )
                if plot_individual_rupture_region_subtractions:
                    ax_individual.plot(
                        rupture_region["decrease_region_distance"],
                        rupture_region["decrease_region_force"],
                        label=f"rupture region {rupture_region_index} decrease data",
                        alpha=0.5,
                    )
                    ax_individual.plot(
                        rupture_region["increase_region_distance"],
                        rupture_region["increase_region_force"],
                        label=f"rupture region {rupture_region_index} data",
                        alpha=0.5,
                    )
                if rupture_region_index == len(rupture_regions) - 1:
                    previous_fit_model = return_curve_model
                    previous_fit_params = return_curve_fit.params
                    previous_fitted_curve = return_curve_model(
                        independent=oscillation.increasing_distance, params=return_curve_fit.params
                    )
                else:
                    previous_rupture_region = rupture_regions[rupture_region_index + 1]
                    previous_fit_model = previous_rupture_region["model"]
                    previous_fit_params = previous_rupture_region["fit_params"]
                    previous_fitted_curve = previous_rupture_region["fitted_curve"]

                # grab the previous fit's curve for the current rupture region's distance data
                previous_fit_forces_for_this_region = previous_fitted_curve[
                    rupture_region["increase_region_start_index"] : rupture_region["increase_region_end_index"] + 1
                ]
                # subtract the current rupture region's force data by the previous fit's curve
                rupture_region["subtracted_increase_region_force"] = (
                    rupture_region["increase_region_force"] - previous_fit_forces_for_this_region
                )
                if plot_rupture_regions:
                    ax.plot(
                        rupture_region["increase_region_distance"],
                        previous_fit_forces_for_this_region,
                        label=f"rupture region {rupture_region_index} previous fit forces",
                    )
                    ax.plot(
                        rupture_region["increase_region_distance"],
                        rupture_region["subtracted_increase_region_force"],
                        label=f"rupture region {rupture_region_index} subtracted force",
                    )

                if plot_individual_rupture_region_subtractions:
                    if plot_rupture_regions:
                        ax_individual.plot(
                            rupture_region["increase_region_distance"],
                            previous_fit_forces_for_this_region,
                            label=f"rupture region {rupture_region_index} previous fit forces",
                        )
                        ax_individual.plot(
                            rupture_region["increase_region_distance"],
                            rupture_region["subtracted_increase_region_force"],
                            label=f"rupture region {rupture_region_index} subtracted force",
                        )
                        ax_individual.set_xlabel("Distance (um)")
                        ax_individual.set_ylabel("Force (pN)")
                        ax_individual.legend()
                        fig_individual.show()

                if plot_rupture_regions:
                    peak_distance = oscillation.force_peaks[rupture_region_index + 1].distance
                    ax.axvline(x=peak_distance, color="grey", linestyle="--", label="peak")

                # fit a curve to this region
                rupture_region_model = pylake.ewlc_odijk_force(name=f"ewlc_return_fit") + pylake.force_offset(
                    name=f"ewlc_return_fit"
                )
                rupture_region_fit = pylake.FdFit(rupture_region_model)
                rupture_region_fit.add_data(
                    name=f"rupture_region",
                    f=rupture_region["subtracted_increase_region_force"],
                    d=rupture_region["increase_region_distance"],
                )
                rupture_region_fit["ewlc_return_fit/Lp"].value = 50
                rupture_region_fit["ewlc_return_fit/Lp"].lower_bound = 39
                rupture_region_fit["ewlc_return_fit/Lp"].upper_bound = 80
                rupture_region_fit["ewlc_return_fit/Lc"].value = 27
                rupture_region_fit["ewlc_return_fit/f_offset"].lower_bound = 0
                rupture_region_fit["ewlc_return_fit/f_offset"].upper_bound = 1
                rupture_region_fit.fit()
                rupture_region_error = np.mean(rupture_region_fit.sigma)
                rupture_region["model"] = rupture_region_model
                rupture_region["fit_params"] = rupture_region_fit.params
                rupture_region["fit_error"] = rupture_region_error
                all_distances_until_region_end = oscillation.increasing_distance[
                    : rupture_region["increase_region_end_index"] + 1
                ]
                fitted_rupture_region_curve = rupture_region_model(
                    independent=all_distances_until_region_end, params=rupture_region_fit.params
                )
                rupture_region["fitted_curve"] = fitted_rupture_region_curve

                # integrate the fit to calculate the work done to rupture the region
                work_done = np.trapezoid(y=fitted_rupture_region_curve, x=all_distances_until_region_end)  # in pN*um
                # pico newtons * micrometers = 1e-9 * 1e-6 = 1e-15 Joules = 1 fJ
                rupture_region["work_done_fJ"] = work_done  # in fJ
                # calculate the peak force from the fitted curve
                peak_force = np.max(fitted_rupture_region_curve)
                rupture_region["peak_force_pN"] = peak_force  # in pN

            if plot_rupture_regions:
                ax.legend()
                ax.set_title(
                    f"Rupture regions fitted for oscillation {oscillation_index} in curve {curve_id}, marker {marker_filename}"
                )
                plt.show()

            # for each region, plot the region's data with its fitted curve
            fig, ax = plt.subplots(figsize=(20, 8))
            for rupture_region_index in range(len(rupture_regions) - 1, -1, -1):
                rupture_region = rupture_regions[rupture_region_index]
                ax.plot(
                    rupture_region["increase_region_distance"],
                    rupture_region["subtracted_increase_region_force"],
                    label=f"rupture region {rupture_region_index} subtracted data",
                    alpha=0.5,
                )
                # plot the entire fitted curve until the end of this region
                all_distances_until_region_end = oscillation.increasing_distance[
                    : rupture_region["increase_region_end_index"] + 1
                ]
                ax.plot(
                    all_distances_until_region_end,
                    rupture_region["fitted_curve"],
                    label=f"rupture region {rupture_region_index} fitted curve, error: {rupture_region['fit_error']:.2f} pN",
                )
            # plot the return curve fit
            return_curve_force_cap = max(
                [rupture_region["increase_region_force"].max() for rupture_region in rupture_regions]
            )
            # trim the return data, note that it's in reverse order
            trimmed_return_curve_mask = oscillation.decreasing_force < return_curve_force_cap
            trimmed_return_curve_distance = oscillation.decreasing_distance[trimmed_return_curve_mask]
            trimmed_return_curve_force = oscillation.decreasing_force[trimmed_return_curve_mask]
            ax.plot(
                trimmed_return_curve_distance,
                trimmed_return_curve_force,
                label="return curve data",
                color="black",
                alpha=0.2,
            )
            modelled_return_curve = return_curve_model(
                independent=trimmed_return_curve_distance, params=return_curve_fit.params
            )
            ax.plot(
                trimmed_return_curve_distance,
                modelled_return_curve,
                label="return curve fitted",
                color="black",
                alpha=0.5,
            )
            ax.set_xlabel("Distance (um)")
            ax.set_ylabel("Force (pN)")
            ax.legend()
            ax.set_title(
                f"Fitted rupture regions for oscillation {oscillation_index} in curve {curve_id} marker {marker_filename}"
            )
            plt.show()

            # print the work done and peak force for each rupture region
            for rupture_region_index in range(len(rupture_regions) - 1, -1, -1):
                rupture_region = rupture_regions[rupture_region_index]
                print(
                    f"Rupture region {rupture_region_index}: "
                    f"Work done: {rupture_region['work_done_fJ']:.2f} fJ, "
                    f"Peak force: {rupture_region['peak_force_pN']:.2f} pN, "
                    f"Fit error: {rupture_region['fit_error']:.2f} pN"
                )