In [None]:
from pathlib import Path
import re

import numpy as np
import matplotlib.pyplot as plt

from topostats.plottingfuncs import Colormap

colormap = Colormap()
cmap = colormap.get_cmap()

In [None]:
# Rename files

path = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/masks_extra_doritos_256/")
image_starting_index = 0
testing_mode = False
file_type = ".png"
prefix = "mask"

image_index = image_starting_index
files = list(path.glob(f"*{file_type}"))
# files.sort()
# Sort by the index in the filenaem
files.sort(key=lambda x: int(re.search(r"\d+", x.stem).group(0)))
for file in files:
    print(file.name)
    new_filename = path / f"{prefix}_{image_index}{file_type}"
    if testing_mode:
        print(f"renaming {file.name} to {new_filename.name}")
    else:
        file.rename(new_filename)
    image_index += 1

In [None]:
# Check that the files are renamed correctly

image_dir = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/images_extra_doritos_256/")
mask_dir = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/masks_extra_doritos_256/")

image_files_png = list(image_dir.glob(f"*.png"))
image_files_png.sort(key=lambda x: int(re.search(r"\d+", x.stem).group(0)))
image_files_npy = list(image_dir.glob(f"*.npy"))
image_files_npy.sort(key=lambda x: int(re.search(r"\d+", x.stem).group(0)))
mask_files_png = list(mask_dir.glob(f"*.png"))
mask_files_png.sort(key=lambda x: int(re.search(r"\d+", x.stem).group(0)))
mask_files_npy = list(mask_dir.glob(f"*.npy"))
mask_files_npy.sort(key=lambda x: int(re.search(r"\d+", x.stem).group(0)))


def plot_images(
    images_png: list, images_npy: list, masks_png: list, masks_npy: list, width=5, cmap=cmap, vmin=-8, vmax=8
):
    num_images = len(images_png)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 4, figsize=(30, 50))
    for i, (image_png, image_npy, mask_png, mask_npy) in enumerate(zip(images_png, images_npy, masks_png, masks_npy)):
        ax[i // width, i % width * 4].imshow(image_png, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 4].set_title("Image png")
        ax[i // width, i % width * 4 + 1].imshow(mask_png, cmap="gray")
        ax[i // width, i % width * 4 + 1].set_title("Mask png")
        ax[i // width, i % width * 4 + 2].imshow(image_npy, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 4 + 2].set_title("Image npy")
        ax[i // width, i % width * 4 + 3].imshow(mask_npy, cmap="gray")
        ax[i // width, i % width * 4 + 3].set_title("Mask npy")
    fig.tight_layout()
    plt.show()


plot_images(
    images_png=[plt.imread(image) for image in image_files_png],
    images_npy=[np.load(image) for image in image_files_npy],
    masks_png=[plt.imread(mask) for mask in mask_files_png],
    masks_npy=[np.load(mask) for mask in mask_files_npy],
)