In [1]:
import numpy as np
import cv2
import os
from torchvision.datasets import MNIST
from PIL import Image

In [2]:
os.makedirs("output/foreground_masks", exist_ok=True)
os.makedirs("output/circular_localization", exist_ok=True)
os.makedirs("output/pairwise_2x2/images", exist_ok=True)
os.makedirs("output/pairwise_2x2/masks", exist_ok=True)

In [3]:
try:
    mnist = MNIST(root="./data", train=True, download=True)
except Exception as e:
    print("MNIST download error suppressed. Proceeding if already downloaded.")

In [4]:
def get_otsu_mask(image):
    gray = np.array(image)
    _, mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    return mask

otsu_masks = []
for idx, (img, _) in enumerate(mnist):
    mask = get_otsu_mask(img)
    otsu_masks.append(mask)
    Image.fromarray(mask).save(f"output/foreground_masks/{idx:05d}_fgmask.png")

print(f"Saved {len(otsu_masks)} foreground (Otsu) masks in 'output/foreground_masks/'")

Saved 60000 foreground (Otsu) masks in 'output/foreground_masks/'


In [5]:
def get_tight_circle(mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    circle_mask = np.zeros_like(mask)
    if contours:
        cnt = max(contours, key=cv2.contourArea)
        (x, y), radius = cv2.minEnclosingCircle(cnt)
        cv2.circle(circle_mask, (int(x), int(y)), int(radius), 255, -1)
    return circle_mask

circle_masks = []
for idx, mask in enumerate(otsu_masks):
    circ_mask = get_tight_circle(mask)
    circle_masks.append(circ_mask)
    Image.fromarray(circ_mask).save(f"output/circular_localization/{idx:05d}_circle.png")

print(f"Saved {len(circle_masks)} circular localization masks in 'output/circular_localization/'")

Saved 60000 circular localization masks in 'output/circular_localization/'


In [6]:
def create_spatial_concat_dataset(images, masks):
    composite_images = []
    composite_masks = []

    for i in range(0, len(images) - 3, 4):
        imgs = [np.array(images[j]) for j in range(i, i + 4)]
        msks = [np.array(masks[j]) for j in range(i, i + 4)]

        top = np.hstack((imgs[0], imgs[1]))
        bottom = np.hstack((imgs[2], imgs[3]))
        new_img = np.vstack((top, bottom))

        top_mask = np.hstack((msks[0], msks[1]))
        bottom_mask = np.hstack((msks[2], msks[3]))
        new_mask = np.vstack((top_mask, bottom_mask))

        composite_images.append(new_img)
        composite_masks.append(new_mask)

        idx = i // 4
        Image.fromarray(new_img).save(f"output/pairwise_2x2/images/{idx:05d}_2x2.png")
        Image.fromarray(new_mask).save(f"output/pairwise_2x2/masks/{idx:05d}_mask.png")

    return composite_images, composite_masks

images = [img for img, _ in mnist]
composite_imgs, composite_msks = create_spatial_concat_dataset(images, otsu_masks)

print(f"Saved {len(composite_imgs)} 2x2 spatially concatenated image-mask pairs in 'output/pairwise_2x2/'")

Saved 15000 2x2 spatially concatenated image-mask pairs in 'output/pairwise_2x2/'
