In [None]:
from pathlib import Path
import re
import pickle

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage import distance_transform_edt
from skimage.measure import label, regionprops
from skimage.graph import route_through_array
from skimage.morphology import binary_dilation, binary_erosion

from topostats.plottingfuncs import Colormap

colormap = Colormap()
CMAP = colormap.get_cmap()

VMIN = 0
VMAX = 4

In [None]:
SAMPLE_TYPE = "ON_SC"
DATE = "2024-03-22"
DATA_DIR = Path(f"/users/sylvi/topo_data/hariborings/extracted_grains/cas9_{SAMPLE_TYPE}/{DATE}")
MAX_PX_TO_NM = 10.0
PLOT_RESULTS = True
MAX_GRAIN_NUMBER = 100

FILE_PATH = DATA_DIR / "grain_dict.pkl"
with open(FILE_PATH, "rb") as f:
    grain_dicts = pickle.load(f)

print(f"number of images for sample type [{SAMPLE_TYPE}] : {len(grain_dicts.keys())}")

# Cut off the number of grains
grain_dicts_sample = {}
for i, (key, grain_dictionary) in enumerate(grain_dicts.items()):
    if i > MAX_GRAIN_NUMBER:
        break
    grain_dicts_sample[key] = grain_dictionary

print(f"number of images in sample for {SAMPLE_TYPE} : {len(grain_dicts_sample.keys())}")

grain_dict = grain_dicts_sample

In [None]:
def plot_images(
    images: list, masks: list, grain_indexes: list, px_to_nms: list, width=5, VMIN=VMIN, VMAX=VMAX, cmap=CMAP
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm) in enumerate(zip(images, masks, grain_indexes, px_to_nms)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


if PLOT_RESULTS:
    images = [grain_dicts[i]["image"] for i in grain_dicts]
    masks = [grain_dicts[i]["predicted_mask"] for i in grain_dicts]
    grain_indexes = [i for i in grain_dicts]
    px_to_nms = [grain_dicts[i]["p_to_nm"] for i in grain_dicts]
    plot_images(images, masks, grain_indexes, px_to_nms)

# Pathfinding

In [None]:
def flip_if_anticlockwise(trace: np.ndarray):
    # Check if the trace is clockwise or anticlockwise by summing the cross products of the vectors
    # If the sum is positive, the trace is clockwise
    # If the sum is negative, the trace is anticlockwise
    # If the sum is 0, the trace is a straight line
    cross_sum = 0
    for i in range(1, len(trace) - 1):
        v1 = trace[i, :] - trace[i - 1, :]
        v2 = trace[i + 1, :] - trace[i, :]
        v1 /= np.linalg.norm(v1)
        v2 /= np.linalg.norm(v2)

        # Get cross product
        cross_prod = np.cross(v1, v2)
        cross_sum += cross_prod

    if cross_sum > 0:
        # print("clockwise")
        # Reverse the trace
        trace = np.flip(trace, axis=0)
    elif cross_sum < 0:
        # print("anticlockwise")
        pass

    return trace


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    paths: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm, path) in enumerate(zip(images, masks, grain_indexes, px_to_nms, paths)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.plot(path[:, 1], path[:, 0], color="white")
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


grain_dict_paths = {}

for index, grain_dict in grain_dicts.items():
    image = grain_dict["image"]
    mask = grain_dict["predicted_mask"]
    p_to_nm = grain_dict["p_to_nm"]
    intersection_labels = grain_dict["intersection_labels"]

    # Get distance transform
    distance_transform = distance_transform_edt(mask > 0)
    distance_transform[mask == 2] = 0

    # Starting at the point where intersection region 0 have maximum distance transform
    intersection_labels = label(intersection_labels)
    intersection_regions = regionprops(intersection_labels)
    region_0 = intersection_regions[0]
    region_1 = intersection_regions[1]
    region_0_distance_transform_values = []
    region_1_distance_transform_values = []
    for pixel in region_0.coords:
        region_0_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    for pixel in region_1.coords:
        region_1_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    region_0_distance_transform_values = np.array(region_0_distance_transform_values)
    region_1_distance_transform_values = np.array(region_1_distance_transform_values)
    region_0_max_distance_transform_value = np.max(region_0_distance_transform_values)
    region_1_max_distance_transform_value = np.max(region_1_distance_transform_values)
    region_0_max_distance_transform_value_index = np.argmax(region_0_distance_transform_values)
    region_1_max_distance_transform_value_index = np.argmax(region_1_distance_transform_values)
    region_0_max_distance_transform_value_pixel = region_0.coords[region_0_max_distance_transform_value_index]
    region_1_max_distance_transform_value_pixel = region_1.coords[region_1_max_distance_transform_value_index]

    start_point = (region_0_max_distance_transform_value_pixel[0], region_0_max_distance_transform_value_pixel[1])
    end_point = (region_1_max_distance_transform_value_pixel[0], region_1_max_distance_transform_value_pixel[1])

    inverted_distance_transform = np.max(distance_transform) - distance_transform
    inverted_distance_transform[inverted_distance_transform == np.max(inverted_distance_transform)] = 1000

    route, weight = route_through_array(inverted_distance_transform, start_point, end_point)
    route = np.array(route)

    route = flip_if_anticlockwise(route.astype(float))

    grain_dict_paths[index] = grain_dict
    grain_dict_paths[index]["path"] = route

# plot_images(
#     [grain_dict_paths[i]["image"] for i in grain_dict_paths],
#     [grain_dict_paths[i]["predicted_mask"] for i in grain_dict_paths],
#     [i for i in grain_dict_paths],
#     [grain_dict_paths[i]["p_to_nm"] for i in grain_dict_paths],
#     [grain_dict_paths[i]["path"] for i in grain_dict_paths],
# )

# Pooling

In [None]:
# Pool the traces

pooled_path_grain_dict = {}
# Number of points to pool
n = 6
for index, grain_dict in grain_dict_paths.items():
    path = grain_dict["path"]

    pooled_path = []

    for i in range(n // 2, len(path) - n // 2):
        binned_points = []
        for j in range(-n // 2, n // 2):
            if i + j < len(path):
                binned_points.append(path[i + j])
            else:
                break

        pooled_path.append(np.mean(binned_points, axis=0))
    pooled_path = np.array(pooled_path)

    # Add start and end points to the path
    pooled_path = np.concatenate([np.array([path[0]]), pooled_path, np.array([path[-1]])])

    pooled_path = flip_if_anticlockwise(pooled_path)

    pooled_path_grain_dict[index] = grain_dict
    pooled_path_grain_dict[index]["pooled_path"] = pooled_path

In [None]:
# Calculate binding angle using the vector of the last n nanometres of the path


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    paths: list,
    binding_vector_positions: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm, path, binding_vector_position_tuple) in enumerate(
        zip(images, masks, grain_indexes, px_to_nms, paths, binding_vector_positions)
    ):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.plot(path[:, 1], path[:, 0], color="white")
        mask_ax.plot([])
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


binding_angle_grain_dict = {}

for index, grain_dict in pooled_path_grain_dict.items():
    pooled_path = grain_dict["pooled_path"]
    p_to_nm = grain_dict["p_to_nm"]

    n = 2

    # Get start vector by finding the number of pooled points that make up n nanometres
    start_vector_start = pooled_path[0]
    distance = 0
    proposed_index = 0
    while distance < n:
        proposed_index += 1
        proposed_start_vector_end = pooled_path[proposed_index]
        distance += np.linalg.norm(pooled_path[proposed_index] - pooled_path[proposed_index - 1]) * p_to_nm
    start_vector_end = pooled_path[proposed_index]
    start_vector = start_vector_end - start_vector_start

    # Get end vector by finding the number of pooled points that make up n nanometres
    end_vector_start = pooled_path[-1]
    distance = 0
    proposed_index = len(pooled_path) - 1
    while distance < n:
        proposed_index -= 1
        proposed_end_vector_end = pooled_path[proposed_index]
        distance += np.linalg.norm(pooled_path[proposed_index] - pooled_path[proposed_index + 1]) * p_to_nm
    end_vector_end = pooled_path[proposed_index]
    end_vector = end_vector_end - end_vector_start

    end_angle = np.arctan2(end_vector[0], end_vector[1])
    if end_angle < 0:
        end_angle += np.pi * 2
    start_angle = np.arctan2(start_vector[0], start_vector[1])
    if start_angle < 0:
        start_angle += np.pi * 2

    angle_diff = end_angle - start_angle

    # Calculate angle difference
    plt.imshow(grain_dict["image"], cmap=CMAP, vmin=VMIN, vmax=VMAX)
    plt.scatter(start_vector_start[1], start_vector_start[0], c="blue", s=100)
    plt.scatter(end_vector_start[1], end_vector_start[0], c="red", s=100)
    plt.plot(pooled_path[:, 1], pooled_path[:, 0], color="white")
    plt.plot([start_vector_start[1], start_vector_end[1]], [start_vector_start[0], start_vector_end[0]], color="blue")
    plt.arrow(
        start_vector_start[1],
        start_vector_start[0],
        start_vector[1],
        start_vector[0],
        color="blue",
        head_width=5,
        head_length=10,
    )
    plt.arrow(
        end_vector_start[1],
        end_vector_start[0],
        end_vector[1],
        end_vector[0],
        color="red",
        head_width=5,
        head_length=10,
    )
    plt.plot([end_vector_start[1], end_vector_end[1]], [end_vector_start[0], end_vector_end[0]], color="red")
    plt.title(
        f"end angle: {np.degrees(end_angle):.2f} start angle: {np.degrees(start_angle):.2f} angle diff: {np.degrees(angle_diff):.2f}"
    )
    plt.show()

In [None]:
# Characterize the binding angle through intersection region position around perimeter

perimeter_intersection_grain_dict = {}

for index, grain_dict in .items():

    image = grain_dict["image"]
    mask = grain_dict["predicted_mask"]
    p_to_nm = grain_dict["p_to_nm"]
    pooled_path = grain_dict["pooled_path"]
    intersection_labels = grain_dict["intersection_labels"]

    # Get perimeter of the gem region
    gem_region = mask == 2
    gem_outline = binary_dilation(gem_region) ^ gem_region

    gem_centroid =

    gem_region_with_intersections = np.copy(gem_outline).astype(int)
    gem_region_with_intersections[intersection_labels > 0] = 2

    plt.imshow(gem_region_with_intersections)

    break