In [None]:
from pathlib import Path
import PIL
import matplotlib.pyplot as plt
import numpy as np
from ruamel.yaml import YAML
from skimage.morphology import label
import importlib.resources as pkg_resources

from topostats.io import save_topostats_file
from topostats.grains import Grains, GrainCrop, GrainCropsDirection, ImageGrainCrops
from topostats.processing import run_grainstats
from topostats.io import read_yaml
import topostats

base_dir = Path("/Users/sylvi/topo_data/perovskites-mountains")
assert base_dir.exists()
data_dir = base_dir / "freqsplit"
assert data_dir.exists()

testdir = data_dir / "cutoff_freq_nm-225" / "1st samples- just retrace - 1-2um pyramids" / "C60 data"

config_yaml_files = list(testdir.glob("*_config.yaml"))

yaml = YAML()
config_topostats = read_yaml(Path("./topostats") / "default_config.yaml")
plotting_dictionary = Path("./topostats/plotting_dictionary.yaml")
config_topostats["plotting"]["plot_dict"] = read_yaml(plotting_dictionary)
config_plotting = config_topostats["plotting"]

In [None]:
for config_yaml_file in config_yaml_files:
    filename = config_yaml_file.stem
    file_directory = testdir
    # strip off the _config
    filename = filename[:-7]
    print(filename)

    filename_files = list(testdir.glob(f"{filename}*"))
    for filename_file in filename_files:
        print(filename_file)

    raw_image_file = testdir / f"{filename}_original.npy"
    assert raw_image_file.exists()
    processed_image_file = testdir / f"{filename}_high_pass.npy"
    assert processed_image_file.exists()
    mask_file = testdir / f"{filename}_mask.npy"
    assert mask_file.exists()
    mask_yaml_file = testdir / f"{filename}_mask.yaml"
    assert mask_yaml_file.exists()

    raw_image = np.load(raw_image_file)
    processed_image = np.load(processed_image_file)
    mask = np.load(mask_file).astype(bool)
    mask = np.invert(mask)
    plt.imshow(label(mask, connectivity=1))
    plt.show()

    # construct topostats file

    topostats_dict = {}
    topostats_dict["image_original"] = raw_image
    topostats_dict["image"] = processed_image
    # topostats_dict["mask"]
    grain_mask_tensor_background = np.zeros_like(mask)
    grain_mask_tensor = np.stack(
        [
            grain_mask_tensor_background,
            mask,
        ],
        axis=-1,
    )
    Grains.update_background_class(grain_mask_tensor)

    topostats_dict["grain_tensors"] = {"above": grain_mask_tensor}
    # add yaml data as metadata
    yaml = YAML()
    with mask_yaml_file.open("r") as f:
        mask_yaml = yaml.load(f)
    topostats_dict["mask_metadata"] = mask_yaml
    with config_yaml_file.open("r") as f:
        config_yaml = yaml.load(f)
    topostats_dict["image_metadata"] = config_yaml

    pixel_to_nm_scaling = config_yaml["pixel_to_nm_scaling"]

    graincrops = Grains.extract_grains_from_full_image_tensor(
        image=processed_image,
        full_mask_tensor=grain_mask_tensor,
        padding=10,
        pixel_to_nm_scaling=pixel_to_nm_scaling,
        filename=filename,
    )

    print(f"Extracted {len(graincrops)} grains")

    graincropsdirection = GrainCropsDirection(
        crops=graincrops,
        full_mask_tensor=grain_mask_tensor,
    )

    imagegraincrops = ImageGrainCrops(
        above=graincropsdirection,
        below=None,
    )

    # run grainstats
    grainstats_df, height_profiles = run_grainstats(
        image_grain_crops=imagegraincrops,
        filename=filename,
        basename=file_directory,
        grainstats_config={"run": True},
        plotting_config=config_plotting,
        grain_out_path=file_directory,
    )

    save_topostats_file(output_dir=testdir, filename=f"{filename}.topostats", topostats_object=topostats_dict)

    break