In [None]:
from pathlib import Path
import re
import shutil

import numpy as np
from skimage.morphology import dilation
import matplotlib.pyplot as plt
from PIL import Image

from topostats.plottingfuncs import Colormap

colormap = Colormap()
cmap = colormap.get_cmap()

In [None]:
def plot_images(images: list, masks: list, width=5, cmap=cmap, vmin=-8, vmax=8):
    num_images = len(images)
    fig, ax = plt.subplots(np.ceil(num_images / width).astype(int), width * 3, figsize=(30, 30))
    for i, (image, mask) in enumerate(zip(images, masks)):
        ax[i // width, i % width * 3].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3].axis("off")
        ax[i // width, i % width * 3 + 1].imshow(mask, cmap="binary")
        ax[i // width, i % width * 3 + 2].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[i // width, i % width * 3 + 2].imshow(mask, cmap="viridis", alpha=0.2)
    fig.tight_layout()
    plt.show()

In [None]:
# dna only rename images

mask_dir_png = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/masks_256_sharper_png")
mask_dir_npy = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/masks_256_sharper_npy")
mask_dir = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/masks_256_sharper")
image_dir = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/images_256")
assert mask_dir.exists()
assert mask_dir_png.exists()
assert mask_dir_npy.exists()
assert image_dir.exists()

# Get the lists of files
masks_png = list(mask_dir_png.glob("*.png"))
masks_npy = list(mask_dir_npy.glob("*.npy"))
print(f"len masks_png: {len(masks_png)}")
print(f"len masks_npy: {len(masks_npy)}")

# Sort the lists
masks_png_sorted = sorted(masks_png, key=lambda path: int(re.search(r"task-(\d+)-", str(path)).group(1)))

for i in range(len(masks_png)):
    # New file name: mask_{index}.png
    new_name = f"mask_{i}.png"
    # Don't rename it, copy it but with the new name
    shutil.copy(masks_png_sorted[i], mask_dir / new_name)

masks_npy_sorted = sorted(masks_npy, key=lambda path: int(re.search(r"task-(\d+)-", str(path)).group(1)))
for i in range(len(masks_npy)):
    # New file name: mask_{index}.npy
    new_name = f"mask_{i}.npy"
    # Don't rename it, copy it but with the new name
    shutil.copy(masks_npy_sorted[i], mask_dir / new_name)

# Display all images in a gallery with their masks
# Load all images and masks
images = list(image_dir.glob("*.npy"))
masks = list(mask_dir.glob("*.npy"))

# Sort the lists
images_sorted = sorted(images, key=lambda path: int(re.search(r"image_(\d+)", str(path)).group(1)))
masks_sorted = sorted(masks, key=lambda path: int(re.search(r"mask_(\d+)", str(path)).group(1)))

# Plot the images and masks
images = [np.load(image) for image in images_sorted]
masks = [np.load(mask) for mask in masks_sorted]
plot_images(images, masks)

In [None]:
# dna only view images
image_dir = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/images_256/")
mask_dir = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/masks_256/")
assert image_dir.exists()
assert mask_dir.exists()

# Plot images and masks side by side
image_paths = sorted(image_dir.glob("*.npy"))
mask_paths = sorted(mask_dir.glob("*.npy"))

for image_path, mask_path in zip(image_paths, mask_paths):
    image = np.load(image_path)
    mask = np.load(mask_path)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(image, cmap="gray")
    axes[1].imshow(mask, cmap="gray")
    plt.show()

In [None]:
image_dir = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_cas9/big_dataset/images_all_extra/")

In [None]:
# Resize the images_all to 256x256

images_256_dir = Path(
    "/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_cas9/big_dataset/images_all_extra_256/"
)

# Create the directory if it doesn't exist
if not images_256_dir.exists():
    images_256_dir.mkdir()

images = list(image_dir.glob("*.npy"))
print(f"Found {len(images)} images")

for image_path in images:
    image = np.load(image_path)
    image = Image.fromarray(image)
    image = image.resize((256, 256))
    image = np.array(image)
    print(f"image: {image_path} resized to {image.shape}")
    plt.imsave(images_256_dir / f"{image_path.stem}.png", image)
    np.save(images_256_dir / f"{image_path.stem}.npy", np.array(image))

In [None]:
mask_dir = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_cas9/big_dataset/masks_all_extra_256/")

In [None]:
# Get all the masks numpy and png files
mask_gem_npy_paths = list(mask_dir.glob("*gem*.npy"))
mask_ring_npy_paths = list(mask_dir.glob("*ring*.npy"))
mask_gem_png_paths = list(mask_dir.glob("*gem*.png"))
mask_ring_png_paths = list(mask_dir.glob("*ring*.png"))

print(
    f"gem npy: {len(mask_gem_npy_paths)} ring npy: {len(mask_ring_npy_paths)} gem png: {len(mask_gem_png_paths)} ring png: {len(mask_ring_png_paths)}"
)

# Sort the paths based on the number in the name, where the number comes after "task-" and before "-"
mask_gem_npy_paths = sorted(mask_gem_npy_paths, key=lambda path: int(re.search(r"task-(\d+)-", str(path)).group(1)))
mask_ring_npy_paths = sorted(mask_ring_npy_paths, key=lambda path: int(re.search(r"task-(\d+)-", str(path)).group(1)))
mask_gem_png_paths = sorted(mask_gem_png_paths, key=lambda path: int(re.search(r"task-(\d+)-", str(path)).group(1)))
mask_ring_png_paths = sorted(mask_ring_png_paths, key=lambda path: int(re.search(r"task-(\d+)-", str(path)).group(1)))


for i in range(len(mask_gem_npy_paths)):
    mask_gem_npy_new_name = mask_gem_npy_paths[i].parent / f"gem_{i}.npy"
    mask_ring_npy_new_name = mask_ring_npy_paths[i].parent / f"ring_{i}.npy"
    mask_gem_png_new_name = mask_gem_png_paths[i].parent / f"gem_{i}.png"
    mask_ring_png_new_name = mask_ring_png_paths[i].parent / f"ring_{i}.png"

    print(f"{mask_gem_npy_paths[i].stem} -> {mask_gem_npy_new_name.stem}")
    print(f"{mask_ring_npy_paths[i].stem} -> {mask_ring_npy_new_name.stem}")
    print(f"{mask_gem_png_paths[i].stem} -> {mask_gem_png_new_name.stem}")
    print(f"{mask_ring_png_paths[i].stem} -> {mask_ring_png_new_name.stem}")

    # # Rename the files
    mask_gem_npy_paths[i].rename(mask_gem_npy_paths[i].parent / f"gem_{i}.npy")
    mask_ring_npy_paths[i].rename(mask_ring_npy_paths[i].parent / f"ring_{i}.npy")
    mask_gem_png_paths[i].rename(mask_gem_png_paths[i].parent / f"gem_{i}.png")
    mask_ring_png_paths[i].rename(mask_ring_png_paths[i].parent / f"ring_{i}.png")

In [None]:
# Check data type of a mask
mask = np.load(mask_dir / "ring_0.npy")

print(mask.dtype)
print(mask.shape)
print(mask.min())
print(mask.max())

plt.imshow(mask)

In [None]:
# Combine the ring and gem masks into one mask

# mask_dir_separate_gem_ring = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_cas9/big_dataset/masks_all_extra_256/")

# DILATION_STRENGTH = 8

# Combine all ring and gem masks into one mask
# Find how many gem masks there are
num_gem_masks = len(list(mask_dir_separate_gem_ring.glob("gem*.npy")))
print(f"num gem masks: {num_gem_masks}")

for i in range(num_gem_masks):
    # Load the gem and ring masks
    gem_mask = np.load(mask_dir_separate_gem_ring / f"gem_{i}.npy")
    ring_mask = np.load(mask_dir_separate_gem_ring / f"ring_{i}.npy")

    # ring_mask_dilated = ring_mask > 0
    # for _ in range(DILATION_STRENGTH):
    # ring_mask_dilated = dilation(ring_mask_dilated)

    combined_mask = np.zeros_like(gem_mask)
    combined_mask[ring_mask > 0] = 1
    combined_mask[gem_mask > 0] = 2

    # Save the combined mask
    np.save(mask_dir_separate_gem_ring / f"mask_{i}.npy", combined_mask)

    # Save a png of the combined mask
    plt.imsave(mask_dir_separate_gem_ring / f"mask_{i}.png", combined_mask, cmap="viridis")

In [None]:
# Plot them all in a gallery

image_files = sorted(list(image_dir.glob("*.npy")))
mask_files = sorted(list(mask_dir.glob("*.npy")))


def plot_images(image_files, mask_files, width=5):
    num_files = len(image_files)
    fig, ax = plt.subplots(np.ceil(num_files / width).astype(int), width * 2, figsize=(20, 40))
    for i, (image_file, mask_file) in enumerate(zip(image_files, mask_files)):
        image = np.load(image_file)
        mask = np.load(mask_file)
        ax[i // width, i % width * 2].imshow(image, cmap="viridis")
        ax[i // width, i % width * 2].axis("off")
        ax[i // width, i % width * 2 + 1].imshow(mask, cmap="viridis")
        ax[i // width, i % width * 2 + 1].axis("off")


plot_images(
    image_files=sorted(list(image_dir.glob("image_*.npy"))),
    mask_files=sorted(list(mask_dir.glob("mask_*.npy"))),
)

In [None]:
image_paths = sorted(image_dir.glob("*.npy"))

for i in range(len(image_paths)):
    image = np.load(image_dir / f"image_{i}.npy")
    gem_mask = np.load(mask_dir / f"gem_{i}.npy")
    ring_mask = np.load(mask_dir / f"ring_{i}.npy")

    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(image)
    ax[1].imshow(gem_mask)
    ax[2].imshow(ring_mask)
    fig.suptitle(f"Image {i}")
    plt.show()

In [None]:
file_dir = Path("/Users/sylvi/topo_data/hariborings/training_data/cropped/dna_only/")

image_dir = file_dir / "images_256/"
print(f"image_dir: {image_dir}")
label_dir = file_dir / "masks_256/"
print(f"label_dir: {label_dir}")

# Get all the .npy files and sort them by the number in their filename, eg: image_1.npy
image_paths = list(image_dir.glob("*.npy"))
image_paths = sorted(image_paths, key=lambda f: int("".join(filter(str.isdigit, f.name))))
label_paths = list(label_dir.glob("*.npy"))
label_paths = sorted(label_paths, key=lambda f: int("".join(filter(str.isdigit, f.name))))

print(f"Number of images: {len(image_paths)}")
print(f"Number of labels: {len(label_paths)}")

for image_path, label_path in zip(image_paths, label_paths):
    image = np.load(image_path)
    # Load the png image file, convert to numpy array
    image_png_path = image_path.with_suffix(".png")
    image_png = np.array(Image.open(image_png_path))

    label = np.load(label_path)
    # Load the png label file, convert to numpy array
    label_png_path = label_path.with_suffix(".png")
    label_png = np.array(Image.open(label_png_path))

    # Plot them side by side and then overlaid
    fig, axes = plt.subplots(1, 5, figsize=(12, 4))
    axes[0].imshow(image)
    axes[0].set_title("Image")
    axes[1].imshow(image_png)
    axes[1].set_title("Image PNG")
    axes[2].imshow(label)
    axes[2].set_title("Label")
    axes[3].imshow(label_png)
    axes[3].set_title("Label PNG")
    axes[4].imshow(label)
    axes[4].imshow(image, alpha=0.9)
    axes[4].set_title("Overlay")
    # Set title to the name of the file
    fig.suptitle(f"{image_path.name} {label_path.name}")
    plt.show()