In [None]:
from pathlib import Path
import re
import pickle
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from scipy.ndimage import distance_transform_edt, gaussian_filter1d, binary_fill_holes
from skimage.measure import label, regionprops
from skimage.graph import route_through_array
from skimage.morphology import binary_dilation, binary_erosion, skeletonize

from topostats.plottingfuncs import Colormap

colormap = Colormap()
CMAP = colormap.get_cmap()

VMIN = 0
VMAX = 4

In [None]:
TODAY_DATE = datetime.now().strftime("%Y-%m-%d")
DATE_TO_READ_FROM = "2024-03-22"
BASE_DATA_DIR = Path(f"/Users/sylvi/topo_data/hariborings/extracted_grains/")
data_dirs = {
    "ON_SC": BASE_DATA_DIR / f"cas9_ON_SC/{DATE_TO_READ_FROM}",
    "OT1_SC": BASE_DATA_DIR / f"cas9_OT1_SC/{DATE_TO_READ_FROM}",
    "OT2_SC": BASE_DATA_DIR / f"cas9_OT2_SC/{DATE_TO_READ_FROM}",
}
BASE_SAVE_DIR = Path(f"/Users/sylvi/topo_data/hariborings/processed_grains/")
MAX_PX_TO_NM = 10.0

save_dirs = {
    "ON_SC": BASE_SAVE_DIR / f"cas9_ON_SC/{TODAY_DATE}",
    "OT1_SC": BASE_SAVE_DIR / f"cas9_OT1_SC/{TODAY_DATE}",
    "OT2_SC": BASE_SAVE_DIR / f"cas9_OT2_SC/{TODAY_DATE}",
}
for sample_type in save_dirs.keys():
    save_dirs[sample_type].mkdir(exist_ok=True, parents=True)

sample_types = ["ON_SC", "OT1_SC", "OT2_SC"]
grains_dicts = {}
for sample_type in sample_types:
    grains_dicts[sample_type] = {}
    file_path = data_dirs[sample_type] / "grain_dict.pkl"
    with open(file_path, "rb") as f:
        sample_grains_dicts = pickle.load(f)
        for grain_index, single_grain_dict in sample_grains_dicts.items():
            p_to_nm = single_grain_dict["p_to_nm"]
            if p_to_nm > MAX_PX_TO_NM:
                print(
                    f"Skipping grain {grain_index} for sample type [{sample_type}] due to p_to_nm too large: {p_to_nm} > {MAX_PX_TO_NM}"
                )
            else:
                grains_dicts[sample_type][grain_index] = single_grain_dict


# print num grains for each sample type
for sample_type, grain_dict in grains_dicts.items():
    print(f"Number of grains for sample type [{sample_type}]: {len(grain_dict.keys())}")

In [None]:
def plot_all_grains_dictionary(
    grain_dict: dict,
    num_cols: int = 3,
    vmin: float = VMIN,
    vmax: float = VMAX,
    cmap: mpl.cm = CMAP,
    plot_paths: bool = False,
    save_path: Path | None = None,
) -> None:
    """plot grains in a grain dictioanry in a grid layout"""
    num_grains = len(grain_dict.keys())
    num_rows = num_grains // num_cols + (num_grains % num_cols > 0)
    num_images_per_grain = 2  # image and mask
    fig, ax = plt.subplots(
        nrows=num_rows,
        ncols=num_cols * num_images_per_grain,
        figsize=(num_cols * 4, num_rows * 4),
        constrained_layout=True,
        dpi=200,
    )
    for i, (grain_index, grain_data) in enumerate(grain_dict.items()):
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        grain_p_to_nm = grain_data["p_to_nm"]

        row = i // num_cols
        col = (i % num_cols) * num_images_per_grain

        image_ax = ax[row, col]
        image_ax.imshow(grain_image, vmin=vmin, vmax=vmax, cmap=cmap)
        image_ax.set_title(f"Grain {grain_index}\np_to_nm: {grain_p_to_nm:.2f} nm/px")
        image_ax.axis("off")

        # plot mask
        mask_ax = ax[row, col + 1]
        mask_ax.imshow(grain_mask)
        # plot paths if present and required
        if plot_paths and "path" in grain_data:
            path = grain_data["path"]
            if path is not None:
                mask_ax.plot(path[:, 1], path[:, 0], color="white", linewidth=2)
        mask_ax.set_title("Mask")
        mask_ax.axis("off")

    if save_path is not None:
        # create the parent directory
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, dpi=300)

    plt.show()

In [None]:
# plot all grains
for sample_type, grain_dict in grains_dicts.items():
    print(f"Plotting grains for sample type [{sample_type}]")
    plot_all_grains_dictionary(grain_dict, vmin=VMIN, vmax=VMAX, cmap=CMAP)
    plt.savefig(save_dirs[sample_type] / f"grains_{sample_type}.png", dpi=200)
    plt.close()

In [None]:
# pathfinding
grains_dicts_paths = {}
for sample_type, grains_dict_sample in grains_dicts.items():
    grains_dicts_paths[sample_type] = {}
    for grain_index, grain_data in grains_dict_sample.items():
        grain_image = grain_data["image"]
        grain_mask = grain_data["predicted_mask"]
        intersection_labels = grain_data["intersection_labels"]

        # distance transforms
        distance_transform = distance_transform_edt(grain_mask > 0)
        distance_transform[grain_mask == 2] = 0

        # start at point where intersection region 0 has maximum distance transform
        intersection_labels = label(intersection_labels)
        intersection_regions = regionprops(intersection_labels)
        # Get intersection regions
        region_0 = intersection_regions[0]
        region_1 = intersection_regions[1]
        # Create list of distance transform values for each region
        region_0_distance_transform_values = []
        region_1_distance_transform_values = []
        for pixel in region_0.coords:
            region_0_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
        for pixel in region_1.coords:
            region_1_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
        region_0_distance_transform_values = np.array(region_0_distance_transform_values)
        region_1_distance_transform_values = np.array(region_1_distance_transform_values)
        # Get the maximum distance transform value for each region
        region_0_max_distance_transform_value = np.max(region_0_distance_transform_values)
        region_1_max_distance_transform_value = np.max(region_1_distance_transform_values)
        region_0_max_distance_transform_value_index = np.argmax(region_0_distance_transform_values)
        region_1_max_distance_transform_value_index = np.argmax(region_1_distance_transform_values)
        # get the pixel coord of the maximum distance transform value for each region
        region_0_max_distance_transform_pixel = region_0.coords[region_0_max_distance_transform_value_index]
        region_1_max_distance_transform_pixel = region_1.coords[region_1_max_distance_transform_value_index]

        start_point = (region_0_max_distance_transform_pixel[0], region_0_max_distance_transform_pixel[1])
        end_point = (region_1_max_distance_transform_pixel[0], region_1_max_distance_transform_pixel[1])

        # invert the distance transform to get the cost map
        cost_map = np.max(distance_transform) - distance_transform
        # set the maximum value to be huge
        cost_map[cost_map == np.max(cost_map)] = 1e4

        route, weight = route_through_array(array=cost_map, start=start_point, end=end_point)

        route = np.array(route)

        # save the original grain data
        grains_dicts_paths[sample_type][grain_index] = grain_data
        # add the route to the grain data
        grains_dicts_paths[sample_type][grain_index]["path"] = route

# plot all grains with paths
for sample_type, grain_dict in grains_dicts_paths.items():
    print(f"Plotting grains with paths for sample type [{sample_type}]")
    plot_all_grains_dictionary(
        grain_dict=grain_dict,
        vmin=VMIN,
        vmax=VMAX,
        cmap=CMAP,
        plot_paths=True,
        save_path=save_dirs[sample_type] / f"grains_with_paths_{sample_type}.png",
    )