# Rotated MNIST: Model Training Notebook (Standalone)

This notebook trains and compares:
- Logistic Regression
- Standard CNN
- Rotation-Equivariant CNN (E2CNN) on a class-balanced subset of the Rotated MNIST dataset.

In [35]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
from e2cnn import gspaces
from e2cnn import nn as enn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import random, os
    

In [36]:

# Set seed and device
seed = 42
random.seed(seed)
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    

In [37]:

# Data preparation with balanced validation split
def get_dataloaders(batch_size=64, subset_fraction=1.0, seed=42):
    transform = transforms.Compose([
        transforms.RandomRotation(degrees=(-180, 180)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    full_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
    
    if subset_fraction < 1.0:
        class_indices = {i: [] for i in range(10)}
        for idx, (_, label) in enumerate(full_dataset):
            class_indices[label].append(idx)
        samples_per_class = int((len(full_dataset) * subset_fraction) / 10)
        balanced_indices = [random.sample(class_indices[i], samples_per_class) for i in range(10)]
        balanced_indices = [item for sublist in balanced_indices for item in sublist]
        random.shuffle(balanced_indices)
        subset = Subset(full_dataset, balanced_indices)
    else:
        subset = full_dataset

    train_size = int(0.8 * len(subset))
    val_size = len(subset) - train_size
    train_set, val_set = random_split(subset, [train_size, val_size])

    test_set = datasets.MNIST('data', train=False, download=True, transform=transform)
    return (
        DataLoader(train_set, batch_size=batch_size, shuffle=True),
        DataLoader(val_set, batch_size=batch_size),
        DataLoader(test_set, batch_size=batch_size)
    )
    

In [38]:

# Simple Logistic Regression model
class LogisticRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.linear(x)
    

In [39]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [40]:
class RotEquivariantCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        r2_act = gspaces.Rot2dOnR2(N=8)

        in_type = enn.FieldType(r2_act, [r2_act.trivial_repr])
        self.input_type = in_type

        self.block1 = enn.SequentialModule(
            enn.R2Conv(
                in_type,
                enn.FieldType(r2_act, 8 * [r2_act.regular_repr]),
                kernel_size=5,
                padding=2,
                bias=False,
            ),
            enn.ReLU(enn.FieldType(r2_act, 8 * [r2_act.regular_repr]), inplace=True),
            enn.PointwiseMaxPool(
                enn.FieldType(r2_act, 8 * [r2_act.regular_repr]), kernel_size=2
            ),
        )

        self.block2 = enn.SequentialModule(
            enn.R2Conv(
                self.block1.out_type,
                enn.FieldType(r2_act, 16 * [r2_act.regular_repr]),
                kernel_size=5,
                padding=2,
                bias=False,
            ),
            enn.ReLU(enn.FieldType(r2_act, 16 * [r2_act.regular_repr]), inplace=True),
            enn.PointwiseMaxPool(
                enn.FieldType(r2_act, 16 * [r2_act.regular_repr]), kernel_size=2
            ),
        )

        c = self.block2.out_type.size
        self.fc1 = nn.Linear(c * 7 * 7, num_classes)

    def forward(self, x):
        x = enn.GeometricTensor(x, self.input_type)
        x = self.block1(x)
        x = self.block2(x)
        x = x.tensor
        x = x.view(x.size(0), -1)
        return self.fc1(x)


In [44]:
def train_model(model, train_loader, val_loader, device, epochs=10, model_name='model', save_best=True):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    best_f1 = 0
    best_model_state = None

    history = {
        "train_loss": [],
        "val_loss": [],
        "val_accuracy": [],
        "val_precision": [],
        "val_recall": [],
        "val_f1": [],
    }

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0  # running loss added
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()  # accumulate loss

        train_loss = running_loss / len(train_loader)  # compute average
        history["train_loss"].append(train_loss)

        # Validation
        model.eval()
        val_running_loss = 0.0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_running_loss += loss.item() * inputs.size(0)

                preds = outputs.argmax(dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_loss = val_running_loss / len(val_loader.dataset)
        acc = accuracy_score(all_targets, all_preds)
        precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
        recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
        f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

        history["val_loss"].append(val_loss)
        history["val_accuracy"].append(acc)
        history["val_precision"].append(precision)
        history["val_recall"].append(recall)
        history["val_f1"].append(f1)

        if save_best and f1 > best_f1:
            best_f1 = f1
            best_model_state = model.state_dict()

        print(f"📘 Epoch {epoch+1}/{epochs}")
        print(f"   🔹 Train Loss: {train_loss:.4f}")
        print(f"   🔸 Val Loss: {val_loss:.4f} | Acc: {acc:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f}")

    print("✅ Training complete")

    if save_best and best_model_state is not None:
        model.load_state_dict(best_model_state)
        save_path = os.path.join(f"{model_name}.pth")
        torch.save(model.state_dict(), save_path)
        print(f"✅ Model saved as {save_path}")

    return model, history


In [46]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = LogisticRegression()
trained_model = train_model(model, train_loader, val_loader, device, model_name="logistic")

📘 Epoch 1/10
   🔹 Train Loss: 1.9640
   🔸 Val Loss: 1.7161 | Acc: 0.3783 | Precision: 0.3912 | Recall: 0.3770 | F1: 0.3499
📘 Epoch 2/10
   🔹 Train Loss: 1.6196
   🔸 Val Loss: 1.5756 | Acc: 0.4350 | Precision: 0.4429 | Recall: 0.4405 | F1: 0.4116
📘 Epoch 3/10
   🔹 Train Loss: 1.5235
   🔸 Val Loss: 1.5415 | Acc: 0.4300 | Precision: 0.4561 | Recall: 0.4308 | F1: 0.3986
📘 Epoch 4/10
   🔹 Train Loss: 1.5027
   🔸 Val Loss: 1.4884 | Acc: 0.4433 | Precision: 0.4151 | Recall: 0.4412 | F1: 0.4127
📘 Epoch 5/10
   🔹 Train Loss: 1.4608
   🔸 Val Loss: 1.5241 | Acc: 0.4450 | Precision: 0.4284 | Recall: 0.4434 | F1: 0.4175
📘 Epoch 6/10
   🔹 Train Loss: 1.4606
   🔸 Val Loss: 1.4721 | Acc: 0.4617 | Precision: 0.4331 | Recall: 0.4584 | F1: 0.4350
📘 Epoch 7/10
   🔹 Train Loss: 1.4356
   🔸 Val Loss: 1.4293 | Acc: 0.4933 | Precision: 0.4924 | Recall: 0.4887 | F1: 0.4825
📘 Epoch 8/10
   🔹 Train Loss: 1.4400
   🔸 Val Loss: 1.4326 | Acc: 0.4700 | Precision: 0.4427 | Recall: 0.4635 | F1: 0.4420
📘 Epoch 9/10
   

In [47]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = SimpleCNN()
trained_model = train_model(model, train_loader, val_loader, device, model_name="cnn")

📘 Epoch 1/10
   🔹 Train Loss: 1.7102
   🔸 Val Loss: 1.4398 | Acc: 0.4917 | Precision: 0.5110 | Recall: 0.4971 | F1: 0.4715
📘 Epoch 2/10
   🔹 Train Loss: 1.1767
   🔸 Val Loss: 1.0385 | Acc: 0.6317 | Precision: 0.6621 | Recall: 0.6227 | F1: 0.6195
📘 Epoch 3/10
   🔹 Train Loss: 0.9058
   🔸 Val Loss: 0.8513 | Acc: 0.6967 | Precision: 0.7347 | Recall: 0.6956 | F1: 0.6979
📘 Epoch 4/10
   🔹 Train Loss: 0.7650
   🔸 Val Loss: 0.7672 | Acc: 0.7383 | Precision: 0.7727 | Recall: 0.7381 | F1: 0.7320
📘 Epoch 5/10
   🔹 Train Loss: 0.6641
   🔸 Val Loss: 0.6436 | Acc: 0.7867 | Precision: 0.7885 | Recall: 0.7809 | F1: 0.7788
📘 Epoch 6/10
   🔹 Train Loss: 0.6045
   🔸 Val Loss: 0.5613 | Acc: 0.8183 | Precision: 0.8236 | Recall: 0.8156 | F1: 0.8162
📘 Epoch 7/10
   🔹 Train Loss: 0.5239
   🔸 Val Loss: 0.5405 | Acc: 0.8250 | Precision: 0.8367 | Recall: 0.8252 | F1: 0.8261
📘 Epoch 8/10
   🔹 Train Loss: 0.4917
   🔸 Val Loss: 0.4793 | Acc: 0.8433 | Precision: 0.8558 | Recall: 0.8398 | F1: 0.8405
📘 Epoch 9/10
   

In [48]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = RotEquivariantCNN()
trained_model = train_model(model, train_loader, val_loader, device, model_name="rotcnn")

📘 Epoch 1/10
   🔹 Train Loss: 1.8935
   🔸 Val Loss: 1.0749 | Acc: 0.6300 | Precision: 0.6881 | Recall: 0.6369 | F1: 0.6278
📘 Epoch 2/10
   🔹 Train Loss: 0.8379
   🔸 Val Loss: 0.6432 | Acc: 0.7933 | Precision: 0.7987 | Recall: 0.7925 | F1: 0.7903
📘 Epoch 3/10
   🔹 Train Loss: 0.6159
   🔸 Val Loss: 0.5488 | Acc: 0.8333 | Precision: 0.8398 | Recall: 0.8309 | F1: 0.8327
📘 Epoch 4/10
   🔹 Train Loss: 0.4823
   🔸 Val Loss: 0.4544 | Acc: 0.8733 | Precision: 0.8758 | Recall: 0.8750 | F1: 0.8722
📘 Epoch 5/10
   🔹 Train Loss: 0.4233
   🔸 Val Loss: 0.4435 | Acc: 0.8533 | Precision: 0.8584 | Recall: 0.8533 | F1: 0.8504
📘 Epoch 6/10
   🔹 Train Loss: 0.3671
   🔸 Val Loss: 0.3496 | Acc: 0.8817 | Precision: 0.8809 | Recall: 0.8814 | F1: 0.8806
📘 Epoch 7/10
   🔹 Train Loss: 0.3683
   🔸 Val Loss: 0.3613 | Acc: 0.8917 | Precision: 0.8972 | Recall: 0.8887 | F1: 0.8895
📘 Epoch 8/10
   🔹 Train Loss: 0.3240
   🔸 Val Loss: 0.3498 | Acc: 0.8900 | Precision: 0.8927 | Recall: 0.8895 | F1: 0.8891
📘 Epoch 9/10
   