In [1]:
import psutil
import pickle
import re
from pathlib import Path
from typing import Generator

from hashlib import sha256
import matplotlib.pyplot as plt
from pydantic import BaseModel, ConfigDict
import numpy as np
import numpy.typing as npt
import pandas as pd
import seaborn as sns
from AFMReader.topostats import load_topostats
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

from topostats.damage.damage import (
    Defect,
    DefectGap,
    OrderedDefectGapList,
    calculate_indirect_defect_gaps,
    get_defects_and_gaps_from_bool_array,
)
from topostats.io import LoadScans
from topostats.measure.curvature import discrete_angle_difference_per_nm_circular, total_turn_in_region_radians
from topostats.tracing.splining import resample_points_regular_interval
from topostats.unet_masking import make_bounding_box_square, pad_bounding_box_cutting_off_at_image_bounds

In [None]:
def clear_output():
    from IPython.display import clear_output as ipy_clear_output

    ipy_clear_output()

In [None]:
# Get the data directories set up
dir_base = Path("/Volumes/shared/pyne_group/Shared/AFM_Data/dna_damage/Cs137_irradiations")
assert dir_base.exists()
dir_this_analysis = dir_base / "20260204-analysis-getting-back-into-the-project"
assert dir_this_analysis.exists()
dir_processed_data = dir_this_analysis / "output"
assert dir_processed_data.exists()
dir_results = dir_this_analysis / "analysis_results"
dir_results.mkdir(exist_ok=True)
assert dir_results.exists()

# Load the data, lazily since the files are large?
topo_files = list(dir_processed_data.glob("*/**/*.topostats"))
print(f"found {len(topo_files)} topo files")

# Load the corresponding statistics csv file
csv_grain_stats = dir_processed_data / "grain_statistics.csv"
assert csv_grain_stats.exists(), f"could not find grain stats csv at {csv_grain_stats}"
df_grain_stats = pd.read_csv(csv_grain_stats)
print(f"grain stats columns: {df_grain_stats.columns}")

# convert some columns to nanometres
df_grain_stats["total_contour_length"] /= 1e-9

In [None]:
# plot contour length distributions
sns.stripplot(data=df_grain_stats, x="basename", y="total_contour_length", s=2)
sns.violinplot(data=df_grain_stats, x="basename", y="total_contour_length", inner=None)
plt.xticks(rotation=90)
plt.title("Contour length distributions")
plt.show()

# drop any rows with contour length less than a threshold
threshold_contour_length = 300

n_rows_before = len(df_grain_stats)
df_grain_stats = df_grain_stats[df_grain_stats["total_contour_length"] >= threshold_contour_length]
n_rows_after = len(df_grain_stats)
print(
    f"dropped {n_rows_before - n_rows_after} rows with contour length < {threshold_contour_length} nm. remaining rows: {n_rows_after}"
)

sns.stripplot(data=df_grain_stats, x="basename", y="total_contour_length", s=2)
sns.violinplot(data=df_grain_stats, x="basename", y="total_contour_length", inner=None)
plt.xticks(rotation=90)
plt.title("Contour length distributions")
plt.show()

In [None]:
# Function to check ram usage of the notebook
def notebook_ram_usage():
    process = psutil.Process()
    print(f"process: {process}")
    mem_info = process.memory_info()
    print(f"memory info: {mem_info}")
    ram_usage_gb = mem_info.rss / (1024**3)
    print(f"RAM usage: {ram_usage_gb:.2f} GB")


notebook_ram_usage()

In [None]:
# Models
class BaseDamageAnalysis(BaseModel):
    """Data object to hold settings for Models used in the project."""

    model_config = ConfigDict(arbitrary_types_allowed=True)


class GrainModel(BaseDamageAnalysis):
    grain_id: int
    filename: str
    pixel_to_nm_scaling: float
    folder: str
    percent_damage: float
    bbox: tuple[int, int, int, int]
    image: npt.NDArray[np.float64]
    aspect_ratio: float
    smallest_bounding_area: float
    total_contour_length: float
    num_crossings: int
    molecule_data: dict[int, dict[str, npt.NDArray[np.float64]]]
    added_left: int
    added_top: int
    padding: int
    mask: npt.NDArray[np.bool_]
    node_coords: npt.NDArray[np.float64]
    num_nodes: int


class GrainModelCollection(BaseDamageAnalysis):
    grains: dict[int, GrainModel]

    def __getitem__(self, key: int) -> GrainModel:
        return self.grains[key]

    def __iter__(self) -> Generator[tuple[int, GrainModel], None, None]:
        return (item for item in self.grains.items())

    def __len__(self) -> int:
        return len(self.grains)

    def __contains__(self, key: int) -> bool:
        return key in self.grains

    def items(self) -> Generator[tuple[int, GrainModel], None, None]:
        return (item for item in self.grains.items())

    def keys(self) -> Generator[int, None, None]:
        return (key for key in self.grains.keys())

    def values(self) -> Generator[GrainModel, None, None]:
        return (value for value in self.grains.values())

    def get(self, key: int, default: GrainModel | None = None) -> GrainModel | None:
        return self.grains.get(key, default)

    def add_grain(self, grain: GrainModel) -> None:
        # Check if the grain ID already exists in the collection
        if grain.grain_id in self.grains:
            raise ValueError(f"Grain with ID {grain.grain_id} already exists in the collection.")
        self.grains[grain.grain_id] = grain

    def remove_grain(self, grain_id: int) -> None:
        if grain_id not in self.grains:
            raise KeyError(f"Grain with ID {grain_id} does not exist in the collection.")
        del self.grains[grain_id]


def combine_grain_model_collections(collections: list[GrainModelCollection]) -> GrainModelCollection:
    combined_collection = GrainModelCollection(grains={})
    for collection in collections:
        for grain_id, grain in collection.items():
            if grain_id in combined_collection:
                raise ValueError(f"Duplicate grain ID {grain_id} found in multiple collections.")
            combined_collection.add_grain(grain)
    return combined_collection

In [None]:
# Load the files
def load_grain_models_from_topo_files(
    topo_files: list[Path],
    df_grain_stats: pd.DataFrame,
    bbox_padding: int,
    sample_type: str,
) -> GrainModelCollection:
    grain_model_collection = GrainModelCollection(grains={})
    loadscans = LoadScans(img_paths=topo_files, channel="dummy")
    loadscans.get_data()
    loadscans_img_dict = loadscans.img_dict
    for filename, file_data in loadscans_img_dict.items():
        df_grain_stats_image = df_grain_stats[df_grain_stats["image"] == filename]
        full_image = file_data["image"]
        full_mask = file_data["grain_tensors"]["above"]
        pixel_to_nm_scaling = file_data["pixel_to_nm_scaling"]
        ordered_trace_data = file_data["ordered_traces"]["above"]
        try:
            nodestats_data = file_data["nodestats"]["above"]["stats"]
        except KeyError:
            nodestats_data = None
        for current_grain_id_str, grain_ordered_trace_data in ordered_trace_data.items():
            grain_id = int(re.sub(r"grain_", "", current_grain_id_str))
            df_grain_stats_grain = df_grain_stats_image[df_grain_stats_image["grain_number"] == grain_id]
            # get the irradiation dose from the folder path in the form "{dose_percentage}_percent_damage"
            if "Controls" in sample_type:
                dose_percentage = 0
            else:
                dose_percentage_match = re.search(r"(\d+)_percent_damage", str(sample_type))
                if dose_percentage_match:
                    dose_percentage = int(dose_percentage_match.group(1))
                else:
                    raise ValueError(f"could not extract dose percentage from folder path {sample_type}")
            smallest_bounding_area = df_grain_stats_grain["smallest_bounding_area"].values[0]
            aspect_ratio = df_grain_stats_grain["aspect_ratio"].values[0]
            total_contour_length = df_grain_stats_grain["total_contour_length"].values[0]
            num_crossings = df_grain_stats_grain["num_crossings"].values[0]

            all_molecule_data = {}
            for current_molecule_id_str, molecule_ordered_trace_data in grain_ordered_trace_data.items():
                molecule_data = {}
                molecule_id = int(re.sub(r"mol_", "", current_molecule_id_str))
                ordered_coords = molecule_ordered_trace_data["ordered_coords"]
                molecule_data["heights"] = molecule_ordered_trace_data["heights"]
                molecule_data["distances"] = molecule_ordered_trace_data["distances"]
                molecule_data["circular"] = molecule_ordered_trace_data["mol_stats"]["circular"]
                bbox = molecule_ordered_trace_data["bbox"]

                splining_coords = file_data["splining"]["above"][current_grain_id_str][current_molecule_id_str][
                    "spline_coords"
                ]
                curvatures = file_data["grain_curvature_stats"]["above"][current_grain_id_str][current_molecule_id_str]
                molecule_data["curvatures"] = curvatures

                # bbox will be the same for all molecules so this is okay
                bbox_square = make_bounding_box_square(
                    crop_min_row=bbox[0], crop_min_col=bbox[1], crop_max_row=bbox[2], crop_max_col=bbox[3]
                )
                bbox_padded = pad_bounding_box_cutting_off_at_image_bounds(
                    crop_min_row=bbox_square[0],
                    crop_min_col=bbox_square[1],
                    crop_max_row=bbox_square[2],
                    crop_max_col=bbox_square[3],
                    image_shape=full_image.shape,
                    padding=bbox_padding,
                )
                added_left = bbox_padded[1] - bbox[1]
                added_top = bbox_padded[0] - bbox[0]

                # adjust the splining coords to account for padding
                splining_coords[:, 0] -= added_top
                splining_coords[:, 1] -= added_left
                molecule_data["spline_coords"] = splining_coords

                # adjust the ordered coords to account for padding
                ordered_coords[:, 0] -= added_top
                ordered_coords[:, 1] -= added_left
                molecule_data["ordered_coords"] = ordered_coords

                all_molecule_data[molecule_id] = molecule_data

            mask = full_mask[
                bbox_padded[0] : bbox_padded[2],
                bbox_padded[1] : bbox_padded[3],
            ]

            image = full_image[
                bbox_padded[0] : bbox_padded[2],
                bbox_padded[1] : bbox_padded[3],
            ]

            all_node_coords = []
            num_nodes: int = 0
            if nodestats_data is not None:
                try:
                    grain_nodestats_data = nodestats_data[current_grain_id_str]
                    num_nodes = len(grain_nodestats_data)
                    for _node_index, node_data in grain_nodestats_data.items():
                        node_coords = node_data["coords"]
                        # adjust the node coords to account for padding
                        node_coords[:, 0] -= added_top
                        node_coords[:, 1] -= added_left
                        for node_coord in node_coords:
                            all_node_coords.append(node_coord)
                except KeyError as e:
                    if "grain_" in str(e):
                        # grain has no nodestats data here, skip
                        pass
            all_node_coords_array = np.array(all_node_coords)

            grain_model = GrainModel(
                grain_id=grain_id,
                filename=filename,
                pixel_to_nm_scaling=pixel_to_nm_scaling,
                folder=str(sample_type),
                percent_damage=dose_percentage,
                bbox=bbox_padded,
                image=image,
                mask=mask,
                aspect_ratio=aspect_ratio,
                total_contour_length=total_contour_length,
                num_crossings=num_crossings,
                molecule_data=all_molecule_data,
                added_left=added_left,
                added_top=added_top,
                padding=bbox_padding,
                node_coords=all_node_coords_array,
                num_nodes=num_nodes,
                smallest_bounding_area=smallest_bounding_area,
            )

            grain_model_collection.add_grain(grain_model)
    return grain_model_collection