In [None]:
from pathlib import Path
import h5py
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from scipy.optimize import curve_fit

from topostats.io import hdf5_to_dict
import h5glance

In [None]:
data_dir = Path("/Users/sylvi/topo_data/pleng/data/")
assert data_dir.exists()
filename = "20231218_2ngSCcats.0_00003.topostats"
file_path = data_dir / filename
assert file_path.exists()

with h5py.File(file_path, "r") as f:
    image_data = hdf5_to_dict(f, group_path="/")
    print(image_data.keys())
    spline_data = image_data["splining"]["above"]
    p_to_nm = image_data["pixel_to_nm_scaling"]
    print(f"pixel to nm scaling: {p_to_nm}")

    image = image_data["image"]
    plt.imshow(image)
    plt.show()

    spline_data = spline_data["grain_0"]["mol_0"]
    print(spline_data.keys())
    spline_coords = spline_data["spline_coords"]

    # plot spline coords
    fig, ax = plt.subplots(figsize=(10, 10))
    # ax.plot(spline_coords[:, 0], spline_coords[:, 1], marker="o", markersize=1)
    ax.scatter(spline_coords[:, 0], spline_coords[:, 1], s=1)
    plt.show()

In [None]:
def dusanpleng(
    points: np.ndarray, maximum_length_nm: float, plot=False, log_angle_warnings=True, log_details=False
) -> tuple[float, float]:
    if plot:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.plot(points[:, 1], points[:, 0], c="red")
        ax.scatter(points[0, 1], points[0, 0], c="blue")

    current_distance = 0

    distance_cos_angle_pairs = {}
    distances = []
    cos_angles = []

    # iterate over each point
    first_vector = points[1] - points[0]
    first_point = points[0]
    first_vector /= np.linalg.norm(first_vector)
    for point_index, point in enumerate(points):
        if point_index == 0:
            # Skip the first point completely since we can't calculate a vector yet.
            # Don't even add the distance to the dictionary
            continue
        if point_index == len(points) - 1:
            # don't include the last set of points since they're not complete
            break
        vector = point - points[point_index - 1]
        vector_distance = np.linalg.norm(vector)
        current_distance += vector_distance
        vector /= vector_distance
        cos_angle = np.dot(first_vector, vector)
        # Check for bad angles (for cos(angle))
        if log_angle_warnings:
            if cos_angle == 0:
                print("[warning] cos(angle) is orthogonal to first vector, can't be fitted to log plot")
            elif cos_angle < 0:
                print("[warning] cos(angle) is negative, can't be fitted to log plot")

        if current_distance > maximum_length_nm:
            # reset this section and add stats to dictionary
            first_vector = vector
            first_point = point
            current_distance = 0
            # add the distances and cos angles to the list
            distance_cos_angle_pairs[point_index] = {
                "distances": distances,
                "cos_angles": cos_angles,
            }
            distances = []
            cos_angles = []
            if plot:
                ax.scatter(point[1], point[0], c="green")
        else:
            if plot:
                ax.scatter(point[1], point[0], c="blue", s=1)
            distances.append(current_distance)
            cos_angles.append(cos_angle)

    if plot:
        ax.set_aspect("equal")
        plt.show()

    # plot the distances and cos angles
    if plot:
        fig, ax = plt.subplots()
    for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
        distances = distance_cos_angle_pair["distances"]
        cos_angles = distance_cos_angle_pair["cos_angles"]
        if plot:
            ax.plot(distances, cos_angles, label=f"point {point_index}")
    if plot:
        # put the legend to the rhs of the plot and wrap it do it doesn't go off the screen
        ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), ncol=len(distance_cos_angle_pairs) // 4)
        plt.show()

    # get rid of distance cos angle pairs where cos angles are zero or negative, since they can't be fitted
    # to a log plot
    vetted_distance_cos_angle_pairs = {}
    for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
        cos_angles = distance_cos_angle_pair["cos_angles"]
        if np.any(np.array(cos_angles) <= 0):
            continue
        vetted_distance_cos_angle_pairs[point_index] = distance_cos_angle_pair

    distance_cos_angle_pairs = vetted_distance_cos_angle_pairs

    # calculate the average of the cos angles, but since they are all at different distances, resample to use common
    # distances, we need to ditch any points that are outside the smallest range of distances
    average_cos_angles = []
    # find the smallest range of distances
    largest_minimum = -np.inf
    smallest_maximum = np.inf
    if len(distance_cos_angle_pairs) == 0:
        raise ValueError("no points found")
    for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
        distances = distance_cos_angle_pair["distances"]
        smallest_distance = np.min(distances)
        largest_distance = np.max(distances)
        if smallest_distance > largest_minimum:
            largest_minimum = smallest_distance
        if largest_distance < smallest_maximum:
            smallest_maximum = largest_distance
    if largest_minimum == -np.inf or smallest_maximum == np.inf:
        raise ValueError("no common range of distances found")
    common_distances = np.arange(largest_minimum, smallest_maximum, 0.1)
    if log_details:
        print(f"common distances shape: {common_distances.shape}")
    # resample each set of distances and cos angles to the common distances
    if plot:
        fig, ax = plt.subplots(figsize=(5, 5))
    for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
        distances = distance_cos_angle_pair["distances"]
        cos_angles = distance_cos_angle_pair["cos_angles"]
        interpolated_cos_angles = np.interp(common_distances, distances, cos_angles)
        average_cos_angles.append(interpolated_cos_angles)
        if plot:
            ax.plot(common_distances, interpolated_cos_angles, label=f"point {point_index}")

    if plot:
        ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), ncol=len(distance_cos_angle_pairs) // 4)
        plt.show()

    average_cos_angles = np.mean(average_cos_angles, axis=0)

    if plot:
        fig, ax = plt.subplots()
        ax.plot(common_distances, average_cos_angles)
        plt.title("average cos angles")
        plt.show()

    # integrate the plot, ie sum the values
    sum_cos_angles = np.sum(average_cos_angles)
    cumsum_cos_angles = np.cumsum(average_cos_angles)

    print(f"sum of cos angles: {sum_cos_angles}")

    # Plot cumsum
    if plot:
        fig, ax = plt.subplots()
        ax.plot(common_distances, cumsum_cos_angles)
        plt.title("cumsum cos angles")
        plt.show()

    return sum_cos_angles, common_distances


_, _ = dusanpleng(spline_coords, maximum_length_nm=70, plot=True, log_details=True)

In [None]:
def correlations(points, maximum_length_nm, log_angle_warnings=False, log_details=False, plot=False):
    if plot:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.plot(points[:, 1], points[:, 0], c="red")
        ax.scatter(points[0, 1], points[0, 0], c="blue")

    current_distance = 0

    distance_cos_angle_pairs = {}
    distances = []
    cos_angles = []

    # iterate over each point
    first_vector = points[1] - points[0]
    first_point = points[0]
    first_vector /= np.linalg.norm(first_vector)
    for point_index, point in enumerate(points):
        if point_index == 0:
            # Skip the first point completely since we can't calculate a vector yet.
            # Don't even add the distance to the dictionary
            continue
        if point_index == len(points) - 1:
            # don't include the last set of points since they're not complete
            break
        vector = point - points[point_index - 1]
        vector_distance = np.linalg.norm(vector)
        current_distance += vector_distance
        vector /= vector_distance
        cos_angle = np.dot(first_vector, vector)
        # Check for bad angles (for cos(angle))
        if log_angle_warnings:
            if cos_angle == 0:
                print("[warning] cos(angle) is orthogonal to first vector, can't be fitted to log plot")
            elif cos_angle < 0:
                print("[warning] cos(angle) is negative, can't be fitted to log plot")

        if current_distance > maximum_length_nm:
            # reset this section and add stats to dictionary
            first_vector = vector
            first_point = point
            current_distance = 0
            # add the distances and cos angles to the list
            distance_cos_angle_pairs[point_index] = {
                "distances": distances,
                "cos_angles": cos_angles,
            }
            distances = []
            cos_angles = []
            if plot:
                ax.scatter(point[1], point[0], c="green")
        else:
            if plot:
                ax.scatter(point[1], point[0], c="blue", s=1)
            distances.append(current_distance)
            cos_angles.append(cos_angle)

    if plot:
        ax.set_aspect("equal")
        plt.show()

    # plot the distances and cos angles
    if plot:
        fig, ax = plt.subplots()
    for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
        distances = distance_cos_angle_pair["distances"]
        cos_angles = distance_cos_angle_pair["cos_angles"]
        if plot:
            ax.plot(distances, cos_angles, label=f"point {point_index}")
    if plot:
        # put the legend to the rhs of the plot and wrap it do it doesn't go off the screen
        ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), ncol=len(distance_cos_angle_pairs) // 4)
        plt.show()

    # # calculate the average of the cos angles, but since they are all at different distances, resample to use common
    # # distances, we need to ditch any points that are outside the smallest range of distances
    # average_cos_angles = []
    # # find the smallest range of distances
    # largest_minimum = -np.inf
    # smallest_maximum = np.inf
    # if len(distance_cos_angle_pairs) == 0:
    #     raise ValueError("no points found")
    # for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
    #     distances = distance_cos_angle_pair["distances"]
    #     smallest_distance = np.min(distances)
    #     largest_distance = np.max(distances)
    #     if smallest_distance > largest_minimum:
    #         largest_minimum = smallest_distance
    #     if largest_distance < smallest_maximum:
    #         smallest_maximum = largest_distance
    # if largest_minimum == -np.inf or smallest_maximum == np.inf:
    #     raise ValueError("no common range of distances found")
    # common_distances = np.arange(largest_minimum, smallest_maximum, 0.1)
    # if log_details:
    #     print(f"common distances shape: {common_distances.shape}")
    # # resample each set of distances and cos angles to the common distances
    # if plot:
    #     fig, ax = plt.subplots(figsize=(5, 5))
    # for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
    #     distances = distance_cos_angle_pair["distances"]
    #     cos_angles = distance_cos_angle_pair["cos_angles"]
    #     interpolated_cos_angles = np.interp(common_distances, distances, cos_angles)
    #     average_cos_angles.append(interpolated_cos_angles)
    #     if plot:
    #         ax.plot(common_distances, interpolated_cos_angles, label=f"point {point_index}")

    # if plot:
    #     ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), ncol=len(distance_cos_angle_pairs) // 4)
    #     plt.show()

    # average_cos_angles = np.mean(average_cos_angles, axis=0)

    # if plot:
    #     fig, ax = plt.subplots()
    #     ax.plot(common_distances, average_cos_angles)
    #     plt.title("average cos angles")
    #     plt.show()

    # return just a list of cos angles and their distances
    # unpack from dictionary
    distances = []
    cos_angles = []
    for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
        distances.extend(distance_cos_angle_pair["distances"])
        cos_angles.extend(distance_cos_angle_pair["cos_angles"])
    return distances, cos_angles

### Calculate average over all molecules before plotting

In [None]:
sample_types = [
    "magnesium-unknot-plasmid-nicked",
    "magnesium-unknot-plasmid-supercoiled",
    "nickel-relaxed",
    "nickel-supercoiled",
]
data_dir_prefix = "/Users/sylvi/topo_data/pleng/topology-data-check/processed-data-"
maximum_length_nm = 150
maximum_px_to_nm = 1.0
use_grains_with_only_one_molecule = True
assert data_dir.exists()

results_dict = {}

for sample_type in sample_types:
    results_dict[sample_type] = {}
    data_dir = Path(data_dir_prefix + sample_type)
    assert data_dir.exists()
    # Find topostats files
    topostats_files = list(data_dir.glob("*.topostats"))
    for topostats_file in topostats_files:
        with h5py.File(topostats_file, "r") as f:
            results_dict[sample_type][topostats_file] = {}
            image_data = hdf5_to_dict(f, group_path="/")
            if "splining" not in image_data:
                continue
            spline_data = image_data["splining"]["above"]
            p_to_nm = image_data["pixel_to_nm_scaling"]
            image = image_data["image"]
            if p_to_nm > maximum_px_to_nm:
                print(f"skipping {topostats_file} due to pixel to nm scaling of {p_to_nm} too high")
                continue
            for grain_index, grain_data in spline_data.items():
                results_dict[sample_type][topostats_file][grain_index] = {}
                if use_grains_with_only_one_molecule:
                    if len(grain_data) != 1:
                        continue
                for molecule_index, molecule_data in grain_data.items():
                    results_dict[sample_type][topostats_file][grain_index][molecule_index] = {}
                    spline_coords = molecule_data["spline_coords"] * p_to_nm
                    distances, cos_angles = correlations(
                        spline_coords, maximum_length_nm, log_angle_warnings=False, plot=False
                    )
                    results_dict[sample_type][topostats_file][grain_index][molecule_index]["distances"] = distances
                    results_dict[sample_type][topostats_file][grain_index][molecule_index]["cos_angles"] = cos_angles


# Do this for each sample type
for sample_type, sample_type_results in results_dict.items():

    print(f"sample type: {sample_type}")

    # Quick and dirty way to calculate average cos angles over distance
    distance_bin_delta = 0.2
    distance_bins = np.arange(0, maximum_length_nm, distance_bin_delta)
    average_cos_angles = np.zeros_like(distance_bins)
    count_cos_angles = np.zeros_like(distance_bins)
    # Plot scatter of all cos angles vs distances
    fig, ax = plt.subplots()
    for topostats_file, topostats_file_data in sample_type_results.items():
        for grain_index, grain_data in topostats_file_data.items():
            for molecule_index, molecule_data in grain_data.items():
                distances = molecule_data["distances"]
                cos_angles = molecule_data["cos_angles"]
                ax.scatter(distances, cos_angles, s=0.5)

                # bin the cos angles
                for distance, cos_angle in zip(distances, cos_angles):
                    bin_index = int(distance // distance_bin_delta)
                    average_cos_angles[bin_index] += cos_angle
                    count_cos_angles[bin_index] += 1
    plt.xlabel("distance (nm)")
    plt.ylabel("cos(angle)")
    plt.title(f"cos angles for {sample_type}")
    plt.show()

    # Calculate the average cos angles
    average_cos_angles /= count_cos_angles
    # Plot the average cos angles
    fig, ax = plt.subplots()
    ax.scatter(distance_bins, average_cos_angles, s=0.5)
    plt.xlabel("distance (nm)")
    plt.ylabel("average cos(angle)")
    plt.title(f"average cos angles, {sample_type}")
    plt.show()

    # in order to fit, need to remove any zeros or negatives
    average_cos_angles_no_negatives = average_cos_angles[average_cos_angles > 0]
    distance_bins_no_negatives = distance_bins[average_cos_angles > 0]

    # drop nans
    average_cos_angles_no_negatives_no_nans = average_cos_angles_no_negatives[
        ~np.isnan(average_cos_angles_no_negatives)
    ]
    distance_bins_no_negatives_no_nans = distance_bins_no_negatives[~np.isnan(average_cos_angles_no_negatives)]

    average_cos_angles = average_cos_angles_no_negatives_no_nans
    distance_bins = distance_bins_no_negatives_no_nans

    # Integrate over the average cos angles - note: they're noisy datapoints but try anyway
    # To integrate over discrete data without fitting a curve, use the trapezoidal rule
    integral = np.trapz(average_cos_angles, x=distance_bins)
    print(f"integral: {integral}")

    # Fit this to e^-x/p
    def exp_fit(x, p):
        return np.exp(-x / p)

    popt, pcov = curve_fit(exp_fit, distance_bins, average_cos_angles)
    print(f"popt: {popt}")
    print(f"pcov: {pcov}")

    # Plot the fit
    fig, ax = plt.subplots()
    ax.scatter(distance_bins, average_cos_angles, s=0.5)
    ax.plot(distance_bins, exp_fit(distance_bins, *popt), label="fit")
    plt.xlabel("distance (nm)")
    plt.ylabel("average cos(angle)")
    plt.title(f"average cos angles, {sample_type}")
    plt.show()

    # Also calculate the sum of the cos angles
    sum_cos_angles = np.sum(average_cos_angles)
    print(f"sum of cos angles: {sum_cos_angles}")
    # divide by the number of points to get the average
    average_sum_cos_angles = sum_cos_angles / len(average_cos_angles)
    print(f"average sum of cos angles: {average_sum_cos_angles}")

    # plot cumulative sum
    cumsum_cos_angles = np.cumsum(average_cos_angles)
    fig, ax = plt.subplots()
    ax.plot(distance_bins, cumsum_cos_angles)
    plt.title("cumsum cos angles")
    plt.xlabel("distance (nm)")
    plt.ylabel("cumulative sum of cos(angle)")
    plt.title(f"cumulative sum of cos angles, {sample_type}")
    plt.show()