# Rotated MNIST: Model Training Notebook (Standalone)

This notebook trains and compares:
- Logistic Regression
- Standard CNN
- Rotation-Equivariant CNN (E2CNN) on a class-balanced subset of the Rotated MNIST dataset.

In [35]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
from e2cnn import gspaces
from e2cnn import nn as enn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import random, os
    

In [36]:

# Set seed and device
seed = 42
random.seed(seed)
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    

In [37]:

# Data preparation with balanced validation split
def get_dataloaders(batch_size=64, subset_fraction=1.0, seed=42):
    transform = transforms.Compose([
        transforms.RandomRotation(degrees=(-180, 180)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    full_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
    
    if subset_fraction < 1.0:
        class_indices = {i: [] for i in range(10)}
        for idx, (_, label) in enumerate(full_dataset):
            class_indices[label].append(idx)
        samples_per_class = int((len(full_dataset) * subset_fraction) / 10)
        balanced_indices = [random.sample(class_indices[i], samples_per_class) for i in range(10)]
        balanced_indices = [item for sublist in balanced_indices for item in sublist]
        random.shuffle(balanced_indices)
        subset = Subset(full_dataset, balanced_indices)
    else:
        subset = full_dataset

    train_size = int(0.8 * len(subset))
    val_size = len(subset) - train_size
    train_set, val_set = random_split(subset, [train_size, val_size])

    test_set = datasets.MNIST('data', train=False, download=True, transform=transform)
    return (
        DataLoader(train_set, batch_size=batch_size, shuffle=True),
        DataLoader(val_set, batch_size=batch_size),
        DataLoader(test_set, batch_size=batch_size)
    )
    

In [38]:

# Simple Logistic Regression model
class LogisticRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.linear(x)
    

In [39]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [40]:
class RotEquivariantCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        r2_act = gspaces.Rot2dOnR2(N=8)

        in_type = enn.FieldType(r2_act, [r2_act.trivial_repr])
        self.input_type = in_type

        self.block1 = enn.SequentialModule(
            enn.R2Conv(
                in_type,
                enn.FieldType(r2_act, 8 * [r2_act.regular_repr]),
                kernel_size=5,
                padding=2,
                bias=False,
            ),
            enn.ReLU(enn.FieldType(r2_act, 8 * [r2_act.regular_repr]), inplace=True),
            enn.PointwiseMaxPool(
                enn.FieldType(r2_act, 8 * [r2_act.regular_repr]), kernel_size=2
            ),
        )

        self.block2 = enn.SequentialModule(
            enn.R2Conv(
                self.block1.out_type,
                enn.FieldType(r2_act, 16 * [r2_act.regular_repr]),
                kernel_size=5,
                padding=2,
                bias=False,
            ),
            enn.ReLU(enn.FieldType(r2_act, 16 * [r2_act.regular_repr]), inplace=True),
            enn.PointwiseMaxPool(
                enn.FieldType(r2_act, 16 * [r2_act.regular_repr]), kernel_size=2
            ),
        )

        c = self.block2.out_type.size
        self.fc1 = nn.Linear(c * 7 * 7, num_classes)

    def forward(self, x):
        x = enn.GeometricTensor(x, self.input_type)
        x = self.block1(x)
        x = self.block2(x)
        x = x.tensor
        x = x.view(x.size(0), -1)
        return self.fc1(x)


In [None]:
def train_model(model, train_loader, val_loader, device, epochs=10, model_name='model', save_best=True):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    best_f1 = 0
    best_model_state = None

    history = {
        "train_loss": [],
        "val_loss": [],
        "val_accuracy": [],
        "val_precision": [],
        "val_recall": [],
        "val_f1": [],
    }

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0  # running loss added
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()  # accumulate loss

        train_loss = running_loss / len(train_loader)  # compute average
        history["train_loss"].append(train_loss)

        # Validation
        model.eval()
        val_running_loss = 0.0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_running_loss += loss.item() * inputs.size(0)

                preds = outputs.argmax(dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_loss = val_running_loss / len(val_loader.dataset)
        acc = accuracy_score(all_targets, all_preds)
        precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
        recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
        f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

        history["val_loss"].append(val_loss)
        history["val_accuracy"].append(acc)
        history["val_precision"].append(precision)
        history["val_recall"].append(recall)
        history["val_f1"].append(f1)

        if save_best and f1 > best_f1:
            best_f1 = f1
            best_model_state = model.state_dict()

        print(f"📘 Epoch {epoch+1}/{epochs}")
        print(f"   🔹 Train Loss: {train_loss:.4f}")
        print(f"   🔸 Val Loss: {val_loss:.4f} | Acc: {acc:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f}")

    print("✅ Training complete")

    if save_best and best_model_state is not None:
        model.load_state_dict(best_model_state)
        save_path = os.path.join(f"{model_name}.pth")
        torch.save(model.state_dict(), save_path)
        print(f"✅ Model saved as {save_path}")

    return model, history


In [43]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = LogisticRegression()
trained_model = train_model(model, train_loader, val_loader, device, model_name="logistic")

UnboundLocalError: local variable 'val_running_loss' referenced before assignment

In [None]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = SimpleCNN()
trained_model = train_model(model, train_loader, val_loader, device, model_name="cnn")

📘 Epoch 1/10
   🔹 Train Loss: 1.7399
   🔸 Val Loss: 1.4236 | Acc: 0.5300 | Precision: 0.5403 | Recall: 0.5307 | F1: 0.5218
📘 Epoch 2/10
   🔹 Train Loss: 1.1725
   🔸 Val Loss: 1.0356 | Acc: 0.6683 | Precision: 0.6843 | Recall: 0.6668 | F1: 0.6533
📘 Epoch 3/10
   🔹 Train Loss: 0.8677
   🔸 Val Loss: 0.8877 | Acc: 0.7200 | Precision: 0.7284 | Recall: 0.7165 | F1: 0.7163
📘 Epoch 4/10
   🔹 Train Loss: 0.7379
   🔸 Val Loss: 0.7421 | Acc: 0.7583 | Precision: 0.7577 | Recall: 0.7567 | F1: 0.7555
📘 Epoch 5/10
   🔹 Train Loss: 0.6461
   🔸 Val Loss: 0.7313 | Acc: 0.7600 | Precision: 0.7634 | Recall: 0.7568 | F1: 0.7570
📘 Epoch 6/10
   🔹 Train Loss: 0.5801
   🔸 Val Loss: 0.6998 | Acc: 0.7500 | Precision: 0.7671 | Recall: 0.7481 | F1: 0.7481
📘 Epoch 7/10
   🔹 Train Loss: 0.5265
   🔸 Val Loss: 0.6234 | Acc: 0.8067 | Precision: 0.8056 | Recall: 0.8056 | F1: 0.8045
📘 Epoch 8/10
   🔹 Train Loss: 0.4776
   🔸 Val Loss: 0.5591 | Acc: 0.8083 | Precision: 0.8074 | Recall: 0.8081 | F1: 0.8062
📘 Epoch 9/10
   

In [None]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = RotEquivariantCNN()
trained_model = train_model(model, train_loader, val_loader, device, model_name="rotcnn")

📘 Epoch 1/10
   🔹 Train Loss: 2.8032
   🔸 Val Loss: 1.4244 | Acc: 0.4950 | Precision: 0.6331 | Recall: 0.5080 | F1: 0.4667
📘 Epoch 2/10
   🔹 Train Loss: 1.1412
   🔸 Val Loss: 0.9092 | Acc: 0.6867 | Precision: 0.7324 | Recall: 0.7026 | F1: 0.6728
📘 Epoch 3/10
   🔹 Train Loss: 0.8364
   🔸 Val Loss: 0.7002 | Acc: 0.7733 | Precision: 0.7917 | Recall: 0.7833 | F1: 0.7657
📘 Epoch 4/10
   🔹 Train Loss: 0.6760
   🔸 Val Loss: 0.7131 | Acc: 0.7583 | Precision: 0.8140 | Recall: 0.7699 | F1: 0.7603
📘 Epoch 5/10
   🔹 Train Loss: 0.5724
   🔸 Val Loss: 0.5077 | Acc: 0.8400 | Precision: 0.8532 | Recall: 0.8449 | F1: 0.8391
📘 Epoch 6/10
   🔹 Train Loss: 0.4944
   🔸 Val Loss: 0.4629 | Acc: 0.8533 | Precision: 0.8623 | Recall: 0.8549 | F1: 0.8556
📘 Epoch 7/10
   🔹 Train Loss: 0.4452
   🔸 Val Loss: 0.4394 | Acc: 0.8733 | Precision: 0.8801 | Recall: 0.8769 | F1: 0.8747
📘 Epoch 8/10
   🔹 Train Loss: 0.4179
   🔸 Val Loss: 0.4177 | Acc: 0.8583 | Precision: 0.8676 | Recall: 0.8643 | F1: 0.8571
📘 Epoch 9/10
   