In [1]:
# Imports as always...
import numpy as np

from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader

from scipy.spatial import cKDTree

from icoCNN.tools import icosahedral_grid_coordinates, random_icosahedral_rotation_matrix, rotate_signal

In [2]:
from scipts.datasets import IcosahedralMNIST

In [3]:
import warnings
warnings.filterwarnings('ignore')

# Datasets

This notebook produces pre-computed datasets. Projecting onto the sphere and icosahedron, and performing augmentation are all pretty expensive; pre-computing these things serves the same purpose and lets us save lots of time in training and evaluation of models.

In [7]:
def ico_symmetry_augment(ico_signal, idx=None):
    """
    Augments the given icosahedral signal by performing an icosahedral rotation.

    Args:
        - ico_signal (torch.Tensor): Shape [1, 1, 5, 2**r, 2**(r+1)].
        - idx: which symmetry to perform. If None, performs randomly.
    """

    # Sample a (possibly random) rotation matrix from the 60 icosahedral symmetries.
    rotation_matrix = random_icosahedral_rotation_matrix(idx)

    # Apply the rotation to the signal.
    rotated_signal = rotate_signal(ico_signal, rotation_matrix)

    return rotated_signal

In [5]:
def precompute_projected_mnist(coord_tensor, save_path='./ico_projected_mnist.pt', train=True, augment=False):
    """
    Precomputes the icosahedral projections of the MNIST dataset and saves it.

    Args:
        coord_tensor (torch.Tensor): The icosahedral grid coordinates.
        save_path (str): Path to save the projected dataset.
        train (bool): Whether to use training or test MNIST data.
    """

    # Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    mnist_data = datasets.MNIST(root='./data', train=train, download=True, transform=transform)

    projected_data = []
    labels = []

    for idx in tqdm(range(len(mnist_data)), desc='Projecting data.'):
        image, label = mnist_data[idx]
        image = image.squeeze().numpy()  # Convert to 28x28 numpy array.
        image = (image - image.min()) / (image.max() - image.min())  # Normalize to [0,1].

        # Generate spherical coordinates
        rows, cols = image.shape
        theta = np.linspace(0, np.pi, rows)  # Polar angle
        phi = np.linspace(0, 2 * np.pi, cols)  # Azimuthal angle
        theta, phi = np.meshgrid(theta, phi)

        # Convert to Cartesian coordinates
        x_proj = np.sin(theta) * np.cos(phi)
        y_proj = np.sin(theta) * np.sin(phi)
        z_proj = np.cos(theta)

        proj_coords = np.vstack([x_proj.ravel(), y_proj.ravel(), z_proj.ravel()]).T
        image_values = image.ravel()

        # KD-tree for nearest neighbor lookup
        tree = cKDTree(proj_coords)
        input_coords = coord_tensor.reshape(-1, 3).numpy()
        _, nearest_indices = tree.query(input_coords)

        # Map MNIST pixels to icosahedral grid
        interpolated_values = image_values[nearest_indices].reshape(coord_tensor.shape[:-1])

        # Convert to tensor format
        output_tensor = torch.tensor(interpolated_values, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, 5, 2**r, 2**(r+1)]

        if augment:
            output_tensor = ico_symmetry_augment(output_tensor)

        projected_data.append(output_tensor)
        labels.append(label)

    # Save as a dictionary
    torch.save({'data': torch.stack(projected_data), 'labels': torch.tensor(labels)}, save_path)
    print(f"Precomputed dataset saved at {save_path}")

In [6]:
# Pre-computing augmented datasets.
for r in range(2, 5):
    r = 4
    ico_grid = torch.tensor(icosahedral_grid_coordinates(r))
    precompute_projected_mnist(ico_grid, save_path=f'./data/IcoMNIST/ico_projected_mnist_augmented_(r={r}).pt', train=True, augment=True)
    break

Projecting data.:   0%|          | 0/60000 [00:00<?, ?it/s]

Precomputed dataset saved at ./data/IcoMNIST/ico_projected_mnist_augmented_(r=4).pt


In [8]:
def precompute_projected_mnist_all_symmetries(coord_tensor, save_path='./ico_projected_mnist.pt', train=True):
    """
    Precomputes the icosahedral projections of the MNIST dataset and saves it.

    Args:
        coord_tensor (torch.Tensor): The icosahedral grid coordinates.
        save_path (str): Path to save the projected dataset.
        train (bool): Whether to use training or test MNIST data.
    """

    # Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    mnist_data = datasets.MNIST(root='./data', train=train, download=True, transform=transform)

    projected_data = []
    labels = []

    for idx in tqdm(range(len(mnist_data)), desc='Projecting data.'):
        image, label = mnist_data[idx]
        image = image.squeeze().numpy()  # Convert to 28x28 numpy array.
        image = (image - image.min()) / (image.max() - image.min())  # Normalize to [0,1].

        # Generate spherical coordinates
        rows, cols = image.shape
        theta = np.linspace(0, np.pi, rows)  # Polar angle
        phi = np.linspace(0, 2 * np.pi, cols)  # Azimuthal angle
        theta, phi = np.meshgrid(theta, phi)

        # Convert to Cartesian coordinates
        x_proj = np.sin(theta) * np.cos(phi)
        y_proj = np.sin(theta) * np.sin(phi)
        z_proj = np.cos(theta)

        proj_coords = np.vstack([x_proj.ravel(), y_proj.ravel(), z_proj.ravel()]).T
        image_values = image.ravel()

        # KD-tree for nearest neighbor lookup
        tree = cKDTree(proj_coords)
        input_coords = coord_tensor.reshape(-1, 3).numpy()
        _, nearest_indices = tree.query(input_coords)

        # Map MNIST pixels to icosahedral grid
        interpolated_values = image_values[nearest_indices].reshape(coord_tensor.shape[:-1])

        # Convert to tensor format
        output_tensor = torch.tensor(interpolated_values, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, 5, 2**r, 2**(r+1)]

        # Produce an output tensor for all icosahedral symmetries.
        for i in range(60):
            augmented_tensor = ico_symmetry_augment(output_tensor, i)
            projected_data.append(augmented_tensor)
            labels.append(label)

    # Save as a dictionary
    torch.save({'data': torch.stack(projected_data), 'labels': torch.tensor(labels)}, save_path)
    print(f"Precomputed dataset saved at {save_path}")

In [10]:
# Pre-computing augmented datasets with ALL icosahedral symmetries.
for r in range(3, 4):
    ico_grid = torch.tensor(icosahedral_grid_coordinates(r))
    precompute_projected_mnist_all_symmetries(ico_grid, save_path=f'./data/IcoMNIST/ico_projected_mnist_all_symmetries_(r={r}).pt', train=True)

Projecting data.:   0%|          | 0/60000 [00:00<?, ?it/s]

Precomputed dataset saved at ./data/IcoMNIST/ico_projected_mnist_all_symmetries(r=2).pt


Projecting data.:   0%|          | 0/60000 [00:00<?, ?it/s]

KeyboardInterrupt: 