In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader, Subset
import torch.nn.functional as F

import gzip
import pickle

In [None]:
from scripts.dataset import *
from scripts.model import S2CNN, IcoCNN
from scripts.training import experiment, evaluate

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device}.')

In [None]:
# Storing models in a dictionary with keys "Domain (training augmentation)".
# Initialise if/when used -- better for the memory.
models = {
    'Spherical (none)': None,
    'Spherical (s2)'  : None,
    'Spherical (ico)' : None,
    
    'Icosahedral (none)': None,
    'Icosahedral (ico)' : None,
    'Icosahedral (s2)'  : None
}

In [None]:
# These are the architectures matching the data generated.
def new_spherical_model():
    return S2CNN(f_in=1, b_in=48, f_out=10)

def new_icosahedral_model():
    return IcoCNN(r=3, in_channels=1, out_channels=10, R_in=1, bias=True, smooth_vertices=True)

In [None]:
# Indices for train-val split.
n_train = 50000
train_idxs = np.arange(0, n_train)
val_idxs = np.arange(n_train, 60000)

# Training Loop

## Spherical CNN

In [None]:
# Spherical training data.
with gzip.open('./data/spherical_none_train.gz', 'rb') as file:
    dataset = pickle.load(file)
    spherical_none_train = TensorDataset(
        torch.from_numpy(
            dataset['images'][:, None, :, :].astype(np.float32)
        ),
        torch.from_numpy(dataset['labels'].astype(np.int64))
    )
with gzip.open('./data/spherical_ico_train.gz', 'rb') as file:
    dataset = pickle.load(file)
    spherical_ico_train = TensorDataset(
        torch.from_numpy(
            dataset['images'][:, None, :, :].astype(np.float32)
        ),
        torch.from_numpy(dataset['labels'].astype(np.int64))
    )
with gzip.open('./data/spherical_s2_train.gz', 'rb') as file:
    dataset = pickle.load(file)
    spherical_s2_train = TensorDataset(
        torch.from_numpy(
            dataset['images'][:, None, :, :].astype(np.float32)
        ),
        torch.from_numpy(dataset['labels'].astype(np.int64))
    )
    
# Train-val split.
spherical_none_val = Subset(spherical_none_train, val_idxs)
spherical_none_train = Subset(spherical_none_train, train_idxs)
spherical_ico_val = Subset(spherical_ico_train, val_idxs)
spherical_ico_train = Subset(spherical_ico_train, train_idxs)
spherical_s2_val = Subset(spherical_s2_train, val_idxs)
spherical_s2_train = Subset(spherical_s2_train, train_idxs)

# Dataloaders.
batch_size = 64
spherical_none_train_loader = DataLoader(spherical_none_train, batch_size, shuffle=True)
spherical_ico_train_loader = DataLoader(spherical_ico_train, batch_size, shuffle=True)
spherical_s2_train_loader = DataLoader(spherical_s2_train, batch_size, shuffle=True)
spherical_none_val_loader = DataLoader(spherical_none_val, batch_size, shuffle=False)
spherical_ico_val_loader = DataLoader(spherical_ico_val, batch_size, shuffle=False)
spherical_s2_val_loader = DataLoader(spherical_s2_val, batch_size, shuffle=False)

### No augmentation

In [None]:
# Model.
models['Spherical (none)'] = new_spherical_model()
models['Spherical (none)'].to(device)
print()

In [None]:
experiment(
    models['Spherical (none)'],
    spherical_none_train_loader,
    spherical_none_val_loader,
    device,
    lr=5e-3,
    n_epochs=20,
    verbose=False
)

In [None]:
torch.save(models['Spherical (none)'].state_dict(), './models/spherical_none.pth')

### Ico augmentation

In [None]:
# Model.
models['Spherical (ico)'] = new_spherical_model()
models['Spherical (ico)'].to(device)
print()

In [None]:
experiment(
    models['Spherical (ico)'],
    spherical_ico_train_loader,
    spherical_ico_val_loader,
    device,
    lr=5e-3,
    n_epochs=20,
    verbose=False
)

In [None]:
torch.save(models['Spherical (ico)'].state_dict(), './models/spherical_ico.pth')

### $S^2$ augmentation 

In [None]:
# Model.
models['Spherical (s2)'] = new_spherical_model()
models['Spherical (s2)'].to(device)
print()

In [None]:
experiment(
    models['Spherical (s2)'],
    spherical_s2_train_loader,
    spherical_s2_val_loader,
    device,
    lr=5e-3,
    n_epochs=20,
    verbose=False
)

In [None]:
torch.save(models['Spherical (s2)'].state_dict(), './models/spherical_s2.pth')

## Icosahedral CNN

In [None]:
# Icosahedral training data.
with gzip.open('./data/icosahedral_none_train.gz', 'rb') as file:
    dataset = pickle.load(file)
    icosahedral_none_train = TensorDataset(
        torch.from_numpy(
            np.array(dataset['images']).astype(np.float32)
        ),
        torch.from_numpy(dataset['labels'].astype(np.int64))
    )
with gzip.open('./data/icosahedral_ico_train.gz', 'rb') as file:
    dataset = pickle.load(file)
    icosahedral_ico_train = TensorDataset(
        torch.from_numpy(
            np.array(dataset['images']).astype(np.float32)
        ),
        torch.from_numpy(dataset['labels'].astype(np.int64))
    )
with gzip.open('./data/icosahedral_s2_train.gz', 'rb') as file:
    dataset = pickle.load(file)
    icosahedral_s2_train = TensorDataset(
        torch.from_numpy(
            np.array(dataset['images']).astype(np.float32)
        ),
        torch.from_numpy(dataset['labels'].astype(np.int64))
    )
    
# Train-val split.
icosahedral_none_val = Subset(icosahedral_none_train, val_idxs)
icosahedral_none_train = Subset(icosahedral_none_train, train_idxs)
icosahedral_ico_val = Subset(icosahedral_ico_train, val_idxs)
icosahedral_ico_train = Subset(icosahedral_ico_train, train_idxs)
icosahedral_s2_val = Subset(icosahedral_s2_train, val_idxs)
icosahedral_s2_train = Subset(icosahedral_s2_train, train_idxs)

# Dataloaders.
batch_size = 64
icosahedral_none_train_loader = DataLoader(icosahedral_none_train, batch_size, shuffle=True)
icosahedral_ico_train_loader = DataLoader(icosahedral_ico_train, batch_size, shuffle=True)
icosahedral_s2_train_loader = DataLoader(icosahedral_s2_train, batch_size, shuffle=True)
icosahedral_none_val_loader = DataLoader(icosahedral_none_val, batch_size, shuffle=False)
icosahedral_ico_val_loader = DataLoader(icosahedral_ico_val, batch_size, shuffle=False)
icosahedral_s2_val_loader = DataLoader(icosahedral_s2_val, batch_size, shuffle=False)

### No augmentation

In [None]:
# Model.
models['Icosahedral (none)'] = new_icosahedral_model()
models['Icosahedral (none)'].to(device)
print()

In [None]:
experiment(
    models['Icosahedral (none)'],
    icosahedral_none_train_loader,
    icosahedral_none_val_loader,
    device,
    lr=1e-4,
    verbose=False
)

In [None]:
torch.save(models['Icosahedral (none)'].state_dict(), './models/icosahedral_none.pth')

### Ico augmentation

In [None]:
# Model.
models['Icosahedral (ico)'] = new_icosahedral_model()
models['Icosahedral (ico)'].to(device)
print()

In [None]:
experiment(
    models['Icosahedral (ico)'],
    icosahedral_ico_train_loader,
    icosahedral_ico_val_loader,
    device,
    lr=1e-4,
    verbose=False
)

In [None]:
torch.save(models['Icosahedral (ico)'].state_dict(), './models/icosahedral_ico.pth')

### $S^2$ augmentation

In [None]:
# Model.
models['Icosahedral (s2)'] = new_icosahedral_model()
models['Icosahedral (s2)'].to(device)
print()

In [None]:
experiment(
    models['Icosahedral (s2)'],
    icosahedral_s2_train_loader,
    icosahedral_s2_val_loader,
    device,
    lr=1e-3,
    verbose=False
)

In [None]:
torch.save(models['Icosahedral (s2)'].state_dict(), './models/icosahedral_s2.pth')

In [None]:
models['Icosahedral (s2)'].load_state_dict(torch.load('./models/icosahedral_s2.pth'))

In [None]:
correct = 0
for x, y in tqdm(icosahedral_s2_val_loader):
    x, y = x.to(device), y.to(device)
    out = models['Icosahedral (ico)'](x)
    correct += (out.argmax(dim=1) == y).sum().item()
print(correct / len(icosahedral_s2_val_loader.dataset))

# Online Augmentation Training

Having fixed rotated samples in the training set doesn't always seem to work. And using all 60 symmetries in one go is just simply not feasible. So, as a middle ground, we'll randomly rotate the sample at train (or test) time to increase the probability that we see more symmetries as training goes on.

These work with different data. Instead of the projections onto the sphere, this works with the raw signals so that we may rotate then project ourselves during inference.

In [None]:
from tqdm.notebook import tqdm

In [None]:
train_dataset = MNIST(root='./data', train=True, download=True)
train_data = train_dataset.train_data.numpy().reshape(-1, 28, 28).astype(np.float64)
train_labels = train_dataset.train_labels.numpy()

train_loader = DataLoader(TensorDataset(
    torch.from_numpy(train_data[train_idxs]),
    torch.from_numpy(train_labels[train_idxs])
), batch_size=64)
val_loader = DataLoader(TensorDataset(
    torch.from_numpy(train_data[val_idxs]),
    torch.from_numpy(train_labels[val_idxs])
), batch_size=64)

model = new_spherical_model()
model.to(device)
optimiser = torch.optim.Adam(model.parameters(), lr=5e-3)
criterion = torch.nn.CrossEntropyLoss()

grid = get_projection_grid(b=24)

In [None]:
def augment_and_project(x, grid, augment):
    # Augmentation.
    if augment == 's2':
        # Random continuous rotation.
        rotation_matrix = random_rotation_matrix()
        rotated_grid = rotate_grid(rotation_matrix, grid)
    elif augment == 'ico':
        # Random discrete rotation from the icosahedral symmetry group.
        icosahedral_group = R.create_group('I')
        icosahedral_rotations = icosahedral_group.as_matrix()
        rotation_matrix = random.choice(icosahedral_rotations)
        rotated_grid = rotate_grid(rotation_matrix, grid)
    else:
        rotated_grid = grid
        
    # Project.
    x = project_2d_on_sphere(x.numpy(), rotated_grid)
    return torch.from_numpy(
        x[:, None, :, :].astype(np.float32)
    )

In [None]:
def train_one_epoch(model, grid, train_loader, optimizer, criterion, device, augment='none', ico_grid=None, verbose=False):
    model.train()
    running_loss = .0
    
    for x, y in (tqdm(train_loader, desc='Train.') if verbose else train_loader):
        # Reset gradients.
        optimizer.zero_grad()
        
        # Augment and project.
        x = augment_and_project(x, grid, augment)
        #print('after augment:', x.shape)
        
        # Onto the icosahedron.
        if ico_grid is not None:
            ico_x = []
            for signal in x:
                ico_x.append(s2_to_ico(signal, ico_grid))
            x = torch.tensor(np.array(ico_x))
            
        #print('after ico:', x.shape)
        
        # Move to device.
        x, y = x.to(device), y.to(device)
        
        # Model and loss.
        out = model(x)
        loss = criterion(out, y)
        
        # Backprop.
        loss.backward()
        optimizer.step()
        
        # Update loss.
        running_loss += loss.item()
        
    # Return average loss over the epoch.
    return running_loss / len(train_loader)

In [None]:
def evaluate_one_epoch(model, grid, test_loader, criterion, device, augment='none', ico_grid=None, verbose=False):
    model.eval()
    running_loss = 0.0
    correct = 0
    
    with torch.no_grad():
        for x, y in (tqdm(test_loader, desc='Eval.') if verbose else test_loader):
            
            # Augment and project.
            x = augment_and_project(x, grid, augment)
            
            # Onto the icosahedron.
            if ico_grid is not None:
                ico_x = []
                for signal in x:
                    ico_x.append(s2_to_ico(signal, ico_grid))
                x = torch.tensor(np.array(ico_x))
            
            # Move to device.
            x, y = x.to(device), y.to(device)
            
            # Model and loss.
            out = model(x)
            loss = criterion(out, y)
            
            # Update loss.
            running_loss += loss.item()

            # Argmax to get predicted label.
            pred = out.argmax(dim=1)

            # Update accuracy.
            correct += (pred == y).sum().item()
            
    # Return average loss and accuracy over the epoch.
    return running_loss / len(test_loader), correct / len(test_loader.dataset)

In [None]:
# Track the losses and val accuracies.
train_losses = []
val_losses = []
val_accs = []

n_epochs = 10
augment = 'ico'
for epoch in tqdm(range(1, n_epochs + 1), desc='Experiment.'):
    # Train.
    train_loss = train_one_epoch(model, grid, train_loader, optimiser, criterion, device, augment, verbose=True)

    # Evaluate.
    val_loss, val_acc = evaluate_one_epoch(model, grid, val_loader, criterion, device, augment, verbose=True)

    print(f'Epoch {epoch}/{n_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

    # Appends.
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

In [None]:
torch.save(model.state_dict(), './models/spherical_ico.pth')

In [None]:
ico_model = new_icosahedral_model()
ico_model.to(device)
optimiser = torch.optim.Adam(ico_model.parameters(), lr=5e-3)
criterion = torch.nn.CrossEntropyLoss()

grid = get_projection_grid(b=24)
ico_grid = icosahedral_grid_coordinates(r=3)

In [None]:
# Track the losses and val accuracies.
train_losses = []
val_losses = []
val_accs = []

n_epochs = 20
augment = 'none'
for epoch in tqdm(range(1, n_epochs + 1), desc='Experiment.'):
    # Train.
    train_loss = train_one_epoch(ico_model, grid, train_loader, optimiser, criterion, device, augment, ico_grid, verbose=True)

    # Evaluate.
    val_loss, val_acc = evaluate_one_epoch(ico_model, grid, val_loader, criterion, device, augment, ico_grid, verbose=True)

    print(f'Epoch {epoch}/{n_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

    # Appends.
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

In [None]:
torch.save(ico_model.state_dict(), './models/online_augment/icosahedral_none.pth')

In [None]:
ico_model = new_icosahedral_model()
ico_model.to(device)
optimiser = torch.optim.Adam(ico_model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

grid = get_projection_grid(b=24)
ico_grid = icosahedral_grid_coordinates(r=3)

In [None]:
# Track the losses and val accuracies.
train_losses = []
val_losses = []
val_accs = []

n_epochs = 20
augment = 'ico'
for epoch in tqdm(range(1, n_epochs + 1), desc='Experiment.'):
    # Train.
    train_loss = train_one_epoch(ico_model, grid, train_loader, optimiser, criterion, device, augment, ico_grid, verbose=False)

    # Evaluate.
    val_loss, val_acc = evaluate_one_epoch(ico_model, grid, val_loader, criterion, device, augment, ico_grid, verbose=False)

    print(f'Epoch {epoch}/{n_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

    # Appends.
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

In [None]:
torch.save(ico_model.state_dict(), './models/online_augment/icosahedral_ico.pth')

In [None]:
ico_model = new_icosahedral_model()
ico_model.to(device)
optimiser = torch.optim.Adam(ico_model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

grid = get_projection_grid(b=24)
ico_grid = icosahedral_grid_coordinates(r=3)

In [None]:
# Track the losses and val accuracies.
train_losses = []
val_losses = []
val_accs = []

n_epochs = 20
augment = 's2'
for epoch in tqdm(range(1, n_epochs + 1), desc='Experiment.'):
    # Train.
    train_loss = train_one_epoch(ico_model, grid, train_loader, optimiser, criterion, device, augment, ico_grid, verbose=False)

    # Evaluate.
    val_loss, val_acc = evaluate_one_epoch(ico_model, grid, val_loader, criterion, device, augment, ico_grid, verbose=False)

    print(f'Epoch {epoch}/{n_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

    # Appends.
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

In [None]:
torch.save(ico_model.state_dict(), './models/online_augment/icosahedral_s2.pth')

### Evaluating

In [None]:
test_dataset = MNIST(root='./data', train=False, download=True)
test_data = test_dataset.test_data.numpy().reshape(-1, 28, 28).astype(np.float64)
test_labels = test_dataset.test_labels.numpy()

test_loader = DataLoader(TensorDataset(
    torch.from_numpy(test_data),
    torch.from_numpy(test_labels)
), batch_size=64)

In [None]:
grid = get_projection_grid(b=24)
ico_grid = icosahedral_grid_coordinates(r=3)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
ico_models = {
    # Indexed by the dataset on which it was trained.
    'none': new_icosahedral_model(),
    'ico' : new_icosahedral_model(),
    's2'  : new_icosahedral_model()
}

for augment, model in ico_models.items():
    model.to(device)
    model.load_state_dict(torch.load(f'./models/online_augment/icosahedral_{augment}.pth'))

In [None]:
stats = {
    'none': {},
    'ico' : {},
    's2'  : {}
}

for test_augment, stat_dict in stats.items():
    for train_augment, model in ico_models.items():
        stat_dict[train_augment] = evaluate_one_epoch(
            model=model,
            grid=grid,
            test_loader=test_loader,
            criterion=criterion,
            device=device,
            augment=test_augment,
            ico_grid=ico_grid,
            verbose=True
        )

In [None]:
stats

In [None]:
def do_plot(save_path):
    # --- Data Preparation ---
    test_datasets = list(stats.keys()) # Categories for the x-axis ('none', 'ico', 's2')
    # Assuming train datasets are the same for all test datasets
    train_datasets = list(stats[test_datasets[0]].keys()) # Groups within each x-category
    
    # Extract the scores (first value of the tuple)
    # We want a structure where we have scores grouped by the training dataset
    scores_by_train = {train_ds: [] for train_ds in train_datasets}
    for test_ds in test_datasets:
        for train_ds in train_datasets:
            # Append the score for the current train_ds when tested on test_ds
            # Using [0] to get the first value (accuracy) from the tuple
            score = stats[test_ds][train_ds][1]
            scores_by_train[train_ds].append(score)
    
    # --- Plotting Setup ---
    n_test_datasets = len(test_datasets)
    n_train_datasets = len(train_datasets)
    
    # Calculate bar positions
    x_indices = np.arange(n_test_datasets) # Base positions for the groups [0, 1, 2]
    bar_width = 0.25 # Adjust as needed for spacing
    group_width = bar_width * n_train_datasets
    offset_for_centering = (group_width - bar_width) / 2
    
    fig, ax = plt.subplots(figsize=(8, 5)) # Adjust figure size if needed
    
    cmap_name = 'Set3'
    colormap = cm.get_cmap(cmap_name)
    colors = [colormap(i) for i in np.linspace(0, 0.85, n_train_datasets)]
    ax.set_prop_cycle(color=colors)
    
    # --- Create Bars ---
    for i, train_ds in enumerate(train_datasets):
        # Calculate the position for each bar in this group
        bar_positions = x_indices - offset_for_centering + i * bar_width
        # Get the scores for this training dataset across all test datasets
        scores = scores_by_train[train_ds]
        train_ds = train_ds.capitalize()
        # Plot the bars
        rects = ax.bar(bar_positions, scores, bar_width, label=f'{train_ds}')
        # Optional: Add labels on top of bars
        ax.bar_label(rects, padding=-18, fmt=f'%.2f\n\n{train_ds}') # Adjust formatting as needed
    
    # --- Customize Plot ---
    ax.set_xlabel('Test-time Augmentation', fontsize=14)
    ax.set_ylabel('Test Accuracy', fontsize=14)
    #ax.set_title('Model Performance by Training and Testing Dataset')
    ax.set_xticks(x_indices) # Set the positions of the x-axis ticks
    ax.set_xticklabels([x.capitalize() for x in test_datasets]) # Set the labels for the x-axis ticks
    ax.legend(title='Train-time\nAugmentation ', loc=(.088, .05)) # Add a legend to identify bar colors
    
    ax.spines['top'].set_visible(False) # Optional: Remove top border
    ax.spines['right'].set_visible(False) # Optional: Remove right border
    plt.tight_layout() # Adjust layout to prevent labels overlapping
    
    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight')
    
    # --- Show Plot ---
    plt.show()
    
do_plot('./figures/performance_by_augment.pdf')