In [None]:
import os
from glob import glob
from random import sample

import matplotlib.pyplot as plt
from PIL import Image


# Create dataset samples

In [None]:
# Size of grid
COLUMNS, ROWS = 7, 7
# Size of figure
FIGSIZE = (16, 12)


In [None]:
# Parameters
IMAGES_EXT = ["*.gif", "*.jpg", "*.jpeg", "*.png", "*.webp"]

# Dataset
DATASET = "UGallery"
assert DATASET in ["UGallery", "Wikimedia", "Pinterest"]

# Images path
IMAGES_DIR = None
if DATASET == "UGallery":
    # ~35s
    IMAGES_DIR = os.path.join("/", "mnt", "workspace", "Ugallery", "images")
elif DATASET == "Wikimedia":
    # ~1h 10m
    IMAGES_DIR = os.path.join("/", "mnt", "data2", "wikimedia", "images", "img")
elif DATASET == "Pinterest":
    # ~2h
    IMAGES_DIR = os.path.join("/", "mnt", "data2", "pinterest_iccv", "images")


In [None]:
import PIL
from PIL import ImageFile


# Needed for some images in Pinterest and Wikimedia dataset
PIL.Image.MAX_IMAGE_PIXELS = 3_000_000_000
# Some images are "broken" in Wikimedia dataset
ImageFile.LOAD_TRUNCATED_IMAGES = True


In [None]:
# Random sample of COLUMNS * ROWS images from DATASET
images = []
for ext in sorted(IMAGES_EXT):
    # List images in folder by pattern
    pattern = os.path.join(IMAGES_DIR, ext)
    # Use glob over iglob to sort and calculate length
    images.extend(sorted(glob(pattern)))

images = sample(images, COLUMNS * ROWS)


Code to create grid figure taken from [Image t-SNE, available at Nextjournal](https://nextjournal.com/ml4a/image-t-sne):

In [None]:
# Create grid figure
# (https://nextjournal.com/ml4a/image-t-sne)

tile_width = 72
tile_height = 56

full_width = tile_width * COLUMNS
full_height = tile_height * ROWS
aspect_ratio = float(tile_width) / tile_height

grid_image = Image.new("RGB", (full_width, full_height))

positions = [(x, y) for x in range(COLUMNS) for y in range(ROWS)]
for img, grid_pos in zip(images, positions):
    idx_x, idx_y = grid_pos
    x, y = tile_width * idx_x, tile_height * idx_y
    tile = Image.open(img)
    tile_ar = float(tile.width) / tile.height  # center-crop the tile to match aspect_ratio
    if (tile_ar > aspect_ratio):
        margin = 0.5 * (tile.width - aspect_ratio * tile.height)
        tile = tile.crop((margin, 0, margin + aspect_ratio * tile.height, tile.height))
    else:
        margin = 0.5 * (tile.height - float(tile.width) / aspect_ratio)
        tile = tile.crop((0, margin, tile.width, margin + float(tile.width) / aspect_ratio))
    tile = tile.resize((tile_width, tile_height), Image.ANTIALIAS)
    grid_image.paste(tile, (int(x), int(y)))

plt.figure(figsize=FIGSIZE)
plt.axis("off")
plt.imshow(grid_image)
