In [None]:
import pickle
from pathlib import Path
import re

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from AFMReader.topostats import load_topostats

from topostats.tracing.splining import resample_points_regular_interval
from topostats.measure.curvature import discrete_angle_difference_per_nm_circular

import numpy.typing as npt
import numpy as np
from topostats.io import LoadScans
import matplotlib.pyplot as plt
from topostats.unet_masking import make_bounding_box_square, pad_bounding_box

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable


def plot_line_coloured(
    intensity_ys: npt.NDArray[np.float_],
    intensity_xs: npt.NDArray[np.float_] | None = None,
    image: npt.NDArray[np.float_] | None = None,
    xy_coords: npt.NDArray[np.float_] | None = None,
    figsize: tuple[int, int] = (20, 8),
):
    """
    Plot a line plot with the colour of each point determined by the position in the array.

    Parameters
    ----------
    intensity_ys : npt.NDArray[np.float_]
        The values to plot.
    intensity_xs : npt.NDArray[np.float_], optional
        The x-coordinates of the points. If None, the indices of yvals are used.
    image : npt.NDArray[np.float_], optional
        An image to plot the line on top of. If None, no image is plotted.
    xy_coords : npt.NDArray[np.float_], optional
        The coordinates of the points, for plotting trace on top of the image.
    """

    # Generate a colormap
    cmap = plt.get_cmap("cool")  # Smooth transition from red to purple
    norm = Normalize(vmin=0, vmax=len(intensity_ys) - 1)
    sm = ScalarMappable(norm=norm, cmap=cmap)

    # optionally create x coords or use the provided positions
    if intensity_xs is not None:
        assert len(intensity_xs) == len(intensity_ys), "intensity_xs and intensity_ys must have the same length"
        x = intensity_xs
    else:
        x = np.arange(len(intensity_ys))

    # If image supplied
    if image is not None:
        assert xy_coords is not None, "xy_coords must be provided if image is supplied"
        assert intensity_ys.shape[0] == xy_coords.shape[0], "intensity_ys and xy_coords must have the same length"
        fig, ax = plt.subplots(1, 2, figsize=figsize)
        ax[0].imshow(image, cmap="afmhot", interpolation="none")
        ax[0].set_title("Image")

        # plot the points of the line with the corresponding color
        for i in range(len(intensity_ys)):
            ax[0].scatter(xy_coords[i, 1], xy_coords[i, 0], color=sm.to_rgba(i), s=10)
        ax[0].set_title("Image with Line Overlay")

        lineplot_ax = ax[1]
    else:
        fig, ax = plt.subplots(figsize=figsize)
        lineplot_ax = ax

    # Plot each segment with its corresponding color
    for i in range(len(intensity_ys) - 1):
        lineplot_ax.plot(x[i : i + 2], intensity_ys[i : i + 2], color=sm.to_rgba(i), linewidth=2)

    # Add colorbar for reference
    plt.colorbar(sm, ax=ax, label="Position in Array")
    lineplot_ax.set_xlabel("Index")
    lineplot_ax.set_ylabel("Value")
    lineplot_ax.set_title("Line Plot with Colour Gradient")
    plt.show()


random_values = np.random.rand(100)
# plot_line_coloured(random_values)


def defect_stats(
    curvature_defects_indexes: npt.NDArray[np.int_],
    height_defects_indexes: npt.NDArray[np.int_],
) -> dict:
    """Calculate defect stats"""

    # create a bool array for the defects
    pass

In [None]:
base_dir = Path("/Users/sylvi/topo_data/picoz")
assert base_dir.exists()

processed_dir = base_dir / "output_trained_on"
assert processed_dir.exists()

postprocess_dir = base_dir / "postprocess"
postprocess_dir.mkdir(exist_ok=True)

# load the images
topo_files = list(processed_dir.glob("**/*.topostats"))
print(f"found {len(topo_files)} topostats files")

# load the corresponding stats csv
all_stats_csv = processed_dir / "all_statistics.csv"
assert all_stats_csv.exists()
all_stats_df = pd.read_csv(all_stats_csv)
print(all_stats_df.columns)

all_stats_df["total_contour_length"] /= 1e-9


# all_data = {}
# for topo_file in topo_files:
#     file_data = load_topostats(topo_file)
#     print(file_data.keys())

#     try:
#         all_data[topo_file.name] = {
#             "image": file_data["image"],
#             "grain_tensors": file_data["grain_tensors"],
#             "p2nm": file_data["pixel_to_nm_scaling"],
#             "curvature_stats": file_data["grain_curvature_stats"],
#             "splining": file_data["splining"],
#         }
#     except KeyError as e:
#         if "curvature_stats" in str(e):
#             pass
#         elif "grain_tensors" in str(e):
#             pass
#         else:
#             raise e

In [None]:
def construct_grains_dictionary(
    file_list: list,
    bbox_padding: int,
    all_stats_df: pd.DataFrame,
    stop_at_index: int | None = None,
    plot: bool = False,
):
    grains_dictionary: dict[any] = {}

    loadscans = LoadScans(file_list, channel="dummy")
    loadscans.get_data()
    img_dict = loadscans.img_dict

    grain_index = 0
    for file_index, (filename, file_data) in enumerate(img_dict.items()):
        if stop_at_index is not None and file_index >= stop_at_index:
            break
        try:
            try:
                nodestats_data = file_data["nodestats"]["above"]["stats"]
            except KeyError:
                nodestats_data = None

            # get the corresponding image rows from the stats csv
            all_stats_image = all_stats_df[all_stats_df["image"] == filename]
            print(f"found {len(all_stats_image)} rows in stats csv for {filename}")

            # print(f"getting data from {filename}")
            image = file_data["image"]
            ordered_trace_data = file_data["ordered_traces"]["above"]

            for current_grain_index, grain_ordered_trace_data in ordered_trace_data.items():

                grain_index_int = re.sub(r"grain_", "", current_grain_index)
                grain_index_int = int(grain_index_int)
                all_stats_grain = all_stats_image[all_stats_image["grain_number"] == grain_index_int]
                # grab quantities from the stats csv
                smallest_bounding_area = all_stats_grain["smallest_bounding_area"].values[0]
                aspect_ratio = all_stats_grain["aspect_ratio"].values[0]
                total_contour_length = all_stats_grain["total_contour_length"].values[0]
                num_crossings = all_stats_grain["num_crossings"].values[0]

                # print(f"  grain {current_grain_index}")
                grains_dictionary[grain_index] = {}
                grains_dictionary[grain_index]["molecule_data"] = {}

                # all_stats_grain = all_stats_image[all_stats_image[""]]

                for current_molecule_index, molecule_ordered_trace_data in grain_ordered_trace_data.items():
                    molecule_data = {}
                    ordered_coords = molecule_ordered_trace_data["ordered_coords"]
                    molecule_data["heights"] = molecule_ordered_trace_data["heights"]
                    molecule_data["distances"] = molecule_ordered_trace_data["distances"]
                    molecule_data["circular"] = molecule_ordered_trace_data["mol_stats"]["circular"]
                    bbox = molecule_ordered_trace_data["bbox"]
                    print(f"  grain {current_grain_index} molecule {current_molecule_index} bbox {bbox}")

                    splining_coords = file_data["splining"]["above"][current_grain_index][current_molecule_index][
                        "spline_coords"
                    ]

                    curvatures = file_data["grain_curvature_stats"]["above"][current_grain_index][
                        current_molecule_index
                    ]
                    molecule_data["curvatures"] = curvatures

                    # bbox will be same for all molecules so this is okay
                    bbox_square = make_bounding_box_square(bbox[0], bbox[1], bbox[2], bbox[3], image.shape)
                    bbox_padded = pad_bounding_box(
                        bbox_square[0],
                        bbox_square[1],
                        bbox_square[2],
                        bbox_square[3],
                        image.shape,
                        padding=bbox_padding,
                    )
                    added_left = bbox_padded[1] - bbox[1]
                    added_top = bbox_padded[0] - bbox[0]

                    # adjust the spline coords to account for the padding
                    splining_coords[:, 0] -= added_top
                    splining_coords[:, 1] -= added_left
                    molecule_data["spline_coords"] = splining_coords

                    # adjust the ordered coords to account for the padding
                    ordered_coords[:, 0] -= added_top
                    ordered_coords[:, 1] -= added_left
                    molecule_data["ordered_coords"] = ordered_coords

                    grains_dictionary[grain_index]["molecule_data"][current_molecule_index] = molecule_data

                image_crop = image[
                    bbox_padded[0] : bbox_padded[2],
                    bbox_padded[1] : bbox_padded[3],
                ]
                full_grain_mask = file_data["grain_tensors"]["above"]
                grains_dictionary[grain_index]["image"] = image_crop
                grains_dictionary[grain_index]["full_image"] = image
                grains_dictionary[grain_index]["bbox"] = bbox_padded
                grains_dictionary[grain_index]["aspect_ratio"] = aspect_ratio
                grains_dictionary[grain_index]["smallest_bounding_area"] = smallest_bounding_area
                grains_dictionary[grain_index]["total_contour_length"] = total_contour_length
                grains_dictionary[grain_index]["num_crossings"] = num_crossings
                grains_dictionary[grain_index]["added_left"] = added_left
                grains_dictionary[grain_index]["added_top"] = added_top
                grains_dictionary[grain_index]["padding"] = bbox_padding
                mask_crop = full_grain_mask[
                    bbox_padded[0] : bbox_padded[2],
                    bbox_padded[1] : bbox_padded[3],
                ]
                grains_dictionary[grain_index]["mask"] = mask_crop
                grains_dictionary[grain_index]["filename"] = file_data["filename"]
                grains_dictionary[grain_index]["pixel_to_nm_scaling"] = file_data["pixel_to_nm_scaling"]
                grains_dictionary[grain_index]["curvature_stats"] = curvatures

                # grab node coordinates
                all_node_coords = []
                if nodestats_data is not None:
                    try:
                        grain_nodestats_data = nodestats_data[current_grain_index]
                        for _node_index, node_data in grain_nodestats_data.items():
                            node_coords = node_data["node_coords"]
                            # adjust the node coords to account for the padding
                            node_coords[:, 0] -= added_top
                            node_coords[:, 1] -= added_left
                            for node_coord in node_coords:
                                all_node_coords.append(node_coord)
                    except KeyError as e:
                        if "grain_" in str(e):
                            # grain has no nodestats data here, skip
                            pass

                grains_dictionary[grain_index]["node_coords"] = np.array(all_node_coords)

                grain_index += 1
        except KeyError as e:
            if "ordered_traces" in str(e):
                print(f"no ordered traces found in {filename}")
                continue
            raise e

    if plot:
        for grain_index, grain_data in grains_dictionary.items():
            print(f"grain {grain_index}")
            print(grain_data["filename"])
            print(grain_data["pixel_to_nm_scaling"])
            image = grain_data["image"]
            fig, ax = plt.subplots(figsize=(20, 20))
            plt.imshow(image)
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                ordered_coords = molecule_data["ordered_coords"]
                plt.plot(ordered_coords[:, 1], ordered_coords[:, 0], "r")
            # spline coords
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                spline_coords = molecule_data["spline_coords"]
                plt.plot(spline_coords[:, 1], spline_coords[:, 0], "g.")
                spline_distances = np.linalg.norm(spline_coords[1:, :] - spline_coords[:-1, :], axis=1)
                print(
                    f"molecule {molecule_index} spline distances mean {np.mean(spline_distances)} std {np.std(spline_distances)} min {np.min(spline_distances)} max {np.max(spline_distances)}"
                )
            all_node_coords = grain_data["node_coords"]
            if all_node_coords.size > 0:
                plt.plot(all_node_coords[:, 1], all_node_coords[:, 0], "b.")
            plt.show()

            mask = grain_data["mask"][:, :, 1]
            plt.imshow(mask)
            plt.show()

    print(f"found {len(grains_dictionary)} grains in {len(file_list)} images")

    return grains_dictionary


grains_dictionary = construct_grains_dictionary(
    topo_files, bbox_padding=20, all_stats_df=all_stats_df, stop_at_index=None, plot=False
)

In [None]:
defect_stats_dir = postprocess_dir / "defect_stats"
defect_stats_dir.mkdir(exist_ok=True)

interval_nm = 2  # nm
curvature_average_window_nm = 10  # nm
curvature_rolling_window_size = int(curvature_average_window_nm / interval_nm)

curvature_threshold_radpernm = 0.15  # radians/nm
height_threshold_nm = 2.0  # nm

total_index_to_stop_at = 10000

coindicing_defect_distance_threshold_nm = 10  # nm

grains_defects_dictionary = {}
defect_stats_list = []
# for each grain dictionary, if multiple molecules, skip
for total_index, (grain_index, grain_data) in enumerate(grains_dictionary.items()):
    if total_index >= total_index_to_stop_at:
        break
    image = grain_data["image"]
    mask = grain_data["mask"][:, :, 1]
    filename = grain_data["filename"]
    print(f"grain {grain_index} {filename}")
    pixel_to_nm_scaling = grain_data["pixel_to_nm_scaling"]

    grains_defects_dictionary[grain_index] = {}

    # for each molecule, grab the curvatures, heights, and spline coords
    for molecule_index, molecule_data in grain_data["molecule_data"].items():
        ordered_coords = molecule_data["ordered_coords"]
        spline_coords = molecule_data["spline_coords"]
        curvatures = molecule_data["curvatures"]
        heights = molecule_data["heights"]
        distances = molecule_data["distances"]
        circular = molecule_data["circular"]
        if not circular:
            print(f"  molecule {molecule_index} is not circular, skipping")
            continue

        resampled_points = resample_points_regular_interval(
            points=spline_coords, interval=interval_nm / pixel_to_nm_scaling, circular=True
        )

        resampled_distances = np.linalg.norm(resampled_points[1:, :] - resampled_points[:-1, :], axis=1)
        resampled_distances = np.insert(
            resampled_distances, 0, np.linalg.norm(resampled_points[0, :] - resampled_points[-1, :])
        )
        cumulative_resampled_distances = np.cumsum(resampled_distances)

        resampled_curvatures = discrete_angle_difference_per_nm_circular(
            trace_nm=resampled_points * pixel_to_nm_scaling
        )

        # smooth the curvatures using a rolling window
        smoothed_curvatures = np.convolve(resampled_curvatures, np.ones(curvature_rolling_window_size), mode="same")
        smoothed_curvatures = smoothed_curvatures / curvature_rolling_window_size

        # grab the height values at the resampled points
        resampled_heights = np.zeros(resampled_points.shape[0])
        for i, point in enumerate(resampled_points):
            # round to nearest pixel
            int_x = int(np.round(point[0]))
            int_y = int(np.round(point[1]))
            height = image[int_x, int_y]
            resampled_heights[i] = height

        # Find curvature peaks
        curvature_defects_indexes = np.where(np.abs(smoothed_curvatures) > curvature_threshold_radpernm)[0]

        # Find height dips
        height_defects_indexes = np.where(resampled_heights < height_threshold_nm)[0]

        # defects are currently in form [defect_index1, defect_index2, ...]
        # need in form [0, 0, 1, 1, 1, 0, 0, 0]
        curvature_defects_bool = np.zeros(smoothed_curvatures.shape[0], dtype=bool)
        curvature_defects_bool[curvature_defects_indexes] = True
        height_defects_bool = np.zeros(smoothed_curvatures.shape[0], dtype=bool)
        height_defects_bool[height_defects_indexes] = True

        # find the number of defect regions
        num_curvature_defects = 0
        for i in range(1, len(curvature_defects_bool)):
            if curvature_defects_bool[i] and not curvature_defects_bool[i - 1]:
                num_curvature_defects += 1
        # check if the first and last points are defects, if so reduce number of defects by 1 since they are the same
        # defect
        if curvature_defects_bool[0] and curvature_defects_bool[-1]:
            num_curvature_defects -= 1

        num_height_defects = 0
        for i in range(1, len(height_defects_bool)):
            if height_defects_bool[i] and not height_defects_bool[i - 1]:
                num_height_defects += 1
        # check if the first and last points are defects, if so reduce number of defects by 1 since they are the same
        # defect
        if height_defects_bool[0] and height_defects_bool[-1]:
            num_height_defects -= 1

        # Check for coinciding defects within the distance threshold
        coindicing_defect_distance_threshold_px = coindicing_defect_distance_threshold_nm / pixel_to_nm_scaling
        # for each curvature defect, check if there is a height defect within the distance threshold
        coinciding_defect_bool = np.zeros(smoothed_curvatures.shape[0], dtype=bool)
        for i in curvature_defects_indexes:
            for j in height_defects_indexes:
                # get the distance between the two from the cumulative distances
                distance_curv_index = cumulative_resampled_distances[i]
                distance_height_index = cumulative_resampled_distances[j]
                distance_curv_height_diff = np.abs(distance_curv_index - distance_height_index)
                if distance_curv_height_diff < coindicing_defect_distance_threshold_px:
                    coinciding_defect_bool[i] = True
                    break  # just break this loop since we only need find one match for this index

        coinciding_defects_indexes = np.where(coinciding_defect_bool)[0]

        num_coinciding_defects = 0
        for i in range(1, len(coinciding_defect_bool)):
            if coinciding_defect_bool[i] and not coinciding_defect_bool[i - 1]:
                num_coinciding_defects += 1
        # check if the first and last points are defects, if so reduce number of defects by 1 since they are the same
        # defect
        if coinciding_defect_bool[0] and coinciding_defect_bool[-1]:
            num_coinciding_defects -= 1

        grains_defects_dictionary[grain_index][molecule_index] = {
            "image": image,
            "mask": mask,
            "filename": filename,
            "pixel_to_nm_scaling": pixel_to_nm_scaling,
            "grain_index": grain_index,
            "molecule_index": molecule_index,
            "curvature_defects_indexes": curvature_defects_indexes,
            "height_defects_indexes": height_defects_indexes,
            "smoothed_curvatures": smoothed_curvatures,
            "resampled_heights": resampled_heights,
            "cumulative_resampled_distances": cumulative_resampled_distances,
            "num_curvature_defects": num_curvature_defects,
            "num_height_defects": num_height_defects,
            "num_coinciding_defects": num_coinciding_defects,
        }

        # save the defect stats to a list
        defect_stats_list.append(
            {
                "grain_index": grain_index,
                "smallest_bounding_area": grain_data["smallest_bounding_area"],
                "aspect_ratio": grain_data["aspect_ratio"],
                "total_contour_length": grain_data["total_contour_length"],
                "num_crossings": grain_data["num_crossings"],
                "molecule_index": molecule_index,
                "num_curvature_defects": num_curvature_defects,
                "num_height_defects": num_height_defects,
                "num_coinciding_defects": num_coinciding_defects,
            }
        )

        # plot
        # plot_line_coloured(
        #     intensity_ys=smoothed_curvatures,
        #     intensity_xs=cumulative_resampled_distances,
        #     image=image,
        #     xy_coords=resampled_points,
        #     figsize=(25, 6),
        # )

        # plot_line_coloured(
        #     intensity_ys=resampled_heights,
        #     intensity_xs=cumulative_resampled_distances,
        #     image=image,
        #     xy_coords=resampled_points,
        #     figsize=(25, 6),
        # )

        # plot the defects on the image
        # curvature defects
        fig, ax = plt.subplots(3, 2, figsize=(20, 20))
        plt.suptitle(
            f"image {filename} \n grain {grain_index} {molecule_index} \n mean curv: {np.mean(smoothed_curvatures):.2f}",
            fontsize=20,
        )
        # plot horizontal traces
        ax[0, 0].plot(cumulative_resampled_distances, smoothed_curvatures, "k")
        ax[0, 0].set_xlabel("Distance (nm)")
        ax[0, 0].set_ylabel("Curvature (1/nm)")
        ax[0, 0].hlines(
            [curvature_threshold_radpernm, -curvature_threshold_radpernm],
            0,
            cumulative_resampled_distances[-1],
            color="k",
            linestyles="dashed",
        )
        ax[0, 0].set_title(f"Curvature")
        ax[0, 1].plot(cumulative_resampled_distances, resampled_heights, "k")
        ax[0, 1].set_xlabel("Distance (nm)")
        ax[0, 1].set_ylabel("Height (nm)")
        ax[0, 1].hlines([height_threshold_nm], 0, cumulative_resampled_distances[-1], color="k", linestyles="dashed")
        ax[0, 1].set_title(f"Height")
        ax[1, 0].imshow(image)
        ax[1, 0].plot(resampled_points[:, 1], resampled_points[:, 0], "g.")
        for i in curvature_defects_indexes:
            ax[1, 0].plot(resampled_points[i, 1], resampled_points[i, 0], "r.", markersize=10)
        ax[1, 0].scatter(resampled_points[0, 1], resampled_points[0, 0], color="white", s=100)
        ax[1, 0].set_title(f"Curvature Defects")
        # height defects
        ax[1, 1].imshow(image)
        ax[1, 1].plot(resampled_points[:, 1], resampled_points[:, 0], "g.")
        for i in height_defects_indexes:
            ax[1, 1].plot(resampled_points[i, 1], resampled_points[i, 0], "r.", markersize=10)
        ax[1, 1].scatter(resampled_points[0, 1], resampled_points[0, 0], color="white", s=100)
        ax[1, 1].set_title(f"Height Defects")
        # plot coinciding defects
        ax_coincide = ax[2, 0]
        ax_coincide.imshow(image)
        ax_coincide.plot(resampled_points[:, 1], resampled_points[:, 0], "g.")
        for i in coinciding_defects_indexes:
            ax_coincide.plot(resampled_points[i, 1], resampled_points[i, 0], "r.", markersize=10)
        ax_coincide.scatter(resampled_points[0, 1], resampled_points[0, 0], color="white", s=100)
        ax_coincide.set_title(f"Coinciding Defects")

        # save the figure
        fig.savefig(
            defect_stats_dir / f"grain_{grain_index}_mol{molecule_index}_defects.png",
            dpi=200,
        )
        plt.close(fig)

# save the defect stats to a csv
defect_stats_df = pd.DataFrame(defect_stats_list)
defect_stats_df.to_csv(defect_stats_dir / "defect_stats.csv", index=False)

# pickle the dictionary
with open(defect_stats_dir / "grains_defects_stats_dictionary.pickle", "wb") as f:
    pickle.dump(grains_defects_dictionary, f)

In [None]:
# load the csv
defect_stats_df = pd.read_csv(defect_stats_dir / "defect_stats.csv")

print(defect_stats_df.head())

print(len(defect_stats_df))

# plot number of curvature defects per grain
defect_stats_df["num_curvature_defects"].hist()
plt.xlabel("Number of Curvature Defects")
plt.ylabel("Count")
plt.show()
defect_stats_df["num_height_defects"].hist()
plt.xlabel("Number of Height Defects")
plt.ylabel("Count")
plt.show()
defect_stats_df["num_coinciding_defects"].hist()
plt.xlabel("Number of Coinciding Defects")
plt.ylabel("Count")
plt.show()

# plot aspect ratio against number of defects as a scatter
defect_stats_df.plot.scatter(
    x="aspect_ratio",
    y="num_curvature_defects",
)
plt.xlabel("Aspect Ratio")
plt.ylabel("Number of Curvature Defects")
plt.show()
defect_stats_df.plot.scatter(
    x="aspect_ratio",
    y="num_height_defects",
)
plt.xlabel("Aspect Ratio")
plt.ylabel("Number of Height Defects")
plt.show()

plt.figure()
sns.stripplot(data=defect_stats_df, y="total_contour_length", color="black", alpha=0.5)
# violin
sns.violinplot(
    data=defect_stats_df,
    y="total_contour_length",
    color="lightgrey",
)
plt.ylabel("Total Contour Length (nm)")
plt.show()