In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage import gaussian_filter1d

from topostats.io import LoadScans
from topostats.tracing.splining import resample_points_regular_interval, windowTrace
from topostats.measure.curvature import (
    discrete_angle_difference_per_nm_circular,
    discrete_angle_difference_per_nm_linear,
)

In [None]:
base_dir = Path("/Users/sylvi/topo_data/topostats_2/datasets/topology-plasmids")
results_dir = base_dir / "output_old_catsnet"
assert results_dir.exists()
dir_processed_nicked = results_dir / "plasmid_nic/processed"
assert dir_processed_nicked.exists()
dir_processed_sc = results_dir / "plasmid_sup/processed"
assert dir_processed_sc.exists()

file_allstats = results_dir / "all_statistics.csv"
assert file_allstats.exists()
df_allstats = pd.read_csv(file_allstats)
print(df_allstats.columns)

# convert to nm
df_allstats["area"] = df_allstats["area"] * 1e18
df_allstats["total_contour_length"] = df_allstats["total_contour_length"] * 1e9
df_allstats["volume"] = df_allstats["volume"] * 1e27

# print the writhe_string column unique values
print(df_allstats["writhe_string"].unique())


def calculate_num_char_in_string(input_string: str, character: str) -> int:
    """Calculate the number of occurrences of a specific character in a string."""
    # check if nan
    if pd.isna(input_string):
        return 0
    return input_string.count(character)


def remove_datapoints_outside_n_std(df: pd.DataFrame, column: str, n_std: float) -> pd.DataFrame:
    """Remove datapoints outside n standard deviations from the mean."""
    mean = df[column].mean()
    std = df[column].std()
    lower_bound = mean - n_std * std
    upper_bound = mean + n_std * std
    df_filtered = df[(df[column] >= lower_bound) & (df[column] <= upper_bound)]
    return df_filtered


df_allstats["num_plusses"] = df_allstats["writhe_string"].apply(calculate_num_char_in_string, character="+")
df_allstats["num_minuses"] = df_allstats["writhe_string"].apply(calculate_num_char_in_string, character="-")
df_allstats["num_plusses_or_minuses"] = df_allstats["num_plusses"] + df_allstats["num_minuses"]

# get unique grain endpoints
print(df_allstats["grain_endpoints"].unique())
remove_grains_with_endpoints = True
if remove_grains_with_endpoints:
    num_before_removal = len(df_allstats)
    # remove all rows that have grain_endpoints not equal to 0
    df_allstats = df_allstats[df_allstats["grain_endpoints"] == 0]
    num_after_removal = len(df_allstats)
    print(f"Removed {num_before_removal - num_after_removal} rows with grain_endpoints != 0")

In [None]:
# plot violin of length based on basename
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_allstats, column="total_contour_length", n_std=3),
    x="basename",
    y="total_contour_length",
    inner="point",
)
plt.ylabel("Contour Length (nm)")
plt.show()

# plot num of plusses or minuses based on basename
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_allstats, column="num_plusses_or_minuses", n_std=3),
    x="basename",
    y="num_plusses_or_minuses",
    inner="point",
)
plt.ylabel("Number of crossings")
plt.show()

# volume
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_allstats, column="volume", n_std=3),
    x="basename",
    y="volume",
    inner="point",
)
plt.ylabel("Volume (nm^3)")
plt.show()

# height min
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_allstats, column="height_min", n_std=3),
    x="basename",
    y="height_min",
    inner="point",
)
plt.ylabel("Minimum Height (nm)")
plt.show()

# height median
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_allstats, column="height_median", n_std=3),
    x="basename",
    y="height_median",
    inner="point",
)
plt.ylabel("Median Height (nm)")
plt.show()
# height mean
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_allstats, column="height_mean", n_std=3),
    x="basename",
    y="height_mean",
    inner="point",
)
plt.ylabel("Mean Height (nm)")
plt.show()

In [None]:
# Grab the files for the sc and nic processed data
files_ts_nic = list(sorted(dir_processed_nicked.glob("*.topostats")))
print(f"Found {len(files_ts_nic)} processed files for nicked plasmids")
files_ts_sc = list(sorted(dir_processed_sc.glob("*.topostats")))
print(f"Found {len(files_ts_sc)} processed files for supercoiled plasmids")

loadscans_nic = LoadScans(img_paths=files_ts_nic, channel="dummy")
loadscans_nic.get_data()
loadscans_dicts_nic = loadscans_nic.img_dict
loadscans_sc = LoadScans(img_paths=files_ts_sc, channel="dummy")
loadscans_sc.get_data()
loadscans_dicts_sc = loadscans_sc.img_dict

In [None]:
for image_name, image_data in loadscans_dicts_nic.items():
    image_data["sample_type"] = "nicked"
for image_name, image_data in loadscans_dicts_sc.items():
    image_data["sample_type"] = "supercoiled"

# combine the dicts
loadscans_dicts = {**loadscans_dicts_nic, **loadscans_dicts_sc}

print(f"num loaded images: {len(loadscans_dicts)}")

In [None]:
trace_resampling_distance_nm = 2.0
plotting = False
verbose = False
smoothing_window_size_nm = 5
curvature_gaussian_sigma_nm = 2.0
curvature_gaussian_sigma_points = int(curvature_gaussian_sigma_nm / trace_resampling_distance_nm)

stats_curvature_list = []
for image_name, image_data in loadscans_dicts.items():
    sample_type = image_data["sample_type"]
    if verbose:
        print(image_data.keys())
    if "grain_curvature_stats" not in image_data:
        print(f"No grain curvature data for {image_name}")
        continue
    curvature_data = image_data["grain_curvature_stats"]["above"]
    for grain_id, grain_curvature_data in curvature_data.items():
        for mol_id, mol_curvatures in grain_curvature_data.items():
            # print(grain_id, mol_id)
            min_curvature = np.min(mol_curvatures)
            max_curvature = np.max(mol_curvatures)
            mean_curvature = np.mean(mol_curvatures)
            std_curvature = np.std(mol_curvatures)
            percentile_25_curvature = np.percentile(mol_curvatures, 25)
            percentile_75_curvature = np.percentile(mol_curvatures, 75)
            percentile_iqr_curvature = percentile_75_curvature - percentile_25_curvature

            image = image_data["image"]
            if plotting:
                plt.imshow(image)
                plt.show()
            p2nm = image_data["pixel_to_nm_scaling"]

            ordered_traces_data = image_data["ordered_traces"]["above"][grain_id][mol_id]
            ordered_coords = ordered_traces_data["ordered_coords"]
            ordered_heights = ordered_traces_data["heights"]
            if plotting:
                plt.imshow(image)
                plt.plot(ordered_coords[:, 1], ordered_coords[:, 0], "o-", label="Trace", markersize=1)
                plt.show()
                plt.plot(ordered_heights)
                plt.show()

            smoothed_traces_data = image_data["splining"]["above"][grain_id][mol_id]
            smoothed_trace_bbox = smoothed_traces_data["bbox"]
            smoothed_trace_coords = smoothed_traces_data["spline_coords"] + np.array(
                [smoothed_trace_bbox[0], smoothed_trace_bbox[1]]
            )
            # smoothed coords are floats, in pixels, convert to int
            smoothed_trace_coords_int = np.round(smoothed_trace_coords).astype(int)
            smoothed_trace_heights = image[smoothed_trace_coords_int[:, 0], smoothed_trace_coords_int[:, 1]]
            if plotting:
                plt.imshow(image)
                plt.plot(
                    smoothed_trace_coords[:, 1],
                    smoothed_trace_coords[:, 0],
                    "o-",
                    label="Smoothed Trace",
                    markersize=1,
                )
                plt.plot(smoothed_trace_coords[0, 1], smoothed_trace_coords[0, 0], "o", markersize=5, color="yellow")
                plt.show()
                plt.plot(smoothed_trace_heights)
                plt.show()

            diffs_px = np.diff(smoothed_trace_coords, axis=0)
            distances_px = np.linalg.norm(diffs_px, axis=1)
            total_distance_px = np.sum(distances_px)
            total_distance_nm = total_distance_px * p2nm

            heights_percentile_10 = np.percentile(smoothed_trace_heights, 10)
            heights_percentile_25 = np.percentile(smoothed_trace_heights, 25)

            # manually re-calculate curvature again from ordered trace

            redone_smoothed_trace_coords_px = windowTrace.pool_trace_circular(
                pixel_trace=smoothed_trace_coords,
                rolling_window_size=smoothing_window_size_nm,
                pixel_to_nm_scaling=p2nm,
            )

            # resample the path
            resampled_path_px = resample_points_regular_interval(
                points=redone_smoothed_trace_coords_px,
                interval=trace_resampling_distance_nm / p2nm,
                circular=True,
            )

            resampled_diffs_px = np.diff(resampled_path_px, axis=0)
            resampled_distances_px = np.linalg.norm(resampled_diffs_px, axis=1)
            resampled_distances_nm = resampled_distances_px * p2nm

            redone_curvatures = discrete_angle_difference_per_nm_circular(trace_nm=resampled_path_px * p2nm)
            redone_curvatures = gaussian_filter1d(
                redone_curvatures,
                sigma=curvature_gaussian_sigma_points,
                mode="nearest",
            )
            redone_curvatures = np.abs(redone_curvatures)

            stats_curvature_list.append(
                {
                    "image_name": image_name,
                    "sample_type": sample_type,
                    "grain_id": grain_id,
                    "mol_id": mol_id,
                    "min_curvature": min_curvature,
                    "max_curvature": max_curvature,
                    "mean_curvature": mean_curvature,
                    "std_curvature": std_curvature,
                    "percentile_25_curvature": percentile_25_curvature,
                    "percentile_75_curvature": percentile_75_curvature,
                    "percentile_iqr_curvature": percentile_iqr_curvature,
                    "total_curvature": np.sum(mol_curvatures),
                    "total_distance_nm": total_distance_nm,
                    "heights_percentile_10": heights_percentile_10,
                    "heights_percentile_25": heights_percentile_25,
                    "total_curvature_over_total_length": np.sum(mol_curvatures) / total_distance_nm,
                    "redone_total_distance_nm": np.sum(resampled_distances_nm),
                    "redone_total_curvature": np.sum(redone_curvatures),
                    "redone_total_curvature_over_total_length": np.sum(redone_curvatures)
                    / np.sum(resampled_distances_nm),
                    "redone_curvatures_mean": np.mean(redone_curvatures),
                    "redone_curvatures_std": np.std(redone_curvatures),
                    "redone_curvatures_min": np.min(redone_curvatures),
                    "redone_curvatures_max": np.max(redone_curvatures),
                    "redone_curvatures_percentile_25": np.percentile(redone_curvatures, 25),
                    "redone_curvatures_percentile_75": np.percentile(redone_curvatures, 75),
                    "redone_curvatures_percentile_iqr": np.percentile(redone_curvatures, 75)
                    - np.percentile(redone_curvatures, 25),
                }
            )


df_stats_curvature = pd.DataFrame(stats_curvature_list)

## original curvature stats

In [None]:
# plot min curvature based on sample type
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="min_curvature", n_std=3),
    x="sample_type",
    y="min_curvature",
    inner="point",
)
plt.ylabel("Minimum Curvature (1/nm)")
plt.show()
# plot max curvature based on sample type
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="max_curvature", n_std=3),
    x="sample_type",
    y="max_curvature",
    inner="point",
)
plt.ylabel("Maximum Curvature (1/nm)")
plt.show()
# plot mean curvature based on sample type
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="mean_curvature", n_std=3),
    x="sample_type",
    y="mean_curvature",
    inner="point",
)
plt.ylabel("Mean Curvature (1/nm)")
plt.show()
# plot std curvature based on sample type
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="std_curvature", n_std=3),
    x="sample_type",
    y="std_curvature",
    inner="point",
)
plt.ylabel("Standard Deviation of Curvature (1/nm)")
plt.show()
# plot iqr curvature based on sample type
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="percentile_iqr_curvature", n_std=3),
    x="sample_type",
    y="percentile_iqr_curvature",
    inner="point",
)
plt.ylabel("Interquartile Range of Curvature (1/nm)")
plt.show()
# plot total curvature based on sample type
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="total_curvature", n_std=3),
    x="sample_type",
    y="total_curvature",
    inner="point",
)
plt.ylabel("Total Curvature (1/nm)")
plt.show()
# total curvature per length
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="total_curvature_over_total_length", n_std=5),
    x="sample_type",
    y="total_curvature_over_total_length",
    inner="point",
)
plt.ylabel("Total Curvature / total length (1/nm^2)")
plt.show()

# grab the sc data that has total curvature above 50
df_sc_high_total_curvature = df_stats_curvature[
    (df_stats_curvature["sample_type"] == "supercoiled") & (df_stats_curvature["total_curvature"] > 50)
]
df_sc_low_total_curvature = df_stats_curvature[
    (df_stats_curvature["sample_type"] == "supercoiled") & (df_stats_curvature["total_curvature"] <= 50)
]

# plot nth height percentile of sc high curvature against nth height percentile sc low curvature
fig, ax = plt.subplots()
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_sc_high_total_curvature, column="heights_percentile_10", n_std=5),
    x=1,
    y="heights_percentile_10",
    inner="point",
    label="high curvature",
    ax=ax,
)
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_sc_low_total_curvature, column="heights_percentile_10", n_std=5),
    x=2,
    y="heights_percentile_10",
    inner="point",
    label="low curvature",
    ax=ax,
)

## redone curvature stats

In [None]:
# min
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="redone_curvatures_min", n_std=3),
    x="sample_type",
    y="redone_curvatures_min",
    inner="point",
)
plt.show()
# max
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="redone_curvatures_max", n_std=3),
    x="sample_type",
    y="redone_curvatures_max",
    inner="point",
)
plt.show()
# mean
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="redone_curvatures_mean", n_std=3),
    x="sample_type",
    y="redone_curvatures_mean",
    inner="point",
)
plt.show()
# std
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="redone_curvatures_std", n_std=3),
    x="sample_type",
    y="redone_curvatures_std",
    inner="point",
)
plt.show()
# iqr
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="redone_curvatures_percentile_iqr", n_std=3),
    x="sample_type",
    y="redone_curvatures_percentile_iqr",
    inner="point",
)
plt.show()
# 25th percentile
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="redone_curvatures_percentile_25", n_std=3),
    x="sample_type",
    y="redone_curvatures_percentile_25",
    inner="point",
)
plt.show()
# 75th percentile
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="redone_curvatures_percentile_75", n_std=3),
    x="sample_type",
    y="redone_curvatures_percentile_75",
    inner="point",
)
plt.show()
# total curvature
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="redone_total_curvature", n_std=3),
    x="sample_type",
    y="redone_total_curvature",
    inner="point",
)
plt.show()
# total curvature per length
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="redone_total_curvature_over_total_length", n_std=5),
    x="sample_type",
    y="redone_total_curvature_over_total_length",
    inner="point",
)
plt.show()
# total distance
sns.violinplot(
    data=remove_datapoints_outside_n_std(df_stats_curvature, column="redone_total_distance_nm", n_std=3),
    x="sample_type",
    y="redone_total_distance_nm",
    inner="point",
)
plt.show()
