## Imports

In [3]:
import os
import numpy as np
from PIL import Image
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10, SVHN
from tqdm import tqdm
import matplotlib.pyplot as plt

## Methods

In [55]:
def rotate(X, angle_min, angle_max):
    (H, W) = X.shape[1:3]
    xpad, pad = add_padding(X, H)
    xrot = []
    for x in tqdm(xpad):
        xrot.append(rotate_img(x, angle_min, angle_max))
    return remove_padding(np.array(xrot), H, W, pad)

def add_padding(X, H):
    # Assuming X is squared
    pad = H / 2 * 2**0.5 * 2
    pad = int(np.ceil((pad - H) / 2))

    paddings = ((0, 0), (pad, pad), (pad, pad)) if len(X.shape) == 3 else ((0, 0), (pad, pad), (pad, pad), (0, 0))

    return np.pad(X, paddings, mode='symmetric'), pad

def remove_padding(X, H, W, pad):
    return X[:, pad:H+pad, pad:W+pad]

def rotate_img(X, r_min, r_max):
    angle = np.random.randint(r_min, r_max)
    return np.array(Image.fromarray((X*255).astype(np.uint8)).rotate(angle)).astype(np.float32)/255


def convert_to_lossless_polar(img):
    assert len(img.shape) == 3
    assert img.shape[0] == img.shape[1]
    M = img.shape[0]
    C = img.shape[-1]
    R = M//2 + 1
    W = (M+1)*4
    row, col = 0, 0
    new_img = np.zeros((R, W, C), dtype=np.float32)
    for channel in range(C):
        for r in range(R):
            flatidx = 0
            offset = 2*r + 3 
            for col in range(r, M-r-1):
                if col == r:
                    # repeat three times
                    new_img[r, flatidx:flatidx+offset-r*(r>0), channel] = img[row, col, channel]
                    flatidx += offset-r*(r>0)
                else:
                    new_img[r, flatidx, channel] = img[row, col, channel]
                    flatidx += 1
                # print(row,col)
            col += 1
            for row in range(r, M-r-1):
                if row == r:
                    # repeat three times
                    new_img[r, flatidx:flatidx+offset, channel] = img[row, col, channel]
                    flatidx += offset
                else:
                    new_img[r, flatidx, channel] = img[row, col, channel]
                    flatidx += 1
                # print(row,col)
            row += 1
            for col in range(M-r-1, r, -1):
                if col == (M-r-1):
                    # repeat three times
                    new_img[r, flatidx:flatidx+offset, channel] = img[row, col, channel]
                    flatidx += offset
                else:
                    new_img[r, flatidx, channel] = img[row, col, channel]
                    flatidx += 1
                # print(row,col)
            col -= 1
            for row in range(M-r-1, r, -1):
                if row == (M-r-1):
                    # repeat three times
                    new_img[r, flatidx:flatidx+offset, channel] = img[row, col, channel]
                    flatidx += offset
                else:
                    new_img[r, flatidx, channel] = img[row, col, channel]
                    flatidx += 1
                # print(row,col)
            # add remainder
            for e,re in enumerate(range(W-flatidx)):
                new_img[r, flatidx+re, channel] = new_img[r, e, channel]
        new_img[R-1,...,channel] = img[R-1,R-1,channel]
    return new_img

def SquareRotationalLayer(img):
    (B, C, H, W) = img.shape
    res = np.zeros((B, C, int(np.ceil(H / 2)), H * 4), dtype=img.dtype)
    # res1 = torch.zeros((B, C, (H+1) // 2, (H+1) * 4), dtype=img.dtype).to(img.device)
    lmid = int(np.floor((H - 1) / 2))
    for i in range(lmid, -1, -1):
        dif = lmid - i
        el = (2 * dif) if H % 2 == 1 else (2 * dif + 1)
        # top row
        res[:, :, dif, 0:1*i]               = img[:, :, i, i].repeat(1*i).reshape(B, C, 1*i)
        res[:, :, dif, 1*i:1*i+el]          = img[:, :, i, i:i+el]
        # right column
        res[:, :, dif, 1*i+el:3*i+el]       = img[:, :, i, W - i - 1].repeat(2*i).reshape(B, C, 2*i)
        res[:, :, dif, 3*i+el:3*i+el*2]     = img[:, :, i:i+el, W - i - 1]
        # bottom row
        res[:, :, dif, 3*i+el*2:5*i+el*2]   = img[:, :, i+el, i+el].repeat(2*i).reshape(B, C, 2*i)
        if el != 0: res[:, :, dif, 5*i+2*el:5*i+3*el]   = np.fliplr(img[:, :, i + el, i+1:i+1+el].transpose(0, 2, 1)).transpose(0, 2, 1)
        # left column
        res[:, :, dif, 5*i+3*el:7*i+3*el]   = img[:, :, i+el, i].repeat(2*i).reshape(B, C, 2*i)
        if el != 0: res[:, :, dif, 7*i+3*el:7*i+4*el]   = np.fliplr(img[:, :, i+1:i+1+el, i].transpose(0, 2, 1)).transpose(0, 2, 1)
        # Add missing initial corner
        res[:, :, dif, 7*i+4*el:8*i+4*el] = img[:, :, i, i].repeat(1*i).reshape(B, C, 1*i)
    
    # el = 0
    # for i in range((H - 1) * 4):
    #     if i == 0 or i % (H - 1) == 0:
    #         res1[:, :, :, el:el+3] = res[:, :, :, i].repeat(1, 1, 1, 3).reshape(B, C, res.shape[2],3).transpose(2, 3).reshape(B, C, res.shape[2],3)
    #         el += 3
    #     else:
    #         res1[:, :, :, el] = res[:, :, :, i]
    #         el += 1
    res[:, :, :, 4*H-4:4*H+4] = res[:, :, :, 0:4]
    return res


In [None]:
np.set_printoptions(linewidth=2000)
z = np.arange(1, 32*32*3*10+1).reshape((10, 3, 32, 32))
x = np.arange(1, 26).reshape((1, 1, 5, 5))
print(z)
res = SquareRotationalLayer(z)
print(res)

## Create Directories

In [4]:
os.makedirs('data')

### CIFAR-10

In [None]:
cifar_train = CIFAR10('data/cifar-10', train=True, download=True)
cifar_test = CIFAR10('data/cifar-10', train=False, download=True)
cifar_train_x, cifar_train_y = cifar_train.data, cifar_train.targets
cifar_test_x, cifar_test_y = cifar_test.data, cifar_test.targets

cifar_train_x = np.array(cifar_train_x).astype(np.float32)/255.
cifar_train_y = np.array(cifar_train_y).astype(np.float32)/255.
cifar_train_y = cifar_train_y.ravel()
cifar_test_x  = np.array(cifar_test_x).astype(np.float32)/255.
cifar_test_y  = np.array(cifar_test_y).astype(np.float32)/255.
cifar_test_y  = cifar_test_y.ravel()

np.save('data/cifar-10/train_x', cifar_train_x)
np.save('data/cifar-10/train_y', cifar_train_y)
np.save('data/cifar-10/test_x', cifar_test_x)
np.save('data/cifar-10/test_y', cifar_test_y)

test_xrot_45 = rotate(cifar_test_x, -45, 45)
test_xrot_90 = rotate(cifar_test_x, -90, 90)
test_xrot_360 = rotate(cifar_test_x, 0, 359)

np.save('data/cifar-10/test_x_45', test_xrot_45)
np.save('data/cifar-10/test_x_90', test_xrot_90)
np.save('data/cifar-10/test_x_360', test_xrot_360)

sqrl_train_x = SquareRotationalLayer(np.transpose(cifar_train_x, (0, 3, 1, 2)))
sqrl_test_x = SquareRotationalLayer(np.transpose(cifar_test_x, (0, 3, 1, 2)))
sqrl_test_x_45 = SquareRotationalLayer(np.transpose(test_xrot_45, (0, 3, 1, 2)))
sqrl_test_x_90 = SquareRotationalLayer(np.transpose(test_xrot_90, (0, 3, 1, 2)))
sqrl_test_x_360 = SquareRotationalLayer(np.transpose(test_xrot_360, (0, 3, 1, 2)))

np.save('data/cifar-10/sqrl_train_x', sqrl_train_x)
np.save('data/cifar-10/sqrl_test_x', sqrl_test_x)
np.save('data/cifar-10/sqrl_test_x_45', sqrl_test_x_45)
np.save('data/cifar-10/sqrl_test_x_90', sqrl_test_x_90)
np.save('data/cifar-10/sqrl_test_x_360', sqrl_test_x_360)

### Street View House Numbers

In [None]:
svhn_train = SVHN('data/svhn', split='train', download=True)
svhn_test = SVHN('data/svhn', split='test', download=True)
svhn_train_x, svhn_train_y = svhn_train.data, svhn_train.labels
svhn_test_x, svhn_test_y = svhn_test.data, svhn_test.labels

svhn_train_x = np.transpose(np.array(svhn_train_x).astype(np.float32)/255., (0, 2, 3, 1))
svhn_train_y = np.array(svhn_train_y).astype(np.float32)/255.
svhn_train_y = svhn_train_y.ravel()
svhn_test_x  = np.transpose(np.array(svhn_test_x).astype(np.float32)/255., (0, 2, 3, 1))
svhn_test_y  = np.array(svhn_test_y).astype(np.float32)/255.
svhn_test_y  = svhn_test_y.ravel()

np.save('data/svhn/train_x', svhn_train_x)
np.save('data/svhn/train_y', svhn_train_y)
np.save('data/svhn/test_x', svhn_test_x)
np.save('data/svhn/test_y', svhn_test_y)

test_xrot_45 = rotate(svhn_test_x, -45, 45)
test_xrot_90 = rotate(svhn_test_x, -90, 90)
test_xrot_360 = rotate(svhn_test_x, 0, 359)

np.save('data/svhn/test_x_45', test_xrot_45)
np.save('data/svhn/test_x_90', test_xrot_90)
np.save('data/svhn/test_x_360', test_xrot_360)


### MNIST

In [None]:
mnist_train = MNIST('data/mnist', train=True, download=True)
mnist_test = MNIST('data/mnist', train=False, download=True)
mnist_train_x, mnist_train_y = mnist_train.data, mnist_train.targets
mnist_test_x, mnist_test_y = mnist_test.data, mnist_test.targets

mnist_train_x = np.array(mnist_train_x).astype(np.float32)/255.
mnist_train_y = np.array(mnist_train_y).astype(np.float32)/255.
mnist_train_y = mnist_train_y.ravel()
mnist_test_x  = np.array(mnist_test_x).astype(np.float32)/255.
mnist_test_y  = np.array(mnist_test_y).astype(np.float32)/255.
mnist_test_y  = mnist_test_y.ravel()

print(mnist_train_x.shape, mnist_train_y.shape)

np.save('data/mnist/train_x', mnist_train_x)
np.save('data/mnist/train_y', mnist_train_y)
np.save('data/mnist/test_x', mnist_test_x)
np.save('data/mnist/test_y', mnist_test_y)

test_xrot_45 = rotate(mnist_test_x, -45, 45)
test_xrot_90 = rotate(mnist_test_x, -90, 90)
test_xrot_360 = rotate(mnist_test_x, 0, 359)

np.save('data/mnist/test_x_45', test_xrot_45)
np.save('data/mnist/test_x_90', test_xrot_90)
np.save('data/mnist/test_x_360', test_xrot_360)


### Fashion MNIST

In [None]:
fmnist_train = FashionMNIST('data/fmnist', train=True, download=True)
fmnist_test = FashionMNIST('data/fmnist', train=False, download=True)
fmnist_train_x, fmnist_train_y = fmnist_train.data, fmnist_train.targets
fmnist_test_x, fmnist_test_y = fmnist_test.data, fmnist_test.targets

fmnist_train_x = np.array(fmnist_train_x).astype(np.float32)/255.
fmnist_train_y = np.array(fmnist_train_y).astype(np.float32)/255.
fmnist_train_y = fmnist_train_y.ravel()
fmnist_test_x  = np.array(fmnist_test_x).astype(np.float32)/255.
fmnist_test_y  = np.array(fmnist_test_y).astype(np.float32)/255.
fmnist_test_y  = fmnist_test_y.ravel()

np.save('data/fmnist/train_x', fmnist_train_x)
np.save('data/fmnist/train_y', fmnist_train_y)
np.save('data/fmnist/test_x', fmnist_test_x)
np.save('data/fmnist/test_y', fmnist_test_y)

test_xrot_45 = rotate(fmnist_test_x, -45, 45)
test_xrot_90 = rotate(fmnist_test_x, -90, 90)
test_xrot_360 = rotate(fmnist_test_x, 0, 359)

np.save('data/fmnist/test_x_45', test_xrot_45)
np.save('data/fmnist/test_x_90', test_xrot_90)
np.save('data/fmnist/test_x_360', test_xrot_360)
