In [None]:
from pathlib import Path
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import splev, splprep
from mpl_toolkits.axes_grid1 import make_axes_locatable

from topostats.plottingfuncs import Colormap

cmap = Colormap()
cmap = cmap.get_cmap()

In [None]:
def _rim_curvature(xs: np.ndarray, ys: np.ndarray):
    """Calculate the curvature of a set of points. Uses the standard curvature definition of the derivative of the
    tangent vector.

    Parameters:
    ----------
    xs: np.ndarray
        One dimensional numpy array of x-coordinates of the points
    ys: np.ndarray
        One dimensional numpy array of y-coordinates of the points
    Returns:
    -------
    np.ndarray
        One-dimensional numpy array of curvatures for the spline.
    """
    extension_length = xs.shape[0]
    xs_extended = np.append(xs, xs)
    xs_extended = np.append(xs_extended, xs)

    # # Plot the extended points
    # plt.plot(xs_extended)
    # plt.title("Extended x")
    # plt.show()

    ys_extended = np.append(ys, ys)
    ys_extended = np.append(ys_extended, ys)
    dx = np.gradient(xs_extended)

    # # Plot the extended dx
    # plt.plot(dx)
    # plt.title("Extended dx")
    # plt.show()

    dy = np.gradient(ys_extended)
    d2x = np.gradient(dx)

    # # Plot the extended d2x
    # plt.plot(d2x)
    # plt.title("Extended d2x")
    # plt.show()

    d2y = np.gradient(dy)
    curv = np.abs(dx * d2y - d2x * dy) / (dx * dx + dy * dy) ** 1.5

    # # Plot the extended curvature
    # plt.plot(curv)
    # plt.title("Extended curvature")
    # plt.show()

    curv = curv[extension_length : (len(curv) - extension_length)]

    return curv


def _interpolate_points_spline(points: np.ndarray, num_points: int, smoothing: float = 0.0):
    """Interpolate a set of points using a spline.

    Parameters
    ----------
    points: np.ndarray
        Nx2 Numpy array of coordinates for the points.
    num_points: int
        The number of points to return following the calculated spline.

    Returns
    -------
    interpolated_points: np.ndarray
        An Ix2 Numpy array of coordinates of the interpolated points, where I is the number of points
        specified.
    """

    x, y = splprep(points.T, u=None, s=smoothing, per=1)
    x_spline = np.linspace(y.min(), y.max(), num_points)
    x_new, y_new = splev(x_spline, x, der=0)
    interpolated_points = np.array((x_new, y_new)).T
    return interpolated_points


def interpolate_spline_and_get_curvature(points: np.ndarray, interpolation_number: int, smoothing: float = 0.0):
    """Calculate the curvature for a set of points in a closed loop. Interpolates the points using a spline
    to reduce anomalies.

    Parameters
    ----------
    points: np.ndarray
        2xN Numpy array of coordinates for the points.
    interpolation_number: int
        Number of interpolation points per point. Eg: for a set of 10 points and 2 interpolation points,
        there will be 20 points in the spline.

    Returns
    -------
    interpolated_curvatures: np.ndarray
        1xN Numpy array of curvatures corresponding to the interpolated points.
    interpolated_points: np.ndarray
        2xN Numpy array of interpolated points generated from the spline of the
        original points.
    """

    # Interpolate the data using cubic splines
    num_points = interpolation_number * points.shape[0]
    interpolated_points = _interpolate_points_spline(points=points, num_points=num_points, smoothing=smoothing)
    x = interpolated_points[:, 0]
    y = interpolated_points[:, 1]

    # Calculate the curvature
    interpolated_curvatures = _rim_curvature(x, y)

    return interpolated_curvatures, interpolated_points


def visualise_curvature_pixel_image(
    curvatures: np.ndarray,
    points: np.ndarray,
    image_size: int = 100,
    title: str = "",
    figsize=(12, 12),
):
    """Visualise the curvature of a set of points using a pixel heightmap image.

    Parameters
    ----------
    curvatures: np.ndarray
        Numpy Nx1 array of curvatures for the points.
    points: np.ndarray
        Numpy Nx2 array of coordinates for the points.

    Returns
    -------
    None
    """

    # Construct a visualisation
    curv_img = np.zeros((image_size, image_size))
    scaling_factor = (curv_img.shape[0] * 1.4) / np.max(points) / 2
    centroid = np.array([np.mean(points[:, 0]), np.mean(points[:, 1])])
    for point, curvature in zip(points, curvatures):
        scaled_point = ((np.array(curv_img.shape) / 2) + (point * scaling_factor) - centroid * scaling_factor).astype(
            int
        )
        curv_img[scaled_point[0], scaled_point[1]] = curvature

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(np.flipud(curv_img.T), cmap="rainbow")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax, orientation="vertical")
    # ax.colorbar()
    ax.set_title(title)
    plt.show()


def visualise_curvature_scatter(curvatures: np.ndarray, points: np.ndarray, title: str = ""):
    """Visualise the curvature of a set of points using a scatter plot with colours of the markers
    representing the curvatures of the points."""

    # Plot the points
    scatter_plot = plt.scatter(points[:, 0], points[:, 1], c=curvatures, cmap="rainbow", s=0.5)
    plt.title(title)
    plt.colorbar(scatter_plot)
    # aspect ratio of 1
    plt.gca().set_aspect("equal", adjustable="box")
    plt.show()


def simple_curvature(points):
    x, y = points.T
    dx_dt = np.gradient(x)
    dy_dt = np.gradient(y)
    d2x_dt2 = np.gradient(dx_dt)
    d2y_dt2 = np.gradient(dy_dt)
    curvature = (d2x_dt2 * dy_dt - dx_dt * d2y_dt2) / (dx_dt**2 + dy_dt**2) ** 1.5
    return curvature

In [None]:
# Create a set of points for a circle
num_points = 200
radius = 5
theta = np.linspace(0, 2 * np.pi, num_points)
x = radius * np.cos(theta)
y = radius * np.sin(theta)
points = np.array((x, y)).T

# Plot the points
plt.scatter(points[:, 0], points[:, 1], s=0.5)
# Set square aspect ratio
plt.gca().set_aspect("equal", adjustable="box")
plt.show()

# Calculate the curvature
curvatures, interpolated_points = interpolate_spline_and_get_curvature(
    points=points, interpolation_number=10, smoothing=10.0
)

# Plot the interpolated points
plt.scatter(interpolated_points[:, 0], interpolated_points[:, 1], s=0.5)
# Set square aspect ratio
plt.gca().set_aspect("equal", adjustable="box")
plt.show()

# Plot the curvature
fig, ax = plt.subplots(figsize=(12, 12))
ax.plot(curvatures)
ax.set_title("Curvature")
plt.show()

In [None]:
from scipy.interpolate import UnivariateSpline
import numpy as np


def calculate_curvature_from_points(x_points, y_points, error=0.1, k=4):
    """Calculate the curvature for a set of points"""
    # Check that the number of points is the same for both x and y
    if x_points.shape[0] != y_points.shape[0]:
        raise ValueError(
            f"x_points and y_points must have the same number of points. x_points has {x_points.shape[0]} points and y_points has {y_points.shape[0]} points."
        )

    # Weight the values so less weight is given to points with higher error
    # K is the order of the spline to use. Increasing this increases the smoothness of the spline. A
    # value of 1 is linear interpolation, 2 is quadratic, 3 is cubic, etc. 4 is used to ensure the
    # spline is smooth enough to differentiate to the second derivative.
    # t is the independent variable that monotically increases with the data, similar to how
    # we use a dummy x variable in plotting calculations.

    t = np.arange(x_points.shape[0])
    weight_values = 1 / np.sqrt(error * np.ones_like(x_points))
    fx = UnivariateSpline(t, x_points, k=k, w=weight_values)
    fy = UnivariateSpline(t, y_points, k=k, w=weight_values)

    spline_x = fx(t)
    spline_y = fy(t)

    dx = fx.derivative(1)(t)
    dx2 = fx.derivative(2)(t)
    dy = fy.derivative(1)(t)
    dy2 = fy.derivative(2)(t)
    curvatures = (dx * dy2 - dy * dx2) / np.power(dx**2 + dy**2, 3 / 2)
    return curvatures, spline_x, spline_y


def calculate_curvature_periodic_boundary(x_points, y_points, error=0.1, periods=2, k=4):
    """Take a set of points that form a loop and calculate the curvature. Uses periodic boundary conditions, so
    the first and last points are connected. This reduces the error in the curvature calculation at the
    boundaries.

    Parameters
    ----------
    x_points: np.ndarray
        1D numpy array of x coordinates of the points.
    y_points: np.ndarray
        1D numpy array of y coordinates of the points.
    error: float
        Error in the points. Used to weight the points in the spline calculation.
    periods: int
        Number of times to repeat the points either side of the original points to reduce the error at the
        boundaries.

    Returns
    -------
    curvature: np.ndarray
        1D numpy array of the curvature for each point.
    """

    # Check that the number of points is the same for both x and y
    if x_points.shape[0] != y_points.shape[0]:
        raise ValueError(
            f"x_points and y_points must have the same number of points. x_points has {x_points.shape[0]} points and y_points has {y_points.shape[0]} points."
        )

    # Repeat the points either side of the original points to reduce the error at the boundaries
    extended_points_x = np.copy(x_points)
    extended_points_y = np.copy(y_points)
    for i in range(periods * 2):
        extended_points_x = np.append(extended_points_x, x_points)
        extended_points_y = np.append(extended_points_y, y_points)

    # Calculate the curvature
    extended_curvature, spline_x, spline_y = calculate_curvature_from_points(
        extended_points_x, extended_points_y, error=error, k=k
    )

    # Return only the original points
    return (
        extended_curvature[x_points.shape[0] * int(periods / 2) : x_points.shape[0] * int((periods / 2) + 1)],
        spline_x[x_points.shape[0] * int(periods / 2) : x_points.shape[0] * int((periods / 2) + 1)],
        spline_y[x_points.shape[0] * int(periods / 2) : x_points.shape[0] * int((periods / 2) + 1)],
    )


def defect_stats(height_trace: np.ndarray, threshold: float):
    regions = []
    in_region = False
    # Find the largest continuous region below the threshold
    for index, value in enumerate(height_trace):
        # if last point is below threshold, add it to the region
        if index == len(height_trace) - 1 and value < threshold:
            regions.append(
                {
                    "start": region_start,
                    "end": index,
                    "area": area,
                    "deepest_point_index": deepest_point_index,
                    "defect_threshold": threshold,
                }
            )
        if value < threshold:
            # print(f"index {index} is below threshold: {height_trace[index]}")
            if not in_region:
                # Start region
                region_start = index
                area = 0
                deepest_point = value
                deepest_point_index = index
                in_region = True
            else:
                # print(f"value {height_trace[index]} is below threshold {threshold} by {threshold - height_trace[index]}")
                area += threshold - value
                if value < deepest_point:
                    deepest_point = value
                    deepest_point_index = index
        elif in_region:
            regions.append(
                {
                    "start": region_start,
                    "end": index,
                    "area": area,
                    "deepest_point_index": deepest_point_index,
                    "defect_threshold": threshold,
                }
            )
            in_region = False

    # Number of defects
    number_of_defects = len(regions)

    # Find the largest region
    largest_region_below_threshold = None
    largest_area = 0
    for region in regions:
        if region["area"] > largest_area:
            largest_area = region["area"]
            largest_region_below_threshold = region

    return {
        "defect_number": number_of_defects,
        "defect_largest_region": largest_region_below_threshold,
        "defect_regions": regions,
    }

In [None]:
# Create a set of points for a circle
num_points = 10
radius = 5
theta = np.linspace(0, 2 * np.pi, num_points)
x = radius * np.cos(theta)
y = radius * np.sin(theta)
points = np.array((x, y)).T

curvatures = calculate_curvature_from_points(x_points=points[:, 0], y_points=points[:, 1])

plt.scatter(x=points[:, 0], y=points[:, 1], s=0.5)
# Set square aspect ratio
plt.gca().set_aspect("equal", adjustable="box")
for i in range(len(points) - 1):
    plt.plot([points[i, 0], points[i + 1, 0]], [points[i, 1], points[i + 1, 1]], c="k")
plt.show()
plt.plot(curvatures)
plt.plot(points[:, 0])
plt.plot(points[:, 1])

In [None]:
DIR = Path("/Users/sylvi/topo_data/hariborings/testing_workflow/output_dna_only_crops_traces_ON_REL/processed/")

data_files = sorted(list(DIR.glob("*_grain_image_trace_info.pkl")))
print(f"num data files: {len(data_files)}")

# for data_file in data_files:
#     # load the dict
#     with open(data_file, "rb") as f:
#         data_dict = pickle.load(f)
#     all_cropped_images = data_dict["all_cropped_images"]
#     all_ordered_traces = data_dict["all_ordered_traces"]
#     all_trace_heights = data_dict["all_trace_heights"]
#     all_fitted_traces = data_dict["all_fitted_traces"]

#     print(f"all cropped images shape: {len(all_cropped_images)}")
#     print(f"all ordered traces shape: {len(all_ordered_traces)}")
#     print(f"all trace heights shape: {len(all_trace_heights)}")
#     print(f"all fitted traces shape: {len(all_fitted_traces)}")

#     print(f'\n=============\n')

#     for index in range(len(all_cropped_images)):

#         print(f"\n ------------\n")

#         cropped_image = all_cropped_images[index]
#         ordered_trace = all_ordered_traces[index]
#         trace_heights = all_trace_heights[index]
#         fitted_trace = all_fitted_traces[index]

#         plt.imshow(cropped_image)
#         plt.plot(ordered_trace[:, 1], ordered_trace[:, 0], c="r")
#         # plt.plot(fitted_trace[:, 1], fitted_trace[:, 0], c="b")


#         plt.show()


#     break

data_file = data_files[0]
# load the dict
with open(data_file, "rb") as f:
    data_dict = pickle.load(f)
all_cropped_images = data_dict["all_cropped_images"]
all_ordered_traces = data_dict["all_ordered_traces"]
all_trace_heights = data_dict["all_trace_heights"]
all_fitted_traces = data_dict["all_fitted_traces"]
px_2_nm = data_dict["px_2_nm"]

cropped_image = all_cropped_images[0]
ordered_trace = all_ordered_traces[0]
trace_heights = all_trace_heights[0]
fitted_trace = all_fitted_traces[0]

trace_length_nm = len(ordered_trace) * px_2_nm
print(f"trace length nm: {trace_length_nm}")

plt.imshow(cropped_image)
plt.plot(ordered_trace[:, 1], ordered_trace[:, 0], c="r")
plt.show()

fig, ax = plt.subplots(figsize=(20, 3))
plt.plot(trace_heights)
ax.set_ylim([0e-9, 4e-9])
ax.set_title("trace heights")
# rescale to be in nm
# multiply each x tick by pixel to nm scaling factor
ax.set_xticklabels([int(x * px_2_nm) for x in ax.get_xticks()])
ax.set_xlabel("position along trace (nm)")
ax.set_ylabel("height (nm)")
plt.show()

# Get the curvature
interpolated_points = _interpolate_points_spline(points=ordered_trace, num_points=1000, smoothing=0.0)
curvature, spline_x, spline_y = calculate_curvature_periodic_boundary(
    x_points=interpolated_points[:, 0],
    y_points=interpolated_points[:, 1],
    error=0.1,
    periods=4,
    k=5,
)

# get curvature defect stats
curvature_defect_threshold = -0.1
curvature_defect_stats = defect_stats(height_trace=curvature, threshold=curvature_defect_threshold)
print(curvature_defect_stats)

largest_defect_start = curvature_defect_stats["defect_largest_region"]["start"]
largest_defect_end = curvature_defect_stats["defect_largest_region"]["end"]
largest_defect_deepest_point_index = curvature_defect_stats["defect_largest_region"]["deepest_point_index"]

# Plot the curvature with the largest defect highlighted
fig, ax = plt.subplots(figsize=(20, 3))
ax.plot(curvature)
# draw line at zero
ax.axhline(y=0, c="k", ls="--")
ax.set_ylim([-0.3, 0.3])
ax.set_title(
    f"curvature with defects (threshold {curvature_defect_threshold}) marked in red regions & major defect marked with dotted red line. "
)
# draw a line at the defect threshold
ax.axhline(y=curvature_defect_threshold, c="g", ls="--", label="defect threshold")
# highlight the largest defect
ax.axvline(x=largest_defect_deepest_point_index, c="r", ls="--", label="primary defect location")
# colour in all defect regions
for region in curvature_defect_stats["defect_regions"]:
    ax.axvspan(region["start"], region["end"], alpha=0.5, color="red", label="defect site")

# rescale to be in nm
num_curvature_points = len(curvature)
num_trace_points = len(trace_heights)
curvature_to_trace_ratio = num_trace_points / num_curvature_points
# multiply each x tick by pixel to nm scaling factor
ax.set_xticklabels([int(x * px_2_nm * curvature_to_trace_ratio) for x in ax.get_xticks()])

ax.set_xlabel("position along trace (nm)")
ax.set_ylabel("curvature (nm^-1)")
ax.legend()
plt.show()

# Plot curvature on the image using the spline points
plt.imshow(cropped_image)
# plot curvature spline points where the colour represents the curvature
plt.scatter(spline_y, spline_x, c=curvature, cmap="rainbow", s=0.5)
plt.colorbar()
plt.show()

In [None]:
DIR = Path("/Users/sylvi/topo_data/hariborings/testing_workflow/output_V2/")

image_file = (
    DIR / "ON_REL" / "processed" / "20220209_126bp_minicircicles_5ng_NiCl2_3mM.0_00023 - Copy_height_thresholded.npy"
)

ordered_traces_file = (
    DIR / "ON_REL" / "processed" / "20220209_126bp_minicircicles_5ng_NiCl2_3mM.0_00023 - Copy_ordered_traces.pkl"
)

# image_file = (
#     DIR
#     / "OT2_REL"
#     / "processed"
#     / "20231116_minicircle_126bp_OT2_rel_1_5ng_NiCl2_3mM_HEPES_3mM.0_00025_height_thresholded.npy"
# )

# ordered_traces_file = (
#     DIR
#     / "OT2_REL"
#     / "processed"
#     / "20231116_minicircle_126bp_OT2_rel_1_5ng_NiCl2_3mM_HEPES_3mM.0_00025_ordered_traces.pkl"
# )

image = np.load(image_file)

fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(image, cmap=cmap, vmin=0, vmax=3)
plt.show()


with open(ordered_traces_file, "rb") as f:
    traces = pickle.load(f)

for trace_number, trace in enumerate(traces):
    if trace is not None:
        max_coord_x = max([x for x, y in trace])
        max_coord_y = max([y for x, y in trace])
        trace_mask = np.zeros((max_coord_y + 1, max_coord_x + 1))
        for x, y in trace:
            trace_mask[y, x] = 1
        print(trace_number)

        # Plot the trace outline
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.imshow(trace_mask, vmin=0, vmax=3)
        ax.plot([x for x, y in trace], [y for x, y in trace], c="k")
        ax.set_title(f"Trace {trace_number}")
        plt.show()

        # # Get curvature
        # interpolated_curvatures, interpolated_points = interpolate_spline_and_get_curvature(
        #     np.array(trace), interpolation_number=10, smoothing=10.0
        # )
        interpolated_curvatures, interpolated_points = interpolate_spline_and_get_curvature(
            np.array(trace), interpolation_number=10, smoothing=5.0
        )

        # # curvatures = simple_curvature(np.array(trace))

        # # Visualise curvature
        # visualise_curvature_pixel_image(
        #     curvatures=interpolated_curvatures,
        #     points=interpolated_points,
        #     image_size=100,
        #     title=f"Trace {trace_number}",
        # )

        # visualise_curvature_scatter(
        #     curvatures=interpolated_curvatures,
        #     points=interpolated_points,
        #     title=f"Trace {trace_number}",
        # )

        # # Plot the curvatures
        # fig, ax = plt.subplots(figsize=(15, 4))
        # ax.plot(interpolated_curvatures[5:-5], linewidth=3)
        # ax.set_title(f"Curvature")
        # # set title font size
        # ax.title.set_size(30)
        # # make ticks large
        # ax.tick_params(axis="both", which="major", labelsize=20)
        # # set the limits
        # ax.set_xlim([0, len(interpolated_curvatures)])
        # ax.set_ylim([0, 1])
        # plt.show()

        # Plot the alternative curvature method
        # alt_curvatures = curvature_splines(x=interpolated_points[:, 0], y=interpolated_points[:, 1])
        # fig, ax = plt.subplots(figsize=(15, 4))
        # ax.plot(alt_curvatures, linewidth=3)
        # ax.set_title(f"alt Curvature")
        # plt.show()

        # extended_interpolated_points = np.append(interpolated_points, interpolated_points, axis=0)
        # extended_interpolated_points = np.append(extended_interpolated_points, interpolated_points, axis=0)
        # extended_interpolated_points = np.append(extended_interpolated_points, interpolated_points, axis=0)
        # extended_interpolated_points = np.append(extended_interpolated_points, interpolated_points, axis=0)
        # extended_alt_curvatures = calculate_curvature_from_points(
        #     x_points=extended_interpolated_points[:, 0], y_points=extended_interpolated_points[:, 1]
        # )
        # alt_curvatures = extended_alt_curvatures[len(interpolated_points) * 2 : 3 * len(interpolated_points)]

        # # plot extended curvature
        # fig, ax = plt.subplots(figsize=(15, 4))
        # ax.plot(extended_alt_curvatures, linewidth=3)
        # # add vertical bars to show where the original points are
        # for i in range(0, len(interpolated_points) * 5, len(interpolated_points)):
        #     ax.axvline(x=i, color="k", linestyle="--")
        # ax.set_title(f"extended alt Curvature")

        # fig, ax = plt.subplots(figsize=(15, 4))
        # ax.plot(alt_curvatures, linewidth=3)
        # ax.set_title(f"alt Curvature")
        # plt.show()

        curvatures, spline_x, spline_y = calculate_curvature_periodic_boundary(
            x_points=trace[:, 0],
            y_points=trace[:, 1],
            error=0.08,
            periods=2,
            k=5,
        )

        # Plot the curvatures
        fig, ax = plt.subplots(figsize=(15, 4))
        ax.plot(curvatures, linewidth=3)
        ax.set_title(f"Curvature")
        # set title font size
        ax.title.set_size(30)
        # make ticks large
        ax.tick_params(axis="both", which="major", labelsize=20)
        # set the limits
        ax.set_xlim([0, len(curvatures)])
        plt.show()

        # Plot interpolated points
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.imshow(trace_mask, vmin=0, vmax=3)
        ax.scatter(
            interpolated_points[:, 0],
            interpolated_points[:, 1],
            s=0.5,
            c="green",
            label="interpolated points",
        )
        ax.scatter(spline_x, spline_y, s=0.5, c="orange", label="alt spline")
        ax.set_title(f"Interpolated points")
        plt.show()

        plt.imshow(trace_mask)
        plt.show()