# Rotated MNIST: Model Training Notebook (Standalone)

This notebook trains and compares:
- Logistic Regression
- Standard CNN
- Rotation-Equivariant CNN (E2CNN) on a class-balanced subset of the Rotated MNIST dataset.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
from e2cnn import gspaces
from e2cnn import nn as enn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import random, os
    

In [None]:
# Set seed and device
seed = 42
random.seed(seed)
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    

In [3]:

# Data preparation with balanced validation split
def get_dataloaders(batch_size=64, subset_fraction=1.0, seed=42):
    transform = transforms.Compose([
        transforms.RandomRotation(degrees=(-180, 180)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    full_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
    
    if subset_fraction < 1.0:
        class_indices = {i: [] for i in range(10)}
        for idx, (_, label) in enumerate(full_dataset):
            class_indices[label].append(idx)
        samples_per_class = int((len(full_dataset) * subset_fraction) / 10)
        balanced_indices = [random.sample(class_indices[i], samples_per_class) for i in range(10)]
        balanced_indices = [item for sublist in balanced_indices for item in sublist]
        random.shuffle(balanced_indices)
        subset = Subset(full_dataset, balanced_indices)
    else:
        subset = full_dataset

    train_size = int(0.8 * len(subset))
    val_size = len(subset) - train_size
    train_set, val_set = random_split(subset, [train_size, val_size])

    test_set = datasets.MNIST('data', train=False, download=True, transform=transform)
    return (
        DataLoader(train_set, batch_size=batch_size, shuffle=True),
        DataLoader(val_set, batch_size=batch_size),
        DataLoader(test_set, batch_size=batch_size)
    )
    

In [4]:

# Simple Logistic Regression model
class LogisticRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.linear(x)
    

In [5]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [6]:
class RotEquivariantCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        r2_act = gspaces.Rot2dOnR2(N=8)

        in_type = enn.FieldType(r2_act, [r2_act.trivial_repr])
        self.input_type = in_type

        self.block1 = enn.SequentialModule(
            enn.R2Conv(
                in_type,
                enn.FieldType(r2_act, 8 * [r2_act.regular_repr]),
                kernel_size=5,
                padding=2,
                bias=False,
            ),
            enn.ReLU(enn.FieldType(r2_act, 8 * [r2_act.regular_repr]), inplace=True),
            enn.PointwiseMaxPool(
                enn.FieldType(r2_act, 8 * [r2_act.regular_repr]), kernel_size=2
            ),
        )

        self.block2 = enn.SequentialModule(
            enn.R2Conv(
                self.block1.out_type,
                enn.FieldType(r2_act, 16 * [r2_act.regular_repr]),
                kernel_size=5,
                padding=2,
                bias=False,
            ),
            enn.ReLU(enn.FieldType(r2_act, 16 * [r2_act.regular_repr]), inplace=True),
            enn.PointwiseMaxPool(
                enn.FieldType(r2_act, 16 * [r2_act.regular_repr]), kernel_size=2
            ),
        )

        c = self.block2.out_type.size
        self.fc1 = nn.Linear(c * 7 * 7, num_classes)

    def forward(self, x):
        x = enn.GeometricTensor(x, self.input_type)
        x = self.block1(x)
        x = self.block2(x)
        x = x.tensor
        x = x.view(x.size(0), -1)
        return self.fc1(x)


In [7]:
def train_model(model, train_loader, val_loader, device, epochs=10, model_name='model', save_best=True):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    best_f1 = 0
    best_model_state = None

    history = {
        "train_loss": [],
        "val_loss": [],
        "val_accuracy": [],
        "val_precision": [],
        "val_recall": [],
        "val_f1": [],
    }

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0  # running loss added
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()  # accumulate loss

        train_loss = running_loss / len(train_loader)  # compute average
        history["train_loss"].append(train_loss)

        # Validation
        model.eval()
        val_running_loss = 0.0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_running_loss += loss.item() * inputs.size(0)

                preds = outputs.argmax(dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_loss = val_running_loss / len(val_loader.dataset)
        acc = accuracy_score(all_targets, all_preds)
        precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
        recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
        f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

        history["val_loss"].append(val_loss)
        history["val_accuracy"].append(acc)
        history["val_precision"].append(precision)
        history["val_recall"].append(recall)
        history["val_f1"].append(f1)

        if save_best and f1 > best_f1:
            best_f1 = f1
            best_model_state = model.state_dict()

        print(f"ðŸ“˜ Epoch {epoch+1}/{epochs}")
        print(f"   ðŸ”¹ Train Loss: {train_loss:.4f}")
        print(f"   ðŸ”¸ Val Loss: {val_loss:.4f} | Acc: {acc:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f}")

    print("âœ… Training complete")

    if save_best and best_model_state is not None:
        model.load_state_dict(best_model_state)
        save_path = os.path.join(f"{model_name}.pth")
        torch.save(model.state_dict(), save_path)
        print(f"âœ… Model saved as {save_path}")

    return model, history


In [8]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.1)
model = LogisticRegression()
trained_model = train_model(model, train_loader, val_loader, device, model_name="logistic")

ðŸ“˜ Epoch 1/10
   ðŸ”¹ Train Loss: 1.7828
   ðŸ”¸ Val Loss: 1.5783 | Acc: 0.4192 | Precision: 0.4191 | Recall: 0.4299 | F1: 0.4018
ðŸ“˜ Epoch 2/10
   ðŸ”¹ Train Loss: 1.5121
   ðŸ”¸ Val Loss: 1.4967 | Acc: 0.4592 | Precision: 0.4813 | Recall: 0.4663 | F1: 0.4588
ðŸ“˜ Epoch 3/10
   ðŸ”¹ Train Loss: 1.4583
   ðŸ”¸ Val Loss: 1.4479 | Acc: 0.4783 | Precision: 0.4806 | Recall: 0.4875 | F1: 0.4681
ðŸ“˜ Epoch 4/10
   ðŸ”¹ Train Loss: 1.4246
   ðŸ”¸ Val Loss: 1.4288 | Acc: 0.4883 | Precision: 0.4739 | Recall: 0.4955 | F1: 0.4741
ðŸ“˜ Epoch 5/10
   ðŸ”¹ Train Loss: 1.4200
   ðŸ”¸ Val Loss: 1.4425 | Acc: 0.4775 | Precision: 0.4751 | Recall: 0.4872 | F1: 0.4646
ðŸ“˜ Epoch 6/10
   ðŸ”¹ Train Loss: 1.4063
   ðŸ”¸ Val Loss: 1.4478 | Acc: 0.4758 | Precision: 0.4849 | Recall: 0.4855 | F1: 0.4674
ðŸ“˜ Epoch 7/10
   ðŸ”¹ Train Loss: 1.4048
   ðŸ”¸ Val Loss: 1.4169 | Acc: 0.5033 | Precision: 0.5022 | Recall: 0.5094 | F1: 0.4953
ðŸ“˜ Epoch 8/10
   ðŸ”¹ Train Loss: 1.4014
   ðŸ”¸ Val Loss: 1.4279 | Acc: 0

In [9]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.1)
model = SimpleCNN()
trained_model = train_model(model, train_loader, val_loader, device, model_name="cnn")

ðŸ“˜ Epoch 1/10
   ðŸ”¹ Train Loss: 1.5425
   ðŸ”¸ Val Loss: 1.0307 | Acc: 0.6525 | Precision: 0.6817 | Recall: 0.6555 | F1: 0.6350
ðŸ“˜ Epoch 2/10
   ðŸ”¹ Train Loss: 0.8182
   ðŸ”¸ Val Loss: 0.6679 | Acc: 0.7833 | Precision: 0.7864 | Recall: 0.7808 | F1: 0.7797
ðŸ“˜ Epoch 3/10
   ðŸ”¹ Train Loss: 0.6235
   ðŸ”¸ Val Loss: 0.5569 | Acc: 0.8225 | Precision: 0.8285 | Recall: 0.8215 | F1: 0.8208
ðŸ“˜ Epoch 4/10
   ðŸ”¹ Train Loss: 0.5302
   ðŸ”¸ Val Loss: 0.4547 | Acc: 0.8575 | Precision: 0.8672 | Recall: 0.8576 | F1: 0.8585
ðŸ“˜ Epoch 5/10
   ðŸ”¹ Train Loss: 0.4568
   ðŸ”¸ Val Loss: 0.4279 | Acc: 0.8575 | Precision: 0.8640 | Recall: 0.8578 | F1: 0.8574
ðŸ“˜ Epoch 6/10
   ðŸ”¹ Train Loss: 0.4144
   ðŸ”¸ Val Loss: 0.3802 | Acc: 0.8833 | Precision: 0.8885 | Recall: 0.8826 | F1: 0.8837
ðŸ“˜ Epoch 7/10
   ðŸ”¹ Train Loss: 0.3704
   ðŸ”¸ Val Loss: 0.3356 | Acc: 0.8958 | Precision: 0.8965 | Recall: 0.8958 | F1: 0.8952
ðŸ“˜ Epoch 8/10
   ðŸ”¹ Train Loss: 0.3376
   ðŸ”¸ Val Loss: 0.3053 | Acc: 0

In [10]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.1)
model = RotEquivariantCNN()
trained_model = train_model(model, train_loader, val_loader, device, model_name="rotcnn")

  full_mask[mask] = norms.to(torch.uint8)


ðŸ“˜ Epoch 1/10
   ðŸ”¹ Train Loss: 2.2907
   ðŸ”¸ Val Loss: 0.9097 | Acc: 0.6783 | Precision: 0.7261 | Recall: 0.6763 | F1: 0.6751
ðŸ“˜ Epoch 2/10
   ðŸ”¹ Train Loss: 0.6852
   ðŸ”¸ Val Loss: 0.5948 | Acc: 0.7992 | Precision: 0.8465 | Recall: 0.7980 | F1: 0.8067
ðŸ“˜ Epoch 3/10
   ðŸ”¹ Train Loss: 0.5152
   ðŸ”¸ Val Loss: 0.4631 | Acc: 0.8683 | Precision: 0.8744 | Recall: 0.8692 | F1: 0.8667
ðŸ“˜ Epoch 4/10
   ðŸ”¹ Train Loss: 0.4348
   ðŸ”¸ Val Loss: 0.3777 | Acc: 0.8825 | Precision: 0.8863 | Recall: 0.8833 | F1: 0.8830
ðŸ“˜ Epoch 5/10
   ðŸ”¹ Train Loss: 0.4089
   ðŸ”¸ Val Loss: 0.4439 | Acc: 0.8567 | Precision: 0.8803 | Recall: 0.8564 | F1: 0.8588
ðŸ“˜ Epoch 6/10
   ðŸ”¹ Train Loss: 0.3271
   ðŸ”¸ Val Loss: 0.3170 | Acc: 0.8925 | Precision: 0.8962 | Recall: 0.8934 | F1: 0.8929
ðŸ“˜ Epoch 7/10
   ðŸ”¹ Train Loss: 0.3092
   ðŸ”¸ Val Loss: 0.2969 | Acc: 0.9167 | Precision: 0.9220 | Recall: 0.9168 | F1: 0.9178
ðŸ“˜ Epoch 8/10
   ðŸ”¹ Train Loss: 0.2981
   ðŸ”¸ Val Loss: 0.2981 | Acc: 0