In [None]:
!git clone https://github.com/WouterBant/GEVit-DL2-Project.git

In [None]:
%cd GEVit-DL2-Project/

In [None]:
!pip install einops
!pip install wandb
!pip install ml_collections

In [None]:
import torch
import torchvision
from datasets import MNIST_rot
import matplotlib.pyplot as plt
import numpy as np
import models
import torch
import g_selfatt.groups as groups
import models
from torch.cuda.amp import GradScaler, autocast

In [None]:
data_mean = (0.1307,)
data_stddev = (0.3081,)
transform_train = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(data_mean, data_stddev),
    ]
)
transform_test = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(data_mean, data_stddev),
    ]
)

In [None]:
training_set = MNIST_rot(root="./data", stage="train", download=True, transform=transform_train, data_fraction=1)
evaluation_set = MNIST_rot(root="./data", stage="validation", download=True, transform=transform_train, data_fraction=1)
test_set = MNIST_rot(root="./data", stage="test", download=True, transform=transform_test, data_fraction=1)

training_loader = torch.utils.data.DataLoader(
    training_set,
    batch_size=32,
    shuffle=True,
    num_workers=4,
)
evaluation_loader = torch.utils.data.DataLoader(
    evaluation_set,
    batch_size=32,
    shuffle=False,
    num_workers=4,
)

In [None]:
# simulate some little training procedure to investigate speed of different parts of the codebase
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = models.GroupTransformer(
            group=groups.SE2(num_elements=4),
            in_channels=1,
            num_channels=20,
            block_sizes=[2, 3],
            expansion_per_block=1,
            crop_per_layer=[2, 0, 2, 1, 1],
            image_size=28,
            num_classes=10,
            dropout_rate_after_maxpooling=0.0,
            maxpool_after_last_block=False,
            normalize_between_layers=False,
            patch_size=5,
            num_heads=9,
            norm_type="LayerNorm",
            activation_function="Swish",
            attention_dropout_rate=0.1,
            value_dropout_rate=0.1,
            whitening_scale=1.41421356,
        )
model = model.to(device)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
scaler = GradScaler()

In [None]:
# lets try lr = 0.1, this is a really small model!!
# also who uses swish, lets also try relu later (or does this break equivariance?)
# it should be possible to train in less than 5 epochs
# i set epochs to 4 so be patient with judging the loss
# target loss: 0.2 (with normal run achieved after 50 epochs)

model.train()
for epoch in range(1):
    for inputs, labels in training_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            with autocast():
                out = model(inputs)
                loss = criterion(out, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            print(loss.item())

# first run with lr = 0.1 doesn't seem to be too high
# might even try higher, also don't forget to experiment with relu

In [None]:
correct = total = 0

model.eval()
for inputs, labels in training_loader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    with torch.set_grad_enabled(False):
            with autocast():
                out = model(inputs)
    _, preds = torch.max(out, 1)
    correct += (preds == labels).sum().item()
    total += labels.size(0)

correct / total