In [9]:
import os
import random

import torch

from PIL import Image

from torchvision import datasets, transforms

In [16]:
canvas_size = 400
image_size = 100  # resize 28x28 to 100x100
num_imgs = 10
pad = 20
save_root = './mnist_stim'

# 4 corner positions
positions = {
    'top_left': (pad, pad),
    'top_right': (canvas_size - image_size - pad, pad),
    'bottom_left': (pad, canvas_size - image_size - pad),
    'bottom_right': (canvas_size - image_size - pad, canvas_size - image_size - pad)
}

# --- Load MNIST training data ---
transform = transforms.ToTensor()
mnist = datasets.MNIST(root='./mnist', train=True, download=True, transform=transform)

# Organize by label
label_to_imgs = {i: [] for i in range(10)}
for img, label in mnist:
    label_to_imgs[label].append(img)

# Convert tensor to PIL
to_pil = transforms.ToPILImage()

In [17]:
def create_mnist_stimulus(mode='same_image'):
    canvas = Image.new('L', (canvas_size, canvas_size), 0)  # black canvas

    if mode == 'same_image':
        label = random.randint(0, 9)
        img = random.choice(label_to_imgs[label])
        pil_img = to_pil(img).resize((digit_size, digit_size))
        imgs = [pil_img, pil_img.copy()]
    elif mode == 'same_label':
        label = random.randint(0, 9)
        img1, img2 = random.sample(label_to_imgs[label], 2)
        imgs = [to_pil(img1).resize((digit_size, digit_size)),
                to_pil(img2).resize((digit_size, digit_size))]
    elif mode == 'different_label':
        label1, label2 = random.sample(range(10), 2)
        img1 = random.choice(label_to_imgs[label1])
        img2 = random.choice(label_to_imgs[label2])
        imgs = [to_pil(img1).resize((digit_size, digit_size)),
                to_pil(img2).resize((digit_size, digit_size))]
    else:
        raise ValueError(f"Unknown mode: {mode}")

    # Place in 2 random corners
    corners = random.sample(list(positions.values()), 2)
    for img, pos in zip(imgs, corners):
        canvas.paste(img, pos)

    return canvas

In [18]:
os.makedirs(save_root, exist_ok=True)

for mode in ['same_image', 'same_label', 'different_label']:
    save_dir = os.path.join(save_root, mode)
    os.makedirs(save_dir, exist_ok=True)

    for n in range(num_imgs):
        stim = create_mnist_stimulus(mode)
        stim.save(os.path.join(save_dir, f'mnist_stimulus_{n:02d}_{mode}.png'))