In [1]:
import numpy as np
import os
import zipfile

import pytorch_lightning as pl
import requests
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import CIFAR10
from tqdm import tqdm

In [2]:
def convert_to_rgb(images: np.ndarray) -> np.ndarray:
    """
    Converts grayscale images to RGB. It changes NxHxWx1 to a NxHxWx3 array, where N is the number of figures,
    H is the high and W the width.
    :param images: Grayscale images of shape (NxHxWx1).
    :return: Images in RGB format of shape (NxHxWx3).
    """
    dims = np.shape(images)
    if not ((len(dims) == 4 and dims[-1] == 1) or len(dims) == 3):
        raise ValueError("Unexpected shape for grayscale images:" + str(dims))

    if dims[-1] == 1:
        # Squeeze channel axis if it exists
        rgb_images = np.squeeze(images, axis=-1)
    else:
        rgb_images = images
    rgb_images = np.stack((rgb_images,) * 3, axis=-1)

    return rgb_images
    
def create_sprite(images: np.ndarray) -> np.ndarray:
    """
    Creates a sprite of provided images.
    :param images: Images to construct the sprite.
    :return: An image array containing the sprite.
    """
    shape = np.shape(images)

    if len(shape) < 3 or len(shape) > 4:
        raise ValueError("Images provided for sprite have wrong dimensions " + str(len(shape)))

    if len(shape) == 3:
        # Check to see if it's MNIST type of images and add axis to show image is gray-scale
        images = np.expand_dims(images, axis=3)
        shape = np.shape(images)

    # Change black and white images to RGB
    if shape[3] == 1:
        images = convert_to_rgb(images)

    n = int(np.ceil(np.sqrt(images.shape[0])))
    padding = ((0, n ** 2 - images.shape[0]), (0, 0), (0, 0)) + ((0, 0),) * (images.ndim - 3)
    images = np.pad(images, padding, mode="constant", constant_values=0)

    # Tile the individual thumbnails into an image
    images = images.reshape((n, n) + images.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, images.ndim + 1)))
    images = images.reshape((n * images.shape[1], n * images.shape[3]) + images.shape[4:])
    sprite = (images * 255).astype(np.uint8)

    return sprite

In [6]:
# mean = (0.4914, 0.4822, 0.4465)
# std = (0.2471, 0.2435, 0.2616)
# transform = T.Compose(
#             [
#                 T.RandomCrop(32, padding=4),
#                 T.RandomHorizontalFlip(),
#                 T.ToTensor(),
#                 T.Normalize(mean, std),
#             ]
#         )
dataset = CIFAR10(root="data", train=True, transform=transform, download=True)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data


In [3]:
train_dataset = CIFAR10(root="data", train=True, download=True)
test_dataset = CIFAR10(root="data", train=False, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
data = np.concatenate((train_dataset.data, test_dataset.data), axis=0)/255

In [7]:
data.max(),data.min()

(1.0, 0.0)

In [8]:
sprite = create_sprite(data) 

In [9]:
file_name = "cifar10.png"
from PIL import Image
image = Image.fromarray(sprite)
image.save(file_name)