## Imports

In [29]:
import os
import numpy as np
from PIL import Image
import torch
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10, SVHN
from tqdm import tqdm
import matplotlib.pyplot as plt

## Methods

In [47]:
def rotate(X, angle_min, angle_max, grey=False):
    (H, W) = X.shape[1:3]
    xpad, pad = add_padding(X, H, grey)
    xrot = []
    for x in tqdm(xpad):
        xrot.append(rotate_img(x, angle_min, angle_max))
    return remove_padding(np.array(xrot), H, W, pad, grey)

def add_padding(X, H, grey):
    # Assuming X is squared
    pad = H / 2 * 2**0.5 * 2
    pad = int(np.ceil((pad - H) / 2))

    paddings = ((0, 0), (pad, pad), (pad, pad)) if grey else ((0, 0), (pad, pad), (pad, pad), (0, 0))

    return np.pad(X, paddings, mode='symmetric'), pad

def remove_padding(X, H, W, pad, grey):
    return X[:, pad:H+pad, pad:W+pad] if grey else X[:, :, pad:H+pad, pad:W+pad]

def rotate_img(X, r_min, r_max):
    angle = np.random.randint(r_min, r_max)
    return np.array(Image.fromarray((X*255).astype(np.uint8)).rotate(angle)).astype(np.float32)/255


def convert_to_lossless_polar(img):
    assert len(img.shape) == 3
    assert img.shape[0] == img.shape[1]
    M = img.shape[0]
    C = img.shape[-1]
    R = M//2 + 1
    W = (M+1)*4
    row, col = 0, 0
    new_img = np.zeros((R, W, C), dtype=np.float32)
    for channel in range(C):
        for r in range(R):
            flatidx = 0
            offset = 2*r + 3 
            for col in range(r, M-r-1):
                if col == r:
                    # repeat three times
                    new_img[r, flatidx:flatidx+offset-r*(r>0), channel] = img[row, col, channel]
                    flatidx += offset-r*(r>0)
                else:
                    new_img[r, flatidx, channel] = img[row, col, channel]
                    flatidx += 1
                # print(row,col)
            col += 1
            for row in range(r, M-r-1):
                if row == r:
                    # repeat three times
                    new_img[r, flatidx:flatidx+offset, channel] = img[row, col, channel]
                    flatidx += offset
                else:
                    new_img[r, flatidx, channel] = img[row, col, channel]
                    flatidx += 1
                # print(row,col)
            row += 1
            for col in range(M-r-1, r, -1):
                if col == (M-r-1):
                    # repeat three times
                    new_img[r, flatidx:flatidx+offset, channel] = img[row, col, channel]
                    flatidx += offset
                else:
                    new_img[r, flatidx, channel] = img[row, col, channel]
                    flatidx += 1
                # print(row,col)
            col -= 1
            for row in range(M-r-1, r, -1):
                if row == (M-r-1):
                    # repeat three times
                    new_img[r, flatidx:flatidx+offset, channel] = img[row, col, channel]
                    flatidx += offset
                else:
                    new_img[r, flatidx, channel] = img[row, col, channel]
                    flatidx += 1
                # print(row,col)
            # add remainder
            for e,re in enumerate(range(W-flatidx)):
                new_img[r, flatidx+re, channel] = new_img[r, e, channel]
        new_img[R-1,...,channel] = img[R-1,R-1,channel]
    return new_img

def SquareRotationalLayer(img):
    (B, C, H, W) = img.shape
    res = torch.zeros((B, C, H // 2, H * 4), dtype=img.dtype).to(img.device)
    # res1 = torch.zeros((B, C, (H+1) // 2, (H+1) * 4), dtype=img.dtype).to(img.device)
    lmid = np.floor((H - 1) / 2)
    for i in range(lmid, -1, -1):
        dif = lmid - i
        el = 2 * dif if H % 2 == 1 else 2 * dif + 1
        # top row
        res[:, :, dif, 0:1*i]               = img[:, :, i, i].repeat(1, 1, 1, 1*i).reshape(B, C, 1*i)
        res[:, :, dif, 1*i:1*i+el]          = img[:, :, i, i:i+el]
        # right column
        res[:, :, dif, 1*i+el:3*i+el]       = img[:, :, i, W - i - 1].repeat(1, 1, 1, 2*i).reshape(B, C, 2*i)
        res[:, :, dif, 3*i+el:3*i+el*2]     = img[:, :, i:i+el, W - i - 1]
        # bottom row
        res[:, :, dif, 3*i+el*2:5*i+el*2]   = img[:, :, i+el, i+el].repeat(1, 1, 1, 2*i).reshape(B, C, 2*i)
        res[:, :, dif, 5*i+2*el:5*i+3*el]   = torch.fliplr(img[:, :, i + el, i+1:i+1+el].transpose(1, 2)).transpose(1, 2)
        # left column
        res[:, :, dif, 5*i+3*el:7*i+3*el]   = img[:, :, i+el, i].repeat(1, 1, 1, 2*i).reshape(B, C, 2*i)
        res[:, :, dif, 7*i+3*el:7*i+4*el]   = torch.fliplr(img[:, :, i+1:i+1+el, i].transpose(1, 2)).transpose(1, 2)
        # Add missing initial corner
        res[:, :, dif, 7*i+4*el:8*i+4*el] = img[:, :, i, i].repeat(1, 1, 1, 1*i).reshape(B, C, 1*i)
    
    # el = 0
    # for i in range((H - 1) * 4):
    #     if i == 0 or i % (H - 1) == 0:
    #         res1[:, :, :, el:el+3] = res[:, :, :, i].repeat(1, 1, 1, 3).reshape(B, C, res.shape[2],3).transpose(2, 3).reshape(B, C, res.shape[2],3)
    #         el += 3
    #     else:
    #         res1[:, :, :, el] = res[:, :, :, i]
    #         el += 1
    res[:, :, :, 4*H-4:4*H+4] = res[:, :, :, 0:8]
    return res


## Create Directories

In [18]:
os.makedirs('data')

### CIFAR-10

In [48]:
cifar_train = CIFAR10('data/cifar-10', train=True, download=True)
cifar_test = CIFAR10('data/cifar-10', train=False, download=True)
cifar_train_x, cifar_train_y = cifar_train.data, cifar_train.targets
cifar_test_x, cifar_test_y = cifar_test.data, cifar_test.targets

cifar_train_x = np.array(cifar_train_x).astype(np.float32)/255.
cifar_train_y = np.array(cifar_train_y).astype(np.float32)/255.
cifar_train_y = cifar_train_y.ravel()
cifar_test_x  = np.array(cifar_test_x).astype(np.float32)/255.
cifar_test_y  = np.array(cifar_test_y).astype(np.float32)/255.
cifar_test_y  = cifar_test_y.ravel()

np.save('data/cifar-10/train_x', cifar_train_x)
np.save('data/cifar-10/train_y', cifar_train_y)
np.save('data/cifar-10/test_x', cifar_test_x)
np.save('data/cifar-10/test_y', cifar_test_y)

test_xrot_45 = rotate(cifar_test_x, -45, 45)
test_xrot_90 = rotate(cifar_test_x, -90, 90)
test_xrot_360 = rotate(cifar_test_x, 0, 359)

np.save('data/cifar-10/test_x_45', test_xrot_45)
np.save('data/cifar-10/test_x_90', test_xrot_90)
np.save('data/cifar-10/test_x_360', test_xrot_360)

sqrl_train_x = SquareRotationalLayer(np.transpose(cifar_train_x, (0, 3, 1, 2)))
sqrl_test_x = SquareRotationalLayer(np.transpose(cifar_test_x, (0, 3, 1, 2)))
sqrl_test_x_45 = SquareRotationalLayer(np.transpose(test_xrot_45, (0, 3, 1, 2)))
sqrl_test_x_90 = SquareRotationalLayer(np.transpose(test_xrot_90, (0, 3, 1, 2)))
sqrl_test_x_360 = SquareRotationalLayer(np.transpose(test_xrot_360, (0, 3, 1, 2)))

np.save('data/cifar-10/sqrl_train_x', sqrl_train_x)
np.save('data/cifar-10/sqrl_test_x', sqrl_test_x)
np.save('data/cifar-10/sqrl_test_x_45', sqrl_test_x_45)
np.save('data/cifar-10/sqrl_test_x_90', sqrl_test_x_90)
np.save('data/cifar-10/sqrl_test_x_360', sqrl_test_x_360)

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 10000/10000 [00:02<00:00, 4295.07it/s]
100%|██████████| 10000/10000 [00:02<00:00, 4331.09it/s]
100%|██████████| 10000/10000 [00:01<00:00, 5146.56it/s]


### Street View House Numbers

In [42]:
svhn_train = SVHN('data/svhn', split='train', download=True)
svhn_test = SVHN('data/svhn', split='test', download=True)
svhn_train_x, svhn_train_y = svhn_train.data, svhn_train.labels
svhn_test_x, svhn_test_y = svhn_test.data, svhn_test.labels

svhn_train_x = np.transpose(np.array(svhn_train_x).astype(np.float32)/255., (0, 2, 3, 1))
svhn_train_y = np.array(svhn_train_y).astype(np.float32)/255.
svhn_train_y = svhn_train_y.ravel()
svhn_test_x  = np.transpose(np.array(svhn_test_x).astype(np.float32)/255., (0, 2, 3, 1))
svhn_test_y  = np.array(svhn_test_y).astype(np.float32)/255.
svhn_test_y  = svhn_test_y.ravel()

np.save('data/svhn/train_x', svhn_train_x)
np.save('data/svhn/train_y', svhn_train_y)
np.save('data/svhn/test_x', svhn_test_x)
np.save('data/svhn/test_y', svhn_test_y)

test_xrot_45 = rotate(svhn_test_x, -45, 45)
test_xrot_90 = rotate(svhn_test_x, -90, 90)
test_xrot_360 = rotate(svhn_test_x, 0, 359)

np.save('data/svhn/test_x_45', test_xrot_45)
np.save('data/svhn/test_x_90', test_xrot_90)
np.save('data/svhn/test_x_360', test_xrot_360)


Using downloaded and verified file: circle/svhn/train_32x32.mat
Using downloaded and verified file: circle/svhn/test_32x32.mat


100%|██████████| 26032/26032 [00:06<00:00, 3814.21it/s]
100%|██████████| 26032/26032 [00:06<00:00, 4078.61it/s]
100%|██████████| 26032/26032 [00:05<00:00, 4401.52it/s]


### MNIST

In [50]:
mnist_train = MNIST('data/mnist', train=True, download=True)
mnist_test = MNIST('data/mnist', train=False, download=True)
mnist_train_x, mnist_train_y = mnist_train.data, mnist_train.targets
mnist_test_x, mnist_test_y = mnist_test.data, mnist_test.targets

mnist_train_x = np.array(mnist_train_x).astype(np.float32)/255.
mnist_train_y = np.array(mnist_train_y).astype(np.float32)/255.
mnist_train_y = mnist_train_y.ravel()
mnist_test_x  = np.array(mnist_test_x).astype(np.float32)/255.
mnist_test_y  = np.array(mnist_test_y).astype(np.float32)/255.
mnist_test_y  = mnist_test_y.ravel()

print(mnist_train_x.shape, mnist_train_y.shape)

np.save('data/mnist/train_x', mnist_train_x)
np.save('data/mnist/train_y', mnist_train_y)
np.save('data/mnist/test_x', mnist_test_x)
np.save('data/mnist/test_y', mnist_test_y)

test_xrot_45 = rotate(mnist_test_x, -45, 45, grey=True)
test_xrot_90 = rotate(mnist_test_x, -90, 90, grey=True)
test_xrot_360 = rotate(mnist_test_x, 0, 359, grey=True)

np.save('data/mnist/test_x_45', test_xrot_45)
np.save('data/mnist/test_x_90', test_xrot_90)
np.save('data/mnist/test_x_360', test_xrot_360)


(60000, 28, 28) (60000,)


100%|██████████| 10000/10000 [00:01<00:00, 5337.52it/s]
100%|██████████| 10000/10000 [00:02<00:00, 4805.59it/s]
100%|██████████| 10000/10000 [00:01<00:00, 6026.08it/s]


### Fashion MNIST

In [51]:
fmnist_train = FashionMNIST('data/fmnist', train=True, download=True)
fmnist_test = FashionMNIST('data/fmnist', train=False, download=True)
fmnist_train_x, fmnist_train_y = fmnist_train.data, fmnist_train.targets
fmnist_test_x, fmnist_test_y = fmnist_test.data, fmnist_test.targets

fmnist_train_x = np.array(fmnist_train_x).astype(np.float32)/255.
fmnist_train_y = np.array(fmnist_train_y).astype(np.float32)/255.
fmnist_train_y = fmnist_train_y.ravel()
fmnist_test_x  = np.array(fmnist_test_x).astype(np.float32)/255.
fmnist_test_y  = np.array(fmnist_test_y).astype(np.float32)/255.
fmnist_test_y  = fmnist_test_y.ravel()

np.save('data/fmnist/train_x', fmnist_train_x)
np.save('data/fmnist/train_y', fmnist_train_y)
np.save('data/fmnist/test_x', fmnist_test_x)
np.save('data/fmnist/test_y', fmnist_test_y)

test_xrot_45 = rotate(fmnist_test_x, -45, 45, grey=True)
test_xrot_90 = rotate(fmnist_test_x, -90, 90, grey=True)
test_xrot_360 = rotate(fmnist_test_x, 0, 359, grey=True)

np.save('data/fmnist/test_x_45', test_xrot_45)
np.save('data/fmnist/test_x_90', test_xrot_90)
np.save('data/fmnist/test_x_360', test_xrot_360)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to circle/fmnist/FashionMNIST/raw/train-images-idx3-ubyte.gz


26422272it [00:19, 1356405.01it/s]                              


Extracting circle/fmnist/FashionMNIST/raw/train-images-idx3-ubyte.gz to circle/fmnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to circle/fmnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz


29696it [00:00, 944199.31it/s]           


Extracting circle/fmnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz to circle/fmnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to circle/fmnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


4422656it [00:05, 759784.72it/s]                              


Extracting circle/fmnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to circle/fmnist/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to circle/fmnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


6144it [00:00, 5651272.76it/s]          

Extracting circle/fmnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to circle/fmnist/FashionMNIST/raw




100%|██████████| 10000/10000 [00:01<00:00, 5335.32it/s]
100%|██████████| 10000/10000 [00:01<00:00, 5291.87it/s]
100%|██████████| 10000/10000 [00:01<00:00, 5722.41it/s]
