In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

In [None]:
DATA_DIR = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/")

IMAGE_DIR = DATA_DIR / "images"
MASKS_DIR = DATA_DIR / "ring_masks"
MULTICLASS_MASKS_DIR = DATA_DIR / "multiclass_masks"
CHOPPED_MASKS_DIR = DATA_DIR / "chopped_ring_masks"
CHOPPED_MASKS_DIR.mkdir(exist_ok=True)
CHOPPED_RING_IMAGES_DIR = DATA_DIR / "chopped_ring_images"
CHOPPED_RING_IMAGES_DIR.mkdir(exist_ok=True)

# Print the number of images and masks
images = sorted(list(IMAGE_DIR.glob("image_*.npy")))
masks = sorted(list(MASKS_DIR.glob("mask_*.npy")))
multiclass_masks = sorted(list(MULTICLASS_MASKS_DIR.glob("mask_*.npy")))
print(f"Found {len(images)} images")
print(f"Found {len(masks)} masks")
print(f"Found {len(multiclass_masks)} multiclass masks")

In [None]:
# For each mask, get the corresponding multiclass mask and remove the pixels that are part of the gem
for mask_file, multiclass_mask_file, image_file in zip(masks, multiclass_masks, images):
    print(f"Processing {image_file.name}")
    mask = np.load(mask_file)
    multiclass_mask = np.load(multiclass_mask_file)
    image = np.load(image_file)

    image_ring_only = image.copy()
    image_ring_only[multiclass_mask == 0] = 0
    image_ring_only[multiclass_mask == 2] = 0

    fig, ax = plt.subplots(1, 4)
    ax[0].imshow(image)
    ax[0].set_title("Image")
    ax[1].imshow(mask)
    ax[1].set_title("Mask")
    ax[2].imshow(multiclass_mask)
    ax[2].set_title("Multiclass mask")
    ax[3].imshow(image_ring_only)
    ax[3].set_title("Image ring only")
    plt.show()

    # Save the image with the ring only
    np.save(CHOPPED_RING_IMAGES_DIR / image_file.name, image_ring_only)
    # save as png too
    plt.imsave(CHOPPED_RING_IMAGES_DIR / f"{image_file.stem}.png", image_ring_only, cmap="viridis")