In [None]:
import pickle
from pathlib import Path
import re

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from hmmlearn.hmm import GaussianHMM

from AFMReader.topostats import load_topostats

from topostats.tracing.splining import resample_points_regular_interval
from topostats.measure.curvature import discrete_angle_difference_per_nm_circular

import numpy.typing as npt
import numpy as np
from topostats.io import LoadScans
import matplotlib.pyplot as plt
from topostats.unet_masking import make_bounding_box_square, pad_bounding_box

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable


def plot_line_coloured(
    intensity_ys: npt.NDArray[np.float64],
    intensity_xs: npt.NDArray[np.float64] | None = None,
    image: npt.NDArray[np.float64] | None = None,
    xy_coords: npt.NDArray[np.float64] | None = None,
    figsize: tuple[int, int] = (20, 8),
):
    """
    Plot a line plot with the colour of each point determined by the position in the array.

    Parameters
    ----------
    intensity_ys : npt.NDArray[np.float_]
        The values to plot.
    intensity_xs : npt.NDArray[np.float_], optional
        The x-coordinates of the points. If None, the indices of yvals are used.
    image : npt.NDArray[np.float_], optional
        An image to plot the line on top of. If None, no image is plotted.
    xy_coords : npt.NDArray[np.float_], optional
        The coordinates of the points, for plotting trace on top of the image.
    """

    # Generate a colormap
    cmap = plt.get_cmap("cool")  # Smooth transition from red to purple
    norm = Normalize(vmin=0, vmax=len(intensity_ys) - 1)
    sm = ScalarMappable(norm=norm, cmap=cmap)

    # optionally create x coords or use the provided positions
    if intensity_xs is not None:
        assert len(intensity_xs) == len(intensity_ys), "intensity_xs and intensity_ys must have the same length"
        x = intensity_xs
    else:
        x = np.arange(len(intensity_ys))

    # If image supplied
    if image is not None:
        assert xy_coords is not None, "xy_coords must be provided if image is supplied"
        assert intensity_ys.shape[0] == xy_coords.shape[0], "intensity_ys and xy_coords must have the same length"
        fig, ax = plt.subplots(1, 2, figsize=figsize)
        ax[0].imshow(image, cmap="afmhot", interpolation="none")
        ax[0].set_title("Image")

        # plot the points of the line with the corresponding color
        for i in range(len(intensity_ys)):
            ax[0].scatter(xy_coords[i, 1], xy_coords[i, 0], color=sm.to_rgba(i), s=10)
        ax[0].set_title("Image with Line Overlay")

        lineplot_ax = ax[1]
    else:
        fig, ax = plt.subplots(figsize=figsize)
        lineplot_ax = ax

    # Plot each segment with its corresponding color
    for i in range(len(intensity_ys) - 1):
        lineplot_ax.plot(x[i : i + 2], intensity_ys[i : i + 2], color=sm.to_rgba(i), linewidth=2)

    # Add colorbar for reference
    plt.colorbar(sm, ax=ax, label="Position in Array")
    lineplot_ax.set_xlabel("Index")
    lineplot_ax.set_ylabel("Value")
    lineplot_ax.set_title("Line Plot with Colour Gradient")
    plt.show()


random_values = np.random.rand(100)
# plot_line_coloured(random_values)


def defect_stats(
    curvature_defects_indexes: npt.NDArray[np.int_],
    height_defects_indexes: npt.NDArray[np.int_],
) -> dict:
    """Calculate defect stats"""

    # create a bool array for the defects
    pass

In [None]:
base_dir = Path("/Users/sylvi/topo_data/picoz")
assert base_dir.exists()

processed_dir = base_dir / "output_abs_thresh"
assert processed_dir.exists()

postprocess_dir = base_dir / "postprocess"
postprocess_dir.mkdir(exist_ok=True)

# load the images
topo_files = list(processed_dir.glob("*/**/*.topostats"))
print(f"found {len(topo_files)} topostats files")

# load the corresponding stats csv
all_stats_csv = processed_dir / "all_statistics.csv"
assert all_stats_csv.exists()
all_stats_df = pd.read_csv(all_stats_csv)
print(all_stats_df.columns)

all_stats_df["total_contour_length"] /= 1e-9

In [None]:
def construct_grains_dictionary(
    file_list: list,
    bbox_padding: int,
    all_stats_df: pd.DataFrame,
    stop_at_index: int | None = None,
    plot: bool = False,
):
    grains_dictionary: dict[any] = {}

    loadscans = LoadScans(file_list, channel="dummy")
    loadscans.get_data()
    img_dict = loadscans.img_dict

    grain_index = 0
    for file_index, (filename, file_data) in enumerate(img_dict.items()):
        if stop_at_index is not None and file_index >= stop_at_index:
            break
        try:
            try:
                nodestats_data = file_data["nodestats"]["above"]["stats"]
            except KeyError:
                nodestats_data = None

            # get the corresponding image rows from the stats csv
            all_stats_image = all_stats_df[all_stats_df["image"] == filename]
            print(f"found {len(all_stats_image)} rows in stats csv for {filename}")

            # print(f"getting data from {filename}")
            image = file_data["image"]
            ordered_trace_data = file_data["ordered_traces"]["above"]

            for current_grain_index, grain_ordered_trace_data in ordered_trace_data.items():

                grain_index_int = re.sub(r"grain_", "", current_grain_index)
                grain_index_int = int(grain_index_int)
                all_stats_grain = all_stats_image[all_stats_image["grain_number"] == grain_index_int]
                # grab quantities from the stats csv
                # grab the basename, and check that theyre all the same for this grain
                basename = all_stats_grain["basename"].values[0]
                if pd.isna(basename):
                    print(f"skipping grain {current_grain_index} in file {filename} because basename is NaN")
                    continue
                if not all(all_stats_grain["basename"].values == basename):
                    raise ValueError(
                        f"basename mismatch for grain {current_grain_index} in file {filename}: {all_stats_grain['basename'].values}"
                    )
                # dose is stored in the folder name in basename. it'll be to the right of the last forward slash, in
                # the format "dose_XGY" where X is the dose value
                dose_match = re.search(r"dose_(\d+\.?\d*)", basename)
                assert dose_match is not None, f"could not find dose in basename {basename}"
                dose = float(dose_match.group(1))
                # find tip number, an integer in format "tip_x"
                tip_match = re.search(r"tip_(\d+)", filename)
                assert tip_match is not None, f"could not find tip in basename {filename}"
                tip = int(tip_match.group(1))
                # if basename is nan, skip grain
                smallest_bounding_area = all_stats_grain["smallest_bounding_area"].values[0]
                aspect_ratio = all_stats_grain["aspect_ratio"].values[0]
                total_contour_length = all_stats_grain["total_contour_length"].values[0]
                num_crossings = all_stats_grain["num_crossings"].values[0]

                # print(f"  grain {current_grain_index}")
                grains_dictionary[grain_index] = {}
                grains_dictionary[grain_index]["molecule_data"] = {}

                # all_stats_grain = all_stats_image[all_stats_image[""]]

                for current_molecule_index, molecule_ordered_trace_data in grain_ordered_trace_data.items():
                    molecule_data = {}
                    ordered_coords = molecule_ordered_trace_data["ordered_coords"]
                    molecule_data["heights"] = molecule_ordered_trace_data["heights"]
                    molecule_data["distances"] = molecule_ordered_trace_data["distances"]
                    molecule_data["circular"] = molecule_ordered_trace_data["mol_stats"]["circular"]
                    bbox = molecule_ordered_trace_data["bbox"]
                    print(f"  grain {current_grain_index} molecule {current_molecule_index} bbox {bbox}")

                    splining_coords = file_data["splining"]["above"][current_grain_index][current_molecule_index][
                        "spline_coords"
                    ]

                    curvatures = file_data["grain_curvature_stats"]["above"][current_grain_index][
                        current_molecule_index
                    ]
                    molecule_data["curvatures"] = curvatures

                    # bbox will be same for all molecules so this is okay
                    bbox_square = make_bounding_box_square(bbox[0], bbox[1], bbox[2], bbox[3], image.shape)
                    bbox_padded = pad_bounding_box(
                        bbox_square[0],
                        bbox_square[1],
                        bbox_square[2],
                        bbox_square[3],
                        image.shape,
                        padding=bbox_padding,
                    )
                    added_left = bbox_padded[1] - bbox[1]
                    added_top = bbox_padded[0] - bbox[0]

                    # adjust the spline coords to account for the padding
                    splining_coords[:, 0] -= added_top
                    splining_coords[:, 1] -= added_left
                    molecule_data["spline_coords"] = splining_coords

                    # adjust the ordered coords to account for the padding
                    ordered_coords[:, 0] -= added_top
                    ordered_coords[:, 1] -= added_left
                    molecule_data["ordered_coords"] = ordered_coords

                    grains_dictionary[grain_index]["molecule_data"][current_molecule_index] = molecule_data

                image_crop = image[
                    bbox_padded[0] : bbox_padded[2],
                    bbox_padded[1] : bbox_padded[3],
                ]

                full_grain_mask = file_data["grain_tensors"]["above"]
                grains_dictionary[grain_index]["basename"] = basename
                grains_dictionary[grain_index]["dose"] = dose
                grains_dictionary[grain_index]["tip"] = tip
                grains_dictionary[grain_index]["image"] = image_crop
                grains_dictionary[grain_index]["full_image"] = image
                grains_dictionary[grain_index]["bbox"] = bbox_padded
                grains_dictionary[grain_index]["aspect_ratio"] = aspect_ratio
                grains_dictionary[grain_index]["smallest_bounding_area"] = smallest_bounding_area
                grains_dictionary[grain_index]["total_contour_length"] = total_contour_length
                grains_dictionary[grain_index]["num_crossings"] = num_crossings
                grains_dictionary[grain_index]["added_left"] = added_left
                grains_dictionary[grain_index]["added_top"] = added_top
                grains_dictionary[grain_index]["padding"] = bbox_padding
                mask_crop = full_grain_mask[
                    bbox_padded[0] : bbox_padded[2],
                    bbox_padded[1] : bbox_padded[3],
                ]
                grains_dictionary[grain_index]["mask"] = mask_crop
                grains_dictionary[grain_index]["filename"] = file_data["filename"]
                grains_dictionary[grain_index]["pixel_to_nm_scaling"] = file_data["pixel_to_nm_scaling"]
                grains_dictionary[grain_index]["curvature_stats"] = curvatures

                # grab node coordinates
                all_node_coords = []
                if nodestats_data is not None:
                    try:
                        grain_nodestats_data = nodestats_data[current_grain_index]
                        for _node_index, node_data in grain_nodestats_data.items():
                            node_coords = node_data["node_coords"]
                            # adjust the node coords to account for the padding
                            node_coords[:, 0] -= added_top
                            node_coords[:, 1] -= added_left
                            for node_coord in node_coords:
                                all_node_coords.append(node_coord)
                    except KeyError as e:
                        if "grain_" in str(e):
                            # grain has no nodestats data here, skip
                            pass

                grains_dictionary[grain_index]["node_coords"] = np.array(all_node_coords)

                grain_index += 1
        except KeyError as e:
            if "ordered_traces" in str(e):
                print(f"no ordered traces found in {filename}")
                continue
            raise e

    if plot:
        for grain_index, grain_data in grains_dictionary.items():
            print(f"grain {grain_index}")
            print(grain_data["filename"])
            print(grain_data["pixel_to_nm_scaling"])
            image = grain_data["image"]
            fig, ax = plt.subplots(figsize=(20, 20))
            plt.imshow(image)
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                ordered_coords = molecule_data["ordered_coords"]
                plt.plot(ordered_coords[:, 1], ordered_coords[:, 0], "r")
            # spline coords
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                spline_coords = molecule_data["spline_coords"]
                plt.plot(spline_coords[:, 1], spline_coords[:, 0], "g.")
                spline_distances = np.linalg.norm(spline_coords[1:, :] - spline_coords[:-1, :], axis=1)
                print(
                    f"molecule {molecule_index} spline distances mean {np.mean(spline_distances)} std {np.std(spline_distances)} min {np.min(spline_distances)} max {np.max(spline_distances)}"
                )
            all_node_coords = grain_data["node_coords"]
            if all_node_coords.size > 0:
                plt.plot(all_node_coords[:, 1], all_node_coords[:, 0], "b.")
            plt.show()

            mask = grain_data["mask"][:, :, 1]
            plt.imshow(mask)
            plt.show()

    print(f"found {len(grains_dictionary)} grains in {len(file_list)} images")

    return grains_dictionary


grains_dictionary = construct_grains_dictionary(
    topo_files, bbox_padding=20, all_stats_df=all_stats_df, stop_at_index=None, plot=False
)

In [None]:
defect_stats_dir = postprocess_dir / "defect_stats"
defect_stats_dir.mkdir(exist_ok=True)


def grains_dictionary_resample_points(
    grains_dictionary: dict,
    interval_nm: float = 2.0,
    total_index_to_stop_at: int | None = None,
    curvature_average_window_nm: float = 10.0,
) -> dict:
    new_grains_dictionary = {}
    # Iterate through the grains dictionary
    for total_index, (grain_index, grain_data) in enumerate(grains_dictionary.items()):
        if total_index_to_stop_at is not None and total_index >= total_index_to_stop_at:
            break

        pixel_to_nm_scaling = grain_data["pixel_to_nm_scaling"]

        # Iterate through each molecule in the grain
        for molecule_index, molecule_data in grain_data["molecule_data"].items():

            circular = molecule_data["circular"]
            if circular is False:
                print(f"  molecule {molecule_index} is not circular, skipping")
                continue
            if circular is None:
                print(f"  molecule {molecule_index} circluar is NONE, skipping")
                continue

            spline_coords = molecule_data["spline_coords"]
            resampled_points = resample_points_regular_interval(
                points=spline_coords, interval=interval_nm / pixel_to_nm_scaling, circular=True
            )

            # resample distances and curvatures
            resampled_distances = np.linalg.norm(resampled_points[1:, :] - resampled_points[:-1, :], axis=1)
            # add the distance from the last point to the first point
            resampled_distances = np.insert(
                resampled_distances, 0, np.linalg.norm(resampled_points[0, :] - resampled_points[-1, :])
            )
            cumulative_resampled_distances = np.cumsum(resampled_distances)
            resampled_curvatures = discrete_angle_difference_per_nm_circular(
                trace_nm=resampled_points * pixel_to_nm_scaling
            )

            # smooth curvatures using rolling window
            curvature_rolling_window_size = int(curvature_average_window_nm / interval_nm)
            smoothed_curvatures = np.convolve(
                resampled_curvatures, np.ones(curvature_rolling_window_size), mode="same"
            )
            smoothed_curvatures = smoothed_curvatures / curvature_rolling_window_size

            # grab height values at resampled points
            resampled_heights = np.zeros(resampled_points.shape[0])
            for i, point in enumerate(resampled_points):
                int_x = int(np.round(point[0]))
                int_y = int(np.round(point[1]))
                height = grain_data["image"][int_x, int_y]
                resampled_heights[i] = height

            # update the molecule data with the resampled data
            molecule_data["resampled_points"] = resampled_points
            molecule_data["cumulative_resampled_distances"] = cumulative_resampled_distances
            molecule_data["resampled_heights"] = resampled_heights
            molecule_data["resampled_curvatures"] = resampled_curvatures
            molecule_data["smoothed_curvatures"] = smoothed_curvatures
            # update the grain data with the resampled molecule data
            grain_data["molecule_data"][molecule_index] = molecule_data
        # update the grains dictionary with the resampled grain data
        new_grains_dictionary[grain_index] = grain_data
    return new_grains_dictionary


def find_defects_in_height_and_curvature(
    grains_dictionary: dict,
    curvature_threshold_radpernm: float,
    height_threshold_iqr_factor: float,
    defect_method: str,
    coindicing_defect_distance_threshold_nm: float,
    defect_stats_dir: Path,
    plot_defects: bool,
) -> tuple[dict, pd.DataFrame]:
    """
    Find defects in height and curvature for each molecule in the grains dictionary.
    """
    new_grains_dictionary = {}

    if defect_method == "threshold":
        all_heights_by_tip = {}
        all_curvatures_by_tip = {}
        for grain_index, grain_data in grains_dictionary.items():
            tip = grain_data["tip"]
            if tip not in all_heights_by_tip:
                all_heights_by_tip[tip] = []
                all_curvatures_by_tip[tip] = []
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                

    # create a list to store the defect stats
    defect_stats_list = []
    bad_grains = set()
    for grain_index, grain_data in grains_dictionary.items():
        print(f"g{grain_index}", end=" ")
        dose = grain_data["dose"]
        image = grain_data["image"]
        filename = grain_data["filename"]
        grain_bad = False
        for molecule_index, molecule_data in grain_data["molecule_data"].items():
            resampled_points = molecule_data["resampled_points"]
            resampled_heights = molecule_data["resampled_heights"]
            smoothed_curvatures = molecule_data["smoothed_curvatures"]
            cumulative_resampled_distances = molecule_data["cumulative_resampled_distances"]

            # Find curvature defects
            if defect_method == "threshold":
                curvature_defects_indexes = np.where(np.abs(smoothed_curvatures) > curvature_threshold_radpernm)[0]
            elif defect_method == "hmm":
                # Use HMM to find curvature defects
                # Reshape the data for HMM 2D input: (n_samples, n_features), we are working with 1D data so we reshape it
                # to (n_samples, 1). n_samples is just the number of points in the trace.
                smoothed_curvatures_reshaped = smoothed_curvatures.reshape(-1, 1)
                # Fit the HMM model
                hmm_model_curvature = GaussianHMM(n_components=2, covariance_type="diag", n_iter=1000)
                hmm_model_curvature.fit(smoothed_curvatures_reshaped)
                # Predict the state for each point
                curvature_states = hmm_model_curvature.predict(smoothed_curvatures_reshaped)
                # Grab the indexes of curvature defects
                curvature_defects_indexes = np.where(curvature_states == 1)[0]

            # Find height defects
            median_height = np.median(resampled_heights)
            q75, q25 = np.percentile(resampled_heights, [75, 25])
            iqr = q75 - q25
            height_threshold_nm = median_height - height_threshold_iqr_factor * iqr
            height_defects_indexes = np.where(resampled_heights < height_threshold_nm)[0]

            curvature_defects_bool = np.zeros(smoothed_curvatures.shape[0], dtype=bool)
            curvature_defects_bool[curvature_defects_indexes] = True
            height_defects_bool = np.zeros(smoothed_curvatures.shape[0], dtype=bool)
            height_defects_bool[height_defects_indexes] = True

            # find number of defect regions
            num_curvature_defects = 0
            for i in range(1, len(curvature_defects_bool)):
                if curvature_defects_bool[i] and not curvature_defects_bool[i - 1]:
                    num_curvature_defects += 1
            if curvature_defects_bool[0] and curvature_defects_bool[-1]:
                num_curvature_defects -= 1
                if num_curvature_defects < 0:
                    # This happens if the whole trace is a defect
                    print(f"\n[Bad grain: {grain_index} molecule: {molecule_index}: num_curvature_defects < 0]")
                    grain_bad = True
                    break
            num_height_defects = 0
            for i in range(1, len(height_defects_bool)):
                if height_defects_bool[i] and not height_defects_bool[i - 1]:
                    num_height_defects += 1
            if height_defects_bool[0] and height_defects_bool[-1]:
                num_height_defects -= 1

            # Check for coinciding defects within the distance threshold
            coindicing_defect_distance_threshold_px = (
                coindicing_defect_distance_threshold_nm / grain_data["pixel_to_nm_scaling"]
            )
            # for each curvature defect, check if there is a height defect within the distance threshold
            coinciding_defect_bool = np.zeros(smoothed_curvatures.shape[0], dtype=bool)
            for i in curvature_defects_indexes:
                for j in height_defects_indexes:
                    # get the distance between the two from the cumulative distances
                    distance_curv_index = cumulative_resampled_distances[i]
                    distance_height_index = cumulative_resampled_distances[j]
                    distance_curv_height_diff = np.abs(distance_curv_index - distance_height_index)
                    if distance_curv_height_diff < coindicing_defect_distance_threshold_px:
                        coinciding_defect_bool[i] = True
                        break  # just break this loop since we only need find one match for this index

            coinciding_defects_indexes = np.where(coinciding_defect_bool)[0]

            # count the number of coinciding defects
            num_coinciding_defects = 0
            for i in range(1, len(coinciding_defect_bool)):
                if coinciding_defect_bool[i] and not coinciding_defect_bool[i - 1]:
                    num_coinciding_defects += 1
            if coinciding_defect_bool[0] and coinciding_defect_bool[-1]:
                num_coinciding_defects -= 1

            # update the molecule data with the defect stats
            molecule_data["curvature_defects_indexes"] = curvature_defects_indexes
            molecule_data["height_defects_indexes"] = height_defects_indexes
            molecule_data["coinciding_defects_indexes"] = coinciding_defects_indexes
            molecule_data["num_curvature_defects"] = num_curvature_defects
            molecule_data["num_height_defects"] = num_height_defects
            molecule_data["num_coinciding_defects"] = num_coinciding_defects
            grain_data["molecule_data"][molecule_index] = molecule_data

            # save the defect stats to a list
            defect_stats_list.append(
                {
                    "grain_index": grain_index,
                    "filename": grain_data["filename"],
                    "dose": grain_data["dose"],
                    "smallest_bounding_area": grain_data["smallest_bounding_area"],
                    "aspect_ratio": grain_data["aspect_ratio"],
                    "total_contour_length": grain_data["total_contour_length"],
                    "num_crossings": grain_data["num_crossings"],
                    "molecule_index": molecule_index,
                    "num_curvature_defects": num_curvature_defects,
                    "num_height_defects": num_height_defects,
                    "num_coinciding_defects": num_coinciding_defects,
                }
            )

            # plot the defects on the image
            # curvature defects
            if plot_defects:
                fig, ax = plt.subplots(3, 2, figsize=(20, 20))
                plt.suptitle(
                    f"image {filename} \n grain {grain_index} {molecule_index} \n mean curv: {np.mean(smoothed_curvatures):.2f}",
                    fontsize=20,
                )
                # plot horizontal traces
                ax[0, 0].plot(cumulative_resampled_distances, smoothed_curvatures, "k")
                ax[0, 0].set_xlabel("Distance (nm)")
                ax[0, 0].set_ylabel("Curvature (1/nm)")
                ax[0, 0].hlines(
                    [curvature_threshold_radpernm, -curvature_threshold_radpernm],
                    0,
                    cumulative_resampled_distances[-1],
                    color="k",
                    linestyles="dashed",
                )
                ax[0, 0].set_title(f"Curvature")
                ax[0, 1].plot(cumulative_resampled_distances, resampled_heights, "k")
                ax[0, 1].set_xlabel("Distance (nm)")
                ax[0, 1].set_ylabel("Height (nm)")
                ax[0, 1].hlines(
                    [height_threshold_nm], 0, cumulative_resampled_distances[-1], color="k", linestyles="dashed"
                )
                ax[0, 1].set_title(f"Height")
                ax[1, 0].imshow(image)
                ax[1, 0].plot(resampled_points[:, 1], resampled_points[:, 0], "g.")
                for i in curvature_defects_indexes:
                    ax[1, 0].plot(resampled_points[i, 1], resampled_points[i, 0], "r.", markersize=10)
                ax[1, 0].scatter(resampled_points[0, 1], resampled_points[0, 0], color="white", s=100)
                ax[1, 0].set_title(f"Curvature Defects")
                # height defects
                ax[1, 1].imshow(image)
                ax[1, 1].plot(resampled_points[:, 1], resampled_points[:, 0], "g.")
                for i in height_defects_indexes:
                    ax[1, 1].plot(resampled_points[i, 1], resampled_points[i, 0], "r.", markersize=10)
                ax[1, 1].scatter(resampled_points[0, 1], resampled_points[0, 0], color="white", s=100)
                ax[1, 1].set_title(f"Height Defects")
                # plot coinciding defects
                ax_coincide = ax[2, 0]
                ax_coincide.imshow(image)
                ax_coincide.plot(resampled_points[:, 1], resampled_points[:, 0], "g.")
                for i in coinciding_defects_indexes:
                    ax_coincide.plot(resampled_points[i, 1], resampled_points[i, 0], "r.", markersize=10)
                ax_coincide.scatter(resampled_points[0, 1], resampled_points[0, 0], color="white", s=100)
                ax_coincide.set_title(f"Coinciding Defects")

                # save the figure
                fig.savefig(
                    defect_stats_dir / f"{dose}GY" / f"grain_{grain_index}_mol{molecule_index}_defects.png",
                    dpi=200,
                )
                plt.close(fig)

        if grain_bad:
            bad_grains.add(grain_index)
            continue

        # update the new grains dictionary with the grain data
        new_grains_dictionary[grain_index] = grain_data

    print("")
    print(f"bad grains: {bad_grains}")

    return new_grains_dictionary, pd.DataFrame(defect_stats_list)


interval_nm = 2.0  # nm
total_index_to_stop_at = None
curvature_average_window_nm = 10.0  # nm
print("resamping points in grains dictionary")
grains_dictionary_resampled = grains_dictionary_resample_points(
    grains_dictionary=grains_dictionary,
    interval_nm=interval_nm,
    total_index_to_stop_at=total_index_to_stop_at,
    curvature_average_window_nm=curvature_average_window_nm,
)

# find defects in height and curvature
curvature_threshold_radpernm = 0.15  # rad/nm
height_threshold_iqr_factor = 3.0  # IQR factor for height threshold
coindicing_defect_distance_threshold_nm = 10.0  # nm, distance threshold for coinciding defects
print("finding defects in height and curvature")
grains_dictionary_defects, defect_stats_df = find_defects_in_height_and_curvature(
    grains_dictionary=grains_dictionary_resampled,
    curvature_threshold_radpernm=curvature_threshold_radpernm,
    height_threshold_iqr_factor=height_threshold_iqr_factor,
    coindicing_defect_distance_threshold_nm=coindicing_defect_distance_threshold_nm,
    defect_stats_dir=defect_stats_dir,
    plot_defects=False,
)

# save the defect stats to a csv
defect_stats_df.to_csv(defect_stats_dir / "defect_stats.csv", index=False)

# pickle the dictionary
with open(defect_stats_dir / "grains_defects_stats_dictionary.pickle", "wb") as f:
    pickle.dump(grains_dictionary_defects, f)

In [None]:
# load the csv
defect_stats_df = pd.read_csv(defect_stats_dir / "defect_stats.csv")

print(defect_stats_df.head())

print(len(defect_stats_df))

# plot number of curvature defects per grain
defect_stats_df["num_curvature_defects"].hist()
plt.xlabel("Number of Curvature Defects")
plt.ylabel("Count")
plt.show()
defect_stats_df["num_height_defects"].hist()
plt.xlabel("Number of Height Defects")
plt.ylabel("Count")
plt.show()
defect_stats_df["num_coinciding_defects"].hist()
plt.xlabel("Number of Coinciding Defects")
plt.ylabel("Count")
plt.show()

# plot aspect ratio against number of defects as a scatter
defect_stats_df.plot.scatter(
    x="aspect_ratio",
    y="num_curvature_defects",
)
plt.xlabel("Aspect Ratio")
plt.ylabel("Number of Curvature Defects")
plt.show()
defect_stats_df.plot.scatter(
    x="aspect_ratio",
    y="num_height_defects",
)
plt.xlabel("Aspect Ratio")
plt.ylabel("Number of Height Defects")
plt.show()

plt.figure()
sns.stripplot(data=defect_stats_df, y="total_contour_length", color="black", alpha=0.5)
# violin
sns.violinplot(
    data=defect_stats_df,
    y="total_contour_length",
    color="lightgrey",
)
plt.ylabel("Total Contour Length (nm)")
plt.show()

# dose vs number of curvature defects
sns.stripplot(data=defect_stats_df, x="dose", y="num_curvature_defects", color="black", alpha=0.5)
sns.violinplot(
    data=defect_stats_df,
    x="dose",
    y="num_curvature_defects",
    color="lightgrey",
)
plt.xlabel("Dose (Gy)")
plt.ylabel("Number of Curvature Defects")
plt.show()

# dose vs number of height defects
sns.stripplot(data=defect_stats_df, x="dose", y="num_height_defects", color="black", alpha=0.5)
sns.violinplot(
    data=defect_stats_df,
    x="dose",
    y="num_height_defects",
    color="lightgrey",
)
plt.xlabel("Dose (Gy)")
plt.ylabel("Number of Height Defects")
plt.show()

In [None]:
# random state
np.random.seed(1)
X1 = np.random.rand(20) + 20
X2 = np.random.rand(20) + 20
X1 = np.array(X1)
X2 = np.array(X2)

# Add a defect into X2
X2_with_defect = X2.copy()
X2_with_defect[10:12] -= 1.8
X2_with_defect[15:17] -= 5

plt.plot(X2, ".-", label="sample x2 without defect", c="orange")
plt.plot(X2_with_defect, ".-", label="sample x2 with defect", c="cornflowerblue")
model_x2 = GaussianHMM(n_components=2, covariance_type="diag", n_iter=1000)
predicted_states_x2 = model_x2.fit(X2.reshape(-1, 1)).predict(X2.reshape(-1, 1))
model_x2_with_defect = GaussianHMM(n_components=2, covariance_type="diag", n_iter=1000)
model_x2_with_defect.fit(X2_with_defect.reshape(-1, 1))
predicted_states_x2_with_defect = model_x2_with_defect.predict(X2_with_defect.reshape(-1, 1))

plt.plot(predicted_states_x2, ".-", label="Predicted States x2", c="orange", alpha=0.5)
plt.plot(
    predicted_states_x2_with_defect + 2, ".-", label="Predicted States x2 with defect", c="cornflowerblue", alpha=0.5
)
plt.ylim(-1, np.max(X2) + 1)
plt.legend()
plt.title("HMM forces itself to find defects even when none are present.")
plt.show()


combined_no_defects = np.concatenate((X1, X2))
combined_with_defects = np.concatenate((X1, X2_with_defect))

model_combined_no_defects = GaussianHMM(n_components=2, covariance_type="diag", n_iter=1000)
model_combined_no_defects.fit(combined_no_defects.reshape(-1, 1))
predicted_states_combined_no_defects = model_combined_no_defects.predict(combined_no_defects.reshape(-1, 1))
model_combined_with_defects = GaussianHMM(n_components=2, covariance_type="diag", n_iter=1000)
model_combined_with_defects.fit(combined_with_defects.reshape(-1, 1))
predicted_states_combined_with_defects = model_combined_with_defects.predict(combined_with_defects.reshape(-1, 1))

plt.figure(figsize=(10, 5))
plt.plot(combined_no_defects, ".-", label="Combined No Defects", c="orange")
plt.plot(combined_with_defects, ".-", label="Combined With Defects", c="cornflowerblue")
plt.plot(
    predicted_states_combined_no_defects, ".-", label="Predicted States Combined No Defects", c="orange", alpha=0.5
)
plt.plot(
    predicted_states_combined_with_defects + 2,
    ".-",
    label="Predicted States Combined With Defects",
    c="cornflowerblue",
    alpha=0.5,
)
plt.ylim(-1, np.max(combined_no_defects) + 1)
plt.legend()
plt.title(
    "Concatenated data with at least one real defect can make HMM behave.\nBut total absense of defects forces it to find false defects."
)
plt.show()