In [None]:
from pathlib import Path
import re
import pickle
from datetime import datetime

import numpy as np
import numpy.typing as npt
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import seaborn as sns
from scipy.ndimage import distance_transform_edt, gaussian_filter1d, binary_fill_holes
from skimage.measure import label, regionprops
from skimage.graph import route_through_array
from skimage.morphology import binary_dilation, binary_erosion, skeletonize

from topostats.plottingfuncs import Colormap
from topostats.tracing.splining import resample_points_regular_interval, windowTrace
from topostats.measure.feret import get_feret_from_mask
from topostats.measure.curvature import discrete_angle_difference_per_nm_linear
from topostats.damage.damage import get_defects_and_gaps_from_bool_array, Defect, DefectGap

colormap = Colormap()
CMAP = colormap.get_cmap()
COLOUR_PATH_CMAP = mpl.cm.coolwarm

VMIN = 0
VMAX = 4

In [None]:
TODAY_DATE = datetime.now().strftime("%Y-%m-%d")
DATE_TO_READ_FROM = "2024-05-21"
BASE_DATA_DIR = Path(f"/Users/sylvi/topo_data/hariborings/extracted-grains-new-names-20250814/")
data_dirs = {
    "ON_SC": BASE_DATA_DIR / f"cas9_ON_SC/date_{DATE_TO_READ_FROM}",
    "OT1_SC": BASE_DATA_DIR / f"cas9_OT1_SC/date_{DATE_TO_READ_FROM}",
    "OT2_SC": BASE_DATA_DIR / f"cas9_OT2_SC/date_{DATE_TO_READ_FROM}",
}
BASE_SAVE_DIR = Path(f"/Users/sylvi/topo_data/hariborings/processed_grains/")
EXT_FIG_DATA_SAVE_DIR = Path(
    "/Users/sylvi/topo_data/cas9-paper-our-response-to-reviewers/20250805-quentin-reviewer-response/figure_ext_3_images/high_res_curvature_graphs"
)
MAX_PX_TO_NM = 10.0

save_dirs = {
    "ON_SC": BASE_SAVE_DIR / f"cas9_ON_SC/{TODAY_DATE}",
    "OT1_SC": BASE_SAVE_DIR / f"cas9_OT1_SC/{TODAY_DATE}",
    "OT2_SC": BASE_SAVE_DIR / f"cas9_OT2_SC/{TODAY_DATE}",
}
for sample_type in save_dirs.keys():
    save_dirs[sample_type].mkdir(exist_ok=True, parents=True)

sample_types = ["ON_SC", "OT1_SC", "OT2_SC"]
grains_dicts = {}
for sample_type in sample_types:
    grains_dicts[sample_type] = {}
    file_path = data_dirs[sample_type] / "ferets_dict.pkl"
    with open(file_path, "rb") as f:
        sample_grains_dicts = pickle.load(f)
        for grain_index, single_grain_dict in sample_grains_dicts.items():
            p_to_nm = single_grain_dict["p_to_nm"]
            single_grain_dict["sample_type"] = sample_type
            single_grain_dict["grain_index"] = grain_index
            if p_to_nm > MAX_PX_TO_NM:
                print(
                    f"Skipping grain {grain_index} for sample type [{sample_type}] due to p_to_nm too large: {p_to_nm} > {MAX_PX_TO_NM}"
                )
            else:
                grains_dicts[sample_type][grain_index] = single_grain_dict


# print num grains for each sample type
for sample_type, grain_dict in grains_dicts.items():
    print(f"Number of grains for sample type [{sample_type}]: {len(grain_dict.keys())}")

In [None]:
def pad_image_to_set_nm_size(
    image: np.ndarray, p_to_nm: float, target_size_nm: float
) -> tuple[npt.NDArray[np.float64], tuple[int, int, int, int]]:
    """Pad an image to a target size in nm."""
    target_size_px = int(target_size_nm / p_to_nm)
    # if the image is smaller than the target size, pad it with zeros
    # If either dimension is larger than the target size, crop it to the target size
    if image.shape[0] > target_size_px or image.shape[1] > target_size_px:
        cropped_image = image[:target_size_px, :target_size_px]
        padding = (
            -1 * max(0, target_size_px - cropped_image.shape[0]),
            -1 * max(0, target_size_px - cropped_image.shape[1]),
        )
        return cropped_image, padding

    # If either dimension is smaller than the target size, pad it with zeros half on each side
    if image.shape[0] < target_size_px or image.shape[1] < target_size_px:
        pad_height = max(0, target_size_px - image.shape[0])
        pad_width = max(0, target_size_px - image.shape[1])
        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left
        padded_image = np.pad(
            image, ((pad_top, pad_bottom), (pad_left, pad_right)), mode="constant", constant_values=0
        )
        return padded_image, (pad_top, pad_bottom, pad_left, pad_right)
    return image, (0, 0, 0, 0)

In [None]:
def plot_all_grains_dictionary(
    grain_dict: dict,
    num_cols: int = 3,
    vmin: float = VMIN,
    vmax: float = VMAX,
    cmap: mpl.cm = CMAP,
    fontsize: int = 12,
    image_plot_size_nm: float = 40,
    plot_paths: list[str] = None,
    plot_defects: bool = False,
    plot_colour_paths: list[tuple[str, str]] = None,
    plot_colour_norm_bounds: tuple[float, float] = (-0.3, 0.3),
    plot_lineplots: list[tuple[str, str, str, str, str, tuple[float, float] | None]] = None,
    plot_lineplot_hlines: list[tuple[float, str]] = None,
    save_path: Path | None = None,
    show_plot: bool = True,
    dpi=200,
    figsize_multiplier: float = 4,
    wspace: float = 0.3,
    hspace: float = 0.6,
    wspace_inner: float = 0.2,
    hspace_inner: float = 0.3,
    num_grains_to_limit_to: int | None = None,
    stats_to_include_in_title: list[str] = None,
) -> None:
    """plot grains in a grain dictioanry in a grid layout"""
    num_grains = len(grain_dict.keys())
    if num_grains_to_limit_to is not None:
        num_grains = min(num_grains, num_grains_to_limit_to)
    num_rows = num_grains // num_cols + (num_grains % num_cols > 0)
    num_inner_rows = 1
    if plot_lineplots is not None:
        num_inner_rows = 1 + len(plot_lineplots)
    num_inner_cols = 2
    fig = plt.figure(figsize=(num_cols * figsize_multiplier, num_rows * figsize_multiplier), dpi=dpi)
    outer = gridspec.GridSpec(nrows=num_rows, ncols=num_cols, wspace=wspace, hspace=hspace)

    for inner_index, (grain_index, grain_data) in enumerate(grain_dict.items()):
        if inner_index >= num_grains:
            break
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        grain_p_to_nm = grain_data["p_to_nm"]

        image_padding = (0, 0, 0, 0)
        mask_padding = (0, 0, 0, 0)
        if image_plot_size_nm is not None:
            grain_image, image_padding = pad_image_to_set_nm_size(grain_image, grain_p_to_nm, image_plot_size_nm)
            grain_mask, mask_padding = pad_image_to_set_nm_size(grain_mask, grain_p_to_nm, image_plot_size_nm)

        outer_coords = np.unravel_index(inner_index, (num_rows, num_cols))

        inner = gridspec.GridSpecFromSubplotSpec(
            nrows=num_inner_rows,
            ncols=num_inner_cols,
            subplot_spec=outer[outer_coords],
            wspace=wspace_inner,
            hspace=hspace_inner,
        )

        # plot image
        image_ax = plt.Subplot(fig, inner[0, 0])
        image_ax.imshow(grain_image, vmin=vmin, vmax=vmax, cmap=cmap)
        title = f"Grain {grain_index}\np_to_nm: {grain_p_to_nm:.2f} nm/px"
        if stats_to_include_in_title is not None:
            for stat in stats_to_include_in_title:
                stat_value = grain_data[stat]
                if isinstance(stat_value, float):
                    title += f"\n{stat}: {stat_value:.2f}"
                else:
                    title += f"\n{stat}: {stat_value}"
            title += "\n"
        image_ax.set_title(title, fontsize=fontsize, pad=2)
        image_ax.axis("off")
        fig.add_subplot(image_ax)

        # plot mask
        mask_ax = plt.Subplot(fig, inner[0, 1])
        mask_ax.imshow(grain_mask)
        # plot paths if present and required
        if plot_paths is not None:
            for path_name in plot_paths:
                if path_name in grain_data:
                    path = grain_data[path_name]
                    if path is not None:
                        if len(path) > 2:
                            # add image padding to the path coordinates
                            path = path + np.array(mask_padding[:2])
                            mask_ax.plot(path[:, 1], path[:, 0], color="red", linewidth=2, label=path_name)
        if plot_colour_paths is not None:
            for path_name, colour_array in plot_colour_paths:
                if path_name in grain_data:
                    path = grain_data[path_name]
                    cvals = grain_data[colour_array]
                    normalised_cvals = (cvals - plot_colour_norm_bounds[0]) / (
                        plot_colour_norm_bounds[1] - plot_colour_norm_bounds[0]
                    )
                    if path is not None:
                        if len(path) > 2:
                            # add image padding to the path coordinates
                            path = path + np.array(mask_padding[:2])
                            for point_index, point in enumerate(path):
                                colour = COLOUR_PATH_CMAP(normalised_cvals[point_index])
                                if point_index > 0:
                                    previous_point = path[point_index - 1]
                                    mask_ax.plot(
                                        [previous_point[1], point[1]],
                                        [previous_point[0], point[0]],
                                        color=colour,
                                        linewidth=2,
                                    )

        mask_ax.set_title("Mask", fontsize=fontsize, pad=2)
        mask_ax.axis("off")
        fig.add_subplot(mask_ax)

        if plot_defects:
            if "ordered_defect_list" in grain_data:
                ordered_defect_list = grain_data["ordered_defect_list"].defect_gap_list
                for defect_or_gap in ordered_defect_list:
                    if isinstance(defect_or_gap, Defect):
                        defect = defect_or_gap
                        defect_start = defect.start_index
                        defect_end = defect.end_index
                        defect_start_point = grain_data["resampled_path_px"][defect_start] + np.array(mask_padding[:2])
                        defect_end_point = grain_data["resampled_path_px"][defect_end] + np.array(mask_padding[:2])
                        mask_ax.scatter(
                            defect_start_point[1],
                            defect_start_point[0],
                            color="blue",
                            s=100,
                        )
                        mask_ax.scatter(
                            defect_end_point[1],
                            defect_end_point[0],
                            color="red",
                            s=100,
                        )

        if plot_lineplots is not None:
            for lineplot_index, lineplot_x_y_and_titles in enumerate(plot_lineplots):
                # bottom row spans all columns
                lineplot_ax = plt.Subplot(fig, inner[1 + lineplot_index, :])
                # get the data
                (
                    lineplot_x_name,
                    lineplot_y_name,
                    lineplot_xlabel,
                    lineplot_ylabel,
                    lineplot_title,
                    lineplot_norm_bounds,
                ) = lineplot_x_y_and_titles
                lineplot_x_data = grain_data[lineplot_x_name]
                lineplot_y_data = grain_data[lineplot_y_name]
                lineplot_ax.plot(lineplot_x_data, lineplot_y_data)
                lineplot_ax.set_xlabel(lineplot_xlabel, fontsize=fontsize)
                lineplot_ax.set_ylabel(lineplot_ylabel, fontsize=fontsize)
                lineplot_ax.set_title(lineplot_title, fontsize=fontsize, pad=2)
                lineplot_ax.tick_params(labelsize=6)
                if lineplot_norm_bounds is not None:
                    lineplot_ax.set_ylim(lineplot_norm_bounds)

                if plot_lineplot_hlines is not None:
                    for hline_value, hline_label in plot_lineplot_hlines:
                        lineplot_ax.axhline(y=hline_value, color="grey", linestyle="--", linewidth=0.5)
                        # lineplot_ax.text(
                        #     0.01, hline_value, hline_label, color="grey", fontsize=6, transform=lineplot_ax.transAxes
                        # )

                fig.add_subplot(lineplot_ax)

    if save_path is not None:
        # create the parent directory
        save_path.parent.mkdir(parents=True, exist_ok=True)
        print(f"saving figure to {save_path}")
        fig.savefig(save_path, dpi=dpi, bbox_inches="tight")

    if show_plot:
        plt.show()
    else:
        plt.close(fig)

In [None]:
# plot all grains
if False:
    for sample_type, grain_dict in grains_dicts.items():
        print(f"Plotting grains for sample type [{sample_type}]")
        plot_all_grains_dictionary(
            grain_dict,
            vmin=VMIN,
            vmax=VMAX,
            cmap=CMAP,
            show_plot=True,
            num_grains_to_limit_to=10,
            wspace=0,
            hspace=0,
            num_cols=5,
        )

In [None]:
# Remove anomalous grains by index
grains_to_remove_on_sc = []
grains_to_remove_ot1_sc = [21, 22]  # 21 and 22 have the wrong pixel to nm scaling and I don't want to guess, so
# just remove them
grains_to_remove_ot2_sc = []

# remove the grains
for sample_type, grains_to_remove in zip(
    ["ON_SC", "OT1_SC", "OT2_SC"], [grains_to_remove_on_sc, grains_to_remove_ot1_sc, grains_to_remove_ot2_sc]
):
    for grain_index in grains_to_remove:
        if grain_index in grains_dicts[sample_type]:
            print(f"Removing grain {grain_index} from sample type [{sample_type}]")
            del grains_dicts[sample_type][grain_index]

# print new numbers for sample types
for sample_type in ["ON_SC", "OT1_SC", "OT2_SC"]:
    print(f"Number of grains in sample type [{sample_type}]: {len(grains_dicts[sample_type])}")

In [None]:
# pathfinding
grains_dicts_paths = {}
bad_grains = []
for sample_type, grains_dict_sample in grains_dicts.items():
    grains_dicts_paths[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]

        # distance transforms
        distance_transform = distance_transform_edt(grain_mask > 0)
        distance_transform[grain_mask == 2] = 0

        # start at point where intersection region 0 has maximum distance transform
        intersection_labels = label(intersection_labels)
        intersection_regions = regionprops(intersection_labels)
        # Get intersection regions
        region_0 = intersection_regions[0]
        region_1 = intersection_regions[1]
        # Create list of distance transform values for each region
        region_0_distance_transform_values = []
        region_1_distance_transform_values = []
        for pixel in region_0.coords:
            region_0_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
        for pixel in region_1.coords:
            region_1_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
        region_0_distance_transform_values = np.array(region_0_distance_transform_values)
        region_1_distance_transform_values = np.array(region_1_distance_transform_values)
        # Get the maximum distance transform value for each region
        region_0_max_distance_transform_value = np.max(region_0_distance_transform_values)
        region_1_max_distance_transform_value = np.max(region_1_distance_transform_values)
        region_0_max_distance_transform_value_index = np.argmax(region_0_distance_transform_values)
        region_1_max_distance_transform_value_index = np.argmax(region_1_distance_transform_values)
        # get the pixel coord of the maximum distance transform value for each region
        region_0_max_distance_transform_pixel = region_0.coords[region_0_max_distance_transform_value_index]
        region_1_max_distance_transform_pixel = region_1.coords[region_1_max_distance_transform_value_index]

        start_point = (region_0_max_distance_transform_pixel[0], region_0_max_distance_transform_pixel[1])
        end_point = (region_1_max_distance_transform_pixel[0], region_1_max_distance_transform_pixel[1])

        # invert the distance transform to get the cost map
        cost_map = np.max(distance_transform) - distance_transform
        # set the maximum value to be huge
        cost_map[cost_map == np.max(cost_map)] = 1e4

        route, weight = route_through_array(array=cost_map, start=start_point, end=end_point)

        route = np.array(route)

        # if the route is fewer than 4 points, skip it
        if len(route) < 5:
            print(
                f"Skipping grain {grain_index} for sample type [{sample_type}] due to route being too short: {len(route)} < 4"
            )
            bad_grains.append((sample_type, grain_index))
            continue

        # save the original grain data
        grains_dicts_paths[sample_type][grain_index] = grain_data
        # add the route to the grain data
        grains_dicts_paths[sample_type][grain_index]["path"] = route

print(f"Bad grains: {bad_grains}")

# plot all grains with paths
if False:
    for sample_type, grain_dict in grains_dicts_paths.items():
        print(f"Plotting grains with paths for sample type [{sample_type}]")
        plot_all_grains_dictionary(
            grain_dict=grain_dict,
            vmin=VMIN,
            vmax=VMAX,
            cmap=CMAP,
            plot_paths=["path"],
            save_path=save_dirs[sample_type] / f"grains_with_paths_{sample_type}.png",
            num_grains_to_limit_to=10,
        )

In [None]:
# resampling the paths
trace_resampling_distance_nm = 1.0
smoothing_window_size_nm = 3.0
verbose = False

grains_dicts_resampled_paths = {}
for sample_type, grains_dict_sample in grains_dicts_paths.items():
    grains_dicts_resampled_paths[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]
        p_to_nm = grain_data["p_to_nm"]
        path = grain_data["path"]

        if verbose:
            print(f"grain: {grain_index}")
            print(f"p2nm: {p_to_nm}")
            print(f"path shape: {path.shape}")
        # resampled_path_px = resample_points_regular_interval(
        #     points=path, interval=trace_resampling_distance_nm / p_to_nm, circular=False
        # )
        # if verbose:
        #     print(f"resampled path shape: {resampled_path_px.shape}")
        # smooth the path a bit
        smoothed_path_px = windowTrace.pool_trace_linear(
            pixel_trace=path,
            rolling_window_size=smoothing_window_size_nm,
            pixel_to_nm_scaling=p_to_nm,
        )
        if verbose:
            print(f"smoothed path shape: {smoothed_path_px.shape}")
        # re-sample again to ensure the path is regular
        resampled_path_px = resample_points_regular_interval(
            points=smoothed_path_px, interval=trace_resampling_distance_nm / p_to_nm, circular=False
        )
        if verbose:
            print(f"final resampled path shape: {resampled_path_px.shape}")
            print()

        if len(resampled_path_px) <= 2:
            print(
                f"Skipping grain {grain_index} for sample type [{sample_type}] due to resampled path too short: {len(resampled_path_px)}"
            )
            continue

        diffs = resampled_path_px[1:, :] - resampled_path_px[:-1, :]
        resampled_distances_px = np.linalg.norm(diffs, axis=1)
        # prepend a zero to the distances to match the length of the resampled path
        resampled_distances_px = np.insert(resampled_distances_px, 0, 0)
        resampled_distances_nm = resampled_distances_px * p_to_nm
        cumulative_resampled_distances_px = np.cumsum(resampled_distances_px)
        cumulative_resampled_distances_nm = np.cumsum(resampled_distances_nm)

        # save the original grain data
        grains_dicts_resampled_paths[sample_type][grain_index] = grain_data
        # add the smoothed path to the grain data
        grains_dicts_resampled_paths[sample_type][grain_index]["resampled_path_px"] = resampled_path_px
        grains_dicts_resampled_paths[sample_type][grain_index]["resampled_distances_px"] = resampled_distances_px
        grains_dicts_resampled_paths[sample_type][grain_index]["resampled_distances_nm"] = resampled_distances_nm
        grains_dicts_resampled_paths[sample_type][grain_index][
            "cumulative_resampled_distances_px"
        ] = cumulative_resampled_distances_px
        grains_dicts_resampled_paths[sample_type][grain_index][
            "cumulative_resampled_distances_nm"
        ] = cumulative_resampled_distances_nm

# plot all grains with smoothed paths
if False:
    for sample_type, grain_dict in grains_dicts_resampled_paths.items():
        print(f"Plotting grains with smoothed paths for sample type [{sample_type}]")
        plot_all_grains_dictionary(
            grain_dict=grain_dict,
            vmin=VMIN,
            vmax=VMAX,
            cmap=CMAP,
            plot_paths=["resampled_path_px"],
            save_path=save_dirs[sample_type] / f"grains_with_resampled_paths_{sample_type}.png",
            show_plot=False,
        )

In [None]:
# calculate the feret diameter for each grain
grains_dicts_feret_diameter = {}
for sample_type, grains_dict_sample in grains_dicts_resampled_paths.items():
    grains_dicts_feret_diameter[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]
        p_to_nm = grain_data["p_to_nm"]

        grain_mask_dna_only = grain_mask.copy()
        grain_mask_dna_only[grain_mask == 2] = 0

        # calculate the feret diameter
        feret_diameters = get_feret_from_mask(mask_im=grain_mask_dna_only)
        min_feret = feret_diameters["min_feret"] * p_to_nm
        max_feret = feret_diameters["max_feret"] * p_to_nm

        # save the original grain data
        grains_dicts_feret_diameter[sample_type][grain_index] = grain_data
        # add the feret diameter to the grain data
        grains_dicts_feret_diameter[sample_type][grain_index]["min_feret"] = min_feret
        grains_dicts_feret_diameter[sample_type][grain_index]["max_feret"] = max_feret

In [None]:
# calcuate curvature of the path

curvature_gaussian_sigma_nm = 1.0
curvature_gaussian_sigma_points = int(curvature_gaussian_sigma_nm / trace_resampling_distance_nm)

grains_dicts_curvature = {}
for sample_type, grains_dict_sample in grains_dicts_feret_diameter.items():
    grains_dicts_curvature[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]
        p_to_nm = grain_data["p_to_nm"]
        resampled_path_px = grain_data["resampled_path_px"]

        # calculate the curvature
        path_curvatures = discrete_angle_difference_per_nm_linear(trace_nm=resampled_path_px * p_to_nm)
        # smooth the curvature with a gaussian filter
        path_curvatures = gaussian_filter1d(
            path_curvatures,
            sigma=curvature_gaussian_sigma_points,
            mode="nearest",
        )
        abs_path_curvatures = np.abs(path_curvatures)

        # save the original grain data
        grains_dicts_curvature[sample_type][grain_index] = grain_data
        # add the curvature to the grain data
        grains_dicts_curvature[sample_type][grain_index]["curvature"] = path_curvatures
        grains_dicts_curvature[sample_type][grain_index]["abs_curvature"] = abs_path_curvatures


# plot all grains with curvature
for sample_type, grain_dict in grains_dicts_curvature.items():
    print(f"Plotting grains with curvature for sample type [{sample_type}]")
    plot_all_grains_dictionary(
        grain_dict=grain_dict,
        vmin=VMIN,
        vmax=VMAX,
        cmap=CMAP,
        plot_colour_paths=[("resampled_path_px", "curvature")],
        plot_lineplots=[
            [
                "cumulative_resampled_distances_nm",
                "abs_curvature",
                "distance (nm)",
                "curvature (rad/nm)",
                "Curvature",
                (0.0, 0.5),
            ]
        ],
        save_path=save_dirs[sample_type] / f"grains_with_curvature_{sample_type}.png",
        show_plot=True,
        figsize_multiplier=5,
        num_grains_to_limit_to=10,
        num_cols=5,
    )

In [None]:
# detect defects

defect_curvature_threshold_radpernm = 0.3

grains_dicts_curvature_defects = {}
for sample_type, grains_dict_sample in grains_dicts_curvature.items():
    grains_dicts_curvature_defects[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]
        p_to_nm = grain_data["p_to_nm"]
        resampled_path_px = grain_data["resampled_path_px"]
        resampled_distances_nm = grain_data["resampled_distances_nm"]
        path_curvatures = grain_data["curvature"]

        # detect defects based on curvature
        defect_mask = np.abs(path_curvatures) > defect_curvature_threshold_radpernm
        ordered_defect_list = get_defects_and_gaps_from_bool_array(
            defects_bool=defect_mask,
            distance_to_previous_points_nm=grain_data["resampled_distances_nm"],
            circular=False,
        )

        # save the original grain data
        grains_dicts_curvature_defects[sample_type][grain_index] = grain_data
        # add the defects to the grain data
        grains_dicts_curvature_defects[sample_type][grain_index]["ordered_defect_list"] = ordered_defect_list

In [None]:
# plot all grains with defects
if False:
    for sample_type, grain_dict in grains_dicts_curvature_defects.items():
        print(f"Plotting grains with defects for sample type [{sample_type}]")
        plot_all_grains_dictionary(
            grain_dict=grain_dict,
            vmin=VMIN,
            vmax=VMAX,
            cmap=CMAP,
            plot_paths=["resampled_path_px"],
            plot_defects=True,
            plot_colour_paths=[("resampled_path_px", "curvature")],
            plot_lineplots=[
                [
                    "cumulative_resampled_distances_nm",
                    "abs_curvature",
                    "distance (nm)",
                    "curvature (rad/nm)",
                    "Curvature",
                    (0.0, 0.5),
                ]
            ],
            plot_lineplot_hlines=[
                (defect_curvature_threshold_radpernm, "upper curvature threshold"),
                (-defect_curvature_threshold_radpernm, "lower curvature threshold"),
            ],
            save_path=save_dirs[sample_type] / f"grains_with_defects_{sample_type}.png",
            show_plot=True,
            figsize_multiplier=5,
            num_grains_to_limit_to=30,
            num_cols=5,
        )

In [None]:
# plot the positions of the defects for each sample type
# collect all defect positions first
all_defect_positions = {}
maximum_defect_position = 0.0
for sample_type, grain_dict in grains_dicts_curvature_defects.items():
    defect_positions_nm = []
    for grain_index, grain_data in grain_dict.items():
        for defect_or_gap in grain_data["ordered_defect_list"].defect_gap_list:
            if isinstance(defect_or_gap, Defect):
                defect_positions_nm.append(defect_or_gap.position_along_trace_nm)
                maximum_defect_position = max(maximum_defect_position, defect_or_gap.position_along_trace_nm)
    all_defect_positions[sample_type] = defect_positions_nm
    print(f"Number of defects for sample type [{sample_type}]: {len(defect_positions_nm)}")

# create a single figure with subplots for all sample types
fig, axes = plt.subplots(1, len(sample_types), figsize=(15, 5), sharey=True)

for i, sample_type in enumerate(sample_types):
    defect_positions_nm = all_defect_positions[sample_type]

    # create stripplot and violinplot on the same axis
    sns.stripplot(x=[sample_type] * len(defect_positions_nm), y=defect_positions_nm, color="grey", ax=axes[i])
    sns.violinplot(
        x=[sample_type] * len(defect_positions_nm),
        y=defect_positions_nm,
        color="lightblue",
        inner=None,
        linewidth=0.5,
        alpha=0.5,
        ax=axes[i],
    )

    axes[i].set_title(f"{sample_type}\n({len(defect_positions_nm)} defects)")
    axes[i].set_xlabel("")
    if i == 0:
        axes[i].set_ylabel("Position along trace (nm)")
    else:
        axes[i].set_ylabel("")

    # remove x-axis tick labels to clean up appearance
    axes[i].set_xticklabels([])

plt.suptitle("Defect Positions by Sample Type", fontsize=16)
plt.tight_layout()

# histograms of defect positions
fig, axes = plt.subplots(1, len(sample_types), figsize=(15, 5), sharey=True)
for i, sample_type in enumerate(sample_types):
    defect_positions_nm = all_defect_positions[sample_type]

    # create histogram
    axes[i].hist(defect_positions_nm, bins="auto", color="lightblue", edgecolor="black", alpha=0.7)
    axes[i].set_title(f"{sample_type}\n({len(defect_positions_nm)} defects)")
    axes[i].set_xlabel("Defect position along trace (nm)")
    if i == 0:
        axes[i].set_ylabel("Count")
    else:
        axes[i].set_ylabel("")
    # set x-axis limits to be the same for all plots
    axes[i].set_xlim(0, maximum_defect_position * 1.1)
plt.tight_layout()

# save the combined plot
combined_save_path = BASE_SAVE_DIR / f"combined_defect_positions_{TODAY_DATE}.png"
combined_save_path.parent.mkdir(exist_ok=True, parents=True)
plt.savefig(combined_save_path, dpi=200, bbox_inches="tight")
plt.show()

In [None]:
# for each sample type, we want to quantify the sharpness of the turn in the middle X% of the traces for each grain
middle_distance_nm = 10

bad_grains_middle_curvature_analysis = []
grains_dicts_middle_curvature_analysis = {}
for sample_type, grains_dict_sample in grains_dicts_curvature_defects.items():
    grains_dicts_middle_curvature_analysis[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]
        p_to_nm = grain_data["p_to_nm"]
        resampled_path_px = grain_data["resampled_path_px"]
        resampled_distances_nm = grain_data["resampled_distances_nm"]
        path_abs_curvatures = grain_data["abs_curvature"]
        cumulative_resampled_distances_nm = grain_data["cumulative_resampled_distances_nm"]

        # get the middle point index by finding the index of half the distance
        trace_total_distance_nm = np.sum(resampled_distances_nm)
        # print(f"trace total distance nm: {trace_total_distance_nm}")
        trace_half_total_distance_nm = trace_total_distance_nm / 2.0
        # print(f"trace half total distance nm: {trace_half_total_distance_nm}")
        # find the index of the entry in cumulative distances that is closest to half the total distance
        middle_middle_index = np.searchsorted(cumulative_resampled_distances_nm, trace_half_total_distance_nm)
        middle_middle_nm = cumulative_resampled_distances_nm[middle_middle_index]
        # print(f"middle middle index: {middle_middle_index}, middle_middle_nm: {middle_middle_nm}")
        middle_distance_start_nm = middle_middle_nm - middle_distance_nm / 2
        middle_distance_end_nm = middle_middle_nm + middle_distance_nm / 2
        # print(
        #     f"middle distance start nm: {middle_distance_start_nm}, middle distance end nm: {middle_distance_end_nm}"
        # )
        middle_start_index = np.searchsorted(cumulative_resampled_distances_nm, middle_distance_start_nm)
        middle_end_index = np.searchsorted(cumulative_resampled_distances_nm, middle_distance_end_nm)
        # print(f"middle start index: {middle_start_index}, middle end index: {middle_end_index}")

        # Grab the middle path pixels
        middle_path_px = resampled_path_px[middle_start_index:middle_end_index]
        middle_path_curvatures = path_abs_curvatures[middle_start_index:middle_end_index]
        middle_cumulative_resampled_distances_nm = cumulative_resampled_distances_nm[
            middle_start_index:middle_end_index
        ]

        # Grab the start and end (not middle) region path pixels
        outer_path_px = np.concatenate((resampled_path_px[:middle_start_index], resampled_path_px[middle_end_index:]))
        outer_path_curvatures = np.concatenate(
            (path_abs_curvatures[:middle_start_index], path_abs_curvatures[middle_end_index:])
        )
        outer_cumulative_resampled_distances_nm = np.concatenate(
            (
                cumulative_resampled_distances_nm[:middle_start_index],
                cumulative_resampled_distances_nm[middle_end_index:],
            )
        )

        if len(middle_path_px) < 5:
            bad_grains_middle_curvature_analysis.append(grain_index)
            continue

        middle_mean_curvature = np.mean(middle_path_curvatures)
        middle_std_curvature = np.std(middle_path_curvatures)
        middle_sum_curvature = np.sum(middle_path_curvatures)

        outer_mean_curvature = np.mean(outer_path_curvatures)
        outer_std_curvature = np.std(outer_path_curvatures)
        outer_sum_curvature = np.sum(outer_path_curvatures)

        # store
        grains_dicts_middle_curvature_analysis[sample_type][grain_index] = grain_data.copy()

        grains_dicts_middle_curvature_analysis[sample_type][grain_index][
            "middle_path_mean_curvature"
        ] = middle_mean_curvature
        grains_dicts_middle_curvature_analysis[sample_type][grain_index][
            "middle_path_std_curvature"
        ] = middle_std_curvature
        grains_dicts_middle_curvature_analysis[sample_type][grain_index]["middle_path_px"] = middle_path_px
        grains_dicts_middle_curvature_analysis[sample_type][grain_index][
            "middle_path_curvatures"
        ] = middle_path_curvatures
        grains_dicts_middle_curvature_analysis[sample_type][grain_index][
            "middle_cumulative_resampled_distances_nm"
        ] = middle_cumulative_resampled_distances_nm
        grains_dicts_middle_curvature_analysis[sample_type][grain_index][
            "middle_path_sum_curvature"
        ] = middle_sum_curvature

        grains_dicts_middle_curvature_analysis[sample_type][grain_index][
            "outer_path_mean_curvature"
        ] = outer_mean_curvature
        grains_dicts_middle_curvature_analysis[sample_type][grain_index][
            "outer_path_std_curvature"
        ] = outer_std_curvature
        grains_dicts_middle_curvature_analysis[sample_type][grain_index][
            "outer_path_sum_curvature"
        ] = outer_sum_curvature

print(f"Bad grains for middle curvature analysis: {bad_grains_middle_curvature_analysis}")

if False:
    # plot the middle curvature analysis
    for sample_type, grain_dict in grains_dicts_middle_curvature_analysis.items():
        plot_all_grains_dictionary(
            grain_dict=grain_dict,
            vmin=VMIN,
            vmax=VMAX,
            cmap=CMAP,
            plot_paths=["middle_path_px"],
            plot_colour_paths=[("middle_path_px", "middle_path_curvatures")],
            plot_lineplots=[
                [
                    "middle_cumulative_resampled_distances_nm",
                    "middle_path_curvatures",
                    "distance (nm)",
                    "curvature (rad/nm)",
                    "Curvature",
                    (0.0, 0.5),
                ]
            ],
            plot_lineplot_hlines=[
                (defect_curvature_threshold_radpernm, "upper curvature threshold"),
                (-defect_curvature_threshold_radpernm, "lower curvature threshold"),
            ],
            save_path=save_dirs[sample_type] / f"grains_middle_path_curvature_{sample_type}.png",
            show_plot=True,
            figsize_multiplier=5,
            num_grains_to_limit_to=30,
            num_cols=5,
        )

In [None]:
# Stats tests functions
# test for normality to decide between parametric (ANOVA) vs non-parametric (Kruskal-Wallis) tests
from scipy.stats import shapiro, levene, kruskal, f_oneway, ttest_ind
import scipy.stats as scipy_stats
from pingouin import welch_anova


def test_normality_and_homogeneity(data_df: pd.DataFrame, column: str, sample_differentiator_column: str):
    unique_samples = data_df[sample_differentiator_column].unique()
    print(f"Normality and Homogeneity Tests for {column} for {unique_samples}")

    # test normality for each column using shaprio-wilk test
    normality_results = {}
    for sample_type in unique_samples:
        sample_data = data_df[data_df[sample_differentiator_column] == sample_type][column].values
        # shapiro-wilk test for n<5000
        shapiro_stat, shapiro_p_value = shapiro(sample_data)
        is_normal = shapiro_p_value > 0.05
        normality_results[sample_type] = is_normal
        print(f"{sample_type}: n={len(sample_data)}, Shapiro-Wilk p={shapiro_p_value:.4f}, Normal: {is_normal}")

    # test homogeneity of variances using Levene's test
    levene_stat, levene_p_value = levene(
        *[
            data_df[data_df[sample_differentiator_column] == sample_type][column].values
            for sample_type in unique_samples
        ]
    )
    is_homogeneous = levene_p_value > 0.05
    print(f"Levene's test: p={levene_p_value:.4f}, Homogeneous: {is_homogeneous}")

    return normality_results, is_homogeneous


def decide_test(normality_results, is_homogeneous):
    if all(normality_results.values()) and is_homogeneous:
        return "anova"
    elif all(normality_results.values()) and not is_homogeneous:
        return "welch_anova"
    else:
        return "kruskal_wallis"


def qq_plots_visual_normality_test(data_df: pd.DataFrame, column: str, sample_differentiator_column: str):
    unique_samples = data_df[sample_differentiator_column].unique()
    fig, axes = plt.subplots(1, len(unique_samples), figsize=(15, 5), sharey=True)
    for i, sample_type in enumerate(unique_samples):
        scipy_stats.probplot(
            data_df[data_df[sample_differentiator_column] == sample_type][column],
            dist="norm",
            plot=axes[i],
        )
        axes[i].set_title(f"Q-Q Plot: {sample_type}")
        axes[i].grid(True, alpha=0.3)
    plt.suptitle(f"Q-Q Plots for {column} (should be linear if normal)", fontsize=16)
    plt.tight_layout()
    plt.show()


def run_tests(recommended_test: str, data_df: pd.DataFrame, column: str, sample_differentiator_column: str):
    if recommended_test == "anova":
        print("Running ANOVA test")
        anova_result = f_oneway(
            *[
                data_df[data_df[sample_differentiator_column] == sample_type][column].values
                for sample_type in data_df[sample_differentiator_column].unique()
            ]
        )
        print(f"ANOVA result: F={anova_result.statistic:.4f}, p={anova_result.pvalue:.4f}")
        if anova_result.pvalue < 0.05:
            print(
                "Significant differences found, can proceed with a post-hoc test like Tukey HSD (used for when assumptions of ANOVA are met)"
            )
        else:
            print("No significant differences found")
    elif recommended_test == "welch_anova":
        print("Running Welch's ANOVA test")
        welch_result = welch_anova(dv=column, between=sample_differentiator_column, data=data_df)
        print(f"Welch's ANOVA result: F={welch_result['F'].values[0]:.4f}, p={welch_result['pval'].values[0]:.4f}")
        if welch_result["pval"].values[0] < 0.05:
            print(
                "Significant differences found, can proceed with a post-hoc test like Games-Howell (used for when assumptions of ANOVA are not met)"
            )
        else:
            print("No significant differences found")
    elif recommended_test == "kruskal_wallis":
        print("Running Kruskal-Wallis test")
        kruskal_result = kruskal(
            *[
                data_df[data_df[sample_differentiator_column] == sample_type][column].values
                for sample_type in data_df[sample_differentiator_column].unique()
            ]
        )
        print(f"Kruskal-Wallis result: H={kruskal_result.statistic:.4f}, p={kruskal_result.pvalue:.4f}")
        if kruskal_result.pvalue < 0.05:
            print(
                "Significant differences found, can proceed with a post-hoc test like Dunn's test (used for non-parametric tests)"
            )
        else:
            print("No significant differences found")


def decide_and_run_stats_tests(data_df: pd.DataFrame, column: str, sample_differentiator_column: str):

    print(f"\n\n=== running stats tests for column [{column}]")

    normality_results, is_homogeneous = test_normality_and_homogeneity(
        data_df=data_df, column=column, sample_differentiator_column=sample_differentiator_column
    )

    recommended_test = decide_test(normality_results, is_homogeneous)
    print(f"Recommended test: {recommended_test}")

    qq_plots_visual_normality_test(
        data_df=data_df, column=column, sample_differentiator_column=sample_differentiator_column
    )

    run_tests(
        recommended_test=recommended_test,
        data_df=data_df,
        column=column,
        sample_differentiator_column=sample_differentiator_column,
    )


def t_test_between_two_samples(data_df: pd.DataFrame, sample_type_1: str, sample_type_2: str, column: str):
    data1 = data_df[data_df["sample_type"] == sample_type_1][column]
    data2 = data_df[data_df["sample_type"] == sample_type_2][column]
    t_stat, p_value = ttest_ind(data1, data2, equal_var=False)
    print(
        f"\n\nT-test between [{sample_type_1}] and [{sample_type_2}] for [{column}]: T={t_stat:.4f}, p={p_value:.4f}"
    )
    if p_value < 0.05:
        print("Significant differences found")
    else:
        print("No significant differences found")


def mann_whitney_u_test_between_two_samples(
    data_df: pd.DataFrame, sample_type_1: str, sample_type_2: str, column: str
):
    data1 = data_df[data_df["sample_type"] == sample_type_1][column]
    data2 = data_df[data_df["sample_type"] == sample_type_2][column]
    u_stat, p_value = scipy_stats.mannwhitneyu(data1, data2, alternative="two-sided")
    print(
        f"\n\nMann-Whitney U test between [{sample_type_1}] and [{sample_type_2}] for [{column}]: U={u_stat:.4f}, p={p_value:.4f}"
    )
    if p_value < 0.05:
        print("Significant differences found")
    else:
        print("No significant differences found")

In [None]:
# plot the distributions of middle curvature stats
all_middle_curvature_stats = []
all_outer_curvature_stats = []
for sample_type, grain_dict in grains_dicts_middle_curvature_analysis.items():
    for grain_index, grain_data in grain_dict.items():
        all_middle_curvature_stats.append(
            {
                "sample_type": sample_type,
                "grain_index": grain_index,
                "mean_curvature": grain_data["middle_path_mean_curvature"],
                "std_curvature": grain_data["middle_path_std_curvature"],
                "sum_curvature": grain_data["middle_path_sum_curvature"],
            }
        )
        all_outer_curvature_stats.append(
            {
                "sample_type": sample_type,
                "grain_index": grain_index,
                "mean_curvature": grain_data["outer_path_mean_curvature"],
                "std_curvature": grain_data["outer_path_std_curvature"],
                "sum_curvature": grain_data["outer_path_sum_curvature"],
            }
        )

df_all_middle_curvature_stats = pd.DataFrame(all_middle_curvature_stats)
df_all_outer_curvature_stats = pd.DataFrame(all_outer_curvature_stats)

# middle curvatures
# mean
print("\n\n===== middle curvatures plot =====")
sns.stripplot(x="sample_type", y="mean_curvature", data=df_all_middle_curvature_stats, color="grey")
sns.violinplot(
    x="sample_type",
    y="mean_curvature",
    data=df_all_middle_curvature_stats,
    color="lightblue",
    inner=None,
    linewidth=0.5,
    alpha=0.5,
)
plt.title("Middle mean Curvature by Sample Type")
plt.ylabel("Middle mean curvature (rad/nm)")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# std
sns.stripplot(x="sample_type", y="std_curvature", data=df_all_middle_curvature_stats, color="grey")
sns.violinplot(
    x="sample_type",
    y="std_curvature",
    data=df_all_middle_curvature_stats,
    color="lightblue",
    inner=None,
    linewidth=0.5,
    alpha=0.5,
)
plt.title("Middle std Curvature by Sample Type")
plt.ylabel("Middle std curvature (rad/nm)")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# sum
sns.stripplot(x="sample_type", y="sum_curvature", data=df_all_middle_curvature_stats, color="grey")
sns.violinplot(
    x="sample_type",
    y="sum_curvature",
    data=df_all_middle_curvature_stats,
    color="lightblue",
    inner=None,
    linewidth=0.5,
    alpha=0.5,
)
plt.title("Middle sum Curvature by Sample Type")
plt.ylabel("Middle sum curvature (rad/nm)")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# outer curvatures
# mean
print("\n\n===== outer curvatures plot =====")

sns.stripplot(x="sample_type", y="mean_curvature", data=df_all_outer_curvature_stats, color="grey")
sns.violinplot(
    x="sample_type",
    y="mean_curvature",
    data=df_all_outer_curvature_stats,
    color="lightblue",
    inner=None,
    linewidth=0.5,
    alpha=0.5,
)
plt.title("outer mean Curvature by Sample Type")
plt.ylabel("outer mean curvature (rad/nm)")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# std
sns.stripplot(x="sample_type", y="std_curvature", data=df_all_outer_curvature_stats, color="grey")
sns.violinplot(
    x="sample_type",
    y="std_curvature",
    data=df_all_outer_curvature_stats,
    color="lightblue",
    inner=None,
    linewidth=0.5,
    alpha=0.5,
)
plt.title("outer std Curvature by Sample Type")
plt.ylabel("outer std curvature (rad/nm)")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


# sum
sns.stripplot(x="sample_type", y="sum_curvature", data=df_all_outer_curvature_stats, color="grey")
sns.violinplot(
    x="sample_type",
    y="sum_curvature",
    data=df_all_outer_curvature_stats,
    color="lightblue",
    inner=None,
    linewidth=0.5,
    alpha=0.5,
)
plt.title("outer sum Curvature by Sample Type")
plt.ylabel("outer sum curvature (rad/nm)")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

print("\n\n===== middle curvatures stats =====")


decide_and_run_stats_tests(
    data_df=df_all_middle_curvature_stats, column="sum_curvature", sample_differentiator_column="sample_type"
)

decide_and_run_stats_tests(
    data_df=df_all_middle_curvature_stats, column="mean_curvature", sample_differentiator_column="sample_type"
)

decide_and_run_stats_tests(
    data_df=df_all_middle_curvature_stats, column="std_curvature", sample_differentiator_column="sample_type"
)

# just run a t-test between OT1 and OT2
t_test_between_two_samples(
    data_df=df_all_middle_curvature_stats, sample_type_1="OT1_SC", sample_type_2="OT2_SC", column="mean_curvature"
)

t_test_between_two_samples(
    data_df=df_all_middle_curvature_stats, sample_type_1="OT1_SC", sample_type_2="OT2_SC", column="sum_curvature"
)

t_test_between_two_samples(
    data_df=df_all_middle_curvature_stats, sample_type_1="ON_SC", sample_type_2="OT1_SC", column="mean_curvature"
)

t_test_between_two_samples(
    data_df=df_all_middle_curvature_stats, sample_type_1="ON_SC", sample_type_2="OT2_SC", column="mean_curvature"
)

mann_whitney_u_test_between_two_samples(
    data_df=df_all_middle_curvature_stats, sample_type_1="OT1_SC", sample_type_2="OT2_SC", column="mean_curvature"
)

mann_whitney_u_test_between_two_samples(
    data_df=df_all_middle_curvature_stats, sample_type_1="ON_SC", sample_type_2="OT2_SC", column="mean_curvature"
)

mann_whitney_u_test_between_two_samples(
    data_df=df_all_middle_curvature_stats, sample_type_1="ON_SC", sample_type_2="OT2_SC", column="mean_curvature"
)

mann_whitney_u_test_between_two_samples(
    data_df=df_all_middle_curvature_stats, sample_type_1="OT1_SC", sample_type_2="OT2_SC", column="sum_curvature"
)

mann_whitney_u_test_between_two_samples(
    data_df=df_all_middle_curvature_stats, sample_type_1="ON_SC", sample_type_2="OT2_SC", column="sum_curvature"
)

mann_whitney_u_test_between_two_samples(
    data_df=df_all_middle_curvature_stats, sample_type_1="ON_SC", sample_type_2="OT2_SC", column="sum_curvature"
)

print("\n\n===== outer curvatures stats =====")

decide_and_run_stats_tests(
    data_df=df_all_outer_curvature_stats, column="mean_curvature", sample_differentiator_column="sample_type"
)

decide_and_run_stats_tests(
    data_df=df_all_outer_curvature_stats, column="std_curvature", sample_differentiator_column="sample_type"
)

# just run a t-test between OT1 and OT2
t_test_between_two_samples(
    data_df=df_all_outer_curvature_stats, sample_type_1="OT1_SC", sample_type_2="OT2_SC", column="mean_curvature"
)

In [None]:
# show representative molecules of high and low mean curvature in the middle
# find the grain indexes of molecuels with high mean curvature in the middle for OT2 & OT1


def get_mols_with_high_low_stat(
    grains_dict: dict[str, dict], stat_name: str, number: int, highest_lowest: str
) -> dict[str, dict]:
    """Get molecules with the highest or lowest specified statistic."""

    # get a list of tuples of grain indexes with the corresponding stat
    grain_stats = [(grain_index, grain_data[stat_name]) for grain_index, grain_data in grains_dict.items()]

    if highest_lowest == "highest":
        # sort the statistic
        grain_stats_sorted = sorted(grain_stats, key=lambda x: x[1], reverse=True)

    elif highest_lowest == "lowest":
        # sort the statistic
        grain_stats_sorted = sorted(grain_stats, key=lambda x: x[1])
    else:
        raise ValueError("Invalid value for highest_lowest. Use 'highest' or 'lowest'.")

    # get the grains and return them
    grains_to_return = {}
    for grain_index, _ in grain_stats_sorted[:number]:
        grains_to_return[grain_index] = grains_dict[grain_index]

    return grains_to_return


# OT2

highest_middle_curvature_mean_ot2_molecules = get_mols_with_high_low_stat(
    grains_dict=grains_dicts_middle_curvature_analysis["OT2_SC"],
    stat_name="middle_path_mean_curvature",
    number=5,
    highest_lowest="highest",
)

# plot them
plot_all_grains_dictionary(
    grain_dict=highest_middle_curvature_mean_ot2_molecules,
    plot_colour_paths=[("resampled_path_px", "curvature")],
    plot_lineplots=[
        [
            "cumulative_resampled_distances_nm",
            "abs_curvature",
            "distance (nm)",
            "curvature (rad/nm)",
            "Curvature",
            (0.0, 0.5),
        ]
    ],
    stats_to_include_in_title=[
        "middle_path_mean_curvature",
    ],
)

lowest_middle_curvature_mean_ot2_molecules = get_mols_with_high_low_stat(
    grains_dict=grains_dicts_middle_curvature_analysis["OT2_SC"],
    stat_name="middle_path_mean_curvature",
    number=5,
    highest_lowest="lowest",
)

plot_all_grains_dictionary(
    grain_dict=lowest_middle_curvature_mean_ot2_molecules,
    plot_colour_paths=[("resampled_path_px", "curvature")],
    plot_lineplots=[
        [
            "cumulative_resampled_distances_nm",
            "abs_curvature",
            "distance (nm)",
            "curvature (rad/nm)",
            "Curvature",
            (0.0, 0.5),
        ]
    ],
    stats_to_include_in_title=[
        "middle_path_mean_curvature",
    ],
)

# OT1
highest_middle_curvature_mean_ot1_molecules = get_mols_with_high_low_stat(
    grains_dict=grains_dicts_middle_curvature_analysis["OT1_SC"],
    stat_name="middle_path_mean_curvature",
    number=5,
    highest_lowest="highest",
)

# plot them
plot_all_grains_dictionary(
    grain_dict=highest_middle_curvature_mean_ot1_molecules,
    plot_colour_paths=[("resampled_path_px", "curvature")],
    plot_lineplots=[
        [
            "cumulative_resampled_distances_nm",
            "abs_curvature",
            "distance (nm)",
            "curvature (rad/nm)",
            "Curvature",
            (0.0, 0.5),
        ]
    ],
    stats_to_include_in_title=[
        "middle_path_mean_curvature",
    ],
)

lowest_middle_curvature_mean_ot1_molecules = get_mols_with_high_low_stat(
    grains_dict=grains_dicts_middle_curvature_analysis["OT1_SC"],
    stat_name="middle_path_mean_curvature",
    number=5,
    highest_lowest="lowest",
)

plot_all_grains_dictionary(
    grain_dict=lowest_middle_curvature_mean_ot1_molecules,
    plot_colour_paths=[("resampled_path_px", "curvature")],
    plot_lineplots=[
        [
            "cumulative_resampled_distances_nm",
            "abs_curvature",
            "distance (nm)",
            "curvature (rad/nm)",
            "Curvature",
            (0.0, 0.5),
        ]
    ],
    stats_to_include_in_title=[
        "middle_path_mean_curvature",
    ],
)

### ext figure for middle curvature

- append what we have (in the slides) as a picture to the extended figure
	- two pictures of the afm (34 OT2, 52 OT1) for the sharp images, add scale bars, add a dotted line along the dna to show the middle region 
	- show only the OT2 & OT1 violin plots with labels "OT1", "OT2"
	- grab the N for the caption
	- scale bar or add "image width"
- call the name "central mean curvature"

In [None]:
# plot OT2 grain 34 with scale bar and line for the middle region
ot2_grain = grains_dicts_middle_curvature_analysis["OT2_SC"][34]
ot1_grain = grains_dicts_middle_curvature_analysis["OT1_SC"][32]

print(ot2_grain.keys())

figsize = (4, 2.5)
axlabel_font_size = 14
dpi = 500
legend_font_size = 8


def plot_grain_for_ext_fig(
    grain_dict: dict,
    coloured_lines_norm_bounds: tuple[int, int] = (0.0, 0.3),
) -> None:
    fig, ax = plt.subplots(figsize=(5, 5))

    image = grain_dict["image"]
    middle_path_px = grain_dict["middle_path_px"]
    middle_path_curvatures = grain_dict["middle_path_curvatures"]
    resampled_path_px = grain_dict["resampled_path_px"]
    abs_curvature = grain_dict["abs_curvature"]
    p_to_nm = grain_dict["p_to_nm"]

    plt.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)

    plt.plot(middle_path_px[:, 1], middle_path_px[:, 0], color="lightgrey", linewidth=15, label="Middle Path")

    normalised_curvature = (middle_path_curvatures - coloured_lines_norm_bounds[0]) / (
        coloured_lines_norm_bounds[1] - coloured_lines_norm_bounds[0]
    )
    for point_index, point in enumerate(middle_path_px):
        colour = mpl.cm.Blues(normalised_curvature[point_index])
        if point_index > 0:
            previous_point = middle_path_px[point_index - 1]
            plt.plot(
                [previous_point[1], point[1]],
                [previous_point[0], point[0]],
                color=colour,
                linewidth=5,
            )

    # render a scale bar of 20 nm

    scale_bar_length_nm = 20
    scale_bar_length_px = scale_bar_length_nm / p_to_nm
    # draw a white line on the image in the bottom right
    scale_bar_r = image.shape[1] - image.shape[1] * 0.08
    scale_bar_y = image.shape[0] - image.shape[0] * 0.08
    scale_bar_l = scale_bar_r - scale_bar_length_px
    plt.plot(
        [scale_bar_l, scale_bar_r],
        [scale_bar_y, scale_bar_y],
        color="white",
        linewidth=15,
        label=f"Scale Bar: {scale_bar_length_nm} nm",
    )

    # turn axes off
    plt.axis("off")
    plt.tight_layout(pad=0)
    # save the figure
    plt.savefig(
        EXT_FIG_DATA_SAVE_DIR
        / f"grain_central_curvature_plot_{grain_dict['sample_type']}_{grain_dict['grain_index']}.png",
        dpi=dpi,
        bbox_inches="tight",
        pad_inches=0
    )
    plt.show()


plot_grain_for_ext_fig(ot2_grain)
plot_grain_for_ext_fig(ot1_grain)

# plot the violins of just OT1 & OT2 middle mean curvature
df_middle_curvature_stats_ot1_ot2 = df_all_middle_curvature_stats[
    df_all_middle_curvature_stats["sample_type"].isin(["OT1_SC", "OT2_SC"])
]
fig, ax = plt.subplots(figsize=figsize)
sns.stripplot(ax=ax, x="sample_type", y="mean_curvature", data=df_middle_curvature_stats_ot1_ot2, color="grey", s=4)
sns.violinplot(
    x="sample_type",
    y="mean_curvature",
    data=df_middle_curvature_stats_ot1_ot2,
    # color="lightblue",
    hue="sample_type",
    # set palette with hex
    palette={"OT1_SC": "#3C7FE6", "OT2_SC": "#0C5ACE"},
    inner=None,
    alpha=0.5,
)
plt.ylabel("Central mean\ncurvature (rad/nm)", fontsize=axlabel_font_size)
# set font size of x ticks
plt.xticks(fontsize=axlabel_font_size)
# rename the xticks
plt.xticks(ticks=[0, 1], labels=["ot1 sc", "ot2 sc"])
# set font size of y ticks
plt.yticks(fontsize=axlabel_font_size)
plt.xlabel("")
plt.tight_layout()
# save the figure
plt.savefig(EXT_FIG_DATA_SAVE_DIR / f"ot1_ot2_central_mean_curvature_violin_plot.png", dpi=dpi)
plt.show()

# print the Ns
samples = df_middle_curvature_stats_ot1_ot2["sample_type"].value_counts()
print(samples)