In [None]:
import pickle
from datetime import datetime
from pathlib import Path

# clear cell output command
from IPython.display import clear_output

from beaks import construct_grains_dictionary

today = datetime.today().strftime("D-%Y-%m-%d-T-%H-%M")
print(today)

In [None]:
base_dir = Path("/Users/sylvi/topo_data/beaks")
assert base_dir.exists()
beak_topo_data_dir = base_dir / "output-beaks-topostats-unet-good"
assert beak_topo_data_dir.exists()

hummingbird_dir = beak_topo_data_dir / "hummingbird/processed/"
assert hummingbird_dir.exists()
magpie_dir = beak_topo_data_dir / "magpie/processed/"
assert magpie_dir.exists()

# grab files in both directories ending in .topostats
hummingbird_files = hummingbird_dir.glob("*.topostats")
magpie_files = magpie_dir.glob("*.topostats")
# merge lists
all_files = list(hummingbird_files) + list(magpie_files)
print(f"Found {len(all_files)} files")

## construct grain dictionary

In [None]:
grains_dictionary = construct_grains_dictionary(file_list=all_files, plot=False)
clear_output()
print(f"found {len(grains_dictionary)} grains in {len(all_files)} files")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from topostats.io import LoadScans
from beaks import make_bounding_box_square, pad_bounding_box

plot = False

grains_dictionary: dict[any] = {}

loadscans = LoadScans(all_files, channel="dummy")
loadscans.get_data()
img_dict = loadscans.img_dict

bbox_padding = 10
grain_index = 0
for filename, file_data in img_dict.items():
    try:
        print(f"image {filename}")
        # print(f"getting data from {filename}")
        # print(file_data.keys())
        try:
            nodestats_data = file_data["nodestats"]["above"]["stats"]
        except KeyError as e:
            nodestats_data = None

        image = file_data["image"]
        ordered_trace_data = file_data["ordered_traces"]["above"]
        for current_grain_index, grain_ordered_trace_data in ordered_trace_data.items():
            # print(f"  grain {current_grain_index}")
            grains_dictionary[grain_index] = {}
            grains_dictionary[grain_index]["molecule_data"] = {}
            grain_node_coordinates = []
            for current_molecule_index, molecule_ordered_trace_data in grain_ordered_trace_data.items():
                molecule_data = {}
                molecule_data["ordered_coords"] = molecule_ordered_trace_data["ordered_coords"]
                molecule_data["heights"] = molecule_ordered_trace_data["heights"]
                molecule_data["distances"] = molecule_ordered_trace_data["distances"]
                bbox = molecule_ordered_trace_data["bbox"]
                grains_dictionary[grain_index]["molecule_data"][current_molecule_index] = molecule_data

                splining_coords = file_data["splining"]["above"][current_grain_index][current_molecule_index][
                    "spline_coords"
                ]
                molecule_data["spline_coords"] = splining_coords

                # print(molecule_ordered_trace_data.keys())
            bbox_square = make_bounding_box_square(bbox[0], bbox[1], bbox[2], bbox[3], image.shape)
            bbox_padded = pad_bounding_box(
                bbox_square[0], bbox_square[1], bbox_square[2], bbox_square[3], image.shape, padding=bbox_padding
            )
            added_left = bbox_padded[1] - bbox[1]
            added_top = bbox_padded[0] - bbox[0]

            image_crop = image[
                bbox_padded[0] : bbox_padded[2],
                bbox_padded[1] : bbox_padded[3],
            ]
            full_grain_mask = file_data["grain_masks"]["above"]
            grains_dictionary[grain_index]["image"] = image_crop
            grains_dictionary[grain_index]["full_image"] = image
            grains_dictionary[grain_index]["bbox"] = bbox_padded
            grains_dictionary[grain_index]["added_left"] = added_left
            grains_dictionary[grain_index]["added_top"] = added_top
            grains_dictionary[grain_index]["padding"] = bbox_padding
            mask_crop = full_grain_mask[
                bbox_padded[0] : bbox_padded[2],
                bbox_padded[1] : bbox_padded[3],
            ]
            grains_dictionary[grain_index]["mask"] = mask_crop
            grains_dictionary[grain_index]["filename"] = file_data["filename"]
            grains_dictionary[grain_index]["pixel_to_nm_scaling"] = file_data["pixel_to_nm_scaling"]

            # grab node coordinates
            all_node_coords = []
            if nodestats_data is not None:
                try:
                    grain_nodestats_data = nodestats_data[current_grain_index]
                    for node_index, node_data in grain_nodestats_data.items():
                        node_coords = node_data["node_coords"]
                        for node_coord in node_coords:
                            all_node_coords.append(node_coord)
                except KeyError as e:
                    if "grain_" in str(e):
                        # grain has no nodestats data here, skip
                        pass

            grains_dictionary[grain_index]["node_coords"] = np.array(all_node_coords)

            grain_index += 1
    except KeyError as e:
        if "ordered_traces" in str(e):
            print(f"no ordered traces found in {filename}")
            continue
        raise e

if plot:
    for grain_index, grain_data in grains_dictionary.items():
        print(f"grain {grain_index}")
        print(grain_data["filename"])
        print(grain_data["pixel_to_nm_scaling"])
        image = grain_data["image"]
        plt.imshow(image)
        for molecule_index, molecule_data in grain_data["molecule_data"].items():
            ordered_coords = molecule_data["ordered_coords"]
            plt.plot(ordered_coords[:, 1], ordered_coords[:, 0], "r")
        all_node_coords = grain_data["node_coords"]
        if all_node_coords.size > 0:
            plt.plot(all_node_coords[:, 1], all_node_coords[:, 0], "b.")
        plt.show()

        mask = grain_data["mask"][:, :, 1]
        plt.imshow(mask)
        plt.show()

# save the grain dictionary

In [33]:
grains_with_beaks = [7, 13, 16, 26, 28, 30, 34, 36, 37, 38, 43, 45, 54, 55, 58, 65, 66, 71, 72, 79, 82, 89, 91, 92, 93]

grains_with_beaks_dictionary = {}

for grain_index in grains_with_beaks:
    grains_with_beaks_dictionary[grain_index] = grains_dictionary[grain_index]

# save the grains with beaks dictionary
grains_with_beaks_save_dir = base_dir / "grain-dictionaries"
assert grains_with_beaks_save_dir.exists()
with open(grains_with_beaks_save_dir / f"grains_with_beaks_{today}.pkl", "wb") as f:
    pickle.dump(grains_with_beaks_dictionary, f)