In [None]:
import numpy as np
import math
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from topostats.io import LoadScans
import topostats.filters as topofilters
from topostats.utils import get_mask
from scipy.optimize import least_squares
from scipy.interpolate import splprep, splev
from scipy.spatial import distance
from skimage.filters import gaussian
from skimage.filters import threshold_otsu
from skimage.morphology import binary_erosion
from scipy import ndimage
import scipy.stats as stats
from scipy.spatial.distance import cdist
from skimage.morphology import label
from skimage.draw import disk

from network_stats import (
    network_density,
    interpolate_spline_and_get_curvature,
    network_area,
    polygon_perimeter,
    node_stats,
    network_feret_diameters,
    visualise_curvature_scatter,
    visualise_curvature_pixel_image,
    # signed_distance_to_outline,
)

from topostats.plottingfuncs import Colormap

colormap = Colormap()
cmap = colormap.get_cmap()

In [None]:
def plot(image: np.ndarray, title=None, zlimit=True, cmap=cmap, save_path=None, **kwargs) -> None:
    fig, ax = plt.subplots(figsize=(8, 8))
    if zlimit is True:
        vmin = -3
        vmax = 4
    else:
        vmin = None
        vmax = None
    ax.imshow(image, vmin=vmin, vmax=vmax, cmap=cmap, **kwargs)
    if title is not None:
        ax.set_title(title)

    if save_path is not None:
        plt.savefig(save_path)

    plt.show()

In [None]:
# Flat
# FILE_DIR = Path("/Users/sylvi/topo_data/Bradley/Flat/")
FILE_DIR = Path("/Users/sylvi/topo_data/Bradley/M10Digest/")
# Pristine
# file = Path('./Flat/20221213_KDNA001.0_00020.spm')
# file = Path('./flat/20230118_KPN001.0_00005.spm')
# file = Path('./Flat/20230118_KPN001.0_00011.spm')
# file = Path('./Flat/20230126_KPN005.0_00004.spm')
# file = Path(FILE_DIR / "20230118_KPN001.0_00023.spm")
# file = Path('./flat/20230125_KPN005.0_00006.spm')
# file = Path('./flat/20230125_KPN005.0_00028.spm')
# file = Path('./flat/20230126_KPN005.0_00004.spm')
# file = Path('./flat/20230126_KPN005.0_00013.spm')
# file = Path('./flat/20230126_KPN005.0_00019.spm')
# file = Path('./flat/20230126_KPN005.0_00024.spm')
# file = Path('./flat/20230217_KPN010.0_00010.spm')
# file = Path('./flat/20230314_KPN015.0_00016.spm')
# file = Path('./flat/20230417_KPN0018.0_00000.spm')
# file = Path('./flat/20230417_KPN0018.0_00002.spm')
# EcoPst
# file = Path('./EcoPst/20230406_EPN001_BlackTeflon_Standard.0_00005.spm')
# M10Digest - Bad images
# file = Path(FILE_DIR / "20230119_KM10-003_25mMMgCl2_Vac.0_00005.spm")  # Good one
# file = Path(FILE_DIR / "20230119_KM10-003_25mMMgCl2_Vac.0_00009.spm") # Good one
filename = Path(FILE_DIR / "20230119_KM10-003_25mMMgCl2_Vac.0_00011.spm")  # Best
# filename = "20230119_KM10-003_25mMMgCl2_Vac.0_00016.spm" # Second best
file = FILE_DIR / filename
loadscans = LoadScans([file], "Height")
loadscans.get_data()
p_to_nm = loadscans.pixel_to_nm_scaling
image_raw = loadscans.image
plt.imshow(image_raw)

In [None]:
# Parameters
Rosette_Thres = 1
Gauss_Min_Size = 1
Gauss_Max_Size = 100
Gauss_Sigma = 15
Gauss_Thres = 0.001
Fibril_Min_Size = 5
Fibril_Max_Size = 300
Threshold_Dist = 0.1

In [None]:
# Flatten the image
filters = topofilters.Filters(
    image=image_raw,
    filename=file,
    pixel_to_nm_scaling=p_to_nm,
    threshold_method="std_dev",
    threshold_std_dev={"upper": 1.0, "lower": None},
    gaussian_size=1.0,
    remove_scars={"run": False},
)

filters.filter_image()

PLOT_SAVE_DIR = Path("/Users/sylvi/topo_data/Bradley/M10Digest/")

# Plot the flattened image
plot(filters.images["pixels"], title="pixels", zlimit=False, save_path=PLOT_SAVE_DIR / f"{file}_01_pixels.png")
plot(
    filters.images["initial_median_flatten"],
    zlimit=False,
    title="initial median flatten",
    save_path=PLOT_SAVE_DIR / f"{file}_02_initial_median_flatten.png",
)
plot(
    filters.images["initial_quadratic_removal"],
    zlimit=False,
    title="initial quadratic removal",
    save_path=PLOT_SAVE_DIR / f"{file}_03_initial_quadratic_removal.png",
)
plot(filters.images["mask"], title="mask", zlimit=False, save_path=PLOT_SAVE_DIR / f"{file}_04_mask.png")
plot(
    filters.images["masked_median_flatten"],
    title="masked median flatten",
    zlimit=True,
    save_path=PLOT_SAVE_DIR / f"{file}_05_masked_median_flatten.png",
)
plot(
    filters.images["masked_tilt_removal"],
    title="masked tilt removal",
    zlimit=True,
    save_path=PLOT_SAVE_DIR / f"{file}_06_masked_tilt_removal.png",
)

if np.array_equal(filters.images["masked_tilt_removal"], filters.images["masked_median_flatten"]):
    print("EQUAL")

flattened = filters.images["zero_average_background"]
plot(flattened, title="flattened", zlimit=True, save_path=PLOT_SAVE_DIR / f"{file}_07_flattened.png")

In [None]:
# Thresholding

from skimage.morphology import remove_small_objects, label
from skimage.measure import regionprops

rosette_thresholds = topofilters.get_thresholds(
    flattened, threshold_method="std_dev", threshold_std_dev={"upper": Rosette_Thres, "lower": None}
)

print(f"thresholds: {rosette_thresholds}")
rosette_mask = topofilters.get_mask(image=flattened, thresholds=rosette_thresholds)
plot(rosette_mask.astype(bool), title="rosette binary mask", zlimit=False)

# Remove small objects
removed_small_objects = remove_small_objects(rosette_mask, Gauss_Min_Size)
plot(rosette_mask.astype(bool), title="removed small objects", zlimit=False)
# Remove large objects
labelled_rosette_mask = label(removed_small_objects)
plot(labelled_rosette_mask.astype(bool), title="labelled rosette mask", zlimit=False)
regions = regionprops(labelled_rosette_mask)
for props_index, props in enumerate(regions):
    if props.area > Gauss_Max_Size:
        labelled_rosette_mask[labelled_rosette_mask == props.label] = 0

plot(labelled_rosette_mask.astype(bool), title="removed large objects", zlimit=False)
plt.show()

fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(labelled_rosette_mask.astype(bool), cmap="gray")

In [None]:
# Get the edge of the structure

plt.imshow(labelled_rosette_mask.astype(bool))
plt.show()
gauss = gaussian(labelled_rosette_mask.astype(bool), Gauss_Sigma)
plt.imshow(gauss)
plt.show()

# threshold = threshold_otsu(gauss)
threshold = Gauss_Thres

fig, ax = plt.subplots()
ax.hist(gauss.flatten(), bins="auto")
ax.set_ylim(0, 60000)
plt.axvline(x=threshold, color="r")
plt.show()

print(f"otsu threshold: {threshold}")
thresholded = gauss > threshold
plt.imshow(thresholded)
plt.show()

labelled = label(thresholded)
region_props = regionprops(labelled)
max_size = max([props.area for props in region_props])
thresholded = remove_small_objects(thresholded, min_size=max_size - 1)
thresholded = ndimage.binary_fill_holes(thresholded)
plt.imshow(thresholded)
plt.show()

# Add padding (needed for erosion)
padded = np.pad(thresholded, 1)
# Erode by 1 pixel
eroded = binary_erosion(padded)
# Remove padding
eroded = eroded[1:-1, 1:-1]

# Edges is equal to the difference between the
# original image and the eroded image.
edges = thresholded.astype(int) - eroded.astype(int)
plt.imshow(rosette_mask)
plt.imshow(edges, alpha=0.5)
plt.show()

In [None]:
# Second Thresholding
fibril_thres = rosette_mask.copy()
# Remove small objects
fib_removed_small_objects = remove_small_objects(fibril_thres, Fibril_Min_Size)
plot(fib_removed_small_objects.astype(bool), title="removed small objects", zlimit=False)
# Remove large objects
fibril_rosette_mask = label(fib_removed_small_objects)
# plot(fibril_rosette_mask.astype(bool), title='fibril rosette mask', zlimit=False)
regions = regionprops(fibril_rosette_mask)
for props_index, props in enumerate(regions):
    if props.area > Fibril_Max_Size:
        fibril_rosette_mask[fibril_rosette_mask == props.label] = 0

plot(fibril_rosette_mask.astype(bool), title="removed large objects", zlimit=False)
plt.show()

In [None]:
plt.imshow(fibril_rosette_mask.astype(bool))
plt.imshow(edges, alpha=0.6, cmap="gray")
plt.show()

edge_positions = np.transpose(np.nonzero(edges))
region_props = regionprops(fibril_rosette_mask)
min_distances = np.zeros(len(region_props))
for props_index, props in enumerate(region_props):
    proplabel = props["label"]
    centroid = np.array(props["centroid"]).reshape(-1, 2)
    distances = np.linalg.norm(centroid - edge_positions, axis=1)
    min_distance = np.min(distances)
    min_distances[props_index] = min_distance

print(f"mean distance: {np.mean(min_distances)}")
print(f"std dev distance: {np.std(min_distances)}")
plt.hist(min_distances, bins="auto")
plt.show()
distance_threshold = np.mean(min_distances) - Threshold_Dist * np.std(min_distances)
print(f"distance threshold: {distance_threshold}")

removed_too_far_points = fibril_rosette_mask.copy()
points = np.array([])
for props_index, props in enumerate(region_props):
    proplabel = props["label"]
    centroid = np.array(props["centroid"]).reshape(-1, 2)
    distances = np.linalg.norm(centroid - edge_positions, axis=1)
    min_distance = np.min(distances)
    if min_distance > distance_threshold:
        removed_too_far_points[removed_too_far_points == proplabel] = 0
    else:
        points = np.append(points, centroid).reshape(-1, 2)

points = np.array(points)
print(f"points shape: {points.shape}")

fig, ax = plt.subplots(figsize=(14, 14))
ax.imshow(removed_too_far_points.astype(bool))
plt.show()

# Find centroids of remaining points
labelled = label(removed_too_far_points)
regions = regionprops(labelled)
points = np.ndarray((len(regions), 2))
for props_index, props in enumerate(regions):
    points[props_index, :] = props.centroid

centroid = np.array([np.mean(points[:, 0]), np.mean(points[:, 1])])


# Function to find angle of point from centroid
def angle(point, centroid):
    dx, dy = point - centroid
    return (np.arctan2(dy, dx) + 2 * np.pi) % (2 * np.pi)


# Sort the points based on their angles around the centroid
sorted = points[np.argsort([angle(point, centroid) for point in points])]

xs = np.append(sorted[:, 1], sorted[0, 1])
ys = np.append(sorted[:, 0], sorted[0, 0])

fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(flattened, vmin=-3, vmax=4)
ax.plot(xs, ys, "k--")

plt.show()

In [None]:
def point_in_polygon(point: np.ndarray, polygon: np.ndarray):
    count = 0
    x = point[0]
    y = point[1]

    for index in range(polygon.shape[0] - 1):
        x1, y1 = polygon[index, :]
        x2, y2 = polygon[index + 1, :]

        if (y < y1) != (y < y2):
            # if x is to the left of the intersection point.
            # x - x1 < (y - y1) / m
            # intersection's x-coord is x1 plus the difference of the point and
            # p1's y coord, divided by the gradient.
            if x < (x2 - x1) * (y - y1) / (y2 - y1) + x1:
                count += 1
    if count % 2 == 0:
        return False
    else:
        return True


def signed_distance_to_outline(outline_mask, point):
    """Get the signed distance to the outline of a closed polygon represented by a binary mask of connected pixels.
    It uses the nearest pixel as the outline_mask is assumed to be a complete outline without gaps.

    A negative value means the point is inside the outline and a positive value means the point is outside the outline.
    """

    nonzero_points = np.argwhere(outline_mask == True)
    diffs = nonzero_points - point
    dists_squared = diffs[:, 0] ** 2 + diffs[:, 1] ** 2
    min_dist = np.min(dists_squared)
    min_dist = np.sqrt(min_dist)
    if point_in_polygon(point, nonzero_points):
        return min_dist
    else:
        return -min_dist


outline_mask = np.array(
    [
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 1, 1, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 1, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    ]
)

point = np.array([6, 4])
dist = signed_distance_to_outline(outline_mask, point)
print(f"distance: {dist}")
im = outline_mask.copy()
im[point[0], point[1]] = 2

plt.imshow(im)

In [None]:
sorted_loop = np.append(sorted, sorted[0, :]).reshape(-1, 2)
(
    global_density,
    internal_density,
    near_outline_density,
    dens_internal,
    dist_internal,
    dens_noutline,
    dist_noutline,
    points_internal,
) = network_density_internal(
    nodes=sorted_loop, image=flattened, px_to_nm=p_to_nm, stepsize_px=20, kernel_size=50, gaussian_sigma=10
)

all_distances = np.append(dist_internal, dist_noutline)
all_densities = np.append(dens_internal, dens_noutline)

In [None]:
# Apply unit scaling
dist_internal = np.array(dist_internal) * p_to_nm * 0.001
dist_noutline = np.array(dist_noutline) * p_to_nm * 0.001
dens_internal = np.array(dens_internal)

scaling_factor = flattened.shape[0] / internal_density.shape[0]
plt.imshow(flattened, vmin=-3, vmax=4)
plt.plot(sorted_loop[:, 1], sorted_loop[:, 0])
plt.title("flattened image")
plt.show()
plt.imshow(global_density)
plt.plot(sorted_loop[:, 1] / scaling_factor, sorted_loop[:, 0] / scaling_factor)
plt.colorbar()
plt.title("global density")
plt.show()
plt.imshow(internal_density)
plt.plot(sorted_loop[:, 1] / scaling_factor, sorted_loop[:, 0] / scaling_factor)
plt.colorbar()
plt.title("internal density")
plt.show()
plt.imshow(near_outline_density)
plt.plot(sorted_loop[:, 1] / scaling_factor, sorted_loop[:, 0] / scaling_factor)
plt.colorbar()
plt.title("near outline density")
plt.show()

print(dist_internal.shape)
print(dens_internal.shape)
plt.scatter(x=dist_internal, y=dens_internal, marker=".")
plt.xlabel("distance squared to fibril")
plt.ylabel("local density")
plt.title("internal distance vs density")
plt.show()
plt.scatter(x=dist_noutline, y=dens_noutline, marker=".")
plt.title("near outline distance vs density")
plt.xlabel("distance squared to fibril")
plt.ylabel("local density")
plt.show()

# Scatter plot of both internal and near outline together, density vs distance
plt.scatter(x=dist_internal, y=dens_internal, marker=".", label="internal")
plt.scatter(x=dist_noutline, y=dens_noutline, marker=".", label="near outline")
plt.xlabel("distance to fibril edge (μm)")
plt.ylabel("density of DNA (local median height value)")
plt.title("distance vs density")
plt.show()

# Create a heatmap of the density vs distance
x_size = 80
y_size = 40
heatmap = np.zeros((y_size, x_size))

heatmap_x_min = np.min(dist_internal)
heatmap_x_max = np.max(dist_internal)
heatmap_y_min = np.min(dens_internal)
heatmap_y_max = np.max(dens_internal)

# Bin the data into the pixels
for i in range(len(dist_internal)):
    x = dist_internal[i]
    y = dens_internal[i]
    x_bin = int(np.floor((x - heatmap_x_min) / (heatmap_x_max - heatmap_x_min) * x_size) - 1)
    y_bin = int(np.floor((y - heatmap_y_min) / (heatmap_y_max - heatmap_y_min) * y_size) - 1)
    heatmap[y_bin, x_bin] += 1

plt.imshow(heatmap, cmap="cool")

In [None]:
# Plot a 2d histogram of the data
df = pd.DataFrame({"distance": all_distances, "density": all_densities})
sns.histplot(data=df, x="Distance from Fibril (μm)", y="density", bins=70)

# Plot the mean and std dev of the data
bin_means, bin_edges, binnumber = stats.binned_statistic(
    x=all_distances, values=all_densities, statistic="mean", bins=70
)
bin_std, _, _ = stats.binned_statistic(x=all_distances, values=all_densities, statistic="std", bins=70)
bin_width = bin_edges[1] - bin_edges[0]
bin_centers = bin_edges[1:] - bin_width / 2
plt.errorbar(bin_centers, bin_means, yerr=bin_std, fmt="k.", label="binned statistic of data", alpha=0.4)
plt.plot(bin_centers, bin_means, "k-", alpha=0.8)

In [None]:
def visualise_curvature_scatter(curvatures: np.ndarray, points: np.ndarray, title: str = ""):
    """Visualise the curvature of a set of points using a scatter plot with colours of the markers
    representing the curvatures of the points."""

    # Plot the points
    scatter_plot = plt.scatter(points[:, 0], points[:, 1], c=curvatures, cmap="rainbow", s=3)
    plt.title(title)
    plt.colorbar(scatter_plot)
    plt.show()

In [None]:
MICRON_SCALING_FACTOR = 0.001
P_TO_MICRON = p_to_nm * MICRON_SCALING_FACTOR

print("Image:", file)

print("- molecule stats -")
area = network_area(sorted) * P_TO_MICRON**2
print(f"area: {area:.2f} μm^2")
perimeter = polygon_perimeter(sorted) * P_TO_MICRON
print(f"perimeter: {perimeter:.2f} μm")
min_feret, max_feret = network_feret_diameters(sorted)
print(f"min_feret: {min_feret* P_TO_MICRON:.2f} μm | max_feret: {max_feret* P_TO_MICRON:.2f} μm")
regionstats = node_stats(labelled_image=labelled, image=flattened)

interpolated_curvatures, interpolated_points = interpolate_spline_and_get_curvature(
    points=sorted, interpolation_number=30
)
print("- curvature stats -")
print(
    f"min curvature: {np.min(interpolated_curvatures)}, max curvature: {np.max(interpolated_curvatures)}, mean curvature: {np.mean(interpolated_curvatures)} sum curvature: {np.sum(interpolated_curvatures)}"
)

interpolated_curvatures = np.log10(interpolated_curvatures)

visualise_curvature_scatter(
    curvatures=interpolated_curvatures, points=interpolated_points, title="curvature visualised scatter plot"
)
visualise_curvature_pixel_image(
    curvatures=interpolated_curvatures,
    points=interpolated_points,
    title="curvature visualised pixel image",
    image_size=250,
    figsize=(12, 12),
)

print("- node stats -")
num_perimeter_nodes = sorted.shape[0]
print(f"number of nodes in perimeter: {num_perimeter_nodes}")
print(
    f'node areas | min: {np.min(regionstats["node areas"]* p_to_nm**2):.2f} nm^2 max: {np.max(regionstats["node areas"]* p_to_nm**2):.2f} nm^2 mean: {np.mean(regionstats["node areas"]* p_to_nm**2):.2f} nm^2'
)
print(
    f'node volumes | min: {np.min(regionstats["node volumes"]* p_to_nm**2):.2f} nm^3 max: {np.max(regionstats["node volumes"]* p_to_nm**2):.2f} nm^3 mean: {np.mean(regionstats["node volumes"]* p_to_nm**2):.2f} nm^3'
)
print(
    f'mean node height values | min: {np.min(regionstats["node mean_heights"]):.2f} nm max: {np.max(regionstats["node mean_heights"]):.2f} nm mean: {np.mean(regionstats["node mean_heights"]):.2f} nm'
)
print(
    f'max node height values | min: {np.min(regionstats["node max_heights"]):.2f} nm max: {np.max(regionstats["node max_heights"]):.2f} nm mean: {np.mean(regionstats["node max_heights"]):.2f} nm\n'
)

print("Full Curvature Values:")
print(interpolated_curvatures, "\n")
print("---------------------------------------------------------------")

In [None]:
# Plot curvature as a series of coloured line segments

from matplotlib.collections import LineCollection

# x = np.array([0, 1, 2, 4, 6])
x = interpolated_points[:, 0]
# y = np.array([0, -1, 2, 4, 10])
y = interpolated_points[:, 1]
# weights = np.array([1, 2, 3, 6, 10])
weights = np.log(-interpolated_curvatures)

fig, ax = plt.subplots()
ax.hist(weights, bins="auto")
plt.show()

print(np.min(weights), np.max(weights))
weights = weights - np.min(weights)
weights = weights / np.max(weights)

# Get rid of outliers based on standard deviation


# Cap the values
# weights[weights > 0.2] = 0.2

fig, ax = plt.subplots()
ax.hist(weights, bins="auto")
plt.show()

print(np.min(weights), np.max(weights))

line_points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([line_points[:-1], line_points[1:]], axis=1)

print(segments)

from matplotlib import colormaps

line_colormap = colormaps["rainbow"]
line_colours = line_colormap(weights)
print(line_colours.shape)
print(line_colours)
lc = LineCollection(segments, colors=line_colours, linewidths=20 * np.ones(len(segments)))

fig, ax = plt.subplots()
ax.add_collection(lc)
ax.autoscale()
ax.margins(0.1)
plt.show()

In [None]:
# Save the denstiy values to a csv file

# data = [dist_internal,dens_internal,dist_noutline,dens_noutline]

# with open('density.csv', 'w', encoding='UTF8') as f:
#     writer = csv.writer(f)
#     writer.writerow(header)
#     for stat in data:
#         writer.writerow(np.transpose(stat))