In [None]:
import torch
import torchvision
from datasets import MNIST_rot
import matplotlib.pyplot as plt
import numpy as np
import models
import torch
import g_selfatt.groups as groups
import models

In [None]:
data_mean = (0.1307,)
data_stddev = (0.3081,)
transform_train = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(data_mean, data_stddev),
    ]
)
transform_test = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(data_mean, data_stddev),
    ]
)

In [None]:
training_set = MNIST_rot(root="./data", stage="train", download=True, transform=transform_train, data_fraction=1)
evaluation_set = MNIST_rot(root="./data", stage="evaluation", download=True, transform=transform_train, data_fraction=1)
test_set = MNIST_rot(root="./data", stage="test", download=True, transform=transform_test, data_fraction=1)

training_loader = torch.utils.data.DataLoader(
    training_set,
    batch_size=16,
    shuffle=True,
    num_workers=4,
)
evaluation_loader = torch.utils.data.DataLoader(
    evaluation_set,
    batch_size=16,
    shuffle=False,
    num_workers=4,
)

In [None]:
len(training_loader), len(training_set)  # batch is 4

In [None]:


# Visualize some samples from the training set
num_samples_to_visualize = 4
fig, axes = plt.subplots(1, num_samples_to_visualize, figsize=(12, 3))
for i in range(num_samples_to_visualize):
    image, label = training_set[i]
    axes[i].imshow(np.squeeze(image), cmap='gray')
    axes[i].set_title(f"Label: {label}")
    axes[i].axis('off')
plt.show()

# Visualize a batch of images loaded using the training loader
dataiter = iter(training_loader)
images, labels = dataiter.next()

# Plot the images in the batch
plt.figure(figsize=(10, 4))
for idx in range(images.size(0)):
    plt.subplot(1, 4, idx + 1)
    plt.imshow(images[idx].squeeze(), cmap='gray')
    plt.title(f"Label: {labels[idx].item()}")
    plt.axis('off')
plt.show()

In [None]:
# simulate some little training procedure to investigate speed of different parts of the codebase
model = models.GroupTransformer(
            group=groups.SE2(num_elements=4),
            in_channels=1,
            num_channels=20,
            block_sizes=[2, 3],
            expansion_per_block=1,
            crop_per_layer=[2, 0, 2, 1, 1],
            image_size=28,
            num_classes=10,
            dropout_rate_after_maxpooling=0.0,
            maxpool_after_last_block=False,
            normalize_between_layers=False,
            patch_size=5,
            num_heads=9,
            norm_type="LayerNorm",
            activation_function="Swish",
            attention_dropout_rate=0.1,
            value_dropout_rate=0.1,
            whitening_scale=1.41421356,
        )

In [None]:
# lets try lr = 0.1, this is a really small model!!
# also who uses swish, lets also try relu later (or does this break equivariance?)
# it should be possible to train in less than 5 epochs
# i set epochs to 4 so be patient with judging the loss
# target loss: 0.2 (with normal run achieved after 50 epochs) 

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

model.train()
for epoch in range(1):
    for inputs, labels in training_loader:
        optimizer.zero_grad()
        out = model(inputs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        print(loss.item())

# first run with lr = 0.1 doesn't seem to be too high
# might even try higher, also don't forget to experiment with relu

In [None]:
correct = total = 0

model.eval()
for inputs, labels in training_loader[:len(training_loader//4)]:
    out = model(inputs)
    _, preds = torch.max(out, 1)
    correct += (preds == labels).sum().item()
    total += labels.size(0)

correct / total