In [None]:
import psutil
import pickle
import re
from pathlib import Path
from typing import Generator

from hashlib import sha256
import matplotlib.pyplot as plt
from pydantic import BaseModel, ConfigDict
import numpy as np
import numpy.typing as npt
import pandas as pd
import seaborn as sns
from AFMReader.topostats import load_topostats
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

from topostats.damage.damage import (
    Defect,
    DefectGap,
    OrderedDefectGapList,
    calculate_indirect_defect_gaps,
    get_defects_and_gaps_from_bool_array,
)
from topostats.io import LoadScans
from topostats.measure.curvature import discrete_angle_difference_per_nm_circular, total_turn_in_region_radians
from topostats.tracing.splining import resample_points_regular_interval
from topostats.unet_masking import make_bounding_box_square, pad_bounding_box_cutting_off_at_image_bounds

In [None]:
def clear_output():
    from IPython.display import clear_output as ipy_clear_output

    ipy_clear_output()

In [None]:
# Get the data directories set up
dir_base = Path("/Volumes/shared/pyne_group/Shared/AFM_Data/dna_damage/Cs137_irradiations")
assert dir_base.exists()
dir_this_analysis = dir_base / "20260204-analysis-getting-back-into-the-project"
assert dir_this_analysis.exists()
dir_processed_data = dir_this_analysis / "output"
assert dir_processed_data.exists()
dir_results = dir_this_analysis / "analysis_results"
dir_results.mkdir(exist_ok=True)
assert dir_results.exists()

# Load the data, lazily since the files are large?
topo_files = list(dir_processed_data.glob("*/**/*.topostats"))
print(f"found {len(topo_files)} topo files")

# Load the corresponding statistics csv file
csv_grain_stats: Path = dir_processed_data / "grain_statistics.csv"
assert csv_grain_stats.exists(), f"could not find grain stats csv at {csv_grain_stats}"
df_grain_stats: pd.DataFrame = pd.read_csv(csv_grain_stats)
print(f"grain stats columns: {df_grain_stats.columns}")

# convert some columns to nanometres
df_grain_stats["total_contour_length"] /= 1e-9

In [None]:
# plot contour length distributions
sns.stripplot(data=df_grain_stats, x="basename", y="total_contour_length", s=2)
sns.violinplot(data=df_grain_stats, x="basename", y="total_contour_length", inner=None)
plt.xticks(rotation=90)
plt.title("Contour length distributions")
plt.show()

# drop any rows with contour length less than a threshold
threshold_contour_length = 300

n_rows_before = len(df_grain_stats)
df_grain_stats = df_grain_stats[df_grain_stats["total_contour_length"] >= threshold_contour_length]
n_rows_after = len(df_grain_stats)
print(
    f"dropped {n_rows_before - n_rows_after} rows with contour length < {threshold_contour_length} nm. remaining rows: {n_rows_after}"
)

sns.stripplot(data=df_grain_stats, x="basename", y="total_contour_length", s=2)
sns.violinplot(data=df_grain_stats, x="basename", y="total_contour_length", inner=None)
plt.xticks(rotation=90)
plt.title("Contour length distributions")
plt.show()

In [None]:
# Function to check ram usage of the notebook
def notebook_ram_usage():
    process = psutil.Process()
    print(f"process: {process}")
    mem_info = process.memory_info()
    print(f"memory info: {mem_info}")
    ram_usage_gb = mem_info.rss / (1024**3)
    print(f"RAM usage: {ram_usage_gb:.2f} GB")


notebook_ram_usage()

In [None]:
# Models
class BaseDamageAnalysis(BaseModel):
    """Data object to hold settings for Models used in the project."""

    model_config = ConfigDict(arbitrary_types_allowed=True)


class GrainModel(BaseDamageAnalysis):
    global_grain_id: int | None = None
    file_grain_id: int
    filename: str
    pixel_to_nm_scaling: float
    folder: str
    percent_damage: float
    bbox: tuple[int, int, int, int]
    image: npt.NDArray[np.float64]
    aspect_ratio: float
    smallest_bounding_area: float
    total_contour_length: float
    num_crossings: int
    molecule_data: dict[int, dict]
    added_left: int
    added_top: int
    padding: int
    mask: npt.NDArray[np.bool_]
    node_coords: npt.NDArray[np.float64]
    num_nodes: int


class GrainModelCollection(BaseDamageAnalysis):
    grains: dict[int, GrainModel]
    current_global_grain_id: int = 0

    # pretty print
    def __str__(self) -> str:
        grain_indexes = range(self.current_global_grain_id)
        missing_grain_indexes = [index for index in grain_indexes if index not in self.grains]
        return (
            f"GrainModelCollection with {len(self.grains)} grains, with {len(missing_grain_indexes)} "
            f"omitted grains: {missing_grain_indexes}"
        )

    def __getitem__(self, key: int) -> GrainModel:
        return self.grains[key]

    def __iter__(self) -> Generator[tuple[int, GrainModel], None, None]:
        return (item for item in self.grains.items())

    def __len__(self) -> int:
        return len(self.grains)

    def __contains__(self, key: int) -> bool:
        return key in self.grains

    def items(self) -> Generator[tuple[int, GrainModel], None, None]:
        return (item for item in self.grains.items())

    def keys(self) -> Generator[int, None, None]:
        return (key for key in self.grains.keys())

    def values(self) -> Generator[GrainModel, None, None]:
        return (value for value in self.grains.values())

    def get(self, key: int, default: GrainModel | None = None) -> GrainModel | None:
        return self.grains.get(key, default)

    def add_grain(self, grain: GrainModel) -> None:
        # note: a grain might already have a global grain id if it came from another collection, but we can
        # just overwrite it.
        grain.global_grain_id = self.current_global_grain_id
        self.grains[self.current_global_grain_id] = grain
        self.current_global_grain_id += 1

    def remove_grain(self, global_grain_id: int) -> None:
        if global_grain_id not in self.grains:
            raise KeyError(f"grain with global id {global_grain_id} not found in collection, cannot remove")
        del self.grains[global_grain_id]


def combine_grain_model_collections(collections: list[GrainModelCollection]) -> GrainModelCollection:
    combined_collection = GrainModelCollection(grains={})
    for collection in collections:
        for grain in collection.values():
            combined_collection.add_grain(grain)
    return combined_collection


def get_dose_from_sample_type(sample_type: str) -> float:
    if "control" in sample_type.lower():
        return 0.0
    match = re.search(r"(\d+)_percent_damage", sample_type)
    if match:
        return float(match.group(1))
    else:
        raise ValueError(f"Could not extract dose from sample type: {sample_type}")

In [None]:
# Load the files
def load_grain_models_from_topo_files(
    topo_files: list[Path],
    df_grain_stats: pd.DataFrame,
    bbox_padding: int,
    sample_type: str,
) -> GrainModelCollection:
    grain_model_collection = GrainModelCollection(grains={})
    loadscans = LoadScans(img_paths=topo_files, channel="dummy")
    loadscans.get_data()
    loadscans_img_dict = loadscans.img_dict

    # Group the dataframe by image, since we will want to load all the grains from each image at once to avoid loading
    # the same image multiple times

    unique_dir_file_combinations: set = set(zip(df_grain_stats["basename"], df_grain_stats["image"]))

    print(f"unique directory and file combinations: {unique_dir_file_combinations}")

    for basename, filename in unique_dir_file_combinations:
        print(f"extracting data for file {filename} in folder {basename}")
        # locate the corresponding row in the dataframe
        df_grain_stats_image = df_grain_stats[
            (df_grain_stats["image"] == filename) & (df_grain_stats["basename"] == basename)
        ]
        if len(df_grain_stats_image) == 0:
            raise ValueError(
                f"could not find any rows in the grain stats dataframe for image {filename} and"
                f"basename {basename}. this should not happen, debug this!"
            )

        # load the corresponding image file
        try:
            file_data = loadscans_img_dict[filename]
        except KeyError:
            print(f"keys: {list(loadscans_img_dict.keys())}")
            raise KeyError(f"could not find file data for image {filename} in loaded scans. debug this!")

        full_image = file_data["image"]
        full_mask = file_data["grain_tensors"]["above"]
        pixel_to_nm_scaling = file_data["pixel_to_nm_scaling"]
        ordered_trace_data = file_data["ordered_traces"]["above"]

        # Grab nodestats data
        try:
            nodestats_data = file_data["nodestats"]["above"]["stats"]
        except KeyError:
            nodestats_data = None

        # grab individual grain data, based on each row of the dataframe

        # grab the grain indexes (local) ie, 0-N for the file
        grain_indexes_from_df = df_grain_stats_image["grain_number"].unique()
        grain_indexes_from_file = {int(grain_id.replace("grain_", "")) for grain_id in ordered_trace_data.keys()}

        if not set(grain_indexes_from_df) == grain_indexes_from_file:
            print(
                f"WARN: grain indexes from dataframe and file mismatch. extra indexes in dataframe: "
                f"{set(grain_indexes_from_df) - grain_indexes_from_file} extra indexes in "
                f"file: {grain_indexes_from_file - set(grain_indexes_from_df)}"
            )

        # Now we can confidently grab the grain data based on the dataframe
        for grain_index in grain_indexes_from_df:
            grain_id_str = f"grain_{grain_index}"
            if grain_id_str not in ordered_trace_data:
                print(
                    f"WARN: grain id {grain_id_str} from dataframe not found in file data. THIS SHOULD NOT HAPPEN, DEBUG THIS!"
                )
                continue
            grain_ordered_trace_data = ordered_trace_data[grain_id_str]
            df_grain_stats_grain = df_grain_stats_image[df_grain_stats_image["grain_number"] == grain_index]
            assert (
                len(df_grain_stats_grain) == 1
            ), f"expected exactly one row in the grain stats dataframe for grain {grain_index} in"
            f"image {filename}, but found {len(df_grain_stats_grain)}. DEBUG THIS!"

            dose_percentage = get_dose_from_sample_type(sample_type)
            aspect_ratio = df_grain_stats_grain["aspect_ratio"].values[0]
            total_contour_length = df_grain_stats_grain["total_contour_length"].values[0]
            num_crossings = df_grain_stats_grain["num_crossings"].values[0]
            smallest_bounding_area = df_grain_stats_grain["smallest_bounding_area"].values[0]

            # get the molecule data
            all_molecule_data = {}
            for current_molecule_id_str, molecule_ordered_trace_data in grain_ordered_trace_data.items():
                print(f"-- processing molecule {current_molecule_id_str}")
                molecule_data = {}
                molecule_id = int(re.sub(r"mol_", "", current_molecule_id_str))
                ordered_coords = molecule_ordered_trace_data["ordered_coords"]
                molecule_data["heights"] = molecule_ordered_trace_data["heights"]
                molecule_data["distances"] = molecule_ordered_trace_data["distances"]
                molecule_data["circular"] = molecule_ordered_trace_data["mol_stats"]["circular"]
                bbox = molecule_ordered_trace_data["bbox"]

                splining_coords = file_data["splining"]["above"][grain_id_str][current_molecule_id_str]["spline_coords"]
                try:
                    curvature_data_grains = file_data["grain_curvature_stats"]["above"]["grains"]
                    curvature_data_grain = curvature_data_grains[grain_id_str]
                    curvature_data_molecules = curvature_data_grain["molecules"]
                    curvature_data_molecule = curvature_data_molecules[current_molecule_id_str]
                    molecule_data["curvature_data"] = curvature_data_molecule
                except KeyError:
                    print(
                        f"could not find curvature data for grain {grain_id_str}"
                        f"molecule {current_molecule_id_str}), setting to None. file keys: {list(file_data.keys())}"
                    )
                    molecule_data["curvature_data"] = None

                # bbox will be the same for all molecules so this is okay
                bbox_square = make_bounding_box_square(
                    crop_min_row=bbox[0],
                    crop_min_col=bbox[1],
                    crop_max_row=bbox[2],
                    crop_max_col=bbox[3],
                    image_shape=full_image.shape,
                )
                bbox_padded = pad_bounding_box_cutting_off_at_image_bounds(
                    crop_min_row=bbox_square[0],
                    crop_min_col=bbox_square[1],
                    crop_max_row=bbox_square[2],
                    crop_max_col=bbox_square[3],
                    image_shape=full_image.shape,
                    padding=bbox_padding,
                )
                added_left = bbox_padded[1] - bbox[1]
                added_top = bbox_padded[0] - bbox[0]

                # adjust the splining coords to account for padding
                splining_coords[:, 0] -= added_top
                splining_coords[:, 1] -= added_left
                molecule_data["spline_coords"] = splining_coords

                # adjust the ordered coords to account for padding
                ordered_coords[:, 0] -= added_top
                ordered_coords[:, 1] -= added_left
                molecule_data["ordered_coords"] = ordered_coords

                all_molecule_data[molecule_id] = molecule_data

            mask = full_mask[
                bbox_padded[0] : bbox_padded[2],
                bbox_padded[1] : bbox_padded[3],
            ]

            image = full_image[
                bbox_padded[0] : bbox_padded[2],
                bbox_padded[1] : bbox_padded[3],
            ]

            all_node_coords = []
            num_nodes: int = 0
            if nodestats_data is not None:
                try:
                    grain_nodestats_data = nodestats_data[grain_id_str]
                    num_nodes = len(grain_nodestats_data)
                    for _node_index, node_data in grain_nodestats_data.items():
                        node_coords = node_data["coords"]
                        # adjust the node coords to account for padding
                        node_coords[:, 0] -= added_top
                        node_coords[:, 1] -= added_left
                        for node_coord in node_coords:
                            all_node_coords.append(node_coord)
                except KeyError as e:
                    if "grain_" in str(e):
                        # grain has no nodestats data here, skip
                        pass
            all_node_coords_array = np.array(all_node_coords)

            grain_model = GrainModel(
                file_grain_id=grain_index,
                filename=filename,
                pixel_to_nm_scaling=pixel_to_nm_scaling,
                folder=str(sample_type),
                percent_damage=dose_percentage,
                bbox=bbox_padded,
                image=image,
                mask=mask,
                aspect_ratio=aspect_ratio,
                total_contour_length=total_contour_length,
                num_crossings=num_crossings,
                molecule_data=all_molecule_data,
                added_left=added_left,
                added_top=added_top,
                padding=bbox_padding,
                node_coords=all_node_coords_array,
                num_nodes=num_nodes,
                smallest_bounding_area=smallest_bounding_area,
            )

            grain_model_collection.add_grain(grain_model)
    return grain_model_collection

In [None]:
# Extract the grains from the topostats files


def construct_grains_collection_from_topostats_files(
    bbox_padding: int, force_reload: bool = False
) -> GrainModelCollection:
    grain_model_collection = GrainModelCollection(grains={})
    dir_loaded_datasets = dir_results / "loaded_datasets"
    assert dir_loaded_datasets.exists(), f"could not find dataset hash directory at {dir_loaded_datasets}"

    # iterate through each topostats image file that remains in the grain stats dataframe and extract the grain data
    # from the topostats files and store them.
    unique_file_folder_combinations = df_grain_stats[["image", "basename"]].drop_duplicates()
    print(f"found {len(unique_file_folder_combinations)} unique topostats files in the grain stats dataframe")

    # Split the files by folder
    unique_file_folder_combinations_grouped = unique_file_folder_combinations.groupby("basename")
    print(unique_file_folder_combinations_grouped.groups.keys())

    # calculate a hash for each folder by combining the hashes of files in that folder.
    for local_folder, group in unique_file_folder_combinations_grouped:
        sample_type = local_folder.replace("../all_data/", "").replace("../test_subset_data/", "")
        dir_loaded_sample_type = dir_loaded_datasets / sample_type
        print(f"checking folder {sample_type} with {len(group)} files")
        print(f"calculating hashes for topostats files...")
        file_paths_and_hashes_topostats: dict[Path, str] = {}
        # construct the file paths for each unique combination of image and folder and check if it exists.
        for _, row in group.iterrows():
            filename = row["image"]
            basename = row["basename"]
            # reconstruct the path to the file using the basename, image name and structure of directories.
            if "all_data" in basename:
                dir_topo_file = Path(str(basename).replace("../all_data/", ""))
            if "test_subset_data" in basename:
                dir_topo_file = Path(str(basename).replace("../test_subset_data/", ""))
            dir_topo_file = dir_processed_data / dir_topo_file / "processed"
            assert dir_topo_file.exists(), f"could not find folder at {dir_topo_file}"
            file_topostats = dir_topo_file / f"{filename}.topostats"
            assert file_topostats.exists(), f"could not find topostats file at {file_topostats}"

            # calculate the hash of the file
            file_topostats_hash = sha256(file_topostats.read_bytes()).hexdigest()
            file_paths_and_hashes_topostats[file_topostats] = file_topostats_hash

        # create hash for the folder by combining the hashes of the files in the folder
        combined_hash_string = "".join(sorted(file_paths_and_hashes_topostats.values()))
        folder_hash = sha256(combined_hash_string.encode()).hexdigest()

        # load the previous hash from text file if it exists
        previous_hash_file = dir_loaded_sample_type / "hash.txt"
        print(f"checking previous hash for folder {sample_type} at {previous_hash_file}")
        previous_hash = None
        if previous_hash_file.exists():
            previous_hash = previous_hash_file.read_text()

        print(
            f"folder {sample_type} hash calculated: {folder_hash}. previous hash: {previous_hash} force reload: {force_reload}"
        )

        if previous_hash == folder_hash and not force_reload:
            # check if the saved data for this folder exists and if it does, load it instead of loading the data
            # from the topostats files
            file_previous_loaded_data = dir_loaded_sample_type / f"data.pkl"
            if file_previous_loaded_data.exists():
                print(
                    f"folder {sample_type} has not changed since last load, skipping loading it and using previous saved data"
                )
                with open(file_previous_loaded_data, "rb") as f:
                    grain_model_collection_folder: GrainModelCollection = pickle.load(f)
                    # Combine the grain model collection for this folder with the main grain model collection
                    grain_model_collection = combine_grain_model_collections(
                        [grain_model_collection, grain_model_collection_folder]
                    )
            else:
                print(
                    f"hash for folder {sample_type} matched previous, but could not locate "
                    f"saved data: {file_previous_loaded_data}"
                )
        else:
            if not force_reload:
                print(
                    f"folder {sample_type} hash has changed since last load, loading the data from the topostats files"
                )
            else:
                print(f"forcing reload, loading the data from the topostats files for folder {sample_type}")

            # calculate a subset of the dataframe for just this folder
            df_grain_stats_folder = df_grain_stats[df_grain_stats["basename"] == local_folder]
            grain_model_collection_folder = load_grain_models_from_topo_files(
                topo_files=list(file_paths_and_hashes_topostats.keys()),
                df_grain_stats=df_grain_stats_folder,
                bbox_padding=bbox_padding,
                sample_type=sample_type,
            )

            print(
                "loaded topostats file data into grain model collection."
                "Saving loaded data to .pkl file and saving hash."
            )

            # after loading all the grains for the folder, save the model to a pickle and save the hash
            # for the folder
            file_to_save = dir_loaded_sample_type / f"data.pkl"
            # ensure the parent folder is created
            file_to_save.parent.mkdir(parents=True, exist_ok=True)
            with open(file_to_save, "wb") as f:
                pickle.dump(grain_model_collection_folder, f)
            # save the hash for the folder
            print(f"saving hash for folder {sample_type} to {previous_hash_file}")
            previous_hash_file.parent.mkdir(parents=True, exist_ok=True)
            with open(previous_hash_file, "w") as f:
                f.write(folder_hash)

            # combine the grain model collection for this folder with the main grain model collection
            grain_model_collection = combine_grain_model_collections(
                [grain_model_collection, grain_model_collection_folder]
            )

    return grain_model_collection


grain_model_collection = construct_grains_collection_from_topostats_files(bbox_padding=1)
print(grain_model_collection)