In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import re
import shutil

In [None]:
base_dir = Path("../data-preprocessing/")
assert base_dir.exists()
dir_labelled_raw = base_dir / "raw-masks"
assert dir_labelled_raw.exists()
dir_labelled_renamed = base_dir / "masks-renamed"
dir_labelled_renamed.mkdir(exist_ok=True)

files_npy = list(dir_labelled_raw.glob("*.npy"))

files = files_npy + files_png
print(f"Found {len(files)} files")

# for each file, rename
# files are in format: "task-54-restofname.npy"
# rename to: "mask-54.npy"
# we have two masks per task since we have two-class masking. need to combine the masks.

class_labels = {
    "grain": 1,
    "damage": 2,
}
allowed_missing_classes = ["damage"]
shape = (512, 512)

unique_task_numbers = set(
    int(re.search(r"task-(\d+)", str(file)).group(1)) for file in files
)
print(unique_task_numbers)
minimum_task_number = min(unique_task_numbers)

for task_number in unique_task_numbers:
    combined_class_mask = np.zeros(shape, dtype=np.uint8)
    for class_label, class_number in class_labels.items():
        pattern = rf"task-{task_number}*{class_label}*.npy"
        print(pattern)
        class_mask_npy_match = list(dir_labelled_raw.glob(pattern))
        if len(class_mask_npy_match) == 0:
            if class_label in allowed_missing_classes:
                print(f"Warning: {class_label} mask not found for task {task_number}.")
                continue
            else:
                raise FileNotFoundError(f"Mask not found for task {task_number} and class {class_label}.")

        # there are 1 or more class masks for this class, combine them
        combined_single_class_mask = np.zeros(shape, dtype=np.uint8)
        for mask_path in class_mask_npy_match:
            class_mask_npy = np.load(mask_path)
            combined_single_class_mask[class_mask_npy > 0] = class_number

        # add the class mask to the combined mask
        combined_class_mask[combined_single_class_mask > 0] = class_number

    # save the combined mask as npy and png
    filename_npy = f"mask-{task_number - minimum_task_number}.npy"
    filename_png = f"mask-{task_number - minimum_task_number}.png"
    np.save(dir_labelled_renamed / filename_npy, combined_class_mask)
    plt.imsave(dir_labelled_renamed / filename_png, combined_class_mask)