# Rotated MNIST: Model Training Notebook (Standalone)

This notebook trains and compares:
- Logistic Regression
- Standard CNN
- Rotation-Equivariant CNN (E2CNN) on a class-balanced subset of the Rotated MNIST dataset.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
from e2cnn import gspaces
from e2cnn import nn as enn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import random, os
    

In [2]:

# Set seed and device
seed = 42
random.seed(seed)
torch.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    

In [3]:

# Data preparation with balanced validation split
def get_dataloaders(batch_size=64, subset_fraction=1.0, seed=42):
    transform = transforms.Compose([
        transforms.RandomRotation(degrees=(-180, 180)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    full_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
    
    if subset_fraction < 1.0:
        class_indices = {i: [] for i in range(10)}
        for idx, (_, label) in enumerate(full_dataset):
            class_indices[label].append(idx)
        samples_per_class = int((len(full_dataset) * subset_fraction) / 10)
        balanced_indices = [random.sample(class_indices[i], samples_per_class) for i in range(10)]
        balanced_indices = [item for sublist in balanced_indices for item in sublist]
        random.shuffle(balanced_indices)
        subset = Subset(full_dataset, balanced_indices)
    else:
        subset = full_dataset

    train_size = int(0.8 * len(subset))
    val_size = len(subset) - train_size
    train_set, val_set = random_split(subset, [train_size, val_size])

    test_set = datasets.MNIST('data', train=False, download=True, transform=transform)
    return (
        DataLoader(train_set, batch_size=batch_size, shuffle=True),
        DataLoader(val_set, batch_size=batch_size),
        DataLoader(test_set, batch_size=batch_size)
    )
    

In [4]:

# Simple Logistic Regression model
class LogisticRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.linear(x)
    

In [5]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [6]:
class RotEquivariantCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        r2_act = gspaces.Rot2dOnR2(N=8)

        in_type = enn.FieldType(r2_act, [r2_act.trivial_repr])
        self.input_type = in_type

        self.block1 = enn.SequentialModule(
            enn.R2Conv(
                in_type,
                enn.FieldType(r2_act, 8 * [r2_act.regular_repr]),
                kernel_size=5,
                padding=2,
                bias=False,
            ),
            enn.ReLU(enn.FieldType(r2_act, 8 * [r2_act.regular_repr]), inplace=True),
            enn.PointwiseMaxPool(
                enn.FieldType(r2_act, 8 * [r2_act.regular_repr]), kernel_size=2
            ),
        )

        self.block2 = enn.SequentialModule(
            enn.R2Conv(
                self.block1.out_type,
                enn.FieldType(r2_act, 16 * [r2_act.regular_repr]),
                kernel_size=5,
                padding=2,
                bias=False,
            ),
            enn.ReLU(enn.FieldType(r2_act, 16 * [r2_act.regular_repr]), inplace=True),
            enn.PointwiseMaxPool(
                enn.FieldType(r2_act, 16 * [r2_act.regular_repr]), kernel_size=2
            ),
        )

        c = self.block2.out_type.size
        self.fc1 = nn.Linear(c * 7 * 7, num_classes)

    def forward(self, x):
        x = enn.GeometricTensor(x, self.input_type)
        x = self.block1(x)
        x = self.block2(x)
        x = x.tensor
        x = x.view(x.size(0), -1)
        return self.fc1(x)


In [None]:
def train_model(model, train_loader, val_loader, device, epochs=5, model_name='model', save_best=True):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    best_f1 = 0
    best_model_state = None

    history = {
        "train_loss": [],
        "val_loss": [],
        "val_accuracy": [],
        "val_precision": [],
        "val_recall": [],
        "val_f1": [],
    }

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)
        history["train_loss"].append(train_loss)

        # Validation
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                preds = outputs.argmax(dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_loss /= len(val_loader)
        acc = accuracy_score(all_targets, all_preds)
        precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
        recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
        f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

        history["val_loss"].append(val_loss)
        history["val_accuracy"].append(acc)
        history["val_precision"].append(precision)
        history["val_recall"].append(recall)
        history["val_f1"].append(f1)

        if save_best and f1 > best_f1:
            best_f1 = f1
            best_model_state = model.state_dict()

        print(f"ðŸ“˜ Epoch {epoch+1}/{epochs}")
        print(f"   ðŸ”¹ Train Loss: {train_loss:.4f}")
        print(f"   ðŸ”¸ Val Loss: {val_loss:.4f} | Acc: {acc:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f}")

    print("âœ… Training complete")

    if save_best and best_model_state is not None:
        model.load_state_dict(best_model_state)
         # Define the path
        save_path = os.path.join(f"{model_name}.pth")

        # Save the model
        torch.save(model.state_dict(), save_path)
        print(f"âœ… Model saved as {save_path}")

    return model, history

    

In [8]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = LogisticRegression()
trained_model = train_model(model, train_loader, val_loader, device, model_name="logistic")

ðŸ“˜ Epoch 1/10
   ðŸ”¹ Train Loss: 1.9761
   ðŸ”¸ Val Loss: 1.6987 | Acc: 0.3950 | Precision: 0.3892 | Recall: 0.4003 | F1: 0.3688
ðŸ“˜ Epoch 2/10
   ðŸ”¹ Train Loss: 1.6057
   ðŸ”¸ Val Loss: 1.5364 | Acc: 0.4417 | Precision: 0.4205 | Recall: 0.4466 | F1: 0.4100
ðŸ“˜ Epoch 3/10
   ðŸ”¹ Train Loss: 1.5110
   ðŸ”¸ Val Loss: 1.5069 | Acc: 0.4517 | Precision: 0.4419 | Recall: 0.4537 | F1: 0.4346
ðŸ“˜ Epoch 4/10
   ðŸ”¹ Train Loss: 1.4745
   ðŸ”¸ Val Loss: 1.4629 | Acc: 0.4650 | Precision: 0.4526 | Recall: 0.4711 | F1: 0.4502
ðŸ“˜ Epoch 5/10
   ðŸ”¹ Train Loss: 1.4576
   ðŸ”¸ Val Loss: 1.4529 | Acc: 0.4833 | Precision: 0.4676 | Recall: 0.4922 | F1: 0.4606
ðŸ“˜ Epoch 6/10
   ðŸ”¹ Train Loss: 1.4291
   ðŸ”¸ Val Loss: 1.4382 | Acc: 0.4633 | Precision: 0.4517 | Recall: 0.4758 | F1: 0.4413
ðŸ“˜ Epoch 7/10
   ðŸ”¹ Train Loss: 1.4082
   ðŸ”¸ Val Loss: 1.4190 | Acc: 0.4900 | Precision: 0.4698 | Recall: 0.4930 | F1: 0.4672
ðŸ“˜ Epoch 8/10
   ðŸ”¹ Train Loss: 1.3988
   ðŸ”¸ Val Loss: 1.4237 | Acc: 0

In [9]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = SimpleCNN()
trained_model = train_model(model, train_loader, val_loader, device, model_name="cnn")

ðŸ“˜ Epoch 1/10
   ðŸ”¹ Train Loss: 1.7399
   ðŸ”¸ Val Loss: 1.4236 | Acc: 0.5300 | Precision: 0.5403 | Recall: 0.5307 | F1: 0.5218
ðŸ“˜ Epoch 2/10
   ðŸ”¹ Train Loss: 1.1725
   ðŸ”¸ Val Loss: 1.0356 | Acc: 0.6683 | Precision: 0.6843 | Recall: 0.6668 | F1: 0.6533
ðŸ“˜ Epoch 3/10
   ðŸ”¹ Train Loss: 0.8677
   ðŸ”¸ Val Loss: 0.8877 | Acc: 0.7200 | Precision: 0.7284 | Recall: 0.7165 | F1: 0.7163
ðŸ“˜ Epoch 4/10
   ðŸ”¹ Train Loss: 0.7379
   ðŸ”¸ Val Loss: 0.7421 | Acc: 0.7583 | Precision: 0.7577 | Recall: 0.7567 | F1: 0.7555
ðŸ“˜ Epoch 5/10
   ðŸ”¹ Train Loss: 0.6461
   ðŸ”¸ Val Loss: 0.7313 | Acc: 0.7600 | Precision: 0.7634 | Recall: 0.7568 | F1: 0.7570
ðŸ“˜ Epoch 6/10
   ðŸ”¹ Train Loss: 0.5801
   ðŸ”¸ Val Loss: 0.6998 | Acc: 0.7500 | Precision: 0.7671 | Recall: 0.7481 | F1: 0.7481
ðŸ“˜ Epoch 7/10
   ðŸ”¹ Train Loss: 0.5265
   ðŸ”¸ Val Loss: 0.6234 | Acc: 0.8067 | Precision: 0.8056 | Recall: 0.8056 | F1: 0.8045
ðŸ“˜ Epoch 8/10
   ðŸ”¹ Train Loss: 0.4776
   ðŸ”¸ Val Loss: 0.5591 | Acc: 0

In [10]:
# Run training
train_loader, val_loader, test_loader = get_dataloaders(subset_fraction=0.05)
model = RotEquivariantCNN()
trained_model = train_model(model, train_loader, val_loader, device, model_name="rotcnn")

  full_mask[mask] = norms.to(torch.uint8)


ðŸ“˜ Epoch 1/10
   ðŸ”¹ Train Loss: 2.8032
   ðŸ”¸ Val Loss: 1.4244 | Acc: 0.4950 | Precision: 0.6331 | Recall: 0.5080 | F1: 0.4667
ðŸ“˜ Epoch 2/10
   ðŸ”¹ Train Loss: 1.1412
   ðŸ”¸ Val Loss: 0.9092 | Acc: 0.6867 | Precision: 0.7324 | Recall: 0.7026 | F1: 0.6728
ðŸ“˜ Epoch 3/10
   ðŸ”¹ Train Loss: 0.8364
   ðŸ”¸ Val Loss: 0.7002 | Acc: 0.7733 | Precision: 0.7917 | Recall: 0.7833 | F1: 0.7657
ðŸ“˜ Epoch 4/10
   ðŸ”¹ Train Loss: 0.6760
   ðŸ”¸ Val Loss: 0.7131 | Acc: 0.7583 | Precision: 0.8140 | Recall: 0.7699 | F1: 0.7603
ðŸ“˜ Epoch 5/10
   ðŸ”¹ Train Loss: 0.5724
   ðŸ”¸ Val Loss: 0.5077 | Acc: 0.8400 | Precision: 0.8532 | Recall: 0.8449 | F1: 0.8391
ðŸ“˜ Epoch 6/10
   ðŸ”¹ Train Loss: 0.4944
   ðŸ”¸ Val Loss: 0.4629 | Acc: 0.8533 | Precision: 0.8623 | Recall: 0.8549 | F1: 0.8556
ðŸ“˜ Epoch 7/10
   ðŸ”¹ Train Loss: 0.4452
   ðŸ”¸ Val Loss: 0.4394 | Acc: 0.8733 | Precision: 0.8801 | Recall: 0.8769 | F1: 0.8747
ðŸ“˜ Epoch 8/10
   ðŸ”¹ Train Loss: 0.4179
   ðŸ”¸ Val Loss: 0.4177 | Acc: 0