In [None]:
import psutil
import pickle
import re
from pathlib import Path
from typing import Generator, Generic

from hashlib import sha256
import matplotlib.pyplot as plt
import matplotlib as mpl
from pydantic import BaseModel, ConfigDict, Field, computed_field
import numpy as np
import numpy.typing as npt
import pandas as pd
import seaborn as sns
from AFMReader.topostats import load_topostats
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

from topostats.damage.damage import (
    calculate_indirect_defect_gaps,
    get_defects_and_gaps_from_bool_array,
)
from topostats.io import LoadScans
from topostats.measure.curvature import discrete_angle_difference_per_nm_circular, total_turn_in_region_radians
from topostats.tracing.splining import resample_points_regular_interval
from topostats.unet_masking import make_bounding_box_square, pad_bounding_box_cutting_off_at_image_bounds
from topostats.plottingfuncs import Colormap

colormap = Colormap()
CMAP = colormap.get_cmap()
VMIN = -3
VMAX = 4
IMGPLOTARGS = {"cmap": CMAP, "vmin": VMIN, "vmax": VMAX}

In [None]:
def clear_output():
    from IPython.display import clear_output as ipy_clear_output

    ipy_clear_output()

# Loading

In [None]:
# Get the data directories set up
dir_base = Path("/Volumes/shared/pyne_group/Shared/AFM_Data/dna_damage/Cs137_irradiations")
assert dir_base.exists()
dir_this_analysis = dir_base / "20260204-analysis-getting-back-into-the-project"
assert dir_this_analysis.exists()
dir_processed_data = dir_this_analysis / "output"
assert dir_processed_data.exists()
dir_results = dir_this_analysis / "analysis_results"
dir_results.mkdir(exist_ok=True)
assert dir_results.exists()

# Load the data, lazily since the files are large?
topo_files = list(dir_processed_data.glob("*/**/*.topostats"))
print(f"found {len(topo_files)} topo files")

# Load the corresponding statistics csv file
csv_grain_stats: Path = dir_processed_data / "grain_statistics.csv"
assert csv_grain_stats.exists(), f"could not find grain stats csv at {csv_grain_stats}"
df_grain_stats: pd.DataFrame = pd.read_csv(csv_grain_stats)
print(f"grain stats columns: {df_grain_stats.columns}")

# convert some columns to nanometres
df_grain_stats["total_contour_length"] /= 1e-9

In [None]:
# plot contour length distributions
sns.stripplot(data=df_grain_stats, x="basename", y="total_contour_length", s=2)
sns.violinplot(data=df_grain_stats, x="basename", y="total_contour_length", inner=None)
plt.xticks(rotation=90)
plt.title("Contour length distributions")
plt.show()

# drop any rows with contour length less than a threshold
threshold_contour_length = 300

n_rows_before = len(df_grain_stats)
df_grain_stats = df_grain_stats[df_grain_stats["total_contour_length"] >= threshold_contour_length]
n_rows_after = len(df_grain_stats)
print(
    f"dropped {n_rows_before - n_rows_after} rows with contour length < {threshold_contour_length} nm. remaining rows: {n_rows_after}"
)

sns.stripplot(data=df_grain_stats, x="basename", y="total_contour_length", s=2)
sns.violinplot(data=df_grain_stats, x="basename", y="total_contour_length", inner=None)
plt.xticks(rotation=90)
plt.title("Contour length distributions")
plt.show()

# Un-analysed data models (input data format)

In [None]:
# Models
class BaseDamageAnalysis(BaseModel):
    """Data object to hold settings for Models used in the project."""

    model_config = ConfigDict(arbitrary_types_allowed=True)


class UnanalysedMoleculeData(BaseDamageAnalysis):
    molecule_id: int
    heights: npt.NDArray[np.float64]
    distances: npt.NDArray[np.float64]
    circular: bool
    spline_coords: npt.NDArray[np.float64]
    ordered_coords: npt.NDArray[np.float64]
    curvature_data: dict | None


class UnanalysedMoleculeDataCollection(BaseDamageAnalysis):
    molecules: dict[int, UnanalysedMoleculeData]

    def __getitem__(self, key: int) -> UnanalysedMoleculeData:
        return self.molecules[key]

    def __iter__(self) -> Generator[tuple[int, UnanalysedMoleculeData], None, None]:
        return (item for item in self.molecules.items())

    def __len__(self) -> int:
        return len(self.molecules)

    def __contains__(self, key: int) -> bool:
        return key in self.molecules

    def items(self) -> Generator[tuple[int, UnanalysedMoleculeData], None, None]:
        return (item for item in self.molecules.items())

    def keys(self) -> Generator[int, None, None]:
        return (key for key in self.molecules.keys())

    def values(self) -> Generator[UnanalysedMoleculeData, None, None]:
        return (value for value in self.molecules.values())

    def get(self, key: int, default: UnanalysedMoleculeData | None = None) -> UnanalysedMoleculeData | None:
        return self.molecules.get(key, default)

    def add_molecule(self, molecule: UnanalysedMoleculeData) -> None:
        self.molecules[molecule.molecule_id] = molecule

    def remove_molecule(self, molecule_id: int) -> None:
        if molecule_id not in self.molecules:
            raise KeyError(f"molecule with id {molecule_id} not found in collection, cannot remove")
        del self.molecules[molecule_id]


class UnanalysedGrain(BaseDamageAnalysis):
    global_grain_id: int | None = None
    file_grain_id: int
    filename: str
    pixel_to_nm_scaling: float
    folder: str
    percent_damage: float
    bbox: tuple[int, int, int, int]
    image: npt.NDArray[np.float64]
    aspect_ratio: float
    smallest_bounding_area: float
    total_contour_length: float
    num_crossings: int
    molecule_data_collection: UnanalysedMoleculeDataCollection
    added_left: int
    added_top: int
    padding: int
    mask: npt.NDArray[np.bool_]
    node_coords: npt.NDArray[np.float64]
    num_nodes: int

    def __str__(self) -> str:
        # simplified string
        return (
            f"GrainModel(global_grain_id={self.global_grain_id}), {self.percent_damage}% "
            f"damage, from file {self.filename}."
        )

    def plot(self, mask_alpha: float = 0.3) -> None:
        plt.imshow(self.image, **IMGPLOTARGS)
        plt.imshow(self.mask[:, :, 1], alpha=mask_alpha, cmap="gray")
        plt.title(f"grain {self.global_grain_id}, {self.percent_damage}% damage")
        plt.show()


class UnanalysedGrainCollection(BaseDamageAnalysis):
    grains: dict[int, UnanalysedGrain]
    current_global_grain_id: int = 0

    # pretty print
    def __str__(self) -> str:
        grain_indexes = range(self.current_global_grain_id)
        missing_grain_indexes = [index for index in grain_indexes if index not in self.grains]
        return (
            f"GrainModelCollection with {len(self.grains)} grains, with {len(missing_grain_indexes)} "
            f"omitted grains: {missing_grain_indexes}"
        )

    def __getitem__(self, key: int) -> UnanalysedGrain:
        return self.grains[key]

    def __iter__(self) -> Generator[tuple[int, UnanalysedGrain], None, None]:
        return (item for item in self.grains.items())

    def __len__(self) -> int:
        return len(self.grains)

    def __contains__(self, key: int) -> bool:
        return key in self.grains

    def items(self) -> Generator[tuple[int, UnanalysedGrain], None, None]:
        return (item for item in self.grains.items())

    def keys(self) -> Generator[int, None, None]:
        return (key for key in self.grains.keys())

    def values(self) -> Generator[UnanalysedGrain, None, None]:
        return (value for value in self.grains.values())

    def get(self, key: int, default: UnanalysedGrain | None = None) -> UnanalysedGrain | None:
        return self.grains.get(key, default)

    def add_grain(self, grain: UnanalysedGrain) -> None:
        # note: a grain might already have a global grain id if it came from another collection, but we can
        # just overwrite it.
        grain.global_grain_id = self.current_global_grain_id
        self.grains[self.current_global_grain_id] = grain
        self.current_global_grain_id += 1

    def remove_grain(self, global_grain_id: int) -> None:
        if global_grain_id not in self.grains:
            raise KeyError(f"grain with global id {global_grain_id} not found in collection, cannot remove")
        del self.grains[global_grain_id]


def combine_unanalysed_grain_collections(collections: list[UnanalysedGrainCollection]) -> UnanalysedGrainCollection:
    combined_collection = UnanalysedGrainCollection(grains={})
    for collection in collections:
        for grain in collection.values():
            combined_collection.add_grain(grain)
    return combined_collection


def get_dose_from_sample_type(sample_type: str) -> float:
    if "control" in sample_type.lower():
        return 0.0
    match = re.search(r"(\d+)_percent_damage", sample_type)
    if match:
        return float(match.group(1))
    else:
        raise ValueError(f"Could not extract dose from sample type: {sample_type}")

In [None]:
# Load the files
def load_grain_models_from_topo_files(
    topo_files: list[Path],
    df_grain_stats: pd.DataFrame,
    bbox_padding: int,
    sample_type: str,
) -> UnanalysedGrainCollection:
    grain_model_collection = UnanalysedGrainCollection(grains={})
    loadscans = LoadScans(img_paths=topo_files, channel="dummy")
    loadscans.get_data()
    loadscans_img_dict = loadscans.img_dict

    # Group the dataframe by image, since we will want to load all the grains from each image at once to avoid loading
    # the same image multiple times

    unique_dir_file_combinations: set = set(zip(df_grain_stats["basename"], df_grain_stats["image"]))

    print(f"unique directory and file combinations: {unique_dir_file_combinations}")

    for basename, filename in unique_dir_file_combinations:
        print(f"extracting data for file {filename} in folder {basename}")
        # locate the corresponding row in the dataframe
        df_grain_stats_image = df_grain_stats[
            (df_grain_stats["image"] == filename) & (df_grain_stats["basename"] == basename)
        ]
        if len(df_grain_stats_image) == 0:
            raise ValueError(
                f"could not find any rows in the grain stats dataframe for image {filename} and"
                f"basename {basename}. this should not happen, debug this!"
            )

        # load the corresponding image file
        try:
            file_data = loadscans_img_dict[filename]
        except KeyError:
            print(f"keys: {list(loadscans_img_dict.keys())}")
            raise KeyError(f"could not find file data for image {filename} in loaded scans. debug this!")

        full_image = file_data["image"]
        full_mask = file_data["grain_tensors"]["above"]
        pixel_to_nm_scaling = file_data["pixel_to_nm_scaling"]
        ordered_trace_data = file_data["ordered_traces"]["above"]

        # Grab nodestats data
        try:
            nodestats_data = file_data["nodestats"]["above"]["stats"]
        except KeyError:
            nodestats_data = None

        # grab individual grain data, based on each row of the dataframe

        # grab the grain indexes (local) ie, 0-N for the file
        grain_indexes_from_df = df_grain_stats_image["grain_number"].unique()
        grain_indexes_from_file = {int(grain_id.replace("grain_", "")) for grain_id in ordered_trace_data.keys()}

        if not set(grain_indexes_from_df) == grain_indexes_from_file:
            print(
                f"WARN: grain indexes from dataframe and file mismatch. extra indexes in dataframe: "
                f"{set(grain_indexes_from_df) - grain_indexes_from_file} extra indexes in "
                f"file: {grain_indexes_from_file - set(grain_indexes_from_df)}"
            )

        # Now we can confidently grab the grain data based on the dataframe
        for grain_index in grain_indexes_from_df:
            grain_id_str = f"grain_{grain_index}"
            if grain_id_str not in ordered_trace_data:
                print(
                    f"WARN: grain id {grain_id_str} from dataframe not found in file data. THIS SHOULD NOT HAPPEN, DEBUG THIS!"
                )
                continue
            grain_ordered_trace_data = ordered_trace_data[grain_id_str]
            df_grain_stats_grain = df_grain_stats_image[df_grain_stats_image["grain_number"] == grain_index]
            assert (
                len(df_grain_stats_grain) == 1
            ), f"expected exactly one row in the grain stats dataframe for grain {grain_index} in"
            f"image {filename}, but found {len(df_grain_stats_grain)}. DEBUG THIS!"

            dose_percentage = get_dose_from_sample_type(sample_type)
            aspect_ratio = df_grain_stats_grain["aspect_ratio"].values[0]
            total_contour_length = df_grain_stats_grain["total_contour_length"].values[0]
            num_crossings = df_grain_stats_grain["num_crossings"].values[0]
            smallest_bounding_area = df_grain_stats_grain["smallest_bounding_area"].values[0]

            # get the molecule data
            molecule_data_collection = UnanalysedMoleculeDataCollection(molecules={})
            for current_molecule_id_str, molecule_ordered_trace_data in grain_ordered_trace_data.items():
                print(f"-- processing molecule {current_molecule_id_str}")
                molecule_id = int(re.sub(r"mol_", "", current_molecule_id_str))
                ordered_coords = molecule_ordered_trace_data["ordered_coords"]
                molecule_data_heights = molecule_ordered_trace_data["heights"]
                molecule_data_distances = molecule_ordered_trace_data["distances"]
                molecule_data_circular = molecule_ordered_trace_data["mol_stats"]["circular"]
                bbox = molecule_ordered_trace_data["bbox"]

                splining_coords = file_data["splining"]["above"][grain_id_str][current_molecule_id_str]["spline_coords"]
                try:
                    curvature_data_grains = file_data["grain_curvature_stats"]["above"]["grains"]
                    curvature_data_grain = curvature_data_grains[grain_id_str]
                    curvature_data_molecules = curvature_data_grain["molecules"]
                    curvature_data_molecule = curvature_data_molecules[current_molecule_id_str]
                    molecule_data_curvature_data = curvature_data_molecule
                except KeyError:
                    print(
                        f"could not find curvature data for grain {grain_id_str}"
                        f"molecule {current_molecule_id_str}), setting to None. file keys: {list(file_data.keys())}"
                    )
                    molecule_data_curvature_data = None

                # bbox will be the same for all molecules so this is okay
                bbox_square = make_bounding_box_square(
                    crop_min_row=bbox[0],
                    crop_min_col=bbox[1],
                    crop_max_row=bbox[2],
                    crop_max_col=bbox[3],
                    image_shape=full_image.shape,
                )
                bbox_padded = pad_bounding_box_cutting_off_at_image_bounds(
                    crop_min_row=bbox_square[0],
                    crop_min_col=bbox_square[1],
                    crop_max_row=bbox_square[2],
                    crop_max_col=bbox_square[3],
                    image_shape=full_image.shape,
                    padding=bbox_padding,
                )
                added_left = bbox_padded[1] - bbox[1]
                added_top = bbox_padded[0] - bbox[0]

                # adjust the splining coords to account for padding
                splining_coords[:, 0] -= added_top
                splining_coords[:, 1] -= added_left
                molecule_data_spline_coords = splining_coords

                # adjust the ordered coords to account for padding
                ordered_coords[:, 0] -= added_top
                ordered_coords[:, 1] -= added_left
                molecule_data_ordered_coords = ordered_coords

                molecule_data = UnanalysedMoleculeData(
                    molecule_id=molecule_id,
                    heights=molecule_data_heights,
                    distances=molecule_data_distances,
                    circular=molecule_data_circular,
                    curvature_data=molecule_data_curvature_data,
                    spline_coords=molecule_data_spline_coords,
                    ordered_coords=molecule_data_ordered_coords,
                )
                molecule_data_collection.add_molecule(molecule_data)

            mask = full_mask[
                bbox_padded[0] : bbox_padded[2],
                bbox_padded[1] : bbox_padded[3],
            ]

            image = full_image[
                bbox_padded[0] : bbox_padded[2],
                bbox_padded[1] : bbox_padded[3],
            ]

            all_node_coords = []
            num_nodes: int = 0
            if nodestats_data is not None:
                try:
                    grain_nodestats_data = nodestats_data[grain_id_str]
                    num_nodes = len(grain_nodestats_data)
                    for _node_index, node_data in grain_nodestats_data.items():
                        node_coords = node_data["coords"]
                        # adjust the node coords to account for padding
                        node_coords[:, 0] -= added_top
                        node_coords[:, 1] -= added_left
                        for node_coord in node_coords:
                            all_node_coords.append(node_coord)
                except KeyError as e:
                    if "grain_" in str(e):
                        # grain has no nodestats data here, skip
                        pass
            all_node_coords_array = np.array(all_node_coords)

            grain_model = UnanalysedGrain(
                file_grain_id=grain_index,
                filename=filename,
                pixel_to_nm_scaling=pixel_to_nm_scaling,
                folder=str(sample_type),
                percent_damage=dose_percentage,
                bbox=bbox_padded,
                image=image,
                mask=mask,
                aspect_ratio=aspect_ratio,
                total_contour_length=total_contour_length,
                num_crossings=num_crossings,
                molecule_data_collection=molecule_data_collection,
                added_left=added_left,
                added_top=added_top,
                padding=bbox_padding,
                node_coords=all_node_coords_array,
                num_nodes=num_nodes,
                smallest_bounding_area=smallest_bounding_area,
            )

            grain_model_collection.add_grain(grain_model)
    return grain_model_collection

In [None]:
# Extract the grains from the topostats files


def construct_grains_collection_from_topostats_files(
    bbox_padding: int, force_reload: bool = False
) -> UnanalysedGrainCollection:
    grain_model_collection = UnanalysedGrainCollection(grains={})
    dir_loaded_datasets = dir_results / "loaded_datasets"
    assert dir_loaded_datasets.exists(), f"could not find dataset hash directory at {dir_loaded_datasets}"

    # iterate through each topostats image file that remains in the grain stats dataframe and extract the grain data
    # from the topostats files and store them.
    unique_file_folder_combinations = df_grain_stats[["image", "basename"]].drop_duplicates()
    print(f"found {len(unique_file_folder_combinations)} unique topostats files in the grain stats dataframe")

    # Split the files by folder
    unique_file_folder_combinations_grouped = unique_file_folder_combinations.groupby("basename")
    print(unique_file_folder_combinations_grouped.groups.keys())

    # calculate a hash for each folder by combining the hashes of files in that folder.
    for local_folder, group in unique_file_folder_combinations_grouped:
        sample_type = local_folder.replace("../all_data/", "").replace("../test_subset_data/", "")
        dir_loaded_sample_type = dir_loaded_datasets / sample_type
        print(f"checking folder {sample_type} with {len(group)} files")
        print(f"calculating hashes for topostats files...")
        file_paths_and_hashes_topostats: dict[Path, str] = {}
        # construct the file paths for each unique combination of image and folder and check if it exists.
        for _, row in group.iterrows():
            filename = row["image"]
            basename = row["basename"]
            # reconstruct the path to the file using the basename, image name and structure of directories.
            if "all_data" in basename:
                dir_topo_file = Path(str(basename).replace("../all_data/", ""))
            if "test_subset_data" in basename:
                dir_topo_file = Path(str(basename).replace("../test_subset_data/", ""))
            dir_topo_file = dir_processed_data / dir_topo_file / "processed"
            assert dir_topo_file.exists(), f"could not find folder at {dir_topo_file}"
            file_topostats = dir_topo_file / f"{filename}.topostats"
            assert file_topostats.exists(), f"could not find topostats file at {file_topostats}"

            # calculate the hash of the file
            file_topostats_hash = sha256(file_topostats.read_bytes()).hexdigest()
            file_paths_and_hashes_topostats[file_topostats] = file_topostats_hash

        # create hash for the folder by combining the hashes of the files in the folder
        combined_hash_string = "".join(sorted(file_paths_and_hashes_topostats.values()))
        folder_hash = sha256(combined_hash_string.encode()).hexdigest()

        # load the previous hash from text file if it exists
        previous_hash_file = dir_loaded_sample_type / "hash.txt"
        print(f"checking previous hash for folder {sample_type} at {previous_hash_file}")
        previous_hash = None
        if previous_hash_file.exists():
            previous_hash = previous_hash_file.read_text()

        print(
            f"folder {sample_type} hash calculated: {folder_hash}. previous hash: {previous_hash} force reload: {force_reload}"
        )

        if previous_hash == folder_hash and not force_reload:
            # check if the saved data for this folder exists and if it does, load it instead of loading the data
            # from the topostats files
            file_previous_loaded_data = dir_loaded_sample_type / f"data.pkl"
            if file_previous_loaded_data.exists():
                print(
                    f"folder {sample_type} has not changed since last load, skipping loading it and using previous saved data"
                )
                with open(file_previous_loaded_data, "rb") as f:
                    grain_model_collection_folder: UnanalysedGrainCollection = pickle.load(f)
                    # Combine the grain model collection for this folder with the main grain model collection
                    grain_model_collection = combine_unanalysed_grain_collections(
                        [grain_model_collection, grain_model_collection_folder]
                    )
            else:
                print(
                    f"hash for folder {sample_type} matched previous, but could not locate "
                    f"saved data: {file_previous_loaded_data}"
                )
        else:
            if not force_reload:
                print(
                    f"folder {sample_type} hash has changed since last load, loading the data from the topostats files"
                )
            else:
                print(f"forcing reload, loading the data from the topostats files for folder {sample_type}")

            # calculate a subset of the dataframe for just this folder
            df_grain_stats_folder = df_grain_stats[df_grain_stats["basename"] == local_folder]
            grain_model_collection_folder = load_grain_models_from_topo_files(
                topo_files=list(file_paths_and_hashes_topostats.keys()),
                df_grain_stats=df_grain_stats_folder,
                bbox_padding=bbox_padding,
                sample_type=sample_type,
            )

            print(
                "loaded topostats file data into grain model collection."
                "Saving loaded data to .pkl file and saving hash."
            )

            # after loading all the grains for the folder, save the model to a pickle and save the hash
            # for the folder
            file_to_save = dir_loaded_sample_type / f"data.pkl"
            # ensure the parent folder is created
            file_to_save.parent.mkdir(parents=True, exist_ok=True)
            with open(file_to_save, "wb") as f:
                pickle.dump(grain_model_collection_folder, f)
            # save the hash for the folder
            print(f"saving hash for folder {sample_type} to {previous_hash_file}")
            previous_hash_file.parent.mkdir(parents=True, exist_ok=True)
            with open(previous_hash_file, "w") as f:
                f.write(folder_hash)

            # combine the grain model collection for this folder with the main grain model collection
            grain_model_collection = combine_unanalysed_grain_collections(
                [grain_model_collection, grain_model_collection_folder]
            )

    return grain_model_collection


grain_collection = construct_grains_collection_from_topostats_files(bbox_padding=1, force_reload=True)
print(grain_collection)

# Analysis

## Analysed data models (analysed)

In [None]:
class Defect(BaseDamageAnalysis):
    start_index: int
    end_index: int
    length_nm: float
    position_along_trace_nm: float
    total_turn_radians: tuple[float, float]

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Defect):
            raise TypeError(f"Cannot compare Defect with {type(other)}")
        return (
            self.start_index == other.start_index
            and self.end_index == other.end_index
            and np.isclose(self.length_nm, other.length_nm, rtol=1e-9, atol=1e-12)
            and np.isclose(self.position_along_trace_nm, other.position_along_trace_nm, rtol=1e-9, atol=1e-12)
            and np.isclose(self.total_turn_radians[0], other.total_turn_radians[0], rtol=1e-9, atol=1e-12)
            and np.isclose(self.total_turn_radians[1], other.total_turn_radians[1], rtol=1e-9, atol=1e-12)
        )


class DefectGap(BaseDamageAnalysis):
    start_index: int
    end_index: int
    length_nm: float
    position_along_trace_nm: float


class OrderedDefectsGaps(BaseDamageAnalysis):
    defect_gap_list: list[Defect | DefectGap] = Field(default_factory=list)

    # post-init to sort the list by start index
    def model_post_init(self, __context: dict | None = None) -> None:
        self.sort_defect_gap_list()

    def sort_defect_gap_list(self) -> None:
        self.defect_gap_list.sort(key=lambda x: x.start_index)

    def add_item(self, item: Defect | DefectGap) -> None:
        self.defect_gap_list.append(item)
        self.sort_defect_gap_list()

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, OrderedDefectsGaps):
            raise TypeError(f"Cannot compare OrderedDefectsGaps with {type(other)}")
        if len(self.defect_gap_list) != len(other.defect_gap_list):
            return False

        for item_self, item_other in zip(self.defect_gap_list, other.defect_gap_list):
            if item_self != item_other:
                return False


class DefectData(BaseDamageAnalysis):
    ordered_defects_and_gaps: OrderedDefectsGaps

    @computed_field
    @property
    def num_defects(self) -> int:
        return sum(isinstance(item, Defect) for item in self.ordered_defects_and_gaps.defect_gap_list)

    @computed_field
    @property
    def num_gaps(self) -> int:
        return sum(isinstance(item, DefectGap) for item in self.ordered_defects_and_gaps.defect_gap_list)


class MoleculeData(UnanalysedMoleculeData):

    def from_unanalysed_molecule_data(unanalysed_data: UnanalysedMoleculeData) -> "MoleculeData":
        return MoleculeData(
            molecule_id=unanalysed_data.molecule_id,
            heights=unanalysed_data.heights,
            distances=unanalysed_data.distances,
            circular=unanalysed_data.circular,
            spline_coords=unanalysed_data.spline_coords,
            ordered_coords=unanalysed_data.ordered_coords,
            curvature_data=unanalysed_data.curvature_data,
        )


class MoleculeDataCollection(UnanalysedMoleculeDataCollection):

    def from_unanalysed_molecule_data_collection(
        unanalysed_collection: UnanalysedMoleculeDataCollection,
    ) -> "MoleculeDataCollection":
        molecule_data_dict = {}
        for molecule_id, unanalysed_molecule_data in unanalysed_collection.molecules.items():
            molecule_data = MoleculeData.from_unanalysed_molecule_data(unanalysed_molecule_data)
            molecule_data_dict[molecule_id] = molecule_data
        return MoleculeDataCollection(molecules=molecule_data_dict)


class GrainModel(UnanalysedGrain):

    defect_data: DefectData | None = None

    def from_unanalysed_grain(unanalysed_grain: UnanalysedGrain) -> "GrainModel":
        # Create the new molecule data collection
        molecule_data_collection = MoleculeDataCollection.from_unanalysed_molecule_data_collection(
            unanalysed_grain.molecule_data_collection
        )
        return GrainModel(
            global_grain_id=unanalysed_grain.global_grain_id,
            file_grain_id=unanalysed_grain.file_grain_id,
            filename=unanalysed_grain.filename,
            pixel_to_nm_scaling=unanalysed_grain.pixel_to_nm_scaling,
            folder=unanalysed_grain.folder,
            percent_damage=unanalysed_grain.percent_damage,
            bbox=unanalysed_grain.bbox,
            image=unanalysed_grain.image,
            aspect_ratio=unanalysed_grain.aspect_ratio,
            smallest_bounding_area=unanalysed_grain.smallest_bounding_area,
            total_contour_length=unanalysed_grain.total_contour_length,
            num_crossings=unanalysed_grain.num_crossings,
            molecule_data_collection=molecule_data_collection,
            added_left=unanalysed_grain.added_left,
            added_top=unanalysed_grain.added_top,
            padding=unanalysed_grain.padding,
            mask=unanalysed_grain.mask,
            node_coords=unanalysed_grain.node_coords,
            num_nodes=unanalysed_grain.num_nodes,
        )

    def plot(self, mask_alpha: float = 0.3, linemode: str = "") -> None:
        plt.imshow(self.image, **IMGPLOTARGS)
        plt.imshow(self.mask[:, :, 1], alpha=mask_alpha, cmap="gray")
        if linemode == "spline":
            for molecule_id, molecule_data in self.molecule_data_collection.items():
                spline_coords = molecule_data.spline_coords
                plt.plot(spline_coords[:, 1], spline_coords[:, 0])
        elif linemode == "curvature":
            for molecule_id, molecule_data in self.molecule_data_collection.items():
                spline_coords = molecule_data.spline_coords
                curvature_data = molecule_data.curvature_data
                if curvature_data is not None:
                    curvature_values = curvature_data["curvatures"]
                    # plot the curvature values as a colormap along the spline coords
                    assert len(curvature_values) == len(spline_coords), (
                        f"length of curvature values {len(curvature_values)} does not match"
                        f"length of spline coords {len(spline_coords)}"
                    )
                    curvature_norm_bounds_lower = -0.1
                    curvature_norm_bounds_upper = 0.1
                    curvature_values_clipped = np.clip(
                        curvature_values, curvature_norm_bounds_lower, curvature_norm_bounds_upper
                    )
                    curvature_values_normalised = (curvature_values_clipped - curvature_norm_bounds_lower) / (
                        curvature_norm_bounds_upper - curvature_norm_bounds_lower
                    )
                    curvature_cmap = mpl.cm.coolwarm
                    for index, point in enumerate(spline_coords):
                        color = curvature_cmap(curvature_values_normalised[index])
                        if index > 0:
                            previous_point = spline_coords[index - 1]
                            plt.plot(
                                [previous_point[1], point[1]],
                                [previous_point[0], point[0]],
                                color=color,
                                linewidth=1,
                            )
        plt.show()


class GrainCollection(UnanalysedGrainCollection):

    def from_unanalysed_grain_collection(
        unanalysed_collection: UnanalysedGrainCollection,
    ) -> "GrainCollection":
        grain_dict = {}
        for global_grain_id, unanalysed_grain in unanalysed_collection.grains.items():
            grain_model = GrainModel.from_unanalysed_grain(unanalysed_grain)
            grain_dict[global_grain_id] = grain_model
        return GrainCollection(grains=grain_dict, current_global_grain_id=unanalysed_collection.current_global_grain_id)

In [None]:
# Create current analysis models from the loaded models, which allow us to add more things in this notebook without
# having to re-load the data each time to initialise fresh versions of the unanalysed models.

grain_collection = GrainCollection.from_unanalysed_grain_collection(grain_collection)
print(grain_collection)

In [None]:
# defect detection
first_grain = next(iter(grain_collection.values()))
print(first_grain)


def find_curvature_defects(
    grain_model_collection: GrainCollection,
    curvature_defect_method: str,
    curvature_threshold_iqr_multiplier: float,
    curvature_threshold_absolute_pernm: float,
) -> set[int]:
    # find curvature defects
    bad_grains = set()
    if curvature_defect_method == "iqr":
        # iterate over each grain
        for global_grain_id, grain_model in grain_model_collection.items():
            for molecule_id, molecule_data in grain_model.molecule_data_collection.items():
                molecule_data_curvature_data = molecule_data.curvature_data
                if molecule_data_curvature_data is None:
                    print(
                        f"no curvature data for grain {global_grain_id} molecule {molecule_id}, skipping curvature defect detection for this molecule"
                    )
                    bad_grains.add(global_grain_id)
                    continue
                curvatures = molecule_data_curvature_data["curvatures"]
                assert isinstance(
                    curvatures, np.ndarray
                ), f"expected curvatures to be a numpy array, but got {type(curvatures)}"
                pass

    return bad_grains


def find_defects_in_height_and_curvature(
    grain_model_collection: GrainCollection,
    height_defect_method: str,
    height_threshold_iqr_multiplier: float,
    height_threshold_absolute_nm: float,
    curvature_defect_method: str,
    curvature_threshold_iqr_multiplier: float,
    curvature_threshold_absolute_pernm: float,
) -> set[int]:

    bad_grains = set()

    # find curvature defects
    additional_bad_grains = find_curvature_defects(
        grain_model_collection=grain_model_collection,
        curvature_defect_method=curvature_defect_method,
        curvature_threshold_iqr_multiplier=curvature_threshold_iqr_multiplier,
        curvature_threshold_absolute_pernm=curvature_threshold_absolute_pernm,
    )
    bad_grains.update(additional_bad_grains)

    return bad_grains


bad_grains = find_defects_in_height_and_curvature(
    grain_model_collection=grain_collection,
    height_defect_method="iqr",
    height_threshold_iqr_multiplier=1.5,
    height_threshold_absolute_nm=0.8,
    curvature_defect_method="iqr",
    curvature_threshold_iqr_multiplier=1.5,
    curvature_threshold_absolute_pernm=0.1,
)

# remove bad grains
for bad_grain_id in bad_grains:
    grain_collection.remove_grain(bad_grain_id)

print(grain_collection)

sample_grain = next(iter(grain_collection.values()))
sample_grain.plot(mask_alpha=0.1, linemode="curvature")
print(sample_grain)