# Rotated MNIST: Model Training Notebook (Standalone)

This notebook trains and compares:
- Logistic Regression
- Standard CNN
- Rotation-Equivariant CNN (E2CNN) on a class-balanced subset of the Rotated MNIST dataset.

In [12]:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
from e2cnn import gspaces
from e2cnn import nn as enn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import random, os
    

In [13]:

# Set seed and device
seed = 42
random.seed(seed)
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    

In [14]:

# Data preparation with balanced validation split
def get_dataloaders(batch_size=64, subset_fraction=1.0, seed=42):
    transform = transforms.Compose([
        transforms.RandomRotation(degrees=(-180, 180)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    full_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
    
    if subset_fraction < 1.0:
        class_indices = {i: [] for i in range(10)}
        for idx, (_, label) in enumerate(full_dataset):
            class_indices[label].append(idx)
        samples_per_class = int((len(full_dataset) * subset_fraction) / 10)
        balanced_indices = [random.sample(class_indices[i], samples_per_class) for i in range(10)]
        balanced_indices = [item for sublist in balanced_indices for item in sublist]
        random.shuffle(balanced_indices)
        subset = Subset(full_dataset, balanced_indices)
    else:
        subset = full_dataset

    train_size = int(0.8 * len(subset))
    val_size = len(subset) - train_size
    train_set, val_set = random_split(subset, [train_size, val_size])

    test_set = datasets.MNIST('data', train=False, download=True, transform=transform)
    return (
        DataLoader(train_set, batch_size=batch_size, shuffle=True),
        DataLoader(val_set, batch_size=batch_size),
        DataLoader(test_set, batch_size=batch_size)
    )
    

In [15]:

# Simple Logistic Regression model
class LogisticRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.linear(x)
    

In [16]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [17]:
class RotEquivariantCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        r2_act = gspaces.Rot2dOnR2(N=8)

        in_type = enn.FieldType(r2_act, [r2_act.trivial_repr])
        self.input_type = in_type

        self.block1 = enn.SequentialModule(
            enn.R2Conv(
                in_type,
                enn.FieldType(r2_act, 8 * [r2_act.regular_repr]),
                kernel_size=5,
                padding=2,
                bias=False,
            ),
            enn.ReLU(enn.FieldType(r2_act, 8 * [r2_act.regular_repr]), inplace=True),
            enn.PointwiseMaxPool(
                enn.FieldType(r2_act, 8 * [r2_act.regular_repr]), kernel_size=2
            ),
        )

        self.block2 = enn.SequentialModule(
            enn.R2Conv(
                self.block1.out_type,
                enn.FieldType(r2_act, 16 * [r2_act.regular_repr]),
                kernel_size=5,
                padding=2,
                bias=False,
            ),
            enn.ReLU(enn.FieldType(r2_act, 16 * [r2_act.regular_repr]), inplace=True),
            enn.PointwiseMaxPool(
                enn.FieldType(r2_act, 16 * [r2_act.regular_repr]), kernel_size=2
            ),
        )

        c = self.block2.out_type.size
        self.fc1 = nn.Linear(c * 7 * 7, num_classes)

    def forward(self, x):
        x = enn.GeometricTensor(x, self.input_type)
        x = self.block1(x)
        x = self.block2(x)
        x = x.tensor
        x = x.view(x.size(0), -1)
        return self.fc1(x)


In [18]:
def train_model(model, train_loader, val_loader, device, epochs=5, save_best=None):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    best_f1 = 0
    best_model_state = None

    history = {
        "train_loss": [],
        "val_loss": [],
        "val_accuracy": [],
        "val_precision": [],
        "val_recall": [],
        "val_f1": [],
    }

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)
        history["train_loss"].append(train_loss)

        # Validation
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = outputs.argmax(dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_loss /= len(val_loader)
        acc = accuracy_score(all_targets, all_preds)
        precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
        recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
        f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

        history["val_loss"].append(val_loss)
        history["val_accuracy"].append(acc)
        history["val_precision"].append(precision)
        history["val_recall"].append(recall)
        history["val_f1"].append(f1)

        if save_best and f1 > best_f1:
            best_f1 = f1
            best_model_state = model.state_dict()

        print(f"📘 Epoch {epoch+1}/{epochs}")
        print(f"   🔹 Train Loss: {train_loss:.4f}")
        print(f"   🔸 Val Loss: {val_loss:.4f} | Acc: {acc:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f}")

    print("✅ Training complete")

    if save_best and best_model_state is not None:
        model.load_state_dict(best_model_state)
        print("💾 Loaded best model based on validation F1")

    return model, history

    

In [19]:

# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = LogisticRegression()
trained_model = train_model(model, train_loader, val_loader, device)
    

📘 Epoch 1/5
   🔹 Train Loss: 1.9761
   🔸 Val Loss: 1.6987 | Acc: 0.3950 | Precision: 0.3892 | Recall: 0.4003 | F1: 0.3688
📘 Epoch 2/5
   🔹 Train Loss: 1.6057
   🔸 Val Loss: 1.5364 | Acc: 0.4417 | Precision: 0.4205 | Recall: 0.4466 | F1: 0.4100
📘 Epoch 3/5
   🔹 Train Loss: 1.5110
   🔸 Val Loss: 1.5069 | Acc: 0.4517 | Precision: 0.4419 | Recall: 0.4537 | F1: 0.4346
📘 Epoch 4/5
   🔹 Train Loss: 1.4745
   🔸 Val Loss: 1.4629 | Acc: 0.4650 | Precision: 0.4526 | Recall: 0.4711 | F1: 0.4502
📘 Epoch 5/5
   🔹 Train Loss: 1.4576
   🔸 Val Loss: 1.4529 | Acc: 0.4833 | Precision: 0.4676 | Recall: 0.4922 | F1: 0.4606
✅ Training complete


In [20]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = SimpleCNN()
trained_model = train_model(model, train_loader, val_loader, device)

📘 Epoch 1/5
   🔹 Train Loss: 1.6908
   🔸 Val Loss: 1.3060 | Acc: 0.5483 | Precision: 0.5788 | Recall: 0.5451 | F1: 0.5334
📘 Epoch 2/5
   🔹 Train Loss: 1.0683
   🔸 Val Loss: 0.9979 | Acc: 0.6567 | Precision: 0.7178 | Recall: 0.6512 | F1: 0.6597
📘 Epoch 3/5
   🔹 Train Loss: 0.8534
   🔸 Val Loss: 0.7704 | Acc: 0.7500 | Precision: 0.7647 | Recall: 0.7486 | F1: 0.7452
📘 Epoch 4/5
   🔹 Train Loss: 0.7083
   🔸 Val Loss: 0.6781 | Acc: 0.7983 | Precision: 0.7955 | Recall: 0.7951 | F1: 0.7921
📘 Epoch 5/5
   🔹 Train Loss: 0.6191
   🔸 Val Loss: 0.6913 | Acc: 0.7850 | Precision: 0.8065 | Recall: 0.7774 | F1: 0.7789
✅ Training complete


In [21]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = RotEquivariantCNN()
trained_model = train_model(model, train_loader, val_loader, device)

📘 Epoch 1/5
   🔹 Train Loss: 1.9331
   🔸 Val Loss: 1.1153 | Acc: 0.6133 | Precision: 0.6772 | Recall: 0.6093 | F1: 0.6013
📘 Epoch 2/5
   🔹 Train Loss: 0.8500
   🔸 Val Loss: 0.6679 | Acc: 0.7800 | Precision: 0.8047 | Recall: 0.7793 | F1: 0.7727
📘 Epoch 3/5
   🔹 Train Loss: 0.5973
   🔸 Val Loss: 0.5984 | Acc: 0.7983 | Precision: 0.8223 | Recall: 0.7974 | F1: 0.7958
📘 Epoch 4/5
   🔹 Train Loss: 0.4996
   🔸 Val Loss: 0.4819 | Acc: 0.8467 | Precision: 0.8488 | Recall: 0.8468 | F1: 0.8457
📘 Epoch 5/5
   🔹 Train Loss: 0.4532
   🔸 Val Loss: 0.4243 | Acc: 0.8767 | Precision: 0.8836 | Recall: 0.8776 | F1: 0.8777
✅ Training complete
