In [None]:
import pickle
from pathlib import Path
import re
import pytest

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from hmmlearn.hmm import GaussianHMM

from AFMReader.topostats import load_topostats

from topostats.tracing.splining import resample_points_regular_interval
from topostats.measure.curvature import discrete_angle_difference_per_nm_circular
from topostats.damage.damage import calculate_defects_and_gap_lengths

import numpy.typing as npt
import numpy as np
from topostats.io import LoadScans
import matplotlib.pyplot as plt
from topostats.unet_masking import make_bounding_box_square, pad_bounding_box

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable


def plot_line_coloured(
    intensity_ys: npt.NDArray[np.float64],
    intensity_xs: npt.NDArray[np.float64] | None = None,
    image: npt.NDArray[np.float64] | None = None,
    xy_coords: npt.NDArray[np.float64] | None = None,
    figsize: tuple[int, int] = (20, 8),
):
    """
    Plot a line plot with the colour of each point determined by the position in the array.

    Parameters
    ----------
    intensity_ys : npt.NDArray[np.float_]
        The values to plot.
    intensity_xs : npt.NDArray[np.float_], optional
        The x-coordinates of the points. If None, the indices of yvals are used.
    image : npt.NDArray[np.float_], optional
        An image to plot the line on top of. If None, no image is plotted.
    xy_coords : npt.NDArray[np.float_], optional
        The coordinates of the points, for plotting trace on top of the image.
    """

    # Generate a colormap
    cmap = plt.get_cmap("cool")  # Smooth transition from red to purple
    norm = Normalize(vmin=0, vmax=len(intensity_ys) - 1)
    sm = ScalarMappable(norm=norm, cmap=cmap)

    # optionally create x coords or use the provided positions
    if intensity_xs is not None:
        assert len(intensity_xs) == len(intensity_ys), "intensity_xs and intensity_ys must have the same length"
        x = intensity_xs
    else:
        x = np.arange(len(intensity_ys))

    # If image supplied
    if image is not None:
        assert xy_coords is not None, "xy_coords must be provided if image is supplied"
        assert intensity_ys.shape[0] == xy_coords.shape[0], "intensity_ys and xy_coords must have the same length"
        fig, ax = plt.subplots(1, 2, figsize=figsize)
        ax[0].imshow(image, cmap="afmhot", interpolation="none")
        ax[0].set_title("Image")

        # plot the points of the line with the corresponding color
        for i in range(len(intensity_ys)):
            ax[0].scatter(xy_coords[i, 1], xy_coords[i, 0], color=sm.to_rgba(i), s=10)
        ax[0].set_title("Image with Line Overlay")

        lineplot_ax = ax[1]
    else:
        fig, ax = plt.subplots(figsize=figsize)
        lineplot_ax = ax

    # Plot each segment with its corresponding color
    for i in range(len(intensity_ys) - 1):
        lineplot_ax.plot(x[i : i + 2], intensity_ys[i : i + 2], color=sm.to_rgba(i), linewidth=2)

    # Add colorbar for reference
    plt.colorbar(sm, ax=ax, label="Position in Array")
    lineplot_ax.set_xlabel("Index")
    lineplot_ax.set_ylabel("Value")
    lineplot_ax.set_title("Line Plot with Colour Gradient")
    plt.show()

In [None]:
base_dir = Path("/Users/sylvi/topo_data/picoz")
assert base_dir.exists()

processed_dir = base_dir / "output_abs_thresh"
assert processed_dir.exists()

postprocess_dir = base_dir / "postprocess"
postprocess_dir.mkdir(exist_ok=True)

# load the images
topo_files = list(processed_dir.glob("*/**/*.topostats"))
print(f"found {len(topo_files)} topostats files")

# load the corresponding stats csv
all_stats_csv = processed_dir / "all_statistics.csv"
assert all_stats_csv.exists()
all_stats_df = pd.read_csv(all_stats_csv)
print(all_stats_df.columns)

all_stats_df["total_contour_length"] /= 1e-9

In [None]:
def construct_grains_dictionary(
    file_list: list,
    bbox_padding: int,
    all_stats_df: pd.DataFrame,
    stop_at_index: int | None = None,
    plot: bool = False,
):
    grains_dictionary: dict[any] = {}

    bad_grains: list[tuple[int, str, str]] = []

    loadscans = LoadScans(file_list, channel="dummy")
    loadscans.get_data()
    img_dict = loadscans.img_dict

    grain_index = 0
    for file_index, (filename, file_data) in enumerate(img_dict.items()):
        if stop_at_index is not None and file_index >= stop_at_index:
            break
        try:
            try:
                nodestats_data = file_data["nodestats"]["above"]["stats"]
            except KeyError:
                nodestats_data = None

            # get the corresponding image rows from the stats csv
            all_stats_image = all_stats_df[all_stats_df["image"] == filename]
            print(f"found {len(all_stats_image)} rows in stats csv for {filename}")

            # print(f"getting data from {filename}")
            image = file_data["image"]
            ordered_trace_data = file_data["ordered_traces"]["above"]

            for current_grain_index, grain_ordered_trace_data in ordered_trace_data.items():

                grain_index_int = re.sub(r"grain_", "", current_grain_index)
                grain_index_int = int(grain_index_int)
                all_stats_grain = all_stats_image[all_stats_image["grain_number"] == grain_index_int]
                # grab quantities from the stats csv
                # grab the basename, and check that theyre all the same for this grain
                basename = all_stats_grain["basename"].values[0]
                if pd.isna(basename):
                    print(f"skipping grain {current_grain_index} in file {filename} because basename is NaN")
                    continue
                if not all(all_stats_grain["basename"].values == basename):
                    raise ValueError(
                        f"basename mismatch for grain {current_grain_index} in file {filename}: {all_stats_grain['basename'].values}"
                    )
                # dose is stored in the folder name in basename. it'll be to the right of the last forward slash, in
                # the format "dose_XGY" where X is the dose value
                dose_match = re.search(r"dose_(\d+\.?\d*)", basename)
                assert dose_match is not None, f"could not find dose in basename {basename}"
                dose = float(dose_match.group(1))
                # find tip number, an integer in format "tip_x"
                tip_match = re.search(r"tip_(\d+)", filename)
                assert tip_match is not None, f"could not find tip in basename {filename}"
                tip = int(tip_match.group(1))
                # if basename is nan, skip grain
                smallest_bounding_area = all_stats_grain["smallest_bounding_area"].values[0]
                aspect_ratio = all_stats_grain["aspect_ratio"].values[0]
                total_contour_length = all_stats_grain["total_contour_length"].values[0]

                if total_contour_length < 300:
                    print(
                        f"skipping grain {current_grain_index} in file {filename} because total contour length is too small: {total_contour_length}"
                    )
                    bad_grains.append((grain_index, filename, "contour length too small"))
                    continue

                num_crossings = all_stats_grain["num_crossings"].values[0]

                # print(f"  grain {current_grain_index}")
                grains_dictionary[grain_index] = {}
                grains_dictionary[grain_index]["molecule_data"] = {}

                # all_stats_grain = all_stats_image[all_stats_image[""]]

                for current_molecule_index, molecule_ordered_trace_data in grain_ordered_trace_data.items():
                    molecule_data = {}
                    ordered_coords = molecule_ordered_trace_data["ordered_coords"]
                    molecule_data["heights"] = molecule_ordered_trace_data["heights"]
                    molecule_data["distances"] = molecule_ordered_trace_data["distances"]
                    molecule_data["circular"] = molecule_ordered_trace_data["mol_stats"]["circular"]
                    bbox = molecule_ordered_trace_data["bbox"]
                    print(f"  grain {current_grain_index} molecule {current_molecule_index} bbox {bbox}")

                    splining_coords = file_data["splining"]["above"][current_grain_index][current_molecule_index][
                        "spline_coords"
                    ]

                    curvatures = file_data["grain_curvature_stats"]["above"][current_grain_index][
                        current_molecule_index
                    ]
                    molecule_data["curvatures"] = curvatures

                    # bbox will be same for all molecules so this is okay
                    bbox_square = make_bounding_box_square(bbox[0], bbox[1], bbox[2], bbox[3], image.shape)
                    bbox_padded = pad_bounding_box(
                        bbox_square[0],
                        bbox_square[1],
                        bbox_square[2],
                        bbox_square[3],
                        image.shape,
                        padding=bbox_padding,
                    )
                    added_left = bbox_padded[1] - bbox[1]
                    added_top = bbox_padded[0] - bbox[0]

                    # adjust the spline coords to account for the padding
                    splining_coords[:, 0] -= added_top
                    splining_coords[:, 1] -= added_left
                    molecule_data["spline_coords"] = splining_coords

                    # adjust the ordered coords to account for the padding
                    ordered_coords[:, 0] -= added_top
                    ordered_coords[:, 1] -= added_left
                    molecule_data["ordered_coords"] = ordered_coords

                    grains_dictionary[grain_index]["molecule_data"][current_molecule_index] = molecule_data

                image_crop = image[
                    bbox_padded[0] : bbox_padded[2],
                    bbox_padded[1] : bbox_padded[3],
                ]

                full_grain_mask = file_data["grain_tensors"]["above"]
                grains_dictionary[grain_index]["basename"] = basename
                grains_dictionary[grain_index]["dose"] = dose
                grains_dictionary[grain_index]["tip"] = tip
                grains_dictionary[grain_index]["image"] = image_crop
                grains_dictionary[grain_index]["full_image"] = image
                grains_dictionary[grain_index]["bbox"] = bbox_padded
                grains_dictionary[grain_index]["aspect_ratio"] = aspect_ratio
                grains_dictionary[grain_index]["smallest_bounding_area"] = smallest_bounding_area
                grains_dictionary[grain_index]["total_contour_length"] = total_contour_length
                grains_dictionary[grain_index]["num_crossings"] = num_crossings
                grains_dictionary[grain_index]["added_left"] = added_left
                grains_dictionary[grain_index]["added_top"] = added_top
                grains_dictionary[grain_index]["padding"] = bbox_padding
                mask_crop = full_grain_mask[
                    bbox_padded[0] : bbox_padded[2],
                    bbox_padded[1] : bbox_padded[3],
                ]
                grains_dictionary[grain_index]["mask"] = mask_crop
                grains_dictionary[grain_index]["filename"] = file_data["filename"]
                grains_dictionary[grain_index]["pixel_to_nm_scaling"] = file_data["pixel_to_nm_scaling"]
                grains_dictionary[grain_index]["curvature_stats"] = curvatures

                # grab node coordinates
                all_node_coords = []
                if nodestats_data is not None:
                    try:
                        grain_nodestats_data = nodestats_data[current_grain_index]
                        for _node_index, node_data in grain_nodestats_data.items():
                            node_coords = node_data["node_coords"]
                            # adjust the node coords to account for the padding
                            node_coords[:, 0] -= added_top
                            node_coords[:, 1] -= added_left
                            for node_coord in node_coords:
                                all_node_coords.append(node_coord)
                    except KeyError as e:
                        if "grain_" in str(e):
                            # grain has no nodestats data here, skip
                            pass

                grains_dictionary[grain_index]["node_coords"] = np.array(all_node_coords)

                grain_index += 1
        except KeyError as e:
            if "ordered_traces" in str(e):
                print(f"no ordered traces found in {filename}")
                continue
            raise e

    if plot:
        for grain_index, grain_data in grains_dictionary.items():
            print(f"grain {grain_index}")
            print(grain_data["filename"])
            print(grain_data["pixel_to_nm_scaling"])
            image = grain_data["image"]
            fig, ax = plt.subplots(figsize=(20, 20))
            plt.imshow(image)
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                ordered_coords = molecule_data["ordered_coords"]
                plt.plot(ordered_coords[:, 1], ordered_coords[:, 0], "r")
            # spline coords
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                spline_coords = molecule_data["spline_coords"]
                plt.plot(spline_coords[:, 1], spline_coords[:, 0], "g.")
                spline_distances = np.linalg.norm(spline_coords[1:, :] - spline_coords[:-1, :], axis=1)
                print(
                    f"molecule {molecule_index} spline distances mean {np.mean(spline_distances)} std {np.std(spline_distances)} min {np.min(spline_distances)} max {np.max(spline_distances)}"
                )
            all_node_coords = grain_data["node_coords"]
            if all_node_coords.size > 0:
                plt.plot(all_node_coords[:, 1], all_node_coords[:, 0], "b.")
            plt.show()

            mask = grain_data["mask"][:, :, 1]
            plt.imshow(mask)
            plt.show()

    print(f"found {len(grains_dictionary)} grains in {len(file_list)} images")

    for bad_grain in bad_grains:
        print(f"bad grain {bad_grain[0]} in file {bad_grain[1]}: {bad_grain[2]}")

    return grains_dictionary


grains_dictionary = construct_grains_dictionary(
    topo_files, bbox_padding=20, all_stats_df=all_stats_df, stop_at_index=None, plot=False
)

In [None]:
defect_stats_dir = postprocess_dir / "defect_stats"
defect_stats_dir.mkdir(exist_ok=True)


def grains_dictionary_resample_points(
    grains_dictionary: dict,
    interval_nm: float = 2.0,
    total_index_to_stop_at: int | None = None,
    curvature_average_window_nm: float = 10.0,
) -> dict:
    new_grains_dictionary = {}
    # Iterate through the grains dictionary
    for total_index, (grain_index, grain_data) in enumerate(grains_dictionary.items()):
        if total_index_to_stop_at is not None and total_index >= total_index_to_stop_at:
            break

        pixel_to_nm_scaling = grain_data["pixel_to_nm_scaling"]

        # Iterate through each molecule in the grain
        for molecule_index, molecule_data in grain_data["molecule_data"].items():

            circular = molecule_data["circular"]
            if circular is False:
                print(f"  molecule {molecule_index} is not circular, skipping")
                continue
            if circular is None:
                print(f"  molecule {molecule_index} circluar is NONE, skipping")
                continue

            spline_coords = molecule_data["spline_coords"]
            resampled_points_px = resample_points_regular_interval(
                points=spline_coords, interval=interval_nm / pixel_to_nm_scaling, circular=True
            )

            # resample distances and curvatures
            resampled_distances_px = np.linalg.norm(resampled_points_px[1:, :] - resampled_points_px[:-1, :], axis=1)
            # add the distance from the last point to the first point
            resampled_distances_px = np.insert(
                resampled_distances_px, 0, np.linalg.norm(resampled_points_px[0, :] - resampled_points_px[-1, :])
            )
            resampled_distances_nm = resampled_distances_px * pixel_to_nm_scaling
            cumulative_resampled_distances_px = np.cumsum(resampled_distances_px)
            resampled_curvatures = discrete_angle_difference_per_nm_circular(
                trace_nm=resampled_points_px * pixel_to_nm_scaling
            )

            # smooth curvatures using rolling window
            curvature_rolling_window_size = int(curvature_average_window_nm / interval_nm)
            smoothed_curvatures = np.convolve(
                resampled_curvatures, np.ones(curvature_rolling_window_size), mode="same"
            )
            smoothed_curvatures = smoothed_curvatures / curvature_rolling_window_size

            # grab height values at resampled points
            resampled_heights = np.zeros(resampled_points_px.shape[0])
            for i, point in enumerate(resampled_points_px):
                int_x = int(np.round(point[0]))
                int_y = int(np.round(point[1]))
                height = grain_data["image"][int_x, int_y]
                resampled_heights[i] = height

            # update the molecule data with the resampled data
            molecule_data["resampled_points_px"] = resampled_points_px
            molecule_data["cumulative_resampled_distances_px"] = cumulative_resampled_distances_px
            molecule_data["distances_to_previous_points_nm"] = resampled_distances_nm
            molecule_data["resampled_heights"] = resampled_heights
            molecule_data["resampled_curvatures"] = resampled_curvatures
            molecule_data["smoothed_curvatures"] = smoothed_curvatures
            # update the grain data with the resampled molecule data
            grain_data["molecule_data"][molecule_index] = molecule_data
        # update the grains dictionary with the resampled grain data
        new_grains_dictionary[grain_index] = grain_data
    return new_grains_dictionary


def fetch_all_curvatures_by_tip(
    grains_dictionary: dict,
) -> dict[int, list[float]]:
    """
    Fetch all curvatures by tip from the grains dictionary.
    """
    all_curvatures_by_tip = {}
    for grain_index, grain_data in grains_dictionary.items():
        tip = grain_data["tip"]
        if tip not in all_curvatures_by_tip:
            all_curvatures_by_tip[tip] = []
        for molecule_index, molecule_data in grain_data["molecule_data"].items():
            smoothed_curvatures = molecule_data["smoothed_curvatures"]
            all_curvatures_by_tip[tip].extend(smoothed_curvatures)
    return all_curvatures_by_tip


def fetch_all_heights_by_tip(
    grains_dictionary: dict,
) -> dict[int, list[float]]:
    """
    Fetch all heights by tip from the grains dictionary.
    """
    all_heights_by_tip = {}
    for grain_index, grain_data in grains_dictionary.items():
        tip = grain_data["tip"]
        if tip not in all_heights_by_tip:
            all_heights_by_tip[tip] = []
        for molecule_index, molecule_data in grain_data["molecule_data"].items():
            resampled_heights = molecule_data["resampled_heights"]
            all_heights_by_tip[tip].extend(resampled_heights)
    return all_heights_by_tip


def median_iqr_thresholds(
    values: npt.NDArray[np.float64],
    iqr_factor: float,
) -> tuple[float, float]:
    """
    Calculate the median IQR threshold for a set of values.
    """
    median_value = np.median(values)
    print(f"median: {median_value}")
    q75, q25 = np.percentile(values, [75, 25])
    iqr = q75 - q25
    print(f"{iqr_factor} * {iqr} = {iqr_factor * iqr}")
    return (median_value + iqr_factor * iqr, median_value - iqr_factor * iqr)


def count_defects_in_binary_array(
    binary_array: npt.NDArray[np.bool_],
) -> int:
    """
    Count the number of defect regions in a binary array.
    A defect region is defined as a contiguous sequence of True values.
    """
    num_defects = 0
    if np.all(binary_array == True):
        # if the whole array is a defect, we have one defect, this is weird and shouldn't really happen but we can
        # handle this later
        return 1
    for i in range(1, len(binary_array)):
        if binary_array[i] and not binary_array[i - 1]:
            num_defects += 1
    return num_defects


def find_curvature_defects_grains_dictionary(
    grains_dictionary: dict,
    defect_method: str,
    curvature_threshold_abs_radpernm: float,
    curvature_threshold_iqr_factor: float,
) -> tuple[dict, dict]:
    """Find curvature defects in grains.

    Parameters
    ----------
    grains_dictionary : dict
        Dictionary containing grain data.
    defect_method : str
        Method to use for finding defects.
        Options: "per_tip_median_iqr", "threshold_abs"
    curvature_threshold_abs_radpernm : float
        Absolute threshold for curvature defects in rad/nm.
    curvature_threshold_iqr_factor : float
        IQR factor for calculating curvature thresholds.

    Returns
    -------
    tuple[dict, dict]
        - new_grains_dictionary: Updated grains dictionary with curvature defects.
        - curvature_thresholds_by_tip: Dictionary with curvature thresholds by tip. Possibly empty if defect_method is
        "threshold_abs".

    Raises
    ------
    ValueError
        If the defect_method is not supported.
    """
    print(f"finding curvature defects using method {defect_method}")
    new_grains_dictionary = {}
    curvature_thresholds_by_tip: dict[int, tuple[float, float]] = {}
    if defect_method == "per_tip_median_iqr":
        # calculate curvature threshold
        all_curvatures_by_tip = fetch_all_curvatures_by_tip(grains_dictionary=grains_dictionary)
        for tip, all_curvatures_for_tip in all_curvatures_by_tip.items():
            print(f"tip: {tip}")
            all_curvatures_for_tip = np.array(all_curvatures_for_tip)
            curvature_thresholds_by_tip[tip] = median_iqr_thresholds(
                values=all_curvatures_for_tip, iqr_factor=curvature_threshold_iqr_factor
            )
        print(f"curvature thresholds by tip: {curvature_thresholds_by_tip}")
        # find curvature defects
        for grain_index, grain_data in grains_dictionary.items():
            tip = grain_data["tip"]
            total_curvature_defects_grain = 0
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                smoothed_curvatures = molecule_data["smoothed_curvatures"]
                curvature_defects_above_bool = smoothed_curvatures > curvature_thresholds_by_tip[tip][0]
                curvature_defects_below_bool = smoothed_curvatures < curvature_thresholds_by_tip[tip][1]
                curvature_defects_bool = np.logical_or(curvature_defects_above_bool, curvature_defects_below_bool)
                curvature_defects_indexes = np.where(curvature_defects_bool)[0]
                total_curvature_defects_grain += count_defects_in_binary_array(binary_array=curvature_defects_bool)
                molecule_data["curvature_defects_bool"] = curvature_defects_bool
                molecule_data["curvature_defects_indexes"] = curvature_defects_indexes

            # Add stats to grain data
            grain_data["num_curvature_defects"] = total_curvature_defects_grain
            new_grains_dictionary[grain_index] = grain_data

    elif defect_method == "threshold_abs":
        # find curvature defects
        for grain_index, grain_data in grains_dictionary.items():
            total_curvature_defects_grain = 0
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                smoothed_curvatures = molecule_data["smoothed_curvatures"]
                curvature_defects_bool = np.abs(smoothed_curvatures) > curvature_threshold_abs_radpernm
                curvature_defects_indexes = np.where(curvature_defects_bool)[0]
                total_curvature_defects_grain += count_defects_in_binary_array(binary_array=curvature_defects_bool)
                molecule_data["curvature_defects_bool"] = curvature_defects_bool
                molecule_data["curvature_defects_indexes"] = curvature_defects_indexes

            # Add stats to grain data
            grain_data["num_curvature_defects"] = total_curvature_defects_grain
            new_grains_dictionary[grain_index] = grain_data
    else:
        raise ValueError(f"defect_method {defect_method} not supported")

    return new_grains_dictionary, curvature_thresholds_by_tip


def find_height_defects_grains_dictionary(
    grains_dictionary: dict,
    defect_method: str,
    height_threshold_iqr_factor: float,
    height_threshold_abs_nm: float,
) -> tuple[dict, dict]:
    """Find height defects in grains.

    Parameters
    ----------
    grains_dictionary : dict
        Dictionary containing grain data.
    defect_method : str
        Method to use for finding defects.
        Options: "per_tip_median_iqr", "abs"
    height_threshold_iqr_factor : float
        IQR factor for calculating height thresholds.
    height_threshold_abs_nm : float
        Absolute threshold for height defects in nm.

    Returns
    -------
    tuple[dict, dict]
        - new_grains_dictionary: Updated grains dictionary with height defects.
        - height_thresholds_by_tip: Dictionary with height thresholds by tip. Possibly empty if defect_method is "abs".
    Raises
    ------
    ValueError
        If the defect_method is not supported.
    """
    print(f"finding height defects using method {defect_method}")
    new_grains_dictionary = {}
    height_thresholds_by_tip: dict[int, tuple[float, float]] = {}
    if defect_method == "per_tip_median_iqr":
        # calculate height threshold
        all_heights_by_tip = fetch_all_heights_by_tip(grains_dictionary=grains_dictionary)
        for tip, heights in all_heights_by_tip.items():
            print(f"tip: {tip}")
            heights = np.array(heights)
            height_thresholds_by_tip[tip] = median_iqr_thresholds(
                values=heights, iqr_factor=height_threshold_iqr_factor
            )
        print(f"height thresholds by tip: {height_thresholds_by_tip}")
        # find height defects
        for grain_index, grain_data in grains_dictionary.items():
            tip = grain_data["tip"]
            total_height_defects_grain = 0
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                resampled_heights = molecule_data["resampled_heights"]
                height_defects_bool = resampled_heights < height_thresholds_by_tip[tip][1]  # below lower threshold
                height_defects_indexes = np.where(height_defects_bool)[0]
                total_height_defects_grain += count_defects_in_binary_array(binary_array=height_defects_bool)
                molecule_data["height_defects_bool"] = height_defects_bool
                molecule_data["height_defects_indexes"] = height_defects_indexes

            # Add stats to grain data
            grain_data["num_height_defects"] = total_height_defects_grain
            new_grains_dictionary[grain_index] = grain_data

    elif defect_method == "abs":
        # find height defects
        for grain_index, grain_data in grains_dictionary.items():
            total_height_defects_grain = 0
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                resampled_heights = molecule_data["resampled_heights"]
                height_defects_bool = resampled_heights < height_threshold_abs_nm
                height_defects_indexes = np.where(height_defects_bool)[0]
                total_height_defects_grain += count_defects_in_binary_array(binary_array=height_defects_bool)
                molecule_data["height_defects_bool"] = height_defects_bool
                molecule_data["height_defects_indexes"] = height_defects_indexes

            # Add stats to grain data
            grain_data["num_height_defects"] = total_height_defects_grain
            new_grains_dictionary[grain_index] = grain_data
    else:
        raise ValueError(f"defect_height_threshold_method {defect_method} not supported")

    return new_grains_dictionary, height_thresholds_by_tip


def find_defects_in_height_and_curvature(
    grains_dictionary: dict,
    curvature_defect_method: str,
    curvature_threshold_abs_radpernm: float,
    curvature_threshold_iqr_factor: float,
    height_defect_method: str,
    height_threshold_iqr_factor: float,
    height_threshold_abs_nm: float,
    coindicing_defect_distance_threshold_nm: float,
    defect_stats_dir: Path,
    plot_defects: bool,
) -> tuple[dict, pd.DataFrame, dict, dict]:
    """
    Find defects in height and curvature for each molecule in the grains dictionary.

    Parameters
    ----------
    grains_dictionary : dict
        Dictionary containing grain data.
    curvature_defect_method : str
        Method to use for finding curvature defects.
        Options: "per_tip_median_iqr", "threshold_abs"
    curvature_threshold_abs_radpernm : float
        Absolute threshold for curvature defects in rad/nm.
    curvature_threshold_iqr_factor : float
        IQR factor for calculating curvature thresholds.
    height_defect_method : str
        Method to use for finding height defects.
        Options: "per_tip_median_iqr", "abs"
    height_threshold_iqr_factor : float
        IQR factor for calculating height thresholds.
    height_threshold_abs_nm : float
        Absolute threshold for height defects in nm.
    coindicing_defect_distance_threshold_nm : float
        Distance threshold in nm to consider defects coinciding.
    defect_stats_dir : Path
        Directory to save defect statistics.
    plot_defects : bool
        Whether to plot the defects.

    Returns
    -------
    tuple[dict, pd.DataFrame, dict, dict]
        - grains_dictionary: Updated grains dictionary with defects.
        - defect_stats_df: DataFrame containing defect statistics.
        - curvature_thresholds_by_tip: Dictionary with curvature thresholds by tip.
        - height_thresholds_by_tip: Dictionary with height thresholds by tip.
    """
    defect_stats_list = []
    bad_grains = set()
    new_grains_dictionary = {}

    # find curvature defects
    grains_dictionary, curvature_thresholds_by_tip = find_curvature_defects_grains_dictionary(
        grains_dictionary=grains_dictionary,
        defect_method=curvature_defect_method,
        curvature_threshold_abs_radpernm=curvature_threshold_abs_radpernm,
        curvature_threshold_iqr_factor=curvature_threshold_iqr_factor,
    )

    # find height defects
    grains_dictionary, height_thresholds_by_tip = find_height_defects_grains_dictionary(
        grains_dictionary=grains_dictionary,
        defect_method=height_defect_method,
        height_threshold_iqr_factor=height_threshold_iqr_factor,
        height_threshold_abs_nm=height_threshold_abs_nm,
    )

    # iterate over all grains and molecules
    for grain_index, grain_data in grains_dictionary.items():
        print(f"g{grain_index}", end=" ")
        dose = grain_data["dose"]
        tip = grain_data["tip"]
        image = grain_data["image"]
        filename = grain_data["filename"]
        grain_bad = False
        total_num_coinciding_defects = 0
        for molecule_index, molecule_data in grain_data["molecule_data"].items():
            resampled_points_px = molecule_data["resampled_points_px"]
            resampled_heights = molecule_data["resampled_heights"]
            smoothed_curvatures = molecule_data["smoothed_curvatures"]
            cumulative_resampled_distances_px = molecule_data["cumulative_resampled_distances_px"]
            curvature_defects_bool = molecule_data["curvature_defects_bool"]
            curvature_defects_lengths, curvature_defect_gaps_lengths = calculate_defects_and_gap_lengths(
                points_distance_to_previous_nm=molecule_data["distances_to_previous_points_nm"],
                defects_bool=curvature_defects_bool,
            )
            mean_curvature_defect_length = np.mean(curvature_defects_lengths)
            mean_curvature_defect_gap_length = np.mean(curvature_defect_gaps_lengths)
            height_defects_bool = molecule_data["height_defects_bool"]
            height_defects_lengths, height_defect_gaps_lengths = calculate_defects_and_gap_lengths(
                points_distance_to_previous_nm=molecule_data["distances_to_previous_points_nm"],
                defects_bool=height_defects_bool,
            )
            mean_height_defect_length = np.mean(height_defects_lengths)
            mean_height_defect_gap_length = np.mean(height_defect_gaps_lengths)
            curvature_defects_indexes = np.where(curvature_defects_bool)[0]
            height_defects_indexes = np.where(height_defects_bool)[0]

            # Check for coinciding defects within the distance threshold
            coindicing_defect_distance_threshold_px = (
                coindicing_defect_distance_threshold_nm / grain_data["pixel_to_nm_scaling"]
            )
            # for each curvature defect, check if there is a height defect within the distance threshold
            coinciding_defects_bool = np.zeros(smoothed_curvatures.shape[0], dtype=bool)
            for i in curvature_defects_indexes:
                for j in height_defects_indexes:
                    # get the distance between the two from the cumulative distances
                    distance_curv_index = cumulative_resampled_distances_px[i]
                    distance_height_index = cumulative_resampled_distances_px[j]
                    distance_curv_height_diff = np.abs(distance_curv_index - distance_height_index)
                    if distance_curv_height_diff < coindicing_defect_distance_threshold_px:
                        coinciding_defects_bool[i] = True
                        break  # just break this loop since we only need find one match for this index

            coinciding_defects_indexes = np.where(coinciding_defects_bool)[0]

            # count the number of coinciding defects
            total_num_coinciding_defects += count_defects_in_binary_array(binary_array=coinciding_defects_bool)

            # update the molecule data with the defect stats
            molecule_data["coinciding_defects_bool"] = coinciding_defects_bool
            molecule_data["coinciding_defects_indexes"] = coinciding_defects_indexes

        grain_data["molecule_data"][molecule_index] = molecule_data

        num_curvature_defects = grain_data["num_curvature_defects"]
        num_height_defects = grain_data["num_height_defects"]

        # save the defect stats to a list
        defect_stats_list.append(
            {
                "grain_index": grain_index,
                "filename": grain_data["filename"],
                "dose": grain_data["dose"],
                "smallest_bounding_area": grain_data["smallest_bounding_area"],
                "aspect_ratio": grain_data["aspect_ratio"],
                "total_contour_length": grain_data["total_contour_length"],
                "num_crossings": grain_data["num_crossings"],
                "molecule_index": molecule_index,
                "num_curvature_defects": num_curvature_defects,
                "num_height_defects": num_height_defects,
                "num_coinciding_defects": total_num_coinciding_defects,
                "mean_curvature_defect_length": mean_curvature_defect_length,
                "mean_height_defect_length": mean_height_defect_length,
                "height_defects_lengths": height_defects_lengths.tolist(),
                "curvature_defects_lengths": curvature_defects_lengths.tolist(),
                "mean_curvature_defect_gap_length": mean_curvature_defect_gap_length,
                "mean_height_defect_gap_length": mean_height_defect_gap_length,
                "curvature_defect_gaps_lengths": curvature_defect_gaps_lengths.tolist(),
                "height_defect_gaps_lengths": height_defect_gaps_lengths.tolist(),
            }
        )

        # plot the defects on the image
        # curvature defects
        if plot_defects:
            fig, ax = plt.subplots(3, 2, figsize=(20, 20))
            plt.suptitle(
                f"image {filename} \n grain {grain_index} {molecule_index} \n mean curv: {np.mean(smoothed_curvatures):.2f}",
                fontsize=20,
            )
            # plot horizontal traces
            ax[0, 0].plot(cumulative_resampled_distances_px, smoothed_curvatures, "k")
            if curvature_defect_method == "per_tip_median_iqr":
                threshold_above, threshold_below = curvature_thresholds_by_tip[tip]
                ax[0, 0].axhline(threshold_above, color="red", linestyle="--", label="threshold above")
                ax[0, 0].axhline(threshold_below, color="blue", linestyle="--", label="threshold below")
                ax[0, 0].legend()
            ax[0, 0].set_ylim(-0.5, 0.5)
            ax[0, 0].set_xlabel("Distance (nm)")
            ax[0, 0].set_ylabel("Curvature (1/nm)")
            ax[0, 0].set_title(f"Curvature")
            ax[0, 1].plot(cumulative_resampled_distances_px, resampled_heights, "k")
            if height_defect_method == "per_tip_median_iqr":
                threshold_height = height_thresholds_by_tip[tip][1]
                ax[0, 1].axhline(threshold_height, color="red", linestyle="--", label="threshold")
                ax[0, 1].legend()
            ax[0, 1].set_ylim(0, 4)
            ax[0, 1].set_xlabel("Distance (nm)")
            ax[0, 1].set_ylabel("Height (nm)")
            ax[0, 1].set_title(f"Height")
            ax[1, 0].imshow(image)
            ax[1, 0].plot(resampled_points_px[:, 1], resampled_points_px[:, 0], "g.")
            for i in curvature_defects_indexes:
                ax[1, 0].plot(resampled_points_px[i, 1], resampled_points_px[i, 0], "r.", markersize=10)
            ax[1, 0].scatter(resampled_points_px[0, 1], resampled_points_px[0, 0], color="white", s=100)
            ax[1, 0].set_title(f"Curvature Defects")
            # height defects
            ax[1, 1].imshow(image)
            ax[1, 1].plot(resampled_points_px[:, 1], resampled_points_px[:, 0], "g.")
            for i in height_defects_indexes:
                ax[1, 1].plot(resampled_points_px[i, 1], resampled_points_px[i, 0], "r.", markersize=10)
            ax[1, 1].scatter(resampled_points_px[0, 1], resampled_points_px[0, 0], color="white", s=100)
            ax[1, 1].set_title(f"Height Defects")
            # plot coinciding defects
            ax_coincide = ax[2, 0]
            ax_coincide.imshow(image)
            ax_coincide.plot(resampled_points_px[:, 1], resampled_points_px[:, 0], "g.")
            for i in coinciding_defects_indexes:
                ax_coincide.plot(resampled_points_px[i, 1], resampled_points_px[i, 0], "r.", markersize=10)
            ax_coincide.scatter(resampled_points_px[0, 1], resampled_points_px[0, 0], color="white", s=100)
            ax_coincide.set_title(f"Coinciding Defects")

            output_dir = defect_stats_dir / f"{dose}GY"
            output_dir.mkdir(exist_ok=True)

            # save the figure
            fig.savefig(
                output_dir / f"grain_{grain_index}_mol{molecule_index}_defects.png",
                dpi=200,
            )
            plt.close(fig)

        if grain_bad:
            bad_grains.add(grain_index)
            continue

        # update the new grains dictionary with the grain data
        new_grains_dictionary[grain_index] = grain_data

    print("")
    print(f"bad grains: {bad_grains}")

    return (
        new_grains_dictionary,
        pd.DataFrame(defect_stats_list),
        curvature_thresholds_by_tip,
        height_thresholds_by_tip,
    )


interval_nm = 2.0  # nm
total_index_to_stop_at = None
curvature_average_window_nm = 10.0  # nm
print("resamping points in grains dictionary")
grains_dictionary_resampled = grains_dictionary_resample_points(
    grains_dictionary=grains_dictionary,
    interval_nm=interval_nm,
    total_index_to_stop_at=total_index_to_stop_at,
    curvature_average_window_nm=curvature_average_window_nm,
)

# Find defects in height and curvature
curvature_defect_method = "per_tip_median_iqr"
curvature_threshold_abs_radpernm = 0.15  # rad/nm
curvature_threshold_iqr_factor = 3.0  # IQR factor for curvature threshold
height_defect_method = "per_tip_median_iqr"
height_threshold_iqr_factor = 3.0  # IQR factor for height threshold
height_threshold_abs_nm = 10.0  # nm, absolute height threshold
coindicing_defect_distance_threshold_nm = 10.0  # nm, distance threshold for coinciding defects
print("finding defects in height and curvature")
grains_dictionary_defects, defect_stats_df, curvature_thresholds_by_tip, height_thresholds_by_tip = (
    find_defects_in_height_and_curvature(
        grains_dictionary=grains_dictionary_resampled,
        curvature_defect_method=curvature_defect_method,
        curvature_threshold_abs_radpernm=curvature_threshold_abs_radpernm,
        curvature_threshold_iqr_factor=curvature_threshold_iqr_factor,
        height_defect_method=height_defect_method,
        height_threshold_iqr_factor=height_threshold_iqr_factor,
        height_threshold_abs_nm=height_threshold_abs_nm,
        coindicing_defect_distance_threshold_nm=coindicing_defect_distance_threshold_nm,
        defect_stats_dir=defect_stats_dir,
        plot_defects=True,
    )
)

# save the defect stats to a csv
defect_stats_df.to_csv(defect_stats_dir / "defect_stats.csv", index=False)

# pickle the dictionary
with open(defect_stats_dir / "grains_defects_stats_dictionary.pickle", "wb") as f:
    pickle.dump(grains_dictionary_defects, f)

In [None]:
# load the csv
defect_stats_df = pd.read_csv(defect_stats_dir / "defect_stats.csv")

print(defect_stats_df.head())

print(len(defect_stats_df))

# plot number of curvature defects per grain
defect_stats_df["num_curvature_defects"].hist()
plt.xlabel("Number of Curvature Defects")
plt.ylabel("Count")
plt.show()
defect_stats_df["num_height_defects"].hist()
plt.xlabel("Number of Height Defects")
plt.ylabel("Count")
plt.show()
defect_stats_df["num_coinciding_defects"].hist()
plt.xlabel("Number of Coinciding Defects")
plt.ylabel("Count")
plt.show()

# plot aspect ratio against number of defects as a scatter
defect_stats_df.plot.scatter(
    x="aspect_ratio",
    y="num_curvature_defects",
)
plt.xlabel("Aspect Ratio")
plt.ylabel("Number of Curvature Defects")
plt.show()
defect_stats_df.plot.scatter(
    x="aspect_ratio",
    y="num_height_defects",
)
plt.xlabel("Aspect Ratio")
plt.ylabel("Number of Height Defects")
plt.show()

median_total_contour_length = defect_stats_df["total_contour_length"].median()
iqr_total_contour_length = defect_stats_df["total_contour_length"].quantile(0.75) - defect_stats_df[
    "total_contour_length"
].quantile(0.25)
print(f"median total contour length: {median_total_contour_length: .2f} nm")
print(f"IQR total contour length: {iqr_total_contour_length: .2f} nm")
plt.figure()
sns.stripplot(data=defect_stats_df, y="total_contour_length", color="black", alpha=0.5)
# violin
sns.violinplot(
    data=defect_stats_df,
    y="total_contour_length",
    color="lightgrey",
)
plt.ylabel("Total Contour Length (nm)")
plt.title(
    f"Total contour length. Median: {median_total_contour_length:.2f} nm, IQR: {iqr_total_contour_length:.2f} nm"
)
plt.show()

# dose vs number of curvature defects
sns.stripplot(data=defect_stats_df, x="dose", y="num_curvature_defects", color="black", alpha=0.5)
sns.violinplot(
    data=defect_stats_df,
    x="dose",
    y="num_curvature_defects",
    color="lightgrey",
)
plt.xlabel("Dose (Gy)")
plt.ylabel("Number of Curvature Defects")
plt.show()

# dose vs number of height defects
sns.stripplot(data=defect_stats_df, x="dose", y="num_height_defects", color="black", alpha=0.5)
sns.violinplot(
    data=defect_stats_df,
    x="dose",
    y="num_height_defects",
    color="lightgrey",
)
plt.xlabel("Dose (Gy)")
plt.ylabel("Number of Height Defects")
plt.show()

# dose vs number of coinciding defects
sns.stripplot(data=defect_stats_df, x="dose", y="num_coinciding_defects", color="black", alpha=0.5)
sns.violinplot(
    data=defect_stats_df,
    x="dose",
    y="num_coinciding_defects",
    color="lightgrey",
)
plt.xlabel("Dose (Gy)")
plt.ylabel("Number of Coinciding Defects")
plt.show()

# mean height defect length vs dose
sns.stripplot(data=defect_stats_df, x="dose", y="mean_height_defect_length", color="black", alpha=0.5)
sns.violinplot(
    data=defect_stats_df,
    x="dose",
    y="mean_height_defect_length",
    color="lightgrey",
)
plt.xlabel("Dose (Gy)")
plt.ylabel("Mean Height Defect Length (nm)")
plt.show()

# mean curvature defect length vs dose
sns.stripplot(data=defect_stats_df, x="dose", y="mean_curvature_defect_length", color="black", alpha=0.5)
sns.violinplot(
    data=defect_stats_df,
    x="dose",
    y="mean_curvature_defect_length",
    color="lightgrey",
)
plt.xlabel("Dose (Gy)")
plt.ylabel("Mean Curvature Defect Length (nm)")
plt.show()

In [None]:
def extract_list_from_lists_in_df(df: pd.DataFrame, column_name: str, ignore_single_elements: bool = False) -> list:
    """
    Extract a list from a column in a DataFrame where each entry is a string representation of a list.
    """
    extracted_list = []
    for item in df[column_name]:
        # convert to list from string using eval
        try:
            list_item = eval(item)
            if isinstance(list_item, list):
                if ignore_single_elements and len(list_item) == 1:
                    continue
                extracted_list.extend(list_item)
            else:
                print(f"Item {item} is not a list, skipping.")
        except Exception as e:
            print(f"Error converting item {item} to list: {e}")

    return extracted_list

In [None]:
# stripplot of height defects lengths
height_defects_lengths = extract_list_from_lists_in_df(defect_stats_df, "height_defects_lengths")
sns.stripplot(y=height_defects_lengths, color="black", alpha=0.5)
sns.violinplot(y=height_defects_lengths, color="lightgrey")
plt.ylabel("Height Defect Lengths (nm)")
plt.title("Height Defect Lengths")
plt.show()
# stripplot of height defect gaps lengths
height_defect_gaps_lengths = extract_list_from_lists_in_df(
    defect_stats_df, "height_defect_gaps_lengths", ignore_single_elements=True
)
sns.stripplot(y=height_defect_gaps_lengths, color="black", alpha=0.5)
sns.violinplot(y=height_defect_gaps_lengths, color="lightgrey")
plt.ylabel("Height Defect Gap Lengths (nm)")
plt.title("Height Defect Gap Lengths")
plt.show()
# stripplot of curvature defects lengths
curvature_defects_lengths = extract_list_from_lists_in_df(defect_stats_df, "curvature_defects_lengths")
sns.stripplot(y=curvature_defects_lengths, color="black", alpha=0.5)
sns.violinplot(y=curvature_defects_lengths, color="lightgrey")
plt.ylabel("Curvature Defect Lengths (nm)")
plt.title("Curvature Defect Lengths")
plt.show()
# stripplot of curvature defect gaps lengths
curvature_defect_gaps_lengths = extract_list_from_lists_in_df(
    defect_stats_df, "curvature_defect_gaps_lengths", ignore_single_elements=True
)
sns.stripplot(y=curvature_defect_gaps_lengths, color="black", alpha=0.5)
sns.violinplot(y=curvature_defect_gaps_lengths, color="lightgrey")
plt.ylabel("Curvature Defect Gap Lengths (nm)")
plt.title("Curvature Defect Gap Lengths")
plt.show()

In [None]:
# random state
np.random.seed(1)
X1 = np.random.rand(20) + 20
X2 = np.random.rand(20) + 20
X1 = np.array(X1)
X2 = np.array(X2)

# Add a defect into X2
X2_with_defect = X2.copy()
X2_with_defect[10:12] -= 1.8
X2_with_defect[15:17] -= 5

plt.plot(X2, ".-", label="sample x2 without defect", c="orange")
plt.plot(X2_with_defect, ".-", label="sample x2 with defect", c="cornflowerblue")
model_x2 = GaussianHMM(n_components=2, covariance_type="diag", n_iter=1000)
predicted_states_x2 = model_x2.fit(X2.reshape(-1, 1)).predict(X2.reshape(-1, 1))
model_x2_with_defect = GaussianHMM(n_components=2, covariance_type="diag", n_iter=1000)
model_x2_with_defect.fit(X2_with_defect.reshape(-1, 1))
predicted_states_x2_with_defect = model_x2_with_defect.predict(X2_with_defect.reshape(-1, 1))

plt.plot(predicted_states_x2, ".-", label="Predicted States x2", c="orange", alpha=0.5)
plt.plot(
    predicted_states_x2_with_defect + 2, ".-", label="Predicted States x2 with defect", c="cornflowerblue", alpha=0.5
)
plt.ylim(-1, np.max(X2) + 1)
plt.legend()
plt.title("HMM forces itself to find defects even when none are present.")
plt.show()


combined_no_defects = np.concatenate((X1, X2))
combined_with_defects = np.concatenate((X1, X2_with_defect))

model_combined_no_defects = GaussianHMM(n_components=2, covariance_type="diag", n_iter=1000)
model_combined_no_defects.fit(combined_no_defects.reshape(-1, 1))
predicted_states_combined_no_defects = model_combined_no_defects.predict(combined_no_defects.reshape(-1, 1))
model_combined_with_defects = GaussianHMM(n_components=2, covariance_type="diag", n_iter=1000)
model_combined_with_defects.fit(combined_with_defects.reshape(-1, 1))
predicted_states_combined_with_defects = model_combined_with_defects.predict(combined_with_defects.reshape(-1, 1))

plt.figure(figsize=(10, 5))
plt.plot(combined_no_defects, ".-", label="Combined No Defects", c="orange")
plt.plot(combined_with_defects, ".-", label="Combined With Defects", c="cornflowerblue")
plt.plot(
    predicted_states_combined_no_defects, ".-", label="Predicted States Combined No Defects", c="orange", alpha=0.5
)
plt.plot(
    predicted_states_combined_with_defects + 2,
    ".-",
    label="Predicted States Combined With Defects",
    c="cornflowerblue",
    alpha=0.5,
)
plt.ylim(-1, np.max(combined_no_defects) + 1)
plt.legend()
plt.title(
    "Concatenated data with at least one real defect can make HMM behave.\nBut total absense of defects forces it to find false defects."
)
plt.show()