In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

from AFMReader.topostats import load_topostats

import numpy.typing as npt
import numpy as np
from topostats.io import LoadScans
import matplotlib.pyplot as plt
from topostats.unet_masking import make_bounding_box_square, pad_bounding_box

In [None]:
base_dir = Path("/Users/sylvi/topo_data/picoz")
assert base_dir.exists()

processed_dir = base_dir / "output_trained_on"
assert processed_dir.exists()

# load the images
topo_files = list(processed_dir.glob("**/*.topostats"))
print(f"found {len(topo_files)} topostats files")
# all_data = {}
# for topo_file in topo_files:
#     file_data = load_topostats(topo_file)
#     print(file_data.keys())

#     try:
#         all_data[topo_file.name] = {
#             "image": file_data["image"],
#             "grain_tensors": file_data["grain_tensors"],
#             "p2nm": file_data["pixel_to_nm_scaling"],
#             "curvature_stats": file_data["grain_curvature_stats"],
#             "splining": file_data["splining"],
#         }
#     except KeyError as e:
#         if "curvature_stats" in str(e):
#             pass
#         elif "grain_tensors" in str(e):
#             pass
#         else:
#             raise e

In [None]:
def construct_grains_dictionary(
    file_list: list, bbox_padding: int, stop_at_index: int | None = None, plot: bool = False
):
    grains_dictionary: dict[any] = {}

    loadscans = LoadScans(file_list, channel="dummy")
    loadscans.get_data()
    img_dict = loadscans.img_dict

    grain_index = 0
    for file_index, (filename, file_data) in enumerate(img_dict.items()):
        if stop_at_index is not None and file_index >= stop_at_index:
            break
        try:
            try:
                nodestats_data = file_data["nodestats"]["above"]["stats"]
            except KeyError:
                nodestats_data = None

            # print(f"getting data from {filename}")
            image = file_data["image"]
            ordered_trace_data = file_data["ordered_traces"]["above"]
            for current_grain_index, grain_ordered_trace_data in ordered_trace_data.items():
                # print(f"  grain {current_grain_index}")
                grains_dictionary[grain_index] = {}
                grains_dictionary[grain_index]["molecule_data"] = {}

                for current_molecule_index, molecule_ordered_trace_data in grain_ordered_trace_data.items():
                    molecule_data = {}
                    ordered_coords = molecule_ordered_trace_data["ordered_coords"]
                    molecule_data["heights"] = molecule_ordered_trace_data["heights"]
                    molecule_data["distances"] = molecule_ordered_trace_data["distances"]
                    bbox = molecule_ordered_trace_data["bbox"]
                    print(f"  grain {current_grain_index} molecule {current_molecule_index} bbox {bbox}")
                    grains_dictionary[grain_index]["molecule_data"][current_molecule_index] = molecule_data

                    splining_coords = file_data["splining"]["above"][current_grain_index][current_molecule_index][
                        "spline_coords"
                    ]

                    curvatures = file_data["grain_curvature_stats"]["above"]

                    # bbox will be same for all molecules so this is okay
                    bbox_square = make_bounding_box_square(bbox[0], bbox[1], bbox[2], bbox[3], image.shape)
                    bbox_padded = pad_bounding_box(
                        bbox_square[0],
                        bbox_square[1],
                        bbox_square[2],
                        bbox_square[3],
                        image.shape,
                        padding=bbox_padding,
                    )
                    added_left = bbox_padded[1] - bbox[1]
                    added_top = bbox_padded[0] - bbox[0]

                    # adjust the spline coords to account for the padding
                    splining_coords[:, 0] -= added_top
                    splining_coords[:, 1] -= added_left
                    molecule_data["spline_coords"] = splining_coords

                    # adjust the ordered coords to account for the padding
                    ordered_coords[:, 0] -= added_top
                    ordered_coords[:, 1] -= added_left
                    molecule_data["ordered_coords"] = ordered_coords

                image_crop = image[
                    bbox_padded[0] : bbox_padded[2],
                    bbox_padded[1] : bbox_padded[3],
                ]
                full_grain_mask = file_data["grain_tensors"]["above"]
                grains_dictionary[grain_index]["image"] = image_crop
                grains_dictionary[grain_index]["full_image"] = image
                grains_dictionary[grain_index]["bbox"] = bbox_padded
                grains_dictionary[grain_index]["added_left"] = added_left
                grains_dictionary[grain_index]["added_top"] = added_top
                grains_dictionary[grain_index]["padding"] = bbox_padding
                mask_crop = full_grain_mask[
                    bbox_padded[0] : bbox_padded[2],
                    bbox_padded[1] : bbox_padded[3],
                ]
                grains_dictionary[grain_index]["mask"] = mask_crop
                grains_dictionary[grain_index]["filename"] = file_data["filename"]
                grains_dictionary[grain_index]["pixel_to_nm_scaling"] = file_data["pixel_to_nm_scaling"]
                grains_dictionary[grain_index]["curvature_stats"] = curvatures

                # grab node coordinates
                all_node_coords = []
                if nodestats_data is not None:
                    try:
                        grain_nodestats_data = nodestats_data[current_grain_index]
                        for _node_index, node_data in grain_nodestats_data.items():
                            node_coords = node_data["node_coords"]
                            # adjust the node coords to account for the padding
                            node_coords[:, 0] -= added_top
                            node_coords[:, 1] -= added_left
                            for node_coord in node_coords:
                                all_node_coords.append(node_coord)
                    except KeyError as e:
                        if "grain_" in str(e):
                            # grain has no nodestats data here, skip
                            pass

                grains_dictionary[grain_index]["node_coords"] = np.array(all_node_coords)

                grain_index += 1
        except KeyError as e:
            if "ordered_traces" in str(e):
                print(f"no ordered traces found in {filename}")
                continue
            raise e

    if plot:
        for grain_index, grain_data in grains_dictionary.items():
            print(f"grain {grain_index}")
            print(grain_data["filename"])
            print(grain_data["pixel_to_nm_scaling"])
            image = grain_data["image"]
            plt.imshow(image)
            for molecule_index, molecule_data in grain_data["molecule_data"].items():
                ordered_coords = molecule_data["ordered_coords"]
                plt.plot(ordered_coords[:, 1], ordered_coords[:, 0], "r")
            all_node_coords = grain_data["node_coords"]
            if all_node_coords.size > 0:
                plt.plot(all_node_coords[:, 1], all_node_coords[:, 0], "b.")
            plt.show()

            mask = grain_data["mask"][:, :, 1]
            plt.imshow(mask)
            plt.show()

    print(f"found {len(grains_dictionary)} grains in {len(file_list)} images")

    return grains_dictionary


grains_dictionary = construct_grains_dictionary(topo_files, bbox_padding=20, stop_at_index=2, plot=True)