In [None]:
from pathlib import Path
import h5py
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from scipy.optimize import curve_fit

from topostats.io import hdf5_to_dict
import h5glance

In [None]:
data_dir = Path("/Users/sylvi/topo_data/pleng/data/")
assert data_dir.exists()

In [None]:
filename = "20231218_2ngSCcats.0_00003.topostats"
file_path = data_dir / filename
assert file_path.exists()

with h5py.File(file_path, "r") as f:
    image_data = hdf5_to_dict(f, group_path="/")
    print(image_data.keys())
    spline_data = image_data["splining"]["above"]
    p_to_nm = image_data["pixel_to_nm_scaling"]
    print(f"pixel to nm scaling: {p_to_nm}")

    image = image_data["image"]
    plt.imshow(image)
    plt.show()

    spline_data = spline_data["grain_0"]["mol_0"]
    print(spline_data.keys())
    spline_coords = spline_data["spline_coords"]

    # plot spline coords
    fig, ax = plt.subplots(figsize=(10, 10))
    # ax.plot(spline_coords[:, 0], spline_coords[:, 1], marker="o", markersize=1)
    ax.scatter(spline_coords[:, 0], spline_coords[:, 1], s=1)
    plt.show()

In [None]:
def eucledian_dist(x, pleng):
    return 4 * pleng * x * (1 - (2 * pleng / x)) * (1 - np.exp(-x / (2 * pleng)))


def betterpleng(
    points: np.ndarray, maximum_length_nm: float, plot=False, log_angle_warnings=True, log_details=False
) -> tuple[float, float]:
    if plot:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.plot(points[:, 1], points[:, 0], c="red")
        ax.scatter(points[0, 1], points[0, 0], c="blue")

    current_distance = 0

    distance_cos_angle_pairs = {}
    distances = []
    cos_angles = []

    # iterate over each point
    first_vector = points[1] - points[0]
    first_point = points[0]
    first_vector /= np.linalg.norm(first_vector)
    for point_index, point in enumerate(points):
        if point_index == 0:
            # Skip the first point completely since we can't calculate a vector yet.
            # Don't even add the distance to the dictionary
            continue
        if point_index == len(points) - 1:
            # don't include the last set of points since they're not complete
            break
        vector = point - points[point_index - 1]
        vector_distance = np.linalg.norm(vector)
        current_distance += vector_distance
        vector /= vector_distance
        cos_angle = np.dot(first_vector, vector)
        # Check for bad angles (for cos(angle))
        if log_angle_warnings:
            if cos_angle == 0:
                print("[warning] cos(angle) is orthogonal to first vector, can't be fitted to log plot")
            elif cos_angle < 0:
                print("[warning] cos(angle) is negative, can't be fitted to log plot")

        if current_distance > maximum_length_nm:
            # reset this section and add stats to dictionary
            first_vector = vector
            first_point = point
            current_distance = 0
            # add the distances and cos angles to the list
            distance_cos_angle_pairs[point_index] = {
                "distances": distances,
                "cos_angles": cos_angles,
            }
            distances = []
            cos_angles = []
            if plot:
                ax.scatter(point[1], point[0], c="green")
        else:
            if plot:
                ax.scatter(point[1], point[0], c="blue", s=1)
            distances.append(current_distance)
            cos_angles.append(cos_angle)

    if plot:
        ax.set_aspect("equal")
        plt.show()

    # plot the distances and cos angles
    if plot:
        fig, ax = plt.subplots()
    for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
        distances = distance_cos_angle_pair["distances"]
        cos_angles = distance_cos_angle_pair["cos_angles"]
        if plot:
            ax.plot(distances, cos_angles, label=f"point {point_index}")
    if plot:
        # put the legend to the rhs of the plot and wrap it do it doesn't go off the screen
        ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), ncol=len(distance_cos_angle_pairs) // 4)
        plt.show()

    # get rid of distance cos angle pairs where cos angles are zero or negative, since they can't be fitted
    # to a log plot
    vetted_distance_cos_angle_pairs = {}
    for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
        cos_angles = distance_cos_angle_pair["cos_angles"]
        if np.any(np.array(cos_angles) <= 0):
            continue
        vetted_distance_cos_angle_pairs[point_index] = distance_cos_angle_pair

    distance_cos_angle_pairs = vetted_distance_cos_angle_pairs

    # calculate the average of the cos angles, but since they are all at different distances, resample to use common
    # distances, we need to ditch any points that are outside the smallest range of distances
    average_cos_angles = []
    # find the smallest range of distances
    largest_minimum = -np.inf
    smallest_maximum = np.inf
    if len(distance_cos_angle_pairs) == 0:
        raise ValueError("no points found")
    for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
        distances = distance_cos_angle_pair["distances"]
        smallest_distance = np.min(distances)
        largest_distance = np.max(distances)
        if smallest_distance > largest_minimum:
            largest_minimum = smallest_distance
        if largest_distance < smallest_maximum:
            smallest_maximum = largest_distance
    if largest_minimum == -np.inf or smallest_maximum == np.inf:
        raise ValueError("no common range of distances found")
    common_distances = np.arange(largest_minimum, smallest_maximum, 0.1)
    if log_details:
        print(f"common distances shape: {common_distances.shape}")
    # resample each set of distances and cos angles to the common distances
    if plot:
        fig, ax = plt.subplots(figsize=(5, 5))
    for point_index, distance_cos_angle_pair in distance_cos_angle_pairs.items():
        distances = distance_cos_angle_pair["distances"]
        cos_angles = distance_cos_angle_pair["cos_angles"]
        interpolated_cos_angles = np.interp(common_distances, distances, cos_angles)
        average_cos_angles.append(interpolated_cos_angles)
        if plot:
            ax.plot(common_distances, interpolated_cos_angles, label=f"point {point_index}")

    if plot:
        ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), ncol=len(distance_cos_angle_pairs) // 4)
        plt.show()

    average_cos_angles = np.mean(average_cos_angles, axis=0)

    if plot:
        fig, ax = plt.subplots()
        ax.plot(common_distances, average_cos_angles)
        plt.title("average cos angles")
        plt.show()

    inverse_log_cos_angles = -np.log(average_cos_angles)

    if plot:
        fig, ax = plt.subplots()
        ax.plot(common_distances, inverse_log_cos_angles)
        plt.title("inverse log cos angles")
        plt.show()

    def linear(x, a, b):
        return a * x + b

    if log_details:
        print(f"common distances shape: {common_distances.shape}")
        print(f"inverse log cos angles shape: {inverse_log_cos_angles.shape}")

    popt, pcov = curve_fit(linear, common_distances, inverse_log_cos_angles)

    if plot:
        fig, ax = plt.subplots()
        ax.plot(common_distances, linear(common_distances, *popt), label="fit")
        ax.plot(common_distances, inverse_log_cos_angles, label="data")
        plt.title("fit")
        plt.legend()
        plt.show

    # the slope is 1/(2p) where p is the pleng
    pleng = 1 / (2 * popt[0])

    # calculate residuals
    residuals = inverse_log_cos_angles - linear(common_distances, *popt)
    if plot:
        fig, ax = plt.subplots()
        ax.scatter(x=common_distances, y=residuals)
        plt.title("residuals")
        plt.show()
        print(f"pleng: {pleng}")

    # calculate root mean squared error
    rmse = np.sqrt(np.mean(residuals**2))
    if plot:
        print(f"rmse: {rmse}")

    return pleng, rmse


betterpleng(spline_coords, maximum_length_nm=10, plot=True, log_details=True)

### Get lowest

In [None]:
best_pleng = None
best_rmse = None
best_max_length = None
for maximum_length_nm in [10, 20, 30, 40, 50]:
    pleng, rmse = betterpleng(spline_coords, maximum_length_nm, plot=False)
    # check if better rmse
    if best_rmse is None or rmse < best_rmse:
        best_rmse = rmse
        best_pleng = pleng
        best_max_length = maximum_length_nm

pleng, rmse = betterpleng(spline_coords, maximum_length_nm=best_max_length, plot=True)

print(f"best pleng: {pleng} with rmse: {rmse} at maximum length: {best_max_length}")

In [None]:
# explore what happens are p and s are changed for many different values
p_values = np.linspace(0.01, 0.1, 100)
s_values = np.linspace(0.1, 1, 100)
rsquared_values = np.zeros((len(p_values), len(s_values)))
for p_index, pleng in enumerate(p_values):
    for s_index, x in enumerate(s_values):
        rsquared_values[p_index, s_index] = eucledian_dist(pleng, x)

fig, ax = plt.subplots()
cax = ax.matshow(rsquared_values, cmap="viridis")
fig.colorbar(cax)
plt.show()

In [None]:
pleng = 50
contour_lengths = np.linspace(10, 150, 100)
euclidian_distances = np.zeros_like(contour_lengths)
for index, contour_length in enumerate(contour_lengths):
    euclidian_distances[index] = eucledian_dist(x=contour_length, pleng=pleng)

fig, ax = plt.subplots()
ax.plot(contour_lengths, euclidian_distances)
plt.xlabel("contour length")
plt.ylabel("euclidean distance")
plt.show()

In [None]:
# create results dictionary
results = {}

In [None]:
def multi_file_pleng(
    data_dir: Path,
    maximum_length_nm: float,
    maximum_px_to_nm: float,
    use_grains_with_only_one_molecule: bool,
    quiet=True,
) -> tuple[list, list, list, int]:

    def qprint(*args, **kwargs):
        if not quiet:
            print(*args, **kwargs)

    num_molecules_processed = 0

    plengs, rmses, p2nms = [], [], []

    # get all .topostats files
    topostats_files = list(data_dir.glob("*.topostats"))
    print(f"Found {len(topostats_files)} .topostats files")
    for file in topostats_files:
        qprint(f"file: {file.name}")
        with h5py.File(file, "r") as f:
            image_data = hdf5_to_dict(f, group_path="/")
            if "splining" not in image_data:
                qprint("[ERROR] no splining data")
                continue
            spline_data = image_data["splining"]["above"]
            p_to_nm = image_data["pixel_to_nm_scaling"]

            if p_to_nm > maximum_px_to_nm:
                qprint(f"skipping file, pixel to nm scaling is too high: {p_to_nm}")
                continue

            for grain_index, grain_data in spline_data.items():
                qprint(f"  grain: {grain_index}")
                qprint(f"  number of molecules: {len(grain_data)}")
                if use_grains_with_only_one_molecule and len(grain_data) != 1:
                    qprint(f"    skipping grain with {len(grain_data)} molecules")
                    continue
                for molecule_index, molecule_data in grain_data.items():
                    qprint(f"    molecule: {molecule_index}")
                    spline_coords = molecule_data["spline_coords"] * p_to_nm
                    try:
                        pleng, rmse = betterpleng(
                            points=spline_coords,
                            maximum_length_nm=maximum_length_nm,
                            plot=False,
                            log_angle_warnings=False,
                        )
                        plengs.append(pleng)
                        rmses.append(rmse)
                        p2nms.append(p_to_nm)
                    except ValueError as e:
                        if "no points found" in str(e):
                            qprint("[ERROR] no points found")
                            continue
                        elif "no common range of distances found" in str(e):
                            qprint("[ERROR] no common range of distances found")
                            continue
                        else:
                            raise e
                    num_molecules_processed += 1
    print(f"processed {num_molecules_processed} molecules")

    return plengs, rmses, p2nms, num_molecules_processed


sample_type = "magnesium-unknot-plasmid-nicked"
data_dir = Path(f"/Users/sylvi/topo_data/pleng/topology-data-check/processed-data-{sample_type}")
maximum_length_nm = 10
maximum_px_to_nm = 1.0
use_grains_with_only_one_molecule = True
assert data_dir.exists()

plengs, rmses, p2nms, num_molecules_processed = multi_file_pleng(
    data_dir=data_dir,
    maximum_length_nm=maximum_length_nm,
    maximum_px_to_nm=maximum_px_to_nm,
    use_grains_with_only_one_molecule=use_grains_with_only_one_molecule,
)

fig, ax = plt.subplots(figsize=(5, 5))
p2nms = np.array(p2nms)
ax.scatter(x=plengs, y=rmses, s=4, c=p2nms, cmap="viridis")
cbar = plt.colorbar(ax.collections[0], ax=ax)

plt.xlabel("pleng")
plt.ylabel("rmse")
plt.title(f"pleng vs rmse for {sample_type}, n={num_molecules_processed} L={maximum_length_nm}")
plt.show()

# calculate a weighted average, inversely proportional to the rmse
weighted_average = np.average(plengs, weights=1 / np.array(rmses))
print(f"rmse weighted average: {weighted_average}")

non_weighted_average = np.mean(plengs)
print(f"non weighted average: {non_weighted_average}")

results[sample_type] = {
    "plengs": plengs,
    "rmses": rmses,
    "p2nms": p2nms,
    "weighted_average": weighted_average,
    "non_weighted_average": non_weighted_average,
}

In [None]:
# See what happens to a dataset when changing maximum length
# sample_type = "magnesium-unknot-plasmid-nicked"
sample_type = "magnesium-unknot-plasmid-supercoiled"
data_dir = Path(f"/Users/sylvi/topo_data/pleng/topology-data-check/processed-data-{sample_type}")
maximum_px_to_nm = 1.0
use_grains_with_only_one_molecule = True
assert data_dir.exists()

maximum_lengths = list(range(8, 30, 1))
average_plengs = []

for maximum_length_nm in maximum_lengths:
    plengs, rmses, p2nms, num_molecules_processed = multi_file_pleng(
        data_dir=data_dir,
        maximum_length_nm=maximum_length_nm,
        maximum_px_to_nm=maximum_px_to_nm,
        use_grains_with_only_one_molecule=use_grains_with_only_one_molecule,
    )

    non_weighted_average = np.mean(plengs)
    average_plengs.append(non_weighted_average)

fig, ax = plt.subplots()
ax.plot(maximum_lengths, average_plengs)
plt.xlabel("maximum length")
plt.ylabel("average pleng")
plt.title(f"average pleng vs maximum length for {sample_type}")
plt.show()

In [None]:
sample_types = [
    "magnesium-unknot-plasmid-nicked",
    "magnesium-unknot-plasmid-supercoiled",
    "nickel-relaxed",
    "nickel-supercoiled",
]
data_dir_prefix = "/Users/sylvi/topo_data/pleng/topology-data-check/processed-data-"
maximum_length_nm = 20
maximum_px_to_nm = 1.0
use_grains_with_only_one_molecule = True
assert data_dir.exists()

results = {}

for sample_type in sample_types:
    data_dir = Path(data_dir_prefix + sample_type)
    plengs, rmses, p2nms, num_molecules_processed = multi_file_pleng(
        data_dir=data_dir,
        maximum_length_nm=maximum_length_nm,
        maximum_px_to_nm=maximum_px_to_nm,
        use_grains_with_only_one_molecule=use_grains_with_only_one_molecule,
    )

    fig, ax = plt.subplots(figsize=(5, 5))
    p2nms = np.array(p2nms)
    ax.scatter(x=plengs, y=rmses, s=4, c=p2nms, cmap="viridis")
    cbar = plt.colorbar(ax.collections[0], ax=ax)

    plt.xlabel("pleng")
    plt.ylabel("rmse")
    plt.title(f"pleng vs rmse for {sample_type}, n={num_molecules_processed} L={maximum_length_nm}")
    plt.show()

    # calculate a weighted average, inversely proportional to the rmse
    weighted_average = np.average(plengs, weights=1 / np.array(rmses))
    print(f"rmse weighted average: {weighted_average}")

    non_weighted_average = np.mean(plengs)
    print(f"non weighted average: {non_weighted_average}")

    results[sample_type] = {
        "plengs": plengs,
        "rmses": rmses,
        "p2nms": p2nms,
        "weighted_average": weighted_average,
        "non_weighted_average": non_weighted_average,
    }


# plot kde of plengs in results
fig, ax = plt.subplots(figsize=(5, 5))
for sample_type, data in results.items():
    plengs = data["plengs"]
    sns.kdeplot(
        plengs,
        # label=f" {sample_type}\n   | n: {len(plengs)}\n   | mean: {np.mean(plengs):.2f} | std: {np.std(plengs):.2f} |\n",
        label=f" {sample_type}\n    n: {len(plengs)} ",
        ax=ax,
    )
# force the legend to be outside the plot
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.title(f"plengs L={maximum_length_nm}")
plt.xlabel("pleng (nm)")
plt.ylabel("density")
plt.show()

# plot another kde but using samples that contain nickel and samples that contain magnesium as the two groups
fig, ax = plt.subplots(figsize=(5, 5))
nickel_plengs = results["nickel-relaxed"]["plengs"] + results["nickel-supercoiled"]["plengs"]
magnesium_plengs = (
    results["magnesium-unknot-plasmid-nicked"]["plengs"] + results["magnesium-unknot-plasmid-supercoiled"]["plengs"]
)
sns.kdeplot(
    nickel_plengs,
    # label=f"nickel\n n:{len(nickel_plengs)}\n mean: {np.mean(nickel_plengs):.2f} std: {np.std(nickel_plengs):.2f}",
    label=f"nickel\n n:{len(nickel_plengs)}",
    ax=ax,
)
sns.kdeplot(
    magnesium_plengs,
    # label=f"magnesium\n n:{len(magnesium_plengs)}\n mean: {np.mean(magnesium_plengs):.2f} std: {np.std(magnesium_plengs):.2f}",
    label=f"magnesium\n n:{len(magnesium_plengs)}",
    ax=ax,
)
plt.legend()
plt.title(f"nickel vs magnesium plengs L={maximum_length_nm}")
plt.xlabel("pleng (nm)")
plt.ylabel("density")
plt.show()