In [None]:
from pathlib import Path
import re
import pickle

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage import distance_transform_edt, gaussian_filter1d, binary_fill_holes
from skimage.measure import label, regionprops
from skimage.graph import route_through_array
from skimage.morphology import binary_dilation, binary_erosion, skeletonize

from topostats.plottingfuncs import Colormap

colormap = Colormap()
CMAP = colormap.get_cmap()

VMIN = 0
VMAX = 4

In [None]:
SAMPLE_TYPE = "ON_SC"
DATE = "2024-03-22"
DATA_DIR = Path(f"/users/sylvi/topo_data/hariborings/extracted_grains/cas9_{SAMPLE_TYPE}/{DATE}")
MAX_PX_TO_NM = 10.0
PLOT_RESULTS = True
MAX_GRAIN_NUMBER = 100

FILE_PATH = DATA_DIR / "grain_dict.pkl"
with open(FILE_PATH, "rb") as f:
    grain_dicts = pickle.load(f)

print(f"number of images for sample type [{SAMPLE_TYPE}] : {len(grain_dicts.keys())}")

# Cut off the number of grains
grain_dicts_sample = {}
for i, (key, grain_dictionary) in enumerate(grain_dicts.items()):
    if i > MAX_GRAIN_NUMBER:
        break
    grain_dicts_sample[key] = grain_dictionary

print(f"number of images in sample for {SAMPLE_TYPE} : {len(grain_dicts_sample.keys())}")

grain_dict = grain_dicts_sample

In [None]:
def plot_images(
    images: list, masks: list, grain_indexes: list, px_to_nms: list, width=5, VMIN=VMIN, VMAX=VMAX, cmap=CMAP
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm) in enumerate(zip(images, masks, grain_indexes, px_to_nms)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


if PLOT_RESULTS:
    images = [grain_dicts[i]["image"] for i in grain_dicts]
    masks = [grain_dicts[i]["predicted_mask"] for i in grain_dicts]
    grain_indexes = [i for i in grain_dicts]
    px_to_nms = [grain_dicts[i]["p_to_nm"] for i in grain_dicts]
    plot_images(images, masks, grain_indexes, px_to_nms)

# Pathfinding

In [None]:
def flip_if_anticlockwise(trace: np.ndarray):
    # Check if the trace is clockwise or anticlockwise by summing the cross products of the vectors
    # If the sum is positive, the trace is clockwise
    # If the sum is negative, the trace is anticlockwise
    # If the sum is 0, the trace is a straight line
    cross_sum = 0
    for i in range(1, len(trace) - 1):
        v1 = trace[i, :] - trace[i - 1, :]
        v2 = trace[i + 1, :] - trace[i, :]
        v1 /= np.linalg.norm(v1)
        v2 /= np.linalg.norm(v2)

        # Get cross product
        cross_prod = np.cross(v1, v2)
        cross_sum += cross_prod

    if cross_sum > 0:
        # print("clockwise")
        # Reverse the trace
        trace = np.flip(trace, axis=0)
    elif cross_sum < 0:
        # print("anticlockwise")
        pass

    return trace


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    paths: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm, path) in enumerate(zip(images, masks, grain_indexes, px_to_nms, paths)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.plot(path[:, 1], path[:, 0], color="white")
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


grain_dict_paths = {}

for index, grain_dict in grain_dicts.items():
    image = grain_dict["image"]
    mask = grain_dict["predicted_mask"]
    p_to_nm = grain_dict["p_to_nm"]
    intersection_labels = grain_dict["intersection_labels"]

    # Get distance transform
    distance_transform = distance_transform_edt(mask > 0)
    distance_transform[mask == 2] = 0

    # Starting at the point where intersection region 0 have maximum distance transform
    intersection_labels = label(intersection_labels)
    intersection_regions = regionprops(intersection_labels)
    region_0 = intersection_regions[0]
    region_1 = intersection_regions[1]
    region_0_distance_transform_values = []
    region_1_distance_transform_values = []
    for pixel in region_0.coords:
        region_0_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    for pixel in region_1.coords:
        region_1_distance_transform_values.append(distance_transform[pixel[0], pixel[1]])
    region_0_distance_transform_values = np.array(region_0_distance_transform_values)
    region_1_distance_transform_values = np.array(region_1_distance_transform_values)
    region_0_max_distance_transform_value = np.max(region_0_distance_transform_values)
    region_1_max_distance_transform_value = np.max(region_1_distance_transform_values)
    region_0_max_distance_transform_value_index = np.argmax(region_0_distance_transform_values)
    region_1_max_distance_transform_value_index = np.argmax(region_1_distance_transform_values)
    region_0_max_distance_transform_value_pixel = region_0.coords[region_0_max_distance_transform_value_index]
    region_1_max_distance_transform_value_pixel = region_1.coords[region_1_max_distance_transform_value_index]

    start_point = (region_0_max_distance_transform_value_pixel[0], region_0_max_distance_transform_value_pixel[1])
    end_point = (region_1_max_distance_transform_value_pixel[0], region_1_max_distance_transform_value_pixel[1])

    inverted_distance_transform = np.max(distance_transform) - distance_transform
    inverted_distance_transform[inverted_distance_transform == np.max(inverted_distance_transform)] = 1000

    route, weight = route_through_array(inverted_distance_transform, start_point, end_point)
    route = np.array(route)

    route = flip_if_anticlockwise(route.astype(float))

    grain_dict_paths[index] = grain_dict
    grain_dict_paths[index]["path"] = route

plot_images(
    [grain_dict_paths[i]["image"] for i in grain_dict_paths],
    [grain_dict_paths[i]["predicted_mask"] for i in grain_dict_paths],
    [i for i in grain_dict_paths],
    [grain_dict_paths[i]["p_to_nm"] for i in grain_dict_paths],
    [grain_dict_paths[i]["path"] for i in grain_dict_paths],
)

# Pooling

In [None]:
# Pool the traces


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    paths: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm, path) in enumerate(zip(images, masks, grain_indexes, px_to_nms, paths)):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.plot(path[:, 1], path[:, 0], color="white")
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


pooled_path_grain_dict = {}
# Number of points to pool
n = 6
for index, grain_dict in grain_dict_paths.items():
    path = grain_dict["path"]

    pooled_path = []

    for i in range(n // 2, len(path) - n // 2):
        binned_points = []
        for j in range(-n // 2, n // 2):
            if i + j < len(path):
                binned_points.append(path[i + j])
            else:
                break

        pooled_path.append(np.mean(binned_points, axis=0))
    pooled_path = np.array(pooled_path)

    # Add start and end points to the path
    pooled_path = np.concatenate([np.array([path[0]]), pooled_path, np.array([path[-1]])])

    pooled_path = flip_if_anticlockwise(pooled_path)

    pooled_path_grain_dict[index] = grain_dict
    pooled_path_grain_dict[index]["pooled_path"] = pooled_path

plot_images(
    [pooled_path_grain_dict[i]["image"] for i in pooled_path_grain_dict],
    [pooled_path_grain_dict[i]["predicted_mask"] for i in pooled_path_grain_dict],
    [i for i in pooled_path_grain_dict],
    [pooled_path_grain_dict[i]["p_to_nm"] for i in pooled_path_grain_dict],
    [pooled_path_grain_dict[i]["pooled_path"] for i in pooled_path_grain_dict],
)

In [None]:
# Calculate binding angle using the vector of the last n nanometres of the path


def angle_diff_signed(v1: np.ndarray, v2: np.ndarray):
    """Calculate the signed angle difference between two vectors.

    Parameters
    ----------
    v1: np.ndarray
        The first vector.
    v2: np.ndarray
        The second vector.

    Returns
    -------
    float
        The signed angle difference between the two vectors.
    """

    # Calculate if the new vector is clockwise or anticlockwise from the old vector

    # Calculate the angle between the vectors
    angle = np.arccos(np.clip(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)), -1.0, 1.0))

    # Calculate the cross product
    cross = np.cross(v1, v2)

    # If the cross product is positive, the new vector is clockwise from the old vector
    if cross > 0:
        angle = -angle

    return angle


def angle_per_nm_non_closed(trace: np.ndarray, p_to_nm: float, plot: bool = False) -> np.ndarray:
    """Calculate the angle per nm of a non-closed trace.

    Parameters
    ----------
    trace: np.ndarray
        The trace to calculate the angle per nm of.
    p_to_nm: float
        The pixel to nm scaling factor.

    Returns
    -------
    np.ndarray
        The angle change per nm for each point in the trace
    """

    if plot:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    angles_per_nm = np.zeros(len(trace))
    angle_diffs = np.zeros(len(trace))

    for index, point in enumerate(trace):
        # print(f"index: {index}")
        if plot:
            ax.scatter(point[1], point[0], c="purple", s=20)

        # Get the vectors to the previous and next points
        if index == 0:
            v_prev = point - trace[-1]
            v_next = trace[index + 1] - point
        if index == len(trace) - 1:
            v_prev = point - trace[index - 1]
            v_next = trace[0] - point
        else:
            v_prev = point - trace[index - 1]
            v_next = trace[index + 1] - point

        # print(f"vprev: {v_prev} vnext: {v_next}")

        # Normalise the vectors to unit length
        norm_v_prev = v_prev / np.linalg.norm(v_prev) * 0.1
        norm_v_next = v_next / np.linalg.norm(v_next) * 0.1

        if index != 0 and index != len(trace) - 1:
            angle = angle_diff_signed(v_prev, v_next)
        else:
            angle = 0

        if plot:
            # Plot the vectors
            ax.arrow(
                point[1], point[0], norm_v_prev[1], norm_v_prev[0], head_width=0.01, head_length=0.2, fc="r", ec="r"
            )
            ax.arrow(
                point[1], point[0], norm_v_next[1], norm_v_next[0], head_width=0.01, head_length=0.2, fc="b", ec="b"
            )
            # Write text for the angle
            ax.text(point[1], point[0], f"{np.degrees(angle):.2f}", fontsize=12, color="black")

        distance = np.linalg.norm(v_prev) * p_to_nm

        # print(f"distance: {distance:.4f} angle: {angle:.4f} angle per nm: {angle / distance:.4f}")

        angles_per_nm[index] = angle / distance
        angle_diffs[index] = angle

    if plot:
        plt.plot(trace[:, 1], trace[:, 0], "k")
        plt.show()

    return angles_per_nm, angle_diffs


def plot_images(
    images: list,
    masks: list,
    grain_indexes: list,
    px_to_nms: list,
    paths: list,
    binding_vector_positions: list,
    width=5,
    VMIN=VMIN,
    VMAX=VMAX,
    cmap=CMAP,
):
    num_images = len(images)
    num_rows = num_images // width + 1
    num_images_in_batch = 2
    fig, axes = plt.subplots(num_rows, width * num_images_in_batch, figsize=(width * 4, num_rows * 4))
    for i, (image, mask, grain_index, p_to_nm, path, binding_vector_position_tuple) in enumerate(
        zip(images, masks, grain_indexes, px_to_nms, paths, binding_vector_positions)
    ):
        # Plot image
        im_ax = axes[i // width, i % width * num_images_in_batch]
        im_ax.imshow(image, cmap=CMAP, vmin=VMIN, vmax=VMAX)
        im_ax.set_title(f"Grain {grain_index} {p_to_nm} p/nm")
        im_ax.axis("off")
        # Plot mask
        mask_ax = axes[i // width, i % width * num_images_in_batch + 1]
        mask_ax.imshow(mask.astype(int))
        mask_ax.plot(path[:, 1], path[:, 0], color="white")
        mask_ax.plot([])
        mask_ax.axis("off")

    fig.tight_layout()
    plt.show()


binding_angle_grain_dict = {}

for index, grain_dict in pooled_path_grain_dict.items():
    image = grain_dict["image"]
    pooled_path = grain_dict["pooled_path"]
    predicted_mask = grain_dict["predicted_mask"]
    p_to_nm = grain_dict["p_to_nm"]
    intersection_labels = grain_dict["intersection_labels"]

    # Sample points every n nm
    nm_sample_distance = 1
    nm_sampled_path = []
    pixel_sample_distance = nm_sample_distance / p_to_nm
    for i in range(len(pooled_path)):
        if i == 0:
            nm_sampled_path.append(pooled_path[i])
        else:
            last_index_nm_sampled_path = len(nm_sampled_path) - 1
            if np.linalg.norm(pooled_path[i] - nm_sampled_path[-1]) > pixel_sample_distance:
                nm_sampled_path.append(pooled_path[i])
    nm_sampled_path = np.array(nm_sampled_path)

    # Get angle diffs
    angles_per_nm, angle_diffs = angle_per_nm_non_closed(nm_sampled_path, p_to_nm, plot=False)

    # Smooth the angles per nm
    angles_per_nm = gaussian_filter1d(angles_per_nm, 3)

    total_angle_diff_over_path = np.sum(angle_diffs)

    angle_diff_between_start_end = total_angle_diff_over_path - np.pi

    # plt.imshow(image)
    # plt.plot(nm_sampled_path[:, 1], nm_sampled_path[:, 0], color="orange")
    # plt.title(f"{np.degrees(angle_diff_between_start_end):.2f}")
    # plt.show()

    binding_angle_grain_dict[index] = grain_dict
    binding_angle_grain_dict[index]["binding_angle"] = angle_diff_between_start_end

# Plot binding angle
binding_angles = np.array([binding_angle_grain_dict[i]["binding_angle"] for i in binding_angle_grain_dict])
sns.kdeplot(binding_angles)
plt.title(f"Binding angles for {SAMPLE_TYPE} (n = {len(binding_angles)})")

In [None]:
# Get connection regions as a proportion of perimeter

perimeter_intersection_grain_dict = {}

for index, grain_dict in binding_angle_grain_dict.items():
    print(f"\ngrain index: {index}")

    try:
        image = grain_dict["image"]
        mask = grain_dict["predicted_mask"]
        p_to_nm = grain_dict["p_to_nm"]
        pooled_path = grain_dict["pooled_path"]
        intersection_labels = grain_dict["intersection_labels"]

        # Get perimeter of the gem region
        gem_region = mask == 2
        gem_region = binary_fill_holes(gem_region.astype(bool))

        gem_outline = binary_dilation(gem_region) ^ gem_region
        gem_outline[gem_region == 1] = 0

        # Skeletonize to get rid of any rogue corners
        gem_outline = skeletonize(gem_outline, method="lee").astype(bool)
        intersection_labels = label(skeletonize(intersection_labels.astype(bool), method="lee").astype(bool))

        gem_region_with_intersections = np.copy(gem_outline).astype(int)
        gem_region_with_intersections[intersection_labels > 0] = 2

        # Trace the outline
        # Start at a random point
        gem_tracing_hist = np.copy(gem_region_with_intersections)
        start_point = np.array(np.where(gem_tracing_hist == 1))[:, 0]
        print(f"start point: {start_point}")

        point_coords = start_point
        gem_trace = []
        intersection_regions = []
        in_intersection = False
        intersection_start = None
        len_gem_outline = len(np.argwhere(gem_tracing_hist > 0))
        print(f"number of > 0 pixels in gem tracing hist : {len_gem_outline}")

        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        ax.imshow(gem_tracing_hist)
        plt.show()

        for tracing_index in range(len_gem_outline - 1):
            # print(f"\ntracing index: {tracing_index}")
            # Add point to trace
            gem_trace.append(point_coords)

            # Get neighbours
            neighbours = gem_tracing_hist[
                point_coords[0] - 1 : point_coords[0] + 2, point_coords[1] - 1 : point_coords[1] + 2
            ]
            gem_tracing_hist[point_coords[0], point_coords[1]] = -1
            # print(neighbours)
            # Grab an index of a non zero point in neighbourhood
            nonzero_neighbours = np.argwhere(neighbours > 0)
            # print(f"nonzero neighbours: {nonzero_neighbours}")

            if tracing_index == 0:
                point_relative_coords = nonzero_neighbours[0]
            else:
                if len(nonzero_neighbours) > 1:
                    raise ValueError(
                        f"Point {point_coords} has more than one neighbour: {neighbours} {nonzero_neighbours}"
                    )

                point_relative_coords = nonzero_neighbours[0]

            point_relative_coords -= 1

            # print(f"point relative coords: {point_relative_coords}")

            point_coords = point_coords + point_relative_coords
            # print(f"new point coords: {point_coords}")
            point_value = gem_tracing_hist[point_coords[0], point_coords[1]]
            # print(f"point value: {point_value}")

            if point_value == 2 and not in_intersection:
                in_intersection = True
                intersection_start = tracing_index + 1
            elif point_value == 2:
                pass
            elif point_value == 1 and in_intersection:
                intersection_end = tracing_index - 1
                intersection_regions.append(
                    {
                        "start": intersection_start,
                        "end": intersection_end,
                        "midpoint": intersection_start + (intersection_end - intersection_start) // 2,
                    }
                )
                in_intersection = False

        if in_intersection:
            intersection_end = tracing_index - 1
            intersection_regions.append(
                {
                    "start": intersection_start,
                    "end": intersection_end,
                    "midpoint": intersection_start + (intersection_end - intersection_start) // 2,
                }
            )
            in_intersection = False

        if len(intersection_regions) < 2:
            raise ValueError(f"Only one intersection region {intersection_regions}")
        if len(intersection_regions) > 2:
            raise ValueError(f"More than two intersection regions: {intersection_regions}")

        print(f"intersection regions: {intersection_regions}")

        # Distances between intersection midpoints
        # Inner distance (where the end of the trace is not encountered)
        inner_distance = intersection_regions[1]["midpoint"] - intersection_regions[0]["midpoint"]
        outer_distance = len(gem_trace) - inner_distance
        print(f"inner distance: {inner_distance}")
        print(f"outer distance: {outer_distance}")
        midpoint_distance_ratio = np.min([inner_distance, outer_distance]) / np.max([inner_distance, outer_distance])

        # plt.imshow(gem_region_with_intersections)
        # Create colour plot
        coloured_gem_trace = np.zeros_like(image)
        for gem_trace_index, gem_trace_coord in enumerate(gem_trace):
            coloured_gem_trace[gem_trace_coord[0], gem_trace_coord[1]] = gem_trace_index

        # plt.imshow(coloured_gem_trace)
        # plt.scatter(start_point[1], start_point[0], c="white", s=10)
        # plt.show()

        # print(f"intersection midpoint perimeter position ratio: {midpoint_distance_ratio}")
    except ValueError:
        print(f"Index {index} failed")