In [None]:
from pathlib import Path
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import splev, splprep
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
def _rim_curvature(xs: np.ndarray, ys: np.ndarray):
    """Calculate the curvature of a set of points. Uses the standard curvature definition of the derivative of the
    tangent vector.

    Parameters:
    ----------
    xs: np.ndarray
        One dimensional numpy array of x-coordinates of the points
    ys: np.ndarray
        One dimensional numpy array of y-coordinates of the points
    Returns:
    -------
    np.ndarray
        One-dimensional numpy array of curvatures for the spline.
    """
    extension_length = xs.shape[0]
    xs_extended = np.append(xs, xs)
    xs_extended = np.append(xs_extended, xs)
    ys_extended = np.append(ys, ys)
    ys_extended = np.append(ys_extended, ys)
    dx = np.gradient(xs_extended)
    dy = np.gradient(ys_extended)
    d2x = np.gradient(dx)
    d2y = np.gradient(dy)
    curv = np.abs(dx * d2y - d2x * dy) / (dx * dx + dy * dy) ** 1.5
    curv = curv[extension_length : (len(curv) - extension_length)]
    return curv


def _interpolate_points_spline(points: np.ndarray, num_points: int):
    """Interpolate a set of points using a spline.

    Parameters
    ----------
    points: np.ndarray
        Nx2 Numpy array of coordinates for the points.
    num_points: int
        The number of points to return following the calculated spline.

    Returns
    -------
    interpolated_points: np.ndarray
        An Ix2 Numpy array of coordinates of the interpolated points, where I is the number of points
        specified.
    """

    x, y = splprep(points.T, u=None, s=5.0, per=1)
    x_spline = np.linspace(y.min(), y.max(), num_points)
    x_new, y_new = splev(x_spline, x, der=0)
    interpolated_points = np.array((x_new, y_new)).T
    return interpolated_points


def interpolate_spline_and_get_curvature(points: np.ndarray, interpolation_number: int):
    """Calculate the curvature for a set of points in a closed loop. Interpolates the points using a spline
    to reduce anomalies.

    Parameters
    ----------
    points: np.ndarray
        2xN Numpy array of coordinates for the points.
    interpolation_number: int
        Number of interpolation points per point. Eg: for a set of 10 points and 2 interpolation points,
        there will be 20 points in the spline.

    Returns
    -------
    interpolated_curvatures: np.ndarray
        1xN Numpy array of curvatures corresponding to the interpolated points.
    interpolated_points: np.ndarray
        2xN Numpy array of interpolated points generated from the spline of the
        original points.
    """

    # Interpolate the data using cubic splines
    num_points = interpolation_number * points.shape[0]
    interpolated_points = _interpolate_points_spline(points=points, num_points=num_points)
    x = interpolated_points[:, 0]
    y = interpolated_points[:, 1]

    # Calculate the curvature
    interpolated_curvatures = _rim_curvature(x, y)

    return interpolated_curvatures, interpolated_points


def visualise_curvature_pixel_image(
    curvatures: np.ndarray,
    points: np.ndarray,
    image_size: int = 100,
    title: str = "",
    figsize=(12, 12),
):
    """Visualise the curvature of a set of points using a pixel heightmap image.

    Parameters
    ----------
    curvatures: np.ndarray
        Numpy Nx1 array of curvatures for the points.
    points: np.ndarray
        Numpy Nx2 array of coordinates for the points.

    Returns
    -------
    None
    """

    # Construct a visualisation
    curv_img = np.zeros((image_size, image_size))
    scaling_factor = (curv_img.shape[0] * 1.4) / np.max(points) / 2
    centroid = np.array([np.mean(points[:, 0]), np.mean(points[:, 1])])
    for point, curvature in zip(points, curvatures):
        scaled_point = ((np.array(curv_img.shape) / 2) + (point * scaling_factor) - centroid * scaling_factor).astype(
            int
        )
        curv_img[scaled_point[0], scaled_point[1]] = curvature

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(np.flipud(curv_img.T), cmap="rainbow")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax, orientation="vertical")
    # ax.colorbar()
    ax.set_title(title)
    plt.show()


def visualise_curvature_scatter(curvatures: np.ndarray, points: np.ndarray, title: str = ""):
    """Visualise the curvature of a set of points using a scatter plot with colours of the markers
    representing the curvatures of the points."""

    # Plot the points
    scatter_plot = plt.scatter(points[:, 0], points[:, 1], c=curvatures, cmap="rainbow", s=0.5)
    plt.title(title)
    plt.colorbar(scatter_plot)
    plt.show()


def simple_curvature(points):
    x, y = points.T
    dx_dt = np.gradient(x)
    dy_dt = np.gradient(y)
    d2x_dt2 = np.gradient(dx_dt)
    d2y_dt2 = np.gradient(dy_dt)
    curvature = (d2x_dt2 * dy_dt - dx_dt * d2y_dt2) / (dx_dt**2 + dy_dt**2) ** 1.5
    return curvature

In [None]:
DIR = Path("/Users/sylvi/topo_data/hariborings/testing_workflow/output_V2/")

ordered_traces_file = (
    DIR / "ON_REL" / "processed" / "20220209_126bp_minicircicles_5ng_NiCl2_3mM.0_00023 - Copy_ordered_traces.pkl"
)

image_file = (
    DIR / "ON_REL" / "processed" / "20220209_126bp_minicircicles_5ng_NiCl2_3mM.0_00023 - Copy_height_thresholded.npy"
)
image = np.load(image_file)

plt.imshow(image)
plt.show()


with open(ordered_traces_file, "rb") as f:
    traces = pickle.load(f)

for trace_number, trace in enumerate(traces):
    if trace is not None:
        max_coord_x = max([x for x, y in trace])
        max_coord_y = max([y for x, y in trace])
        trace_mask = np.zeros((max_coord_y + 1, max_coord_x + 1))
        for x, y in trace:
            trace_mask[y, x] = 1
        print(trace_number)

        # Get curvature
        interpolated_curvatures, interpolated_points = interpolate_spline_and_get_curvature(np.array(trace), 10)

        # curvatures = simple_curvature(np.array(trace))

        # Visualise curvature
        visualise_curvature_pixel_image(
            curvatures=interpolated_curvatures,
            points=interpolated_points,
            image_size=100,
            title=f"Trace {trace_number}",
        )

        # Plot the curvatures
        fig, ax = plt.subplots(figsize=(15, 4))
        ax.plot(interpolated_curvatures)
        ax.set_title(f"Trace {trace_number}")
        plt.show()

        plt.imshow(trace_mask)
        plt.show()