In [None]:
import sys
sys.path.append('../src')

import torch
import torchvision
from datasets import MNIST_rot
import matplotlib.pyplot as plt
import numpy as np
import models
import torch
import g_selfatt.groups as groups
import models
from torch.cuda.amp import GradScaler, autocast

import g_selfatt
from g_selfatt.nn import (
    Conv3d1x1,
    GroupLocalSelfAttention,
    GroupSelfAttention,
    LayerNorm,
    LiftLocalSelfAttention,
    LiftSelfAttention,
    TransformerBlock,
    activations,
)
from g_selfatt.utils import num_params


In [None]:
data_mean = (0.1307,)
data_stddev = (0.3081,)
transform_train = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(data_mean, data_stddev),
    ]
)
transform_test = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(data_mean, data_stddev),
    ]
)

## MNIST training set

In [None]:
# Specify the fraction of the dataset you want to download
data_fraction = 0.1  # For example, to download 50% of the dataset

# Create the training set with a fraction of the data
mnist_full = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform_train)

num_samples = len(mnist_full)
indices = np.random.choice(num_samples, int(data_fraction * num_samples), replace=False)

training_set = torch.utils.data.Subset(mnist_full, indices)

In [None]:
# training_set = MNIST_rot(root="./data", stage="train", download=True, transform=transform_train, data_fraction=1)
evaluation_set = MNIST_rot(root="./data", stage="validation", download=True, transform=transform_train, data_fraction=1)
test_set = MNIST_rot(root="./data", stage="test", download=True, transform=transform_test, data_fraction=1)

training_loader = torch.utils.data.DataLoader(
    training_set,
    batch_size=32,
    shuffle=True,
    num_workers=4,
)
evaluation_loader = torch.utils.data.DataLoader(
    evaluation_set,
    batch_size=32,
    shuffle=False,
    num_workers=4,
)

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=32,
    shuffle=False,
    num_workers=4,
)

In [None]:
# simulate some little training procedure to investigate speed of different parts of the codebase
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = models.GroupTransformer(
    group=groups.SE2(num_elements=4),
    in_channels=1,
    num_channels=20,
    block_sizes=[2, 3],
    expansion_per_block=1,
    crop_per_layer=[2, 0, 2, 1, 1],
    image_size=28,
    num_classes=10,
    dropout_rate_after_maxpooling=0.0,
    maxpool_after_last_block=False,
    normalize_between_layers=False,
    patch_size=5,
    num_heads=9,
    norm_type="LayerNorm",
    activation_function="Swish",
    attention_dropout_rate=0.1,
    value_dropout_rate=0.1,
    whitening_scale=1.41421356,
)
model = torch.nn.DataParallel(model)
model = model.to(device)
num_params(model)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
scaler = GradScaler()

In [None]:
# Lists to store loss and accuracy values
train_loss_values = []
evaluation_accuracy_values = []


model.train()
for epoch in range(10):
    epoch_train_loss = 0.0
    for inputs, labels in training_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            with autocast():
                out = model(inputs)
                loss = criterion(out, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            epoch_train_loss += loss.item() * inputs.size(0)  # Accumulate loss
    epoch_train_loss /= len(training_loader.dataset)  # Average loss for the epoch
    print("average loss", epoch_train_loss)
    train_loss_values.append(epoch_train_loss)  # Save epoch loss

    # Evaluation on first test dataset
    correct = total = 0
    model.eval()
    for inputs, labels in evaluation_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.set_grad_enabled(False):
            with autocast():
                out = model(inputs)
        _, preds = torch.max(out, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    accuracy = correct / total

    print("evaulation accucary: ", accuracy)
    print("------------------------------------------------")
    evaluation_accuracy_values.append(accuracy)  # Save accuracy for first test dataset


# After training loop, you can plot the saved values using matplotlib or any other plotting library


In [None]:
# Evaluation on second test dataset
correct = total = 0
for inputs, labels in test_loader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    with torch.set_grad_enabled(False):
        with autocast():
            out = model(inputs)
    _, preds = torch.max(out, 1)
    correct += (preds == labels).sum().item()
    total += labels.size(0)
accuracy2 = correct / total
print("final test: ", accuracy2)

