In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict
from collections import defaultdict
import random


class ResidualBlock(nn.Module):
    """
    Residual Block with two convolutional layers and a shortcut connection.
    Facilitates better gradient flow and allows for deeper networks.
    """

    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut connection
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the residual block.
        """
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        shortcut = self.shortcut(x)
        out += shortcut
        out = self.relu(out)
        return out


class GeometricRegularization(nn.Module):
    """
    Implements geometric regularization for CNN feature maps using differential geometry concepts.
    This version includes improved numerical stability and better loss scaling.
    """

    def __init__(self, lambda_area: float = 0.001, lambda_curv: float = 0.001):
        super().__init__()
        self.lambda_area = lambda_area
        self.lambda_curv = lambda_curv
        self.eps = 1e-6  # Increased epsilon for better numerical stability

        # Instance normalization to normalize feature maps before regularization
        self.instance_norm = nn.InstanceNorm2d(1, affine=False)

    def compute_derivatives(self, feature_map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute first derivatives using Sobel filters for better stability.
        Args:
            feature_map: Tensor of shape (B, C, H, W)
        Returns:
            du, dv: First derivatives in u and v directions
        """
        B, C, H, W = feature_map.shape

        # Sobel filters for better gradient computation
        du_kernel = (
            torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], device=feature_map.device)
            .view(1, 1, 3, 3)
            .repeat(C, 1, 1, 1)
        )
        dv_kernel = (
            torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], device=feature_map.device)
            .view(1, 1, 3, 3)
            .repeat(C, 1, 1, 1)
        )

        # Convert kernels to feature map's dtype
        du_kernel = du_kernel.to(feature_map.dtype)
        dv_kernel = dv_kernel.to(feature_map.dtype)

        # Compute gradients using convolution
        padded = F.pad(feature_map, (1, 1, 1, 1), mode="reflect")
        du = F.conv2d(padded, du_kernel, groups=C) / 8.0
        dv = F.conv2d(padded, dv_kernel, groups=C) / 8.0

        return du, dv

    def compute_second_derivatives(self, feature_map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Compute second derivatives with improved stability using central differences.
        Args:
            feature_map: Tensor of shape (B, C, H, W)
        Returns:
            duu, dvv, duv: Second derivatives
        """
        padded = F.pad(feature_map, (2, 2, 2, 2), mode="reflect")

        # Second derivatives using central differences
        duu = (padded[:, :, 2:-2, 4:] - 2 * padded[:, :, 2:-2, 2:-2] + padded[:, :, 2:-2, :-4]) / 4.0
        dvv = (padded[:, :, 4:, 2:-2] - 2 * padded[:, :, 2:-2, 2:-2] + padded[:, :, :-4, 2:-2]) / 4.0
        duv = (
            padded[:, :, 3:-1, 3:-1] - padded[:, :, 3:-1, 1:-3] - padded[:, :, 1:-3, 3:-1] + padded[:, :, 1:-3, 1:-3]
        ) / 4.0

        return duu, dvv, duv

    def compute_metric_tensor(
        self, du: torch.Tensor, dv: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Compute components of the metric tensor (first fundamental form) with improved stability.
        Args:
            du, dv: First derivatives
        Returns:
            guu, gvv, guv: Metric tensor components
        """
        # Add small constant for stability
        guu = torch.sum(du * du, dim=1) + self.eps
        gvv = torch.sum(dv * dv, dim=1) + self.eps
        guv = torch.sum(du * dv, dim=1)

        return guu, gvv, guv

    def compute_mean_curvature(
        self,
        du: torch.Tensor,
        dv: torch.Tensor,
        duu: torch.Tensor,
        dvv: torch.Tensor,
        duv: torch.Tensor,
        guu: torch.Tensor,
        gvv: torch.Tensor,
        guv: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute mean curvature with improved numerical stability.
        Uses the Laplacian formulation that works for feature maps of any dimensionality.
        """
        # Normalize derivatives for better numerical stability
        du_norm = torch.sqrt(torch.sum(du * du, dim=1) + self.eps)
        dv_norm = torch.sqrt(torch.sum(dv * dv, dim=1) + self.eps)

        du = du / (du_norm.unsqueeze(1) + self.eps)
        dv = dv / (dv_norm.unsqueeze(1) + self.eps)

        # Compute mean curvature using normalized derivatives
        det_g = guu * gvv - guv * guv + self.eps
        H = (
            gvv * torch.sum(duu * du, dim=1) + guu * torch.sum(dvv * dv, dim=1) - 2 * guv * torch.sum(duv * du, dim=1)
        ) / (2 * torch.sqrt(det_g))

        return H

    def forward(self, feature_map: torch.Tensor) -> torch.Tensor:
        """
        Compute geometric regularization loss with improved stability and normalization.
        Args:
            feature_map: Tensor of shape (B, C, H, W)
        Returns:
            loss: Geometric regularization loss
        """
        # Normalize each feature map channel-wise
        # Assuming feature_map shape is (B, C, H, W)
        B, C, H, W = feature_map.shape
        # Reshape to (B*C, 1, H, W) for instance normalization
        feature_map = feature_map.view(B * C, 1, H, W)
        feature_map = self.instance_norm(feature_map)
        # Reshape back to (B, C, H, W)
        feature_map = feature_map.view(B, C, H, W)

        # Compute derivatives
        du, dv = self.compute_derivatives(feature_map)
        duu, dvv, duv = self.compute_second_derivatives(feature_map)

        # Compute metric tensor
        guu, gvv, guv = self.compute_metric_tensor(du, dv)

        # Compute area term with gradient clipping
        det_g = guu * gvv - guv * guv + self.eps
        area_loss = torch.sqrt(det_g).mean()
        area_loss = torch.clamp(area_loss, max=10.0)

        # Compute mean curvature with stability improvements
        H = self.compute_mean_curvature(du, dv, duu, dvv, duv, guu, gvv, guv)
        curvature_loss = torch.abs(H).mean()
        curvature_loss = torch.clamp(curvature_loss, max=10.0)

        # Combine losses with proper scaling
        total_loss = self.lambda_area * area_loss + self.lambda_curv * curvature_loss

        return total_loss


class BaseCNN(nn.Module):
    """
    Residual CNN architecture serving as the baseline model.
    Includes residual connections for better gradient flow and feature reuse.
    """

    def __init__(self, num_classes: int = 10):
        super(BaseCNN, self).__init__()
        # Define layers using ResidualBlock
        self.layer1 = ResidualBlock(3, 64, stride=1)  # Output: 64 x 32 x 32
        self.layer2 = ResidualBlock(64, 128, stride=2)  # Output: 128 x 16 x 16
        self.layer3 = ResidualBlock(128, 256, stride=2)  # Output: 256 x 8 x 8

        # Pooling and dropout
        self.pool = nn.AdaptiveAvgPool2d((4, 4))  # Output: 256 x 4 x 4
        self.dropout = nn.Dropout(0.5)

        # Fully connected layers
        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Forward pass for the BaseCNN.
        Returns:
            - Output logits
            - List of feature maps from different layers
        """
        feature_maps = []

        x = self.layer1(x)  # 64 x 32 x 32
        feature_maps.append(x)

        x = self.layer2(x)  # 128 x 16 x 16
        feature_maps.append(x)

        x = self.layer3(x)  # 256 x 8 x 8
        feature_maps.append(x)

        x = self.pool(x)  # 256 x 4 x 4
        x = x.view(-1, 256 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x, feature_maps


class GeometricCNN(BaseCNN):
    """
    CNN with geometric regularization.
    Extends BaseCNN to include geometric regularization on feature maps.
    """

    def __init__(
        self,
        num_classes: int = 10,
        lambda_area: float = 0.001,
        lambda_curv: float = 0.001,
    ):
        super(GeometricCNN, self).__init__(num_classes)
        self.geo_reg = GeometricRegularization(lambda_area, lambda_curv)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Forward pass for GeometricCNN.
        Returns:
            - Output logits
            - List of feature maps from different layers for geometric regularization
        """
        feature_maps = []

        x = self.layer1(x)  # 64 x 32 x 32
        feature_maps.append(x)

        x = self.layer2(x)  # 128 x 16 x 16
        feature_maps.append(x)

        x = self.layer3(x)  # 256 x 8 x 8
        feature_maps.append(x)

        x = self.pool(x)  # 256 x 4 x 4
        x = x.view(-1, 256 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x, feature_maps


class Trainer:
    """
    Handles model training and evaluation with improved stability and monitoring.
    Includes learning rate scheduling, gradient clipping, and early stopping.
    """

    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        device: torch.device,
        base_model: nn.Module = None,
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.criterion = nn.CrossEntropyLoss()

        # Optimizer and scheduler setup based on model type
        base_lr = 0.001
        if isinstance(model, GeometricCNN):
            self.optimizer = optim.AdamW(
                model.parameters(),
                lr=base_lr * 0.5,
                weight_decay=1e-4,
                amsgrad=True,
            )
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=15, eta_min=1e-6)
        else:
            self.optimizer = optim.Adam(model.parameters(), lr=base_lr)
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=50, eta_min=1e-6)

        # Training state
        self.metrics = defaultdict(list)
        self.best_accuracy = 0.0
        self.patience = 0
        self.max_patience = 15
        self.clip_value = 1.0

        # Loss scaling for geometric regularization
        self.geo_weight_scheduler = lambda epoch: min(1.0, epoch / 10.0)

        self.base_model = base_model.to(device) if base_model else None

    def train_epoch(self) -> Tuple[float, float]:
        """Train for one epoch with gradient clipping and loss scaling"""
        self.model.train()
        running_task_loss = 0.0
        running_geo_loss = 0.0
        current_epoch = len(self.metrics["epoch"])

        for inputs, labels in self.train_loader:
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()

            if isinstance(self.model, GeometricCNN):
                # Forward pass with geometric regularization
                outputs, feature_maps = self.model(inputs)
                task_loss = self.criterion(outputs, labels)

                # Compute geometric loss with progressive scaling
                geo_weight = self.geo_weight_scheduler(current_epoch)
                geo_loss = sum(self.model.geo_reg(fm) for fm in feature_maps)
                loss = task_loss + geo_weight * geo_loss

                running_task_loss += task_loss.item()
                running_geo_loss += geo_loss.item()
            else:
                # Forward pass for BaseCNN
                outputs, feature_maps = self.model(inputs)
                loss = self.criterion(outputs, labels)
                running_task_loss += loss.item()

            # Backward pass with gradient clipping
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_value)
            self.optimizer.step()

        # Compute average losses
        avg_task_loss = running_task_loss / len(self.train_loader)
        avg_geo_loss = running_geo_loss / len(self.train_loader) if isinstance(self.model, GeometricCNN) else 0.0

        return avg_task_loss, avg_geo_loss

    def validate(self) -> Tuple[float, float]:
        """Compute validation accuracy and loss"""
        self.model.eval()
        correct = 0
        total = 0
        val_loss = 0.0

        with torch.no_grad():
            for inputs, labels in self.val_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                outputs, _ = self.model(inputs)

                loss = self.criterion(outputs, labels)
                val_loss += loss.item()

                # Compute accuracy
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        # Calculate final metrics
        accuracy = 100.0 * correct / total
        avg_val_loss = val_loss / len(self.val_loader)

        # Update best accuracy and patience for early stopping
        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            self.patience = 0
        else:
            self.patience += 1

        return accuracy, avg_val_loss

    def train(self, epochs: int = 15) -> Dict[str, List[float]]:
        """
        Complete training process with comprehensive monitoring and visualization
        Args:
            epochs: Number of training epochs
        Returns:
            Dictionary containing training metrics
        """
        for epoch in range(epochs):
            # Training phase with loss computation
            train_loss, geo_loss = self.train_epoch()
            accuracy, val_loss = self.validate()

            # Update learning rate scheduler
            self.scheduler.step()

            # Store all metrics for plotting
            self.metrics["epoch"].append(epoch)
            self.metrics["train_loss"].append(train_loss)
            self.metrics["val_loss"].append(val_loss)
            self.metrics["val_accuracy"].append(accuracy)
            if isinstance(self.model, GeometricCNN):
                self.metrics["geo_loss"].append(geo_loss)

            # Print detailed training progress
            print(f"\nEpoch {epoch + 1}/{epochs}:")
            print(f"Training Loss: {train_loss:.3f}")
            if isinstance(self.model, GeometricCNN):
                print(f"Geometric Loss: {geo_loss:.3f}")
            print(f"Validation Loss: {val_loss:.3f}")
            print(f"Validation Accuracy: {accuracy:.2f}%")
            print(f'Learning Rate: {self.optimizer.param_groups[0]["lr"]:.6f}')

            # Check for early stopping
            if self.patience >= self.max_patience:
                print(f"\nEarly stopping triggered after {epoch + 1} epochs")
                print(f"Best validation accuracy: {self.best_accuracy:.2f}%")
                break

        # Create final visualizations
        self.plot_training_metrics()

        return self.metrics

    def plot_training_metrics(self):
        """
        Create comprehensive visualization of training metrics.
        Includes loss curves, accuracy progression, and geometric metrics if applicable.
        """
        # Determine number of subplots based on model type
        n_plots = 3 if isinstance(self.model, GeometricCNN) else 2
        fig, axes = plt.subplots(1, n_plots, figsize=(6 * n_plots, 5))
        fig.suptitle("Training Progress", fontsize=16, y=1.05)

        epochs = self.metrics["epoch"]

        # Plot training and validation loss
        axes[0].plot(epochs, self.metrics["train_loss"], "b-", label="Training Loss")
        axes[0].plot(epochs, self.metrics["val_loss"], "r-", label="Validation Loss")
        axes[0].set_title("Loss Curves")
        axes[0].set_xlabel("Epoch")
        axes[0].set_ylabel("Loss")
        axes[0].legend()
        axes[0].grid(True)

        # Plot validation accuracy
        axes[1].plot(epochs, self.metrics["val_accuracy"], "g-")
        axes[1].set_title("Validation Accuracy")
        axes[1].set_xlabel("Epoch")
        axes[1].set_ylabel("Accuracy (%)")
        axes[1].grid(True)

        # Plot geometric loss if applicable
        if isinstance(self.model, GeometricCNN):
            axes[2].plot(epochs, self.metrics["geo_loss"], "m-")
            axes[2].set_title("Geometric Regularization Loss")
            axes[2].set_xlabel("Epoch")
            axes[2].set_ylabel("Loss")
            axes[2].grid(True)

        plt.tight_layout()
        plt.show()


def plot_training_comparison(base_metrics: Dict[str, List[float]], geo_metrics: Dict[str, List[float]]) -> None:
    """
    Create side-by-side comparison of base CNN and geometric CNN performance.

    Args:
        base_metrics: Training metrics from base CNN
        geo_metrics: Training metrics from geometric CNN
    """
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    fig.suptitle("Performance Comparison: Base CNN vs Geometric CNN", fontsize=16, y=1.05)

    epochs = range(1, len(base_metrics["train_loss"]) + 1)

    # Plot training loss comparison
    axes[0].plot(epochs, base_metrics["train_loss"], "b-", label="Base CNN")
    axes[0].plot(epochs, geo_metrics["train_loss"], "r-", label="Geometric CNN")
    axes[0].set_title("Training Loss")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].legend()
    axes[0].grid(True)

    # Plot validation accuracy comparison
    axes[1].plot(epochs, base_metrics["val_accuracy"], "b-", label="Base CNN")
    axes[1].plot(epochs, geo_metrics["val_accuracy"], "r-", label="Geometric CNN")
    axes[1].set_title("Validation Accuracy")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Accuracy (%)")
    axes[1].legend()
    axes[1].grid(True)

    plt.tight_layout()
    plt.show()


def compare_activation_maps(
    base_model: nn.Module, geo_model: nn.Module, test_loader: DataLoader, device: torch.device, num_maps: int = 5
):
    """
    Randomly selects a sample image from the test set and compares 5 activation maps from the first and second layers
    of BaseCNN and GeometricCNN.

    Args:
        base_model: Trained BaseCNN model
        geo_model: Trained GeometricCNN model
        test_loader: DataLoader for the test set
        device: Computation device
        num_maps: Number of activation maps to compare per layer
    """
    base_model.eval()
    geo_model.eval()

    # Get a random sample from the test set
    try:
        sample_inputs, _ = next(iter(test_loader))
    except StopIteration:
        print("Test loader is empty. Skipping activation maps comparison.")
        return

    sample_idx = random.randint(0, sample_inputs.size(0) - 1)
    sample_input = sample_inputs[sample_idx].unsqueeze(0).to(device)

    # Get the image before normalization for display
    inv_normalize = transforms.Normalize(
        mean=[-0.4914 / 0.2023, -0.4822 / 0.1994, -0.4465 / 0.2010], std=[1 / 0.2023, 1 / 0.1994, 1 / 0.2010]
    )
    sample_img = inv_normalize(sample_input.squeeze(0)).cpu().numpy()
    sample_img = np.transpose(sample_img, (1, 2, 0))
    sample_img = np.clip(sample_img, 0, 1)

    # Get feature maps
    with torch.no_grad():
        _, base_feature_maps = base_model(sample_input)
        _, geo_feature_maps = geo_model(sample_input)

    # Layers to compare (first and second layers)
    layers = ["Layer 1", "Layer 2"]
    layer_indices = [0, 1]  # Corresponding to layer1 and layer2

    for layer_idx, layer_name in zip(layer_indices, layers):
        base_fmap = base_feature_maps[layer_idx]  # (1, C, H, W)
        geo_fmap = geo_feature_maps[layer_idx]  # (1, C, H, W)

        # Select 5 random channels
        num_channels = base_fmap.size(1)
        if num_channels < num_maps:
            selected_channels = list(range(num_channels))
        else:
            selected_channels = random.sample(range(num_channels), num_maps)

        fig, axes = plt.subplots(num_maps + 1, 3, figsize=(15, 5 * (num_maps + 1)))
        fig.suptitle(f"{layer_name} Activation Maps Comparison", fontsize=16, y=1.02)

        # Display Reference Image
        axes[0, 0].imshow(sample_img)
        axes[0, 0].axis("off")
        axes[0, 0].set_title("Reference Image")

        # Titles for Base and Geometric CNN
        axes[0, 1].imshow(np.zeros((10, 10)), cmap="gray")  # Placeholder
        axes[0, 1].axis("off")
        axes[0, 1].set_title("Base CNN")

        axes[0, 2].imshow(np.zeros((10, 10)), cmap="gray")  # Placeholder
        axes[0, 2].axis("off")
        axes[0, 2].set_title("Geometric CNN")

        for i, channel in enumerate(selected_channels, start=1):
            # BaseCNN activation map
            base_feat = base_fmap[0, channel].cpu().numpy()
            base_feat_norm = (base_feat - base_feat.min()) / (base_feat.max() - base_feat.min() + 1e-8)
            axes[i, 1].imshow(base_feat_norm, cmap="viridis")
            axes[i, 1].axis("off")
            axes[i, 1].set_title(f"Base CNN - Channel {channel}")

            # GeometricCNN activation map
            geo_feat = geo_fmap[0, channel].cpu().numpy()
            geo_feat_norm = (geo_feat - geo_feat.min()) / (geo_feat.max() - geo_feat.min() + 1e-8)
            axes[i, 2].imshow(geo_feat_norm, cmap="viridis")
            axes[i, 2].axis("off")
            axes[i, 2].set_title(f"Geometric CNN - Channel {channel}")

        # Hide the reference image's other columns
        for i in range(1, num_maps + 1):
            axes[i, 0].axis("off")

        plt.tight_layout()
        plt.show()

In [None]:
# Set random seed for reproducibility
torch.manual_seed(17)
np.random.seed(17)
random.seed(17)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data preprocessing
transform_train = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)

testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

# Initialize Base CNN
print("\nTraining Base CNN...")
base_cnn = BaseCNN()
base_trainer = Trainer(base_cnn, trainloader, testloader, device)
base_metrics = base_trainer.train(epochs=15)

# Initialize Geometric CNN with Residual Connections
print("\nTraining Geometric CNN...")
geo_cnn = GeometricCNN(lambda_area=0.01, lambda_curv=0.1)
geo_trainer = Trainer(geo_cnn, trainloader, testloader, device, base_model=base_cnn)
geo_metrics = geo_trainer.train(epochs=15)

# Plot performance comparison
plot_training_comparison(base_metrics, geo_metrics)

# Compare Activation Maps
compare_activation_maps(base_cnn, geo_cnn, testloader, device, num_maps=5)

# Print final results
print("\nFinal Results:")
print(f"Base CNN - Best Validation Accuracy: {base_trainer.best_accuracy:.2f}%")
print(f"Geometric CNN - Best Validation Accuracy: {geo_trainer.best_accuracy:.2f}%")