In [None]:
from pathlib import Path
import re
import pickle
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import seaborn as sns
from scipy.ndimage import distance_transform_edt, gaussian_filter1d, binary_fill_holes
from skimage.measure import label, regionprops
from skimage.graph import route_through_array
from skimage.morphology import binary_dilation, binary_erosion, skeletonize

from topostats.plottingfuncs import Colormap
from topostats.tracing.splining import resample_points_regular_interval, windowTrace
from topostats.measure.feret import get_feret_from_mask
from topostats.measure.curvature import discrete_angle_difference_per_nm_linear
from topostats.damage.damage import get_defects_and_gaps_from_bool_array, Defect, DefectGap

colormap = Colormap()
CMAP = colormap.get_cmap()
COLOUR_PATH_CMAP = mpl.cm.coolwarm

VMIN = 0
VMAX = 4

In [None]:
TODAY_DATE = datetime.now().strftime("%Y-%m-%d")
DATE_TO_READ_FROM = "2024-03-22"
BASE_DATA_DIR = Path(f"/Users/sylvi/topo_data/hariborings/extracted_grains/")
data_dirs = {
    "ON_SC": BASE_DATA_DIR / f"cas9_ON_SC/{DATE_TO_READ_FROM}",
    "OT1_SC": BASE_DATA_DIR / f"cas9_OT1_SC/{DATE_TO_READ_FROM}",
    "OT2_SC": BASE_DATA_DIR / f"cas9_OT2_SC/{DATE_TO_READ_FROM}",
}
BASE_SAVE_DIR = Path(f"/Users/sylvi/topo_data/hariborings/processed_grains/")
MAX_PX_TO_NM = 10.0

save_dirs = {
    "ON_SC": BASE_SAVE_DIR / f"cas9_ON_SC/{TODAY_DATE}",
    "OT1_SC": BASE_SAVE_DIR / f"cas9_OT1_SC/{TODAY_DATE}",
    "OT2_SC": BASE_SAVE_DIR / f"cas9_OT2_SC/{TODAY_DATE}",
}
for sample_type in save_dirs.keys():
    save_dirs[sample_type].mkdir(exist_ok=True, parents=True)

sample_types = ["ON_SC", "OT1_SC", "OT2_SC"]
grains_dicts = {}
for sample_type in sample_types:
    grains_dicts[sample_type] = {}
    file_path = data_dirs[sample_type] / "grain_dict.pkl"
    with open(file_path, "rb") as f:
        sample_grains_dicts = pickle.load(f)
        for grain_index, single_grain_dict in sample_grains_dicts.items():
            p_to_nm = single_grain_dict["p_to_nm"]
            if p_to_nm > MAX_PX_TO_NM:
                print(
                    f"Skipping grain {grain_index} for sample type [{sample_type}] due to p_to_nm too large: {p_to_nm} > {MAX_PX_TO_NM}"
                )
            else:
                grains_dicts[sample_type][grain_index] = single_grain_dict


# print num grains for each sample type
for sample_type, grain_dict in grains_dicts.items():
    print(f"Number of grains for sample type [{sample_type}]: {len(grain_dict.keys())}")

In [None]:
def plot_all_grains_dictionary(
    grain_dict: dict,
    num_cols: int = 3,
    vmin: float = VMIN,
    vmax: float = VMAX,
    cmap: mpl.cm = CMAP,
    plot_paths: list[str] = None,
    plot_defects: bool = False,
    plot_colour_paths: list[tuple[str, str]] = None,
    plot_colour_norm_bounds: tuple[float, float] = (-0.3, 0.3),
    plot_lineplots: list[tuple[str, str, str, str, str, tuple[float, float] | None]] = None,
    save_path: Path | None = None,
    show_plot: bool = True,
    dpi=200,
    figsize_multiplier: float = 4,
    wspace: float = 0.1,
    hspace: float = 0.1,
    num_grains_to_limit_to: int | None = None,
) -> None:
    """plot grains in a grain dictioanry in a grid layout"""
    num_grains = len(grain_dict.keys())
    if num_grains_to_limit_to is not None:
        num_grains = min(num_grains, num_grains_to_limit_to)
    num_rows = num_grains // num_cols + (num_grains % num_cols > 0)
    num_inner_rows = 1
    if plot_lineplots is not None:
        num_inner_rows = 1 + len(plot_lineplots)
    num_inner_cols = 2
    fig = plt.figure(figsize=(num_cols * figsize_multiplier, num_rows * figsize_multiplier), dpi=dpi)
    outer = gridspec.GridSpec(nrows=num_rows, ncols=num_cols, wspace=wspace, hspace=hspace)

    for inner_index, (grain_index, grain_data) in enumerate(grain_dict.items()):
        if inner_index >= num_grains:
            break
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        grain_p_to_nm = grain_data["p_to_nm"]

        outer_coords = np.unravel_index(inner_index, (num_rows, num_cols))

        inner = gridspec.GridSpecFromSubplotSpec(
            nrows=num_inner_rows,
            ncols=num_inner_cols,
            subplot_spec=outer[outer_coords],
            wspace=wspace,
            hspace=hspace,
        )

        # plot image
        image_ax = plt.Subplot(fig, inner[0, 0])
        image_ax.imshow(grain_image, vmin=vmin, vmax=vmax, cmap=cmap)
        image_ax.set_title(f"Grain {grain_index}\np_to_nm: {grain_p_to_nm:.2f} nm/px")
        image_ax.axis("off")
        fig.add_subplot(image_ax)

        # plot mask
        mask_ax = plt.Subplot(fig, inner[0, 1])
        mask_ax.imshow(grain_mask)
        # plot paths if present and required
        if plot_paths is not None:
            for path_name in plot_paths:
                if path_name in grain_data:
                    path = grain_data[path_name]
                    if path is not None:
                        if len(path) > 2:
                            mask_ax.plot(path[:, 1], path[:, 0], color="red", linewidth=2, label=path_name)
        if plot_colour_paths is not None:
            for path_name, colour_array in plot_colour_paths:
                if path_name in grain_data:
                    path = grain_data[path_name]
                    cvals = grain_data[colour_array]
                    normalised_cvals = (cvals - plot_colour_norm_bounds[0]) / (
                        plot_colour_norm_bounds[1] - plot_colour_norm_bounds[0]
                    )
                    if path is not None:
                        if len(path) > 2:
                            for point_index, point in enumerate(path):
                                colour = COLOUR_PATH_CMAP(normalised_cvals[point_index])
                                if point_index > 0:
                                    previous_point = path[point_index - 1]
                                    mask_ax.plot(
                                        [previous_point[1], point[1]],
                                        [previous_point[0], point[0]],
                                        color=colour,
                                        linewidth=2,
                                    )

        mask_ax.set_title("Mask")
        mask_ax.axis("off")
        fig.add_subplot(mask_ax)

        if plot_defects:
            if "ordered_defect_list" in grain_data:
                ordered_defect_list = grain_data["ordered_defect_list"].defect_gap_list
                for defect_or_gap in ordered_defect_list:
                    if isinstance(defect_or_gap, Defect):
                        defect = defect_or_gap
                        defect_start = defect.start_index
                        defect_end = defect.end_index
                        defect_start_point = grain_data["resampled_path_px"][defect_start]
                        defect_end_point = grain_data["resampled_path_px"][defect_end]
                        mask_ax.scatter(
                            defect_start_point[1],
                            defect_start_point[0],
                            color="blue",
                            s=100,
                        )
                        mask_ax.scatter(
                            defect_end_point[1],
                            defect_end_point[0],
                            color="red",
                            s=100,
                        )

        if plot_lineplots is not None:
            for lineplot_index, lineplot_x_y_and_titles in enumerate(plot_lineplots):
                # bottom row spans all columns
                lineplot_ax = plt.Subplot(fig, inner[1 + lineplot_index, :])
                # get the data
                (
                    lineplot_x_name,
                    lineplot_y_name,
                    lineplot_xlabel,
                    lineplot_ylabel,
                    lineplot_title,
                    lineplot_norm_bounds,
                ) = lineplot_x_y_and_titles
                lineplot_x_data = grain_data[lineplot_x_name]
                lineplot_y_data = grain_data[lineplot_y_name]
                lineplot_ax.plot(lineplot_x_data, lineplot_y_data)
                lineplot_ax.set_xlabel(lineplot_xlabel)
                lineplot_ax.set_ylabel(lineplot_ylabel)
                lineplot_ax.set_title(lineplot_title)
                if lineplot_norm_bounds is not None:
                    lineplot_ax.set_ylim(lineplot_norm_bounds)

            fig.add_subplot(lineplot_ax)

    if save_path is not None:
        # create the parent directory
        save_path.parent.mkdir(parents=True, exist_ok=True)
        print(f"saving figure to {save_path}")
        fig.savefig(save_path, dpi=dpi)

    if show_plot:
        plt.show()
    else:
        plt.close(fig)

In [None]:
# plot all grains
if True:
    for sample_type, grain_dict in grains_dicts.items():
        print(f"Plotting grains for sample type [{sample_type}]")
        plot_all_grains_dictionary(
            grain_dict,
            vmin=VMIN,
            vmax=VMAX,
            cmap=CMAP,
            show_plot=True,
            num_grains_to_limit_to=10,
            wspace=0,
            hspace=0,
            num_cols=5,
        )

In [None]:
# pathfinding
grains_dicts_paths = {}
bad_grains = []
for sample_type, grains_dict_sample in grains_dicts.items():
    grains_dicts_paths[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]

        # distance transforms
        distance_transform = distance_transform_edt(grain_mask > 0)
        distance_transform[grain_mask == 2] = 0

        # start at point where intersection region 0 has maximum distance transform
        intersection_labels = label(intersection_labels)
        intersection_regions = regionprops(intersection_labels)
        # Get intersection regions
        region_0 = intersection_regions[0]
        region_1 = intersection_regions[1]
        # Create list of distance transform values for each region
        region_0_distance_transform_values = []
        region_1_distance_transform_values = []
        for pixel in region_0.coords:
            region_0_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
        for pixel in region_1.coords:
            region_1_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
        region_0_distance_transform_values = np.array(region_0_distance_transform_values)
        region_1_distance_transform_values = np.array(region_1_distance_transform_values)
        # Get the maximum distance transform value for each region
        region_0_max_distance_transform_value = np.max(region_0_distance_transform_values)
        region_1_max_distance_transform_value = np.max(region_1_distance_transform_values)
        region_0_max_distance_transform_value_index = np.argmax(region_0_distance_transform_values)
        region_1_max_distance_transform_value_index = np.argmax(region_1_distance_transform_values)
        # get the pixel coord of the maximum distance transform value for each region
        region_0_max_distance_transform_pixel = region_0.coords[region_0_max_distance_transform_value_index]
        region_1_max_distance_transform_pixel = region_1.coords[region_1_max_distance_transform_value_index]

        start_point = (region_0_max_distance_transform_pixel[0], region_0_max_distance_transform_pixel[1])
        end_point = (region_1_max_distance_transform_pixel[0], region_1_max_distance_transform_pixel[1])

        # invert the distance transform to get the cost map
        cost_map = np.max(distance_transform) - distance_transform
        # set the maximum value to be huge
        cost_map[cost_map == np.max(cost_map)] = 1e4

        route, weight = route_through_array(array=cost_map, start=start_point, end=end_point)

        route = np.array(route)

        # if the route is fewer than 4 points, skip it
        if len(route) < 5:
            print(
                f"Skipping grain {grain_index} for sample type [{sample_type}] due to route being too short: {len(route)} < 4"
            )
            bad_grains.append((sample_type, grain_index))
            continue

        # save the original grain data
        grains_dicts_paths[sample_type][grain_index] = grain_data
        # add the route to the grain data
        grains_dicts_paths[sample_type][grain_index]["path"] = route

print(f"Bad grains: {bad_grains}")

# plot all grains with paths
if False:
    for sample_type, grain_dict in grains_dicts_paths.items():
        print(f"Plotting grains with paths for sample type [{sample_type}]")
        plot_all_grains_dictionary(
            grain_dict=grain_dict,
            vmin=VMIN,
            vmax=VMAX,
            cmap=CMAP,
            plot_paths=["path"],
            save_path=save_dirs[sample_type] / f"grains_with_paths_{sample_type}.png",
        )

In [None]:
# resampling the paths
trace_resampling_distance_nm = 2.0
smoothing_window_size_nm = 3.0
verbose = False

grains_dicts_resampled_paths = {}
for sample_type, grains_dict_sample in grains_dicts_paths.items():
    grains_dicts_resampled_paths[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]
        p_to_nm = grain_data["p_to_nm"]
        path = grain_data["path"]

        if verbose:
            print(f"grain: {grain_index}")
            print(f"p2nm: {p_to_nm}")
            print(f"path shape: {path.shape}")
        resampled_path_px = resample_points_regular_interval(
            points=path, interval=trace_resampling_distance_nm / p_to_nm, circular=False
        )
        if verbose:
            print(f"resampled path shape: {resampled_path_px.shape}")
        # smooth the path a bit
        smoothed_path_px = windowTrace.pool_trace_linear(
            pixel_trace=resampled_path_px,
            rolling_window_size=smoothing_window_size_nm,
            pixel_to_nm_scaling=p_to_nm,
        )
        if verbose:
            print(f"smoothed path shape: {smoothed_path_px.shape}")
        # re-sample again to ensure the path is regular
        resampled_path_px = resample_points_regular_interval(
            points=smoothed_path_px, interval=trace_resampling_distance_nm / p_to_nm, circular=False
        )
        if verbose:
            print(f"final resampled path shape: {resampled_path_px.shape}")
            print()

        if len(resampled_path_px) <= 2:
            print(
                f"Skipping grain {grain_index} for sample type [{sample_type}] due to resampled path too short: {len(resampled_path_px)}"
            )
            continue

        diffs = resampled_path_px[1:, :] - resampled_path_px[:-1, :]
        resampled_distances_px = np.linalg.norm(diffs, axis=1)
        # prepend a zero to the distances to match the length of the resampled path
        resampled_distances_px = np.insert(resampled_distances_px, 0, 0)
        resampled_distances_nm = resampled_distances_px * p_to_nm
        cumulative_resampled_distances_px = np.cumsum(resampled_distances_px)
        cumulative_resampled_distances_nm = np.cumsum(resampled_distances_nm)

        # save the original grain data
        grains_dicts_resampled_paths[sample_type][grain_index] = grain_data
        # add the smoothed path to the grain data
        grains_dicts_resampled_paths[sample_type][grain_index]["resampled_path_px"] = resampled_path_px
        grains_dicts_resampled_paths[sample_type][grain_index]["resampled_distances_px"] = resampled_distances_px
        grains_dicts_resampled_paths[sample_type][grain_index]["resampled_distances_nm"] = resampled_distances_nm
        grains_dicts_resampled_paths[sample_type][grain_index][
            "cumulative_resampled_distances_px"
        ] = cumulative_resampled_distances_px
        grains_dicts_resampled_paths[sample_type][grain_index][
            "cumulative_resampled_distances_nm"
        ] = cumulative_resampled_distances_nm

# plot all grains with smoothed paths
if False:
    for sample_type, grain_dict in grains_dicts_resampled_paths.items():
        print(f"Plotting grains with smoothed paths for sample type [{sample_type}]")
        plot_all_grains_dictionary(
            grain_dict=grain_dict,
            vmin=VMIN,
            vmax=VMAX,
            cmap=CMAP,
            plot_paths=["resampled_path_px"],
            save_path=save_dirs[sample_type] / f"grains_with_resampled_paths_{sample_type}.png",
            show_plot=False,
        )

In [None]:
# calculate the feret diameter for each grain
grains_dicts_feret_diameter = {}
for sample_type, grains_dict_sample in grains_dicts_resampled_paths.items():
    grains_dicts_feret_diameter[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]
        p_to_nm = grain_data["p_to_nm"]

        grain_mask_dna_only = grain_mask.copy()
        grain_mask_dna_only[grain_mask == 2] = 0

        # calculate the feret diameter
        feret_diameters = get_feret_from_mask(mask_im=grain_mask_dna_only)
        min_feret = feret_diameters["min_feret"] * p_to_nm
        max_feret = feret_diameters["max_feret"] * p_to_nm

        # save the original grain data
        grains_dicts_feret_diameter[sample_type][grain_index] = grain_data
        # add the feret diameter to the grain data
        grains_dicts_feret_diameter[sample_type][grain_index]["min_feret"] = min_feret
        grains_dicts_feret_diameter[sample_type][grain_index]["max_feret"] = max_feret

In [None]:
# calcuate curvature of the path
grains_dicts_curvature = {}
for sample_type, grains_dict_sample in grains_dicts_feret_diameter.items():
    grains_dicts_curvature[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]
        p_to_nm = grain_data["p_to_nm"]
        resampled_path_px = grain_data["resampled_path_px"]

        # calculate the curvature
        path_curvatures = discrete_angle_difference_per_nm_linear(trace_nm=resampled_path_px * p_to_nm)

        # save the original grain data
        grains_dicts_curvature[sample_type][grain_index] = grain_data
        # add the curvature to the grain data
        grains_dicts_curvature[sample_type][grain_index]["curvature"] = path_curvatures


# plot all grains with curvature
for sample_type, grain_dict in grains_dicts_curvature.items():
    print(f"Plotting grains with curvature for sample type [{sample_type}]")
    plot_all_grains_dictionary(
        grain_dict=grain_dict,
        vmin=VMIN,
        vmax=VMAX,
        cmap=CMAP,
        plot_colour_paths=[("resampled_path_px", "curvature")],
        plot_lineplots=[
            [
                "cumulative_resampled_distances_nm",
                "curvature",
                "distance (nm)",
                "curvature (rad/nm)",
                "Curvature",
                (-0.5, 0.5),
            ]
        ],
        save_path=save_dirs[sample_type] / f"grains_with_curvature_{sample_type}.png",
        show_plot=True,
        figsize_multiplier=5,
        num_grains_to_limit_to=10,
        num_cols=5,
    )

In [None]:
# detect defects

defect_curvature_threshold_radpernm = 0.3

grains_dicts_curvature_defects = {}
for sample_type, grains_dict_sample in grains_dicts_curvature.items():
    grains_dicts_curvature_defects[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]
        p_to_nm = grain_data["p_to_nm"]
        resampled_path_px = grain_data["resampled_path_px"]
        resampled_distances_nm = grain_data["resampled_distances_nm"]
        path_curvatures = grain_data["curvature"]

        # detect defects based on curvature
        defect_mask = np.abs(path_curvatures) > defect_curvature_threshold_radpernm
        ordered_defect_list = get_defects_and_gaps_from_bool_array(
            defects_bool=defect_mask,
            distance_to_previous_points_nm=grain_data["resampled_distances_nm"],
            circular=False,
        )

        # save the original grain data
        grains_dicts_curvature_defects[sample_type][grain_index] = grain_data
        # add the defects to the grain data
        grains_dicts_curvature_defects[sample_type][grain_index]["ordered_defect_list"] = ordered_defect_list

In [None]:
# plot all grains with defects
for sample_type, grain_dict in grains_dicts_curvature_defects.items():
    print(f"Plotting grains with defects for sample type [{sample_type}]")
    plot_all_grains_dictionary(
        grain_dict=grain_dict,
        vmin=VMIN,
        vmax=VMAX,
        cmap=CMAP,
        plot_paths=["resampled_path_px"],
        plot_defects=True,
        plot_colour_paths=[("resampled_path_px", "curvature")],
        plot_lineplots=[
            [
                "cumulative_resampled_distances_nm",
                "curvature",
                "distance (nm)",
                "curvature (rad/nm)",
                "Curvature",
                (-0.5, 0.5),
            ]
        ],
        save_path=save_dirs[sample_type] / f"grains_with_defects_{sample_type}.png",
        show_plot=True,
        figsize_multiplier=5,
        num_grains_to_limit_to=30,
        num_cols=5,
    )

In [None]:
# plot the positions of the defects for each sample type
# collect all defect positions first
all_defect_positions = {}
for sample_type, grain_dict in grains_dicts_curvature_defects.items():
    defect_positions_nm = []
    for grain_index, grain_data in grain_dict.items():
        for defect_or_gap in grain_data["ordered_defect_list"].defect_gap_list:
            if isinstance(defect_or_gap, Defect):
                defect_positions_nm.append(defect_or_gap.position_along_trace_nm)
    all_defect_positions[sample_type] = defect_positions_nm
    print(f"Number of defects for sample type [{sample_type}]: {len(defect_positions_nm)}")

# create a single figure with subplots for all sample types
fig, axes = plt.subplots(1, len(sample_types), figsize=(15, 5), sharey=True)

for i, sample_type in enumerate(sample_types):
    defect_positions_nm = all_defect_positions[sample_type]

    # create stripplot and violinplot on the same axis
    sns.stripplot(
        x=[sample_type] * len(defect_positions_nm), y=defect_positions_nm, jitter=0.1, color="grey", ax=axes[i]
    )
    sns.violinplot(
        x=[sample_type] * len(defect_positions_nm),
        y=defect_positions_nm,
        color="lightblue",
        inner=None,
        linewidth=0.5,
        alpha=0.5,
        ax=axes[i],
    )

    axes[i].set_title(f"{sample_type}\n({len(defect_positions_nm)} defects)")
    axes[i].set_xlabel("Sample Type")
    if i == 0:
        axes[i].set_ylabel("Position along trace (nm)")
    else:
        axes[i].set_ylabel("")

    # remove x-axis tick labels to clean up appearance
    axes[i].set_xticklabels([])

plt.suptitle("Defect Positions by Sample Type", fontsize=16)
plt.tight_layout()

# save the combined plot
combined_save_path = BASE_SAVE_DIR / f"combined_defect_positions_{TODAY_DATE}.png"
combined_save_path.parent.mkdir(exist_ok=True, parents=True)
plt.savefig(combined_save_path, dpi=200, bbox_inches="tight")
plt.show()