In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from topostats.io import LoadScans

In [None]:
base_dir = Path("/Users/sylvi/topo_data/topostats_2/datasets/topology-plasmids")
results_dir = base_dir / "output_old_catsnet"
assert results_dir.exists()
dir_processed_nicked = results_dir / "plasmid_nic/processed"
assert dir_processed_nicked.exists()
dir_processed_sc = results_dir / "plasmid_sup/processed"
assert dir_processed_sc.exists()

file_allstats = results_dir / "all_statistics.csv"
assert file_allstats.exists()
df_allstats = pd.read_csv(file_allstats)
print(df_allstats.columns)

# convert to nm
df_allstats["area"] = df_allstats["area"] * 1e18
df_allstats["total_contour_length"] = df_allstats["total_contour_length"] * 1e9
df_allstats["volume"] = df_allstats["volume"] * 1e27

# print the writhe_string column unique values
print(df_allstats["writhe_string"].unique())


def calculate_num_char_in_string(input_string: str, character: str) -> int:
    """Calculate the number of occurrences of a specific character in a string."""
    # check if nan
    if pd.isna(input_string):
        return 0
    return input_string.count(character)


def remove_datapoints_outside_n_std(df: pd.DataFrame, column: str, n_std: float) -> pd.DataFrame:
    """Remove datapoints outside n standard deviations from the mean."""
    mean = df[column].mean()
    std = df[column].std()
    lower_bound = mean - n_std * std
    upper_bound = mean + n_std * std
    df_filtered = df[(df[column] >= lower_bound) & (df[column] <= upper_bound)]
    return df_filtered


df_allstats["num_plusses"] = df_allstats["writhe_string"].apply(calculate_num_char_in_string, character="+")
df_allstats["num_minuses"] = df_allstats["writhe_string"].apply(calculate_num_char_in_string, character="-")
df_allstats["num_plusses_or_minuses"] = df_allstats["num_plusses"] + df_allstats["num_minuses"]

In [None]:
# plot violin of length based on basename
sns.violinplot(data=remove_datapoints_outside_n_std(df_allstats, column="total_contour_length", n_std=3), x="basename", y="total_contour_length", inner="point")
plt.ylabel("Contour Length (nm)")
plt.show()

# plot num of plusses or minuses based on basename
sns.violinplot(data=remove_datapoints_outside_n_std(df_allstats, column="num_plusses_or_minuses", n_std=3), x="basename", y="num_plusses_or_minuses", inner="point")
plt.ylabel("Number of crossings")
plt.show()

# volume
sns.violinplot(data=remove_datapoints_outside_n_std(df_allstats, column="volume", n_std=3), x="basename", y="volume", inner="point")
plt.ylabel("Volume (nm^3)")
plt.show()

# height min
sns.violinplot(data=remove_datapoints_outside_n_std(df_allstats, column="height_min", n_std=3), x="basename", y="height_min", inner="point")
plt.ylabel("Minimum Height (nm)")
plt.show()

# height median
sns.violinplot(data=remove_datapoints_outside_n_std(df_allstats, column="height_median", n_std=3), x="basename", y="height_median", inner="point")
plt.ylabel("Median Height (nm)")
plt.show()
# height mean
sns.violinplot(data=remove_datapoints_outside_n_std(df_allstats, column="height_mean", n_std=3), x="basename", y="height_mean", inner="point")
plt.ylabel("Mean Height (nm)")
plt.show()

In [None]:
# Grab the files for the sc and nic processed data
files_ts_nic = list(sorted(dir_processed_nicked.glob("*.topostats")))
print(f"Found {len(files_ts_nic)} processed files for nicked plasmids")
files_ts_sc = list(sorted(dir_processed_sc.glob("*.topostats")))
print(f"Found {len(files_ts_sc)} processed files for supercoiled plasmids")

loadscans_nic = LoadScans(img_paths=files_ts_nic, channel="dummy")
loadscans_nic.get_data()
loadscans_dicts_nic = loadscans_nic.img_dict
loadscans_sc = LoadScans(img_paths=files_ts_sc, channel="dummy")
loadscans_sc.get_data()
loadscans_dicts_sc = loadscans_sc.img_dict


In [None]:
for image_name, image_data in loadscans_dicts_nic.items():
    image_data["sample_type"] = "nicked"
for image_name, image_data in loadscans_dicts_sc.items():
    image_data["sample_type"] = "supercoiled"

# combine the dicts
loadscans_dicts = {**loadscans_dicts_nic, **loadscans_dicts_sc}

print(f"num loaded images: {len(loadscans_dicts)}")

In [None]:
stats_curvature_list = []
for image_name, image_data in loadscans_dicts.items():
    sample_type = image_data["sample_type"]
    if "grain_curvature_stats" not in image_data:
        print(f"No grain curvature data for {image_name}")
        continue
    curvature_data = image_data["grain_curvature_stats"]["above"]
    for grain_id, grain_curvature_data in curvature_data.items():
        for mol_id, mol_curvatures in grain_curvature_data.items():
            # print(grain_id, mol_id)
            min_curvature = np.min(mol_curvatures)
            max_curvature = np.max(mol_curvatures)
            mean_curvature = np.mean(mol_curvatures)
            std_curvature = np.std(mol_curvatures)
            percentile_25_curvature = np.percentile(mol_curvatures, 25)
            percentile_75_curvature = np.percentile(mol_curvatures, 75)
            percentile_iqr_curvature = percentile_75_curvature - percentile_25_curvature
            stats_curvature_list.append(
                {
                    "image_name": image_name,
                    "sample_type": sample_type,
                    "grain_id": grain_id,
                    "mol_id": mol_id,
                    "min_curvature": min_curvature,
                    "max_curvature": max_curvature,
                    "mean_curvature": mean_curvature,
                    "std_curvature": std_curvature,
                    "percentile_25_curvature": percentile_25_curvature,
                    "percentile_75_curvature": percentile_75_curvature,
                    "percentile_iqr_curvature": percentile_iqr_curvature,
                    "total_curvature": np.sum(mol_curvatures)
                }
            )

df_stats_curvature = pd.DataFrame(stats_curvature_list)
# plot min curvature based on sample type
sns.violinplot(data=remove_datapoints_outside_n_std(df_stats_curvature, column="min_curvature", n_std=3), x="sample_type", y="min_curvature", inner="point")
plt.ylabel("Minimum Curvature (1/nm)")
plt.show()
# plot max curvature based on sample type
sns.violinplot(data=remove_datapoints_outside_n_std(df_stats_curvature, column="max_curvature", n_std=3), x="sample_type", y="max_curvature", inner="point")
plt.ylabel("Maximum Curvature (1/nm)")
plt.show()
# plot mean curvature based on sample type
sns.violinplot(data=remove_datapoints_outside_n_std(df_stats_curvature, column="mean_curvature", n_std=3), x="sample_type", y="mean_curvature", inner="point")
plt.ylabel("Mean Curvature (1/nm)")
plt.show()
# plot std curvature based on sample type
sns.violinplot(data=remove_datapoints_outside_n_std(df_stats_curvature, column="std_curvature", n_std=3), x="sample_type", y="std_curvature", inner="point")
plt.ylabel("Standard Deviation of Curvature (1/nm)")
plt.show()
# plot iqr curvature based on sample type
sns.violinplot(data=remove_datapoints_outside_n_std(df_stats_curvature, column="percentile_iqr_curvature", n_std=3), x="sample_type", y="percentile_iqr_curvature", inner="point")
plt.ylabel("Interquartile Range of Curvature (1/nm)")
plt.show()
# plot total curvature based on sample type
sns.violinplot(data=remove_datapoints_outside_n_std(df_stats_curvature, column="total_curvature", n_std=3), x="sample_type", y="total_curvature", inner="point")
plt.ylabel("Total Curvature (1/nm)")
plt.show()
