In [1]:
import gzip, pickle
import numpy as np, torch.nn.functional as F
from torchvision import datasets

In [2]:
mnist_data_folder = "MNIST_data"
img_size = (60,60)

assert (img_size[0] % 2 == 0 and img_size[1] % 2 == 0), 'img_size must be even in both dimensions.'
pd = [(img_size[0] - 28)//2, (img_size[1] - 28)//2]

In [3]:
trainset = datasets.MNIST(root=mnist_data_folder, train=True, download=True)
testset = datasets.MNIST(root=mnist_data_folder, train=False, download=True)

mnist_train = {}
mnist_train['images'] = F.pad(trainset.data, (pd[0], pd[0], pd[1], pd[1]), mode='constant', value=0).numpy()
mnist_train['labels'] = trainset.targets.numpy()

mnist_test = {}
mnist_test['images'] = F.pad(testset.data, (pd[0], pd[0], pd[1], pd[1]), mode='constant', value=0).numpy()
mnist_test['labels'] = testset.targets.numpy()


# train_set_sizes = [10000, 20000, 30000, 40000, 50000, 60000]
train_set_sizes = [80000, 100000, 120000, 240000, 600000]

for n_samples in train_set_sizes:

    images = mnist_train['images'].reshape(-1, *img_size).astype(np.float64)
    labels = mnist_train['labels']

    aug_images = np.ndarray((n_samples, *img_size), dtype=np.uint8)
    aug_labels = np.ndarray(n_samples, dtype=np.uint8)

    for i in range(n_samples):
        img_ind = np.random.randint(0, len(images) - 1)
        rand_shift = np.random.randint(0, img_size[0], 1).item(), np.random.randint(0, img_size[1], 1).item()
        
        aug_images[i] = np.roll(images[img_ind], rand_shift, (0,1))
        aug_labels[i] = labels[img_ind]

    dataset = {
        'images': aug_images,
        'labels': aug_labels
    }

    output_file = "flat_mnist_train_aug_" + str(img_size[0]) + "x" + str(img_size[1]) + "_" + str(n_samples) + ".gz"

    with gzip.open(output_file, 'wb') as f:
        pickle.dump(dataset, f)

    print(output_file, 'written')


# images = mnist_test['images'].reshape(-1, *img_size).astype(np.float64)
# labels = mnist_test['labels']

# n_samples = len(images)
# aug_images = np.ndarray((n_samples, *img_size), dtype=np.uint8)
# aug_labels = np.ndarray(n_samples, dtype=np.uint8)

# for i in range(n_samples):
#     rand_shift = np.random.randint(0, img_size[0], 1).item(), np.random.randint(0, img_size[1], 1).item()

#     aug_images[i] = np.roll(images[i], rand_shift, (0,1))
#     aug_labels[i] = labels[i]

# dataset = {
#     'images': aug_images,
#     'labels': aug_labels
# }

# output_file = "flat_mnist_test_aug_" + str(img_size[0]) + "x" + str(img_size[1]) + ".gz"

# with gzip.open(output_file, 'wb') as f:
#     pickle.dump(dataset, f)
    
# print(output_file, 'written')

flat_mnist_train_aug_60x60_80000.gz written
flat_mnist_train_aug_60x60_100000.gz written
flat_mnist_train_aug_60x60_120000.gz written
flat_mnist_train_aug_60x60_240000.gz written
flat_mnist_train_aug_60x60_600000.gz written
