In [None]:
from pathlib import Path
import re
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from lumicks import pylake
import seaborn as sns

In [None]:
data_dirs = [Path("/Users/sylvi/optical_data/loading_markers/data")]

for data_dir in data_dirs:
    assert data_dir.exists(), f"Data directory {data_dir} does not exist."

output_folder = Path("/Users/sylvi/optical_data/loading_markers/processed/")

markers = {}

for data_dir in data_dirs:
    for file in data_dir.glob("*.h5"):
        if file.is_file() and "Marker" in file.name:
            print(f"Loading file {file}")
            marker_data = pylake.File(file)
            print(type(marker_data))

            # get metadata from the file name
            tel_reps = re.search(r"Tel(\d+)", file.name)
            if tel_reps:
                tel_reps = int(tel_reps.group(1))
            else:
                raise ValueError(f"Could not find telereps in file name {file.name}")

            protein_name = re.search(r" (\w+)(?= Marker \d+)", file.name)
            if protein_name:
                protein_name = protein_name.group(1)
            else:
                raise ValueError(f"Could not find protein name in file name {file.name}")

            # extract the curves
            print(f"curve ids: {marker_data.fdcurves.keys()}")

            for curve_id, curve_data in marker_data.fdcurves.items():

                # print(f"Curve ID: {curve_id}")
                force_data = curve_data.f.data
                # curve_data.plot_scatter()
                # plt.title(f"Force-distance plot from built-in plot_scatter() method")
                # plt.show()
                # print(curve_data.f)
                # plt.plot(force_data)
                # plt.title(f"Force data retrieved via file.fdcurves[{curve_id}].f.data")
                # plt.show()

                distance_data = curve_data.d.data

                # segment types:
                # 1: increasing
                # 0: flat
                # -1: decreasing

                # determine the oscillations in the curve.
                # curves may start with any segment type
                # data has noise so can be hard to determine which segment we are in

                # determine the starting distance to be the first peak in frequency of the distance data
                bin_size_um = 0.1
                bin_edges = np.arange(np.min(distance_data), np.max(distance_data) + bin_size_um, bin_size_um)
                hist, _ = np.histogram(distance_data, bins=bin_edges)

                # print(hist)

                # find the largest peak in the histogram
                peak_index = np.argmax(hist)
                # get the midpoint of the bin
                peak_distance = (bin_edges[peak_index] + bin_edges[peak_index + 1]) / 2
                print(f"Peak distance: {peak_distance} um")

                # check that the peak is strong, as in that the peak bin contains a lot more counts than the other bins
                # critera: peak should be at least 2x the next highest bin
                next_highest_bin = np.partition(hist, -2)[-2]
                if hist[peak_index] < 2 * next_highest_bin:
                    print(f"Peak at {peak_distance} um is not strong enough, skipping curve {curve_id}")
                    continue

                flat_distance_um = peak_distance
                flat_distance_tolerance_um = 0.1

                flat_regions_bool = np.abs(distance_data - flat_distance_um) < flat_distance_tolerance_um
                flat_regions: list[tuple[int, int]] = []
                current_flat_region_start = None
                for index, is_flat in enumerate(flat_regions_bool):
                    if is_flat and current_flat_region_start is None:
                        current_flat_region_start = index
                    elif not is_flat and current_flat_region_start is not None:
                        flat_regions.append((current_flat_region_start, index - 1))
                        current_flat_region_start = None
                if current_flat_region_start is not None:
                    flat_regions.append((current_flat_region_start, len(flat_regions_bool) - 1))
                
                print(f"Flat regions: {flat_regions}")

                # eliminate non-flat regions at the start and end of the curve
                if flat_regions:
                    if flat_regions[0][0] > 0:
                        # cut the array to start at the first flat region
                        distance_data_trimmed = distance_data[flat_regions[0][0]:]
                        force_data_trimmed = force_data[flat_regions[0][0]:]
                    else:
                        distance_data_trimmed = distance_data
                        force_data_trimmed = force_data
                    if flat_regions[-1][1] < len(distance_data_trimmed) - 1:
                        # cut the array to end at the last flat region
                        distance_data_trimmed = distance_data_trimmed[:flat_regions[-1][1] + 1]
                        force_data_trimmed = force_data_trimmed[:flat_regions[-1][1] + 1]
                else:
                    distance_data_trimmed = distance_data
                    force_data_trimmed = force_data
                

                plt.plot(distance_data, label="Distance Data")
                plt.plot(distance_data_trimmed, label="Trimmed Distance Data", linestyle='--')
                plt.title(f"Distance data retrieved via file.fdcurves[{curve_id}].d.data")
                plt.xlabel("Index")
                plt.ylabel("Distance (um)")
                plt.legend()
                plt.show()









                # # iterate over the oscillations and plot them and fit the return curve
                # for oscillation_index, oscillation in enumerate(oscillations):
                #     # plot the decreasing segment
                #     decreasing_force = oscillation["decreasing"]["force"]
                #     decreasing_distance = oscillation["decreasing"]["distance"]

                #     plt.plot(decreasing_distance, decreasing_force, label="Decreasing Segment")
                #     plt.title(f"Decreasing segment of oscillation {oscillation_index} for curve {curve_id}")
                #     plt.xlabel("Distance (um)")
                #     plt.ylabel("Force (pN)")
                #     plt.legend()
                #     plt.show()

                #     # fit a ewlc model to it
                #     model = pylake.ewlc_odijk_force(name="ewlc_return_fit") + pylake.force_offset(
                #         name="ewlc_return_fit"
                #     )
                #     fit = pylake.FdFit(model)
                #     fit.add_data(name="return", f=decreasing_force, d=decreasing_distance)
                #     fit["ewlc_return_fit/Lp"].value = 50
                #     fit["ewlc_return_fit/Lp"].lower_bound = 39
                #     fit["ewlc_return_fit/Lp"].upper_bound = 80
                #     fit["ewlc_return_fit/Lc"].value = 27
                #     fit["ewlc_return_fit/f_offset"].lower_bound = 0
                #     fit["ewlc_return_fit/f_offset"].upper_bound = 1

                #     error = np.mean(fit.sigma)

                #     # plot the fit before fitting to the data
                #     fit.plot()
                #     plt.title(
                #         f"Fit for decreasing segment of oscillation {oscillation_index} for"
                #         f"curve {curve_id} before fitting. Error: {error:.2f} pN"
                #     )
                #     plt.xlabel("Distance (um)")
                #     plt.ylabel("Force (pN)")
                #     plt.legend()
                #     plt.show()

                #     # actually fit the model now
                #     fit.fit()
                #     # note that the error for some reason is the same value repeated for each point. taking the mean
                #     # just in case this changes.
                #     error = np.mean(fit.sigma)
                #     print(fit.params)

                #     fit.plot()
                #     plt.title(
                #         f"Fit for decreasing segment of oscillation {oscillation_index} for"
                #         f"curve {curve_id} after fitting. Error: {error:.2f} pN"
                #     )
                #     plt.xlabel("Distance (um)")
                #     plt.ylabel("Force (pN)")
                #     plt.legend()
                #     plt.show()

        break
    break