In [None]:
from pathlib import Path
import pickle
import numpy as np
import matplotlib.pyplot as plt
from marker_analyser.classes import OscillationCollection
from pydantic import BaseModel, model_validator, ValidationError

from lumicks import pylake
from typing import Self

In [None]:
# Load the data
with open("./experimenting_data/marina_data_loaded_oscillations.pkl", "rb") as file:
    oscillations: OscillationCollection = pickle.load(file)
    assert type(oscillations) == OscillationCollection

# get only the first 10 oscillations for testing
first_10_oscillation_ids = list(oscillations.oscillations.keys())[:10]
oscillations = OscillationCollection(
    oscillations={osc_id: oscillations.oscillations[osc_id] for osc_id in first_10_oscillation_ids}
)

print(f"Loaded {len(oscillations.oscillations)} oscillations for testing.")

In [None]:
segment = "both"


class ParamConfig(BaseModel):
    lower_bound: float | None = None
    upper_bound: float | None = None
    initial_value: float | None = None
    global_param: bool = False
    fixed: bool = False

    @model_validator(mode="after")
    def check_bounds(self) -> Self:
        if self.lower_bound is not None and self.upper_bound is not None:
            if self.lower_bound > self.upper_bound:
                raise ValidationError("lower bound cannot be greater than upper bound")
        return self


class FitConfig(BaseModel):
    params_config: dict[str, ParamConfig] = {}
    auto_calculate_and_fix_f_offset: bool
    f_offset_auto_detect_distance_range_um: tuple[float, float] = (10, 12)
    model_name: str = "fit"

    @model_validator(mode="after")
    def check_f_offset_not_global_if_auto_calculating(self) -> Self:
        if self.auto_calculate_and_fix_f_offset:
            f_offset_config = self.params_config.get("f_offset")
            if f_offset_config is not None and f_offset_config.global_param:
                raise ValidationError(
                    "f_offset cannot be a global parameter if auto_calculate_and_fix_f_offset is True"
                )
        return self


def fit_global_model_to_each(
    oscillations: OscillationCollection,
    segment: str,
    fit_config: FitConfig,
) -> None:

    fit_name = fit_config.model_name
    model = pylake.ewlc_odijk_force(name=fit_name) + pylake.force_offset(name=fit_name)
    fit = pylake.FdFit(model)

    for oscillation_id, oscillation in oscillations.oscillations.items():
        print(f"adding data for oscillation {oscillation_id}")
        individual_params = {}
        # create any individual parameters needed
        for param, param_config in fit_config.params_config.items():
            if not param_config.global_param:
                individual_params[f"{fit_name}/{param}"] = f"{fit_name}/{param}_{oscillation_id}"

        # if auto calculating the force offsets, will need to set the parameter to be individual
        if fit_config.auto_calculate_and_fix_f_offset:
            individual_params[f"{fit_name}/f_offset"] = f"{fit_name}/f_offset_{oscillation_id}"

        # Add data to the model
        distances, forces = oscillation.get_segment(segment)
        fit.add_data(
            name=f"Oscillation {oscillation_id}, segment {segment}",
            f=forces,
            d=distances,
            params=individual_params,
        )

        for param, param_config in fit_config.params_config.items():
            if not param_config.global_param:
                param_name = f"{fit_name}/{param}_{oscillation_id}"
                if param_config.initial_value is not None:
                    fit[param_name].value = param_config.initial_value
                if param_config.lower_bound is not None:
                    fit[param_name].lower_bound = param_config.lower_bound
                if param_config.upper_bound is not None:
                    fit[param_name].upper_bound = param_config.upper_bound
                fit[param_name].fixed = param_config.fixed

        # optionally auto-detect the force offset and fix that value
        if fit_config.auto_calculate_and_fix_f_offset:
            # calculate the initial force offset based on the specified distance range
            mask = (distances >= fit_config.f_offset_auto_detect_distance_range_um[0]) & (
                distances <= fit_config.f_offset_auto_detect_distance_range_um[1]
            )
            force_mask = forces[mask]

            # calculate the median force and use that as the initial force offset
            calculated_f_offset = np.median(force_mask)
            print(f"calculated force offset for oscillation {oscillation_id}: {calculated_f_offset}")

            # set the value
            fit[f"{fit_name}/f_offset_{oscillation_id}"].value = calculated_f_offset
            # set it to not be fitted
            fit[f"{fit_name}/f_offset_{oscillation_id}"].fixed = True

    # set the global params
    for param, param_config in fit_config.params_config.items():
        if param_config.global_param:
            param_name = f"{fit_name}/{param}"
            if param_config.initial_value is not None:
                fit[param_name].value = param_config.initial_value
            if param_config.lower_bound is not None:
                fit[param_name].lower_bound = param_config.lower_bound
            if param_config.upper_bound is not None:
                fit[param_name].upper_bound = param_config.upper_bound
            fit[param_name].fixed = param_config.fixed

    fit.fit()

    print(fit)


fit_config = FitConfig(
    params_config={
        "Lp": ParamConfig(lower_bound=0, initial_value=5, global_param=False, fixed=False),
        "Lc": ParamConfig(lower_bound=0, initial_value=10, global_param=True, fixed=True),
    },
    auto_calculate_and_fix_f_offset=True,
    f_offset_auto_detect_distance_range_um=(10, 12),
    model_name="fit",
)

fit_global_model_to_each(oscillations, segment, fit_config)