In [None]:
from pathlib import Path
import re
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from lumicks import pylake
import seaborn as sns
from skimage.measure import label, regionprops
from scipy.signal import find_peaks
from dataclasses import dataclass

In [None]:
@dataclass
class Oscillation:
    increasing_force: npt.NDArray[np.float64]
    increasing_distance: npt.NDArray[np.float64]
    decreasing_force: npt.NDArray[np.float64]
    decreasing_distance: npt.NDArray[np.float64]

    # plotting methods
    def plot_distances(self):
        # create x array of length of both increasing and decreasing distances
        x = np.arange(len(self.increasing_distance) + len(self.decreasing_distance))
        fig, ax = plt.subplots()
        ax.plot(x[: len(self.increasing_distance)], self.increasing_distance, label="Increasing Distance")
        ax.plot(x[len(self.increasing_distance) :], self.decreasing_distance, label="Decreasing Distance")
        ax.set_title("Distance vs. Index")
        ax.set_xlabel("Index")
        ax.set_ylabel("Distance (um)")
        ax.legend()
        plt.show()

    def plot_forces(self):
        # create x array of length of both increasing and decreasing forces
        x = np.arange(len(self.increasing_force) + len(self.decreasing_force))
        fig, ax = plt.subplots()
        ax.plot(x[: len(self.increasing_force)], self.increasing_force, label="Increasing Force")
        ax.plot(x[len(self.increasing_force) :], self.decreasing_force, label="Decreasing Force")
        ax.set_title("Force vs. Index")
        ax.set_xlabel("Index")
        ax.set_ylabel("Force (pN)")
        ax.legend()
        plt.show()

    def plot_force_distance_curve(self):
        fig, ax = plt.subplots()
        ax.plot(self.increasing_distance, self.increasing_force, label="Increasing Force")
        ax.plot(self.decreasing_distance, self.decreasing_force, label="Decreasing Force")
        ax.set_title("Force-Distance Curve")
        ax.set_xlabel("Distance (um)")
        ax.set_ylabel("Force (pN)")
        ax.legend()
        plt.show()


@dataclass
class FDCurve:
    filename: str
    curve_id: int
    all_forces: npt.NDArray[np.float64]
    all_distances: npt.NDArray[np.float64]
    oscillations: list[Oscillation]


class ReducedMarker:
    def __init__(self, filename: str | Path, curve_id: int, verbose: bool = False, plotting: bool = False):
        self.filename = filename
        self.curve_id = curve_id
        self.file = pylake.File(filename)
        self.metadata = self.get_file_metadata(filename)
        self.telereps = self.metadata["telereps"]
        self.protein_name = self.metadata["protein_name"]

    @staticmethod
    def get_file_metadata(filename: str) -> dict[str, str]:
        """
        Obtain file metadata from the filename.

        Parameters
        ----------
        filename : str
            The name of the file to extract the metadata from.

        Returns
        -------
        dict[str, str]
            A dictionary containing the metadata extracted from the filename.
        """
        metadata = {}
        # grab tel_reps
        tel_reps = re.search(r"Tel(\d+)", filename)
        if tel_reps:
            tel_reps = int(tel_reps.group(1))
        else:
            raise ValueError(f"Could not find telereps regex: Tel(\\d+) in file name {filename}")
        metadata["telereps"] = tel_reps
        # grab protein name, assumed to be before the string "Marker X"
        protein_name = re.search(r" (\w+)(?= Marker \d+)", filename)
        if protein_name:
            protein_name = protein_name.group(1)
        else:
            raise ValueError(f"Could not find protein name in file name {filename}")
        metadata["protein_name"] = protein_name

        return metadata

    @staticmethod
    def load_curves(
        filename: str, fdcurves: dict[str, pylake.FdCurve], verbose: bool = False, plotting: bool = False
    ) -> dict[str, FDCurve]:
        """
        Load the force-distance curves from the file.

        Parameters
        ----------
        fdcurves : dict[str, pylake.FdCurve]
            A dictionary of force-distance curves.
        verbose : bool, optional
            If True, print additional information about the curves being loaded. Default is False.
        plotting : bool, optional
            If True, plot the intermediate plots. Default is False.

        Returns
        -------


        """
        fd_curves = {}
        for curve_id, curve_data in fdcurves.items():
            if verbose:
                print(f"Loading curve {curve_id} with {len(curve_data.d.data)} data points.")
            force_data = curve_data.f.data
            distance_data = curve_data.d.data

            if plotting:
                plt.plot(force_data, distance_data, label=f"Curve {curve_id}")
                plt.title(f"Force-Distance Curve {curve_id}")
                plt.xlabel("Force (pN)")
                plt.ylabel("Distance (um)")
                plt.legend()
                plt.show()

            # determine the starting distance to be the first peak in frequency of the distance data
            bin_size_um = 0.1
            bin_edges = np.arange(np.min(distance_data), np.max(distance_data) + bin_size_um, bin_size_um)
            hist, _ = np.histogram(distance_data, bins=bin_edges)

            # find the largest peak in the histogram
            peak_index = np.argmax(hist)
            # get the midpoint of the bin
            base_distance = (bin_edges[peak_index] + bin_edges[peak_index + 1]) / 2
            if verbose:
                print(f"Base distance: {base_distance} um")

            # check that the peak is strong, as in that the peak bin contains a lot more counts than the other bins
            # critera: peak should be at least 2x the next highest bin
            next_highest_bin = np.partition(hist, -2)[-2]
            peak_strength = hist[peak_index] / next_highest_bin if next_highest_bin > 0 else 0
            if verbose:
                print(f"Peak strength: {peak_strength:.2f} (peak is strength * next highest bin)")
            peak_strength_threshold = 2.0
            if peak_strength < peak_strength_threshold:
                print(
                    f"Peak at {base_distance} um is not strong enough (<{peak_strength_threshold})."
                    f" skipping curve {curve_id}"
                )
                continue

            flat_distance_um = base_distance
            flat_distance_tolerance_um = 0.1

            flat_regions_bool = np.abs(distance_data - flat_distance_um) < flat_distance_tolerance_um
            flat_regions: list[tuple[int, int]] = []
            current_flat_region_start = None
            for index, is_flat in enumerate(flat_regions_bool):
                if is_flat and current_flat_region_start is None:
                    current_flat_region_start = index
                elif not is_flat and current_flat_region_start is not None:
                    flat_regions.append((current_flat_region_start, index - 1))
                    current_flat_region_start = None
            if current_flat_region_start is not None:
                flat_regions.append((current_flat_region_start, len(flat_regions_bool) - 1))

            if verbose:
                print(f"Flat regions: {flat_regions}")

            # eliminate non-flat regions at the start and end of the curve
            if flat_regions:
                if flat_regions[0][0] > 0:
                    # cut the array to start at the first flat region
                    distance_data_trimmed = distance_data[flat_regions[0][0] :]
                    force_data_trimmed = force_data[flat_regions[0][0] :]
                    flat_regions_bool_trimmed = flat_regions_bool[flat_regions[0][0] :]
                else:
                    distance_data_trimmed = distance_data
                    force_data_trimmed = force_data
                    flat_regions_bool_trimmed = flat_regions_bool
                if flat_regions[-1][1] < len(distance_data_trimmed) - 1:
                    # cut the array to end at the last flat region
                    distance_data_trimmed = distance_data_trimmed[: flat_regions[-1][1] + 1]
                    force_data_trimmed = force_data_trimmed[: flat_regions[-1][1] + 1]
                    flat_regions_bool_trimmed = flat_regions_bool_trimmed[: flat_regions[-1][1] + 1]
            else:
                distance_data_trimmed = distance_data
                force_data_trimmed = force_data
                flat_regions_bool_trimmed = flat_regions_bool

            if plotting:
                plt.plot(distance_data, label="Distance Data")
                plt.plot(distance_data_trimmed, label="Trimmed Distance Data", linestyle="--")
                plt.title(f"Distance data retrieved via file.fdcurves[{curve_id}].d.data")
                plt.xlabel("Index")
                plt.ylabel("Distance (um)")
                plt.legend()
                plt.show()

            # get the non-flat regions
            non_flat_regions_bool_trimmed = ~flat_regions_bool_trimmed
            # label the non-flat regions
            labelled_non_flat_regions = label(non_flat_regions_bool_trimmed)
            # for each non-flat region, determine the maximum distance value and set the indexes to the left of it
            # as the increasing segment and indexes to the right as the decreasing segment
            oscillations = []
            for label_index in range(1, labelled_non_flat_regions.max() + 1):
                # get the indexes of the current non-flat region
                non_flat_region_indexes = np.argwhere(labelled_non_flat_regions == label_index)
                non_flat_region_start = non_flat_region_indexes[0][0]
                non_flat_region_end = non_flat_region_indexes[-1][0]
                non_flat_region_distances = distance_data_trimmed[non_flat_region_start : non_flat_region_end + 1]
                non_flat_region_local_maximum_distance_index = np.argmax(non_flat_region_distances)
                non_flat_region_global_maximum_distance_index = (
                    non_flat_region_start + non_flat_region_local_maximum_distance_index
                )
                # set the increasing segment to the left of the maximum distance index, with the second index being
                # exclusive
                increasing_segment_start = non_flat_region_start
                # keep the largest distance value in the increasing segment
                increasing_segment_end = non_flat_region_global_maximum_distance_index + 1
                # set the decreasing segment to the right of the maximum distance index, with the second index
                # being exclusive
                decreasing_segment_start = non_flat_region_global_maximum_distance_index + 1
                decreasing_segment_end = non_flat_region_end + 1
                # get the force data for the increasing and decreasing segments
                increasing_force = force_data_trimmed[increasing_segment_start:increasing_segment_end]
                increasing_distance = distance_data_trimmed[increasing_segment_start:increasing_segment_end]
                decreasing_force = force_data_trimmed[decreasing_segment_start:decreasing_segment_end]
                decreasing_distance = distance_data_trimmed[decreasing_segment_start:decreasing_segment_end]
                # create an Oscillation object
                oscillation = Oscillation(
                    increasing_force=increasing_force,
                    increasing_distance=increasing_distance,
                    decreasing_force=decreasing_force,
                    decreasing_distance=decreasing_distance,
                )
                oscillations.append(oscillation)
            if verbose:
                print(f"Found {len(oscillations)} oscillations in curve {curve_id}.")
            # create a FDCurve object
            fd_curve = FDCurve(
                filename=filename,
                curve_id=curve_id,
                all_forces=force_data_trimmed,
                all_distances=distance_data_trimmed,
                oscillations=oscillations,
            )
            fd_curves[curve_id] = fd_curve

            if verbose:
                print(f"Loaded curve {curve_id} from {filename}.")

        return fd_curves

In [None]:
data_dirs = [Path("/Users/sylvi/optical_data/loading_markers/data")]

for data_dir in data_dirs:
    assert data_dir.exists(), f"Data directory {data_dir} does not exist."

output_folder = Path("/Users/sylvi/optical_data/loading_markers/processed/")

markers = {}

for data_dir in data_dirs:
    for file in data_dir.glob("*.h5"):
        if file.is_file() and "Marker" in file.name:
            print(f"Loading file {file}")
            marker_data = pylake.File(file)
            print(type(marker_data))

            # get metadata from the file name
            tel_reps = re.search(r"Tel(\d+)", file.name)
            if tel_reps:
                tel_reps = int(tel_reps.group(1))
            else:
                raise ValueError(f"Could not find telereps in file name {file.name}")

            protein_name = re.search(r" (\w+)(?= Marker \d+)", file.name)
            if protein_name:
                protein_name = protein_name.group(1)
            else:
                raise ValueError(f"Could not find protein name in file name {file.name}")

            # extract the curves
            print(f"curve ids: {marker_data.fdcurves.keys()}")

            for curve_id, curve_data in marker_data.fdcurves.items():

                # print(f"Curve ID: {curve_id}")
                force_data = curve_data.f.data
                # curve_data.plot_scatter()
                # plt.title(f"Force-distance plot from built-in plot_scatter() method")
                # plt.show()
                # print(curve_data.f)
                # plt.plot(force_data)
                # plt.title(f"Force data retrieved via file.fdcurves[{curve_id}].f.data")
                # plt.show()

                distance_data = curve_data.d.data

                # segment types:
                # 1: increasing
                # 0: flat
                # -1: decreasing

                # determine the oscillations in the curve.
                # curves may start with any segment type
                # data has noise so can be hard to determine which segment we are in

                # determine the starting distance to be the first peak in frequency of the distance data
                bin_size_um = 0.1
                bin_edges = np.arange(np.min(distance_data), np.max(distance_data) + bin_size_um, bin_size_um)
                hist, _ = np.histogram(distance_data, bins=bin_edges)

                # print(hist)

                # find the largest peak in the histogram
                peak_index = np.argmax(hist)
                # get the midpoint of the bin
                base_distance = (bin_edges[peak_index] + bin_edges[peak_index + 1]) / 2
                print(f"Peak distance: {base_distance} um")

                # check that the peak is strong, as in that the peak bin contains a lot more counts than the other bins
                # critera: peak should be at least 2x the next highest bin
                next_highest_bin = np.partition(hist, -2)[-2]
                if hist[peak_index] < 2 * next_highest_bin:
                    print(f"Peak at {base_distance} um is not strong enough, skipping curve {curve_id}")
                    continue

                flat_distance_um = base_distance
                flat_distance_tolerance_um = 0.1

                flat_regions_bool = np.abs(distance_data - flat_distance_um) < flat_distance_tolerance_um
                flat_regions: list[tuple[int, int]] = []
                current_flat_region_start = None
                for index, is_flat in enumerate(flat_regions_bool):
                    if is_flat and current_flat_region_start is None:
                        current_flat_region_start = index
                    elif not is_flat and current_flat_region_start is not None:
                        flat_regions.append((current_flat_region_start, index - 1))
                        current_flat_region_start = None
                if current_flat_region_start is not None:
                    flat_regions.append((current_flat_region_start, len(flat_regions_bool) - 1))

                print(f"Flat regions: {flat_regions}")

                # eliminate non-flat regions at the start and end of the curve
                if flat_regions:
                    if flat_regions[0][0] > 0:
                        # cut the array to start at the first flat region
                        distance_data_trimmed = distance_data[flat_regions[0][0] :]
                        force_data_trimmed = force_data[flat_regions[0][0] :]
                        flat_regions_bool_trimmed = flat_regions_bool[flat_regions[0][0] :]
                    else:
                        distance_data_trimmed = distance_data
                        force_data_trimmed = force_data
                        flat_regions_bool_trimmed = flat_regions_bool
                    if flat_regions[-1][1] < len(distance_data_trimmed) - 1:
                        # cut the array to end at the last flat region
                        distance_data_trimmed = distance_data_trimmed[: flat_regions[-1][1] + 1]
                        force_data_trimmed = force_data_trimmed[: flat_regions[-1][1] + 1]
                        flat_regions_bool_trimmed = flat_regions_bool_trimmed[: flat_regions[-1][1] + 1]
                else:
                    distance_data_trimmed = distance_data
                    force_data_trimmed = force_data
                    flat_regions_bool_trimmed = flat_regions_bool

                # plt.plot(distance_data, label="Distance Data")
                # plt.plot(distance_data_trimmed, label="Trimmed Distance Data", linestyle="--")
                # plt.title(f"Distance data retrieved via file.fdcurves[{curve_id}].d.data")
                # plt.xlabel("Index")
                # plt.ylabel("Distance (um)")
                # plt.legend()
                # plt.show()

                # get the non-flat regions
                oscillations: list[dict[str, npt.NDArray[np.float64]]] = []
                non_flat_regions_bool_trimmed = ~flat_regions_bool_trimmed
                # label the non-flat regions
                labelled_non_flat_regions = label(non_flat_regions_bool_trimmed)
                # for each non-flat region, determine the maximum distance value and set the indexes to the left of it
                # as the increasing segment and indexes to the right as the decreasing segment
                for label_index in range(1, labelled_non_flat_regions.max() + 1):
                    # get the indexes of the current non-flat region
                    non_flat_region_indexes = np.argwhere(labelled_non_flat_regions == label_index)
                    non_flat_region_start = non_flat_region_indexes[0][0]
                    non_flat_region_end = non_flat_region_indexes[-1][0]
                    non_flat_region_distances = distance_data_trimmed[non_flat_region_start : non_flat_region_end + 1]
                    non_flat_region_local_maximum_distance_index = np.argmax(non_flat_region_distances)
                    non_flat_region_global_maximum_distance_index = (
                        non_flat_region_start + non_flat_region_local_maximum_distance_index
                    )
                    # set the increasing segment to the left of the maximum distance index, with the second index being
                    # exclusive
                    increasing_segment_start = non_flat_region_start
                    # keep the largest distance value in the increasing segment
                    increasing_segment_end = non_flat_region_global_maximum_distance_index + 1
                    # set the decreasing segment to the right of the maximum distance index, with the second index
                    # being exclusive
                    decreasing_segment_start = non_flat_region_global_maximum_distance_index + 1
                    decreasing_segment_end = non_flat_region_end + 1
                    # get the force data for the increasing and decreasing segments
                    current_oscillation = {
                        "increasing": {"force": None, "distance": None},
                        "decreasing": {"force": None, "distance": None},
                    }
                    current_oscillation["increasing"]["force"] = force_data_trimmed[
                        increasing_segment_start:increasing_segment_end
                    ]
                    current_oscillation["increasing"]["distance"] = distance_data_trimmed[
                        increasing_segment_start:increasing_segment_end
                    ]
                    current_oscillation["decreasing"]["force"] = force_data_trimmed[
                        decreasing_segment_start:decreasing_segment_end
                    ]
                    current_oscillation["decreasing"]["distance"] = distance_data_trimmed[
                        decreasing_segment_start:decreasing_segment_end
                    ]
                    # add the current oscillation to the list of oscillations
                    oscillations.append(current_oscillation.copy())

                for oscillation_index, oscillation in enumerate(oscillations):
                    increasing_force = oscillation["increasing"]["force"]
                    increasing_distance = oscillation["increasing"]["distance"]
                    decreasing_force = oscillation["decreasing"]["force"]
                    decreasing_distance = oscillation["decreasing"]["distance"]
                    xs = np.arange(len(increasing_distance) + len(decreasing_distance))
                    # plot the segments
                    # plt.plot(xs[: len(increasing_distance)], increasing_distance, label="increasing")
                    # plt.plot(xs[len(increasing_distance) :], decreasing_distance, label="decreasing")
                    # plt.title(f"Oscillation {oscillation_index} distance segments for curve {curve_id}")
                    # plt.xlabel("Index")
                    # plt.ylabel("Distance (um)")
                    # plt.legend()
                    # plt.show()

                    # # plot the force data
                    # plt.plot(xs[: len(increasing_force)], increasing_force, label="increasing")
                    # plt.plot(xs[len(increasing_force) :], decreasing_force, label="decreasing")
                    # plt.title(f"Oscillation {oscillation_index} force segments for curve {curve_id}")
                    # plt.xlabel("Index")
                    # plt.ylabel("Force (pN)")
                    # plt.legend()
                    # plt.show()

                    # fit a ewlc model to the decreasing segment
                    model = pylake.ewlc_odijk_force(name="ewlc_return_fit") + pylake.force_offset(
                        name="ewlc_return_fit"
                    )
                    fit = pylake.FdFit(model)
                    fit.add_data(name="return", f=decreasing_force, d=decreasing_distance)
                    fit["ewlc_return_fit/Lp"].value = 50
                    fit["ewlc_return_fit/Lp"].lower_bound = 39
                    fit["ewlc_return_fit/Lp"].upper_bound = 80
                    fit["ewlc_return_fit/Lc"].value = 27
                    fit["ewlc_return_fit/f_offset"].lower_bound = 0
                    fit["ewlc_return_fit/f_offset"].upper_bound = 1

                    fit.fit()
                    # note that the error for some reason is the same value repeated for each point. taking the mean
                    # just in case this changes.
                    error = np.mean(fit.sigma)

                    # fit.plot()
                    # plt.title(
                    #     f"Fit for decreasing segment of oscillation {oscillation_index} for "
                    #     f"curve {curve_id} after fitting. Error: {error:.2f} pN"
                    # )
                    # plt.xlabel("Distance (um)")
                    # plt.ylabel("Force (pN)")
                    # plt.legend()
                    # plt.show()

                    # find peaks in the increasing force-distance segment
                    peak_height = (0.5, 30)
                    peak_prominence = 0.8
                    peaks, _ = find_peaks(increasing_force, height=peak_height, prominence=peak_prominence)
                    if len(peaks) == 0:
                        print(
                            f"No peaks found in increasing segment of oscillation {oscillation_index} for curve {curve_id}"
                        )
                        continue
                    print(
                        f"Found {len(peaks)} peaks in increasing segment of oscillation {oscillation_index} for curve {curve_id}"
                    )
                    # plot the increasing segment with the peaks
                    plt.plot(increasing_distance, increasing_force, label="increasing")
                    # vlines for the peaks
                    plt.vlines(
                        increasing_distance[peaks],
                        ymin=np.min(increasing_force),
                        ymax=np.max(increasing_force),
                        color="grey",
                        label="peaks",
                        linestyle="--",
                    )
                    plt.title(f"Peaks in increasing segment of oscillation {oscillation_index} for curve {curve_id}")
                    plt.xlabel("Distance (um)")
                    plt.ylabel("Force (pN)")
                    plt.legend()
                    plt.show()

        break
    break