In [1]:
import gzip
import pickle
import numpy as np
import argparse
import lie_learn.spaces.S2 as S2
from torchvision import datasets

In [2]:
NORTHPOLE_EPSILON = 1e-3





def rotate_grid(rot, grid):
    x, y, z = grid
    xyz = np.array((x, y, z))
    x_r, y_r, z_r = np.einsum('ij,jab->iab', rot, xyz)
    return x_r, y_r, z_r


def get_projection_grid(b, grid_type="Driscoll-Healy"):
    ''' returns the spherical grid in euclidean
    coordinates, where the sphere's center is moved
    to (0, 0, 1)'''
    theta, phi = S2.meshgrid(b=b, grid_type=grid_type)
    x_ = np.sin(theta) * np.cos(phi)
    y_ = np.sin(theta) * np.sin(phi)
    z_ = np.cos(theta)
    return x_, y_, z_





def project_sphere_on_xy_plane(grid, projection_origin):
    ''' returns xy coordinates on the plane
    obtained from projecting each point of
    the spherical grid along the ray from
    the projection origin through the sphere '''

    sx, sy, sz = projection_origin
    x, y, z = grid
    z = z.copy() + 1

    t = -z / (z - sz)
    qx = t * (x - sx) + x
    qy = t * (y - sy) + y

    xmin = 1/2 * (-1 - sx) + -1
    ymin = 1/2 * (-1 - sy) + -1

    # ensure that plane projection
    # ends up on southern hemisphere
    rx = (qx - xmin) / (2 * np.abs(xmin))
    ry = (qy - ymin) / (2 * np.abs(ymin))

    return rx, ry


def sample_within_bounds(signal, x, y, bounds):
    ''' '''
    xmin, xmax, ymin, ymax = bounds

    idxs = (xmin <= x) & (x < xmax) & (ymin <= y) & (y < ymax)

    if len(signal.shape) > 2:
        sample = np.zeros((signal.shape[0], x.shape[0], x.shape[1]))
        sample[:, idxs] = signal[:, x[idxs], y[idxs]]
    else:
        sample = np.zeros((x.shape[0], x.shape[1]))
        sample[idxs] = signal[x[idxs], y[idxs]]
    return sample


def sample_bilinear(signal, rx, ry):
    ''' '''

    signal_dim_x = signal.shape[1]
    signal_dim_y = signal.shape[2]

    rx *= signal_dim_x
    ry *= signal_dim_y

    # discretize sample position
    ix = rx.astype(int)
    iy = ry.astype(int)

    # obtain four sample coordinates
    ix0 = ix - 1
    iy0 = iy - 1
    ix1 = ix + 1
    iy1 = iy + 1

    bounds = (0, signal_dim_x, 0, signal_dim_y)

    # sample signal at each four positions
    signal_00 = sample_within_bounds(signal, ix0, iy0, bounds)
    signal_10 = sample_within_bounds(signal, ix1, iy0, bounds)
    signal_01 = sample_within_bounds(signal, ix0, iy1, bounds)
    signal_11 = sample_within_bounds(signal, ix1, iy1, bounds)

    # linear interpolation in x-direction
    fx1 = (ix1-rx) * signal_00 + (rx-ix0) * signal_10
    fx2 = (ix1-rx) * signal_01 + (rx-ix0) * signal_11

    # linear interpolation in y-direction
    return (iy1 - ry) * fx1 + (ry - iy0) * fx2


def project_2d_on_sphere(signal, grid, projection_origin=None):
    ''' '''
    if projection_origin is None:
        projection_origin = (0, 0, 2 + NORTHPOLE_EPSILON)

    rx, ry = project_sphere_on_xy_plane(grid, projection_origin)
    sample = sample_bilinear(signal, rx, ry)

    # ensure that only south hemisphere gets projected
    sample *= (grid[2] <= 1).astype(np.float64)

    # rescale signal to [0,1]
    sample_min = sample.min(axis=(1, 2)).reshape(-1, 1, 1)
    sample_max = sample.max(axis=(1, 2)).reshape(-1, 1, 1)

    sample = (sample - sample_min) / (sample_max - sample_min)
    sample *= 255
    sample = sample.astype(np.uint8)

    return sample

In [3]:
mnist_data_folder = "MNIST_data"
bandwidth = 30
chunk_size = 1000

In [4]:
trainset = datasets.MNIST(root=mnist_data_folder, train=True, download=True)
testset = datasets.MNIST(root=mnist_data_folder, train=False, download=True)

mnist_train = {}
mnist_train['images'] = trainset.data.numpy()
mnist_train['labels'] = trainset.targets.numpy()

mnist_test = {}
mnist_test['images'] = testset.data.numpy()
mnist_test['labels'] = testset.targets.numpy()

In [5]:
grid = get_projection_grid(b=bandwidth)

rot = np.array([[0, 0, -1], [0, -1, 0], [1, 0, 0]])
rotated_grid = rotate_grid(rot, grid)

train_set_sizes = [10000, 20000, 30000, 40000, 50000, 60000]

for n_samples in train_set_sizes:
    data = {}
    data['images'] = mnist_train['images'][:n_samples]
    data['labels'] = mnist_train['labels'][:n_samples]
    
    current = 0
    signals = data['images'].reshape(-1, 28, 28).astype(np.float64)
    n_signals = signals.shape[0]
    projections = np.ndarray((n_signals, 2 * bandwidth, 2 * bandwidth), dtype=np.uint8)

    while current < n_signals:
        idxs = np.arange(current, min(n_signals, current + chunk_size))
        chunk = signals[idxs]
        projections[idxs] = project_2d_on_sphere(chunk, rotated_grid)
        current += chunk_size
        print("\r{0}/{1}".format(current, n_signals), end="")
    print("")
    dataset = {
        'images': projections,
        'labels': data['labels']
    }
    output_file = "s2_mnist_train_sphere_center_" + str(n_samples) + ".gz"
    with gzip.open(output_file, 'wb') as f:
        pickle.dump(dataset, f)
        
    print(output_file, 'written')
        
        
data = {}
data['images'] = mnist_test['images']
data['labels'] = mnist_test['labels']

current = 0
signals = data['images'].reshape(-1, 28, 28).astype(np.float64)
n_signals = signals.shape[0]
projections = np.ndarray((n_signals, 2 * bandwidth, 2 * bandwidth), dtype=np.uint8)

while current < n_signals:
    idxs = np.arange(current, min(n_signals, current + chunk_size))
    chunk = signals[idxs]
    projections[idxs] = project_2d_on_sphere(chunk, rotated_grid)
    current += chunk_size
    print("\r{0}/{1}".format(current, n_signals), end="")
print("")
dataset = {
    'images': projections,
    'labels': data['labels']
}
output_file = "s2_mnist_test_sphere_center.gz"
with gzip.open(output_file, 'wb') as f:
    pickle.dump(dataset, f)
    
print(output_file, 'written')

10000/10000
s2_mnist_train_sphere_center_10000.gz written
20000/20000
s2_mnist_train_sphere_center_20000.gz written
30000/30000
s2_mnist_train_sphere_center_30000.gz written
40000/40000
s2_mnist_train_sphere_center_40000.gz written
50000/50000
s2_mnist_train_sphere_center_50000.gz written
60000/60000
s2_mnist_train_sphere_center_60000.gz written
10000/10000
s2_mnist_test_sphere_center.gz written
