In [31]:
# Imports as always...
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset

from icoCNN.tools import random_icosahedral_rotation_matrix, rotate_signal

In [2]:
from scipts import datasets, models, training

In [3]:
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.io as pio

plt.style.use('ggplot')
pio.renderers.default = 'notebook'

In [4]:
import warnings
warnings.filterwarnings('ignore')

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using {device}.')

Using cuda.


# Experiments

In [6]:
def experiment(model_type, r, augment, train=True, save_model=False, save_stats=False, batch_size=128, n_epochs=10, lr=1e-3, verbose=False, print_interval=1):
    assert model_type in ['ico', 'spherical'], 'Unrecognised model type.'
    
    # Output dictionary.
    output = {}
    
    # Dataset.
    if model_type == 'ico':
        dataset = datasets.PrecomputedIcosahedralMNIST(r, augment)
    else:
        # TODO: Pre-computed spherical dataset.
        dataset = None
    
    # Train-val-test split (deterministic). We'll do 2/3, 1/6, 1/6.
    N = len(dataset)
    train_indices = list(range(0, 2*N//3))
    val_indices = list(range(2*N//3, 5*N//6))
    test_indices = list(range(5*N//6, N))
    
    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    test_dataset = Subset(dataset, test_indices)
    
    # Dataloaders.
    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False)
    
    # Model.
    if model_type == 'ico':
        model = models.IcoCNNClassifierNoPooling(
            r,
            in_channels=1, 
            out_channels=10, 
            R_in=1, 
            bias=True, 
            smooth_vertices=True
        )
    else:
        # TODO: Spherical CNN.
        model = None
        
    model.to(device)
    output['Model'] = model
        
    # Train.
    if train:
        output['Train Stats'] = training.experiment(
            model,
            train_loader,
            val_loader,
            device,
            n_epochs,
            lr,
            verbose,
            print_interval
        )
        
        # Save.
        if save_stats:
            # TODO: Save train stats.
            pass
        
        if save_model:
            if augment == 'none':
                torch.save(model.state_dict(), f'./models/IcoCNN/ico{r}.pth')
            elif augment == 'ico':
                torch.save(model.state_dict(), f'./models/IcoCNN/ico{r}_augmented.pth')
            elif augment == 'all ico':
                torch.save(model.state_dict(), f'./models/IcoCNN/ico{r}_all_symmetries.pth')
            else:
                # TODO: The others.
                pass
        
    # Or load.
    else:
        if augment == 'none':
            model.load_state_dict(torch.load(f'./models/IcoCNN/ico{r}.pth'))
        elif augment == 'ico':
            model.load_state_dict(torch.load(f'./models/IcoCNN/ico{r}_augmented.pth'))
        elif augment == 'all ico':
            model.load_state_dict(torch.load(f'./models/IcoCNN/ico{r}_all_symmetries.pth'))
        else:
            # TODO: The others.
            pass
        
    # Evaluate on test set.
    output['Test Stats'] = training.evaluate(model, test_loader, torch.nn.CrossEntropyLoss(), device)
    
    return output

In [8]:
output = experiment(
    model_type='ico',
    r=2,
    augment='all ico',
    train=True,
    save_model=True,
    save_stats=False,
    batch_size=128,
    n_epochs=10,
    lr=1e-3,
    verbose=True,
    print_interval=1
)

All Icosahedral Symmetries Augmentation.


Experiment.:   0%|          | 0/10 [00:00<?, ?it/s]

Train.:   0%|          | 0/18750 [00:00<?, ?it/s]

Eval.:   0%|          | 0/4688 [00:00<?, ?it/s]

Epoch 1/10, Train Loss: 0.7635, Val Loss: 0.6480, Val Acc: 0.7904


Train.:   0%|          | 0/18750 [00:00<?, ?it/s]

Eval.:   0%|          | 0/4688 [00:00<?, ?it/s]

Epoch 2/10, Train Loss: 0.5820, Val Loss: 0.6011, Val Acc: 0.8059


Train.:   0%|          | 0/18750 [00:00<?, ?it/s]

Eval.:   0%|          | 0/4688 [00:00<?, ?it/s]

Epoch 3/10, Train Loss: 0.5424, Val Loss: 0.5617, Val Acc: 0.8191


Train.:   0%|          | 0/18750 [00:00<?, ?it/s]

Eval.:   0%|          | 0/4688 [00:00<?, ?it/s]

Epoch 4/10, Train Loss: 0.5214, Val Loss: 0.5529, Val Acc: 0.8228


Train.:   0%|          | 0/18750 [00:00<?, ?it/s]

Eval.:   0%|          | 0/4688 [00:00<?, ?it/s]

Epoch 5/10, Train Loss: 0.5074, Val Loss: 0.5388, Val Acc: 0.8270


Train.:   0%|          | 0/18750 [00:00<?, ?it/s]

Eval.:   0%|          | 0/4688 [00:00<?, ?it/s]

Epoch 6/10, Train Loss: 0.4976, Val Loss: 0.5269, Val Acc: 0.8314


Train.:   0%|          | 0/18750 [00:00<?, ?it/s]

Eval.:   0%|          | 0/4688 [00:00<?, ?it/s]

Epoch 7/10, Train Loss: 0.4901, Val Loss: 0.5232, Val Acc: 0.8314


Train.:   0%|          | 0/18750 [00:00<?, ?it/s]

Eval.:   0%|          | 0/4688 [00:00<?, ?it/s]

Epoch 8/10, Train Loss: 0.4836, Val Loss: 0.5260, Val Acc: 0.8315


Train.:   0%|          | 0/18750 [00:00<?, ?it/s]

Eval.:   0%|          | 0/4688 [00:00<?, ?it/s]

Epoch 9/10, Train Loss: 0.4786, Val Loss: 0.5136, Val Acc: 0.8354


Train.:   0%|          | 0/18750 [00:00<?, ?it/s]

Eval.:   0%|          | 0/4688 [00:00<?, ?it/s]

Epoch 10/10, Train Loss: 0.4752, Val Loss: 0.5147, Val Acc: 0.8346


## IcoCNN Experiments

In [8]:
# Pre-defined model (models are trained for the below architecture).
def blank_model(r):
    return models.IcoCNNClassifierNoPooling(
        r,
        in_channels=1, 
        out_channels=10, 
        R_in=1, 
        bias=True, 
        smooth_vertices=True
    )

In [139]:
a = -torch.rand((10,))
b = torch.rand((10,))

torch.norm(F.softmax(a) - F.softmax(b))

tensor(0.1392)

In [147]:
# Measuring invariance error for a given model and sample.
def measure_invariance_error(model, x, x_rot):
    return F.kl_div(F.log_softmax(model(x)), F.log_softmax(model(x_rot)), log_target=True).item()

    #return torch.norm(F.softmax(model(x)) - F.softmax(model(x_rot))).item()

In [150]:
# Test datasets for evaluation.
full_dataset = datasets.PrecomputedIcosahedralMNIST(r=2, augment='none')
pre_aug_dataset = datasets.PrecomputedIcosahedralMNIST(r=2, augment='ico')
test_indices = list(range(5*len(full_dataset)//6, len(full_dataset)))
test_dataset = Subset(full_dataset, test_indices)
test_aug_dataset = Subset(pre_aug_dataset, test_indices)

No Augmentation.
Random Icosahedral Rotation Augmentation.


In [78]:
# Load in the non-augmented models.
ico2_none = blank_model(r=2)
ico2_none.load_state_dict(torch.load('./models/IcoCNN/ico2.pth'))
ico2_augmented = blank_model(r=2)
ico2_augmented.load_state_dict(torch.load('./models/IcoCNN/ico2_augmented.pth'))
ico2_all_symmetries = blank_model(r=2)
ico2_all_symmetries.load_state_dict(torch.load('./models/IcoCNN/ico2_all_symmetries.pth'))

# Eval mode.
ico2_none.eval()
ico2_augmented.eval()
ico2_all_symmetries.eval()

print()



In [151]:
# Tracking loss, accuracy, and invariance error (discrepancy between f(x) and Rf(x)).
stats = {
    # TODO: SO(3) rotations.
    # Key convention: 'model - dataset'; e.g. 'none - ico' means the model trained for no augmentations evaluated on random icosahedral rotation. 
    'none - none': {'loss' : .0, 'acc': .0, 'ico err': .0},
    'none - ico': {'loss' : .0, 'acc': .0, 'ico err': .0},
    #'none - all ico': {'loss' : .0, 'acc': .0, 'ico err': .0},
    'ico - none': {'loss' : .0, 'acc': .0, 'ico err': .0},
    'ico - ico': {'loss' : .0, 'acc': .0, 'ico err': .0},
    #'ico - all ico': {'loss' : .0, 'acc': .0, 'ico err': .0},
    'all ico - none': {'loss' : .0, 'acc': .0, 'ico err': .0},
    'all ico - ico': {'loss' : .0, 'acc': .0, 'ico err': .0},
    #'all ico - all ico': {'loss' : .0, 'acc': .0, 'ico err': .0}
}

# Loss.
loss_fn = nn.CrossEntropyLoss()
with torch.no_grad():
    for i in tqdm(range(len(test_dataset))):
        x, y = test_dataset[i]
        x_augmented = test_aug_dataset[i][0]
        
        # Unsqueeze to yield a batch dimension.
        x_batch, y_batch = x.unsqueeze(1), y.unsqueeze(-1)
        x_augmented_batch = x_augmented.unsqueeze(1)
        
        # --- None augmentation ---
        
        # Evaluate models on non-augmented sample.
        out_none = ico2_none(x_batch)
        out_ico = ico2_augmented(x_batch)
        out_all_ico = ico2_all_symmetries(x_batch)
        
        # Add the stats.
        stats['none - none']['loss'] += loss_fn(out_none, y_batch).item()
        stats['ico - none']['loss'] += loss_fn(out_ico, y_batch).item()
        stats['all ico - none']['loss'] += loss_fn(out_all_ico, y_batch).item()
        stats['none - none']['acc'] += int(out_none.argmax() == y)
        stats['ico - none']['acc'] += int(out_ico.argmax() == y)
        stats['all ico - none']['acc'] += int(out_all_ico.argmax() == y)
        
        # Invariance error.
        # Sample random (icosahedral) rotation, apply to the input signal, and pass to model.
        x_batch_ico_rot = rotate_signal(x, random_icosahedral_rotation_matrix()).unsqueeze(1)
        
        stats['none - none']['ico err'] += measure_invariance_error(ico2_none, x_batch, x_batch_ico_rot)
        stats['ico - none']['ico err'] += measure_invariance_error(ico2_augmented, x_batch, x_batch_ico_rot)
        stats['all ico - none']['ico err'] += measure_invariance_error(ico2_all_symmetries, x_batch, x_batch_ico_rot)
        
        # --- Random Icosahedral Rotation Augmentation ---
        
        # Evaluate models on random icosahedral rotations.
        out_none_aug = ico2_none(x_augmented_batch)
        out_ico_aug = ico2_augmented(x_augmented_batch)
        out_all_ico_aug = ico2_all_symmetries(x_augmented_batch)
        
        # Add the stats.
        stats['none - ico']['loss'] += loss_fn(out_none_aug, y_batch).item()
        stats['ico - ico']['loss'] += loss_fn(out_ico_aug, y_batch).item()
        stats['all ico - ico']['loss'] += loss_fn(out_all_ico_aug, y_batch).item()
        stats['none - ico']['acc'] += int(out_none_aug.argmax() == y)
        stats['ico - ico']['acc'] += int(out_ico_aug.argmax() == y)
        stats['all ico - ico']['acc'] += int(out_all_ico_aug.argmax() == y)
        
        # Invariance error.
        x_aug_batch_ico_rot = rotate_signal(x_augmented, random_icosahedral_rotation_matrix()).unsqueeze(1)
        
        stats['none - ico']['ico err'] += measure_invariance_error(ico2_none, x_augmented_batch, x_aug_batch_ico_rot)
        stats['ico - ico']['ico err'] += measure_invariance_error(ico2_augmented, x_augmented_batch, x_aug_batch_ico_rot)
        stats['all ico - ico']['ico err'] += measure_invariance_error(ico2_all_symmetries, x_augmented_batch, x_aug_batch_ico_rot)

    
# Averaging.
stats['none - none']['loss'] /= len(test_dataset)
stats['ico - none']['loss'] /= len(test_dataset)
stats['all ico - none']['loss'] /= len(test_dataset)
stats['none - none']['acc'] /= len(test_dataset)
stats['ico - none']['acc'] /= len(test_dataset)
stats['all ico - none']['acc'] /= len(test_dataset)
stats['none - none']['ico err'] /= len(test_dataset)
stats['ico - none']['ico err'] /= len(test_dataset)
stats['all ico - none']['ico err'] /= len(test_dataset)

stats['none - ico']['loss'] /= len(test_dataset)
stats['ico - ico']['loss'] /= len(test_dataset)
stats['all ico - ico']['loss'] /= len(test_dataset)
stats['none - ico']['acc'] /= len(test_dataset)
stats['ico - ico']['acc'] /= len(test_dataset)
stats['all ico - ico']['acc'] /= len(test_dataset)
stats['none - ico']['ico err'] /= len(test_dataset)
stats['ico - ico']['ico err'] /= len(test_dataset)
stats['all ico - ico']['ico err'] /= len(test_dataset)

# Rounding.
for i, a in stats.items():
    for j, b in a.items():
        stats[i][j] = round(b, 4)

  0%|          | 0/10000 [00:00<?, ?it/s]

In [152]:
display(stats)

{'none - none': {'loss': 0.0968, 'acc': 0.9765, 'ico err': 2.1749},
 'none - ico': {'loss': 16.1872, 'acc': 0.1205, 'ico err': 1.2364},
 'ico - none': {'loss': 1.0416, 'acc': 0.6823, 'ico err': 0.1232},
 'ico - ico': {'loss': 1.0829, 'acc': 0.6691, 'ico err': 0.2337},
 'all ico - none': {'loss': 0.4079, 'acc': 0.868, 'ico err': 0.0471},
 'all ico - ico': {'loss': 0.4328, 'acc': 0.8618, 'ico err': 0.2028}}