In [18]:
import os
import torch
from typing import List, Callable
from torchvision.datasets import MNIST
from torchvision import transforms
import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationMode
from torch.utils.data import TensorDataset, DataLoader
from math import factorial
from itertools import product

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generate basis functions b_i(x, y) = x^p y^q / r! with p+q=r <= MAX_DEGREE
def generate_homogeneous_basis(max_degree: int):
    basis = []
    degrees = []  # (r = p + q)
    for r in range(max_degree + 1):
        for p in range(r + 1):
            q = r - p
            def monomial(p=p, q=q, r=r):
                return lambda X, Y: (X**p * Y**q) / factorial(r)
            basis.append(monomial())
            degrees.append(r)
    return basis, degrees

# Compute basis vector from image phi
def compute_basis_vector(phi: torch.Tensor) -> torch.Tensor:
    width, height = phi.shape
    k = torch.arange(width, device=phi.device) - width // 2
    j = torch.arange(height, device=phi.device) - height // 2
    X, Y = torch.meshgrid(k, j, indexing="ij")
    return torch.stack([(b(X, Y) * phi).sum() for b in BASIS])

# Precompute and cache dataset
def get_or_create_augmented_tensor_dataset(train: bool = True, limit: int | None = None, overwrite_cache: bool = False) -> TensorDataset:
    split_name = "train" if train else "test"
    limit_str = f"{limit}" if limit is not None else "all"
    CACHE_DIR = f"./data/scaled_mnist_limit{limit_str}_{split_name}_BaseDegree{MAX_DEGREE}"
    os.makedirs(CACHE_DIR, exist_ok=True)
    cache_path = os.path.join(CACHE_DIR, f"augmented_mnist_{split_name}.pt")

    if os.path.exists(cache_path) and not overwrite_cache:
        print(f"Loading cached dataset from: {cache_path}")
        dataset = torch.load(cache_path)
        return TensorDataset(dataset.tensors[0].to(device), dataset.tensors[1].to(device))

    print(f"Creating augmented dataset for MNIST ({split_name})...")
    mnist_ds = MNIST(root="./data", train=train, download=True, transform=transforms.ToTensor())

    X_list, y_list = [], []

    for idx, (tensor_img, label) in enumerate(mnist_ds):
        if limit is not None and idx >= limit:
            break

        for scale in SCALES:
            for angle in ANGLES:
                transformed = F.affine(
                    tensor_img,
                    angle=angle,
                    translate=(0, 0),
                    scale=scale,
                    shear=(0.0, 0.0),
                    interpolation=InterpolationMode.NEAREST,
                    fill=0.0,
                    center=None,
                )
                phi = transformed.squeeze(0).to(device)  # (28, 28)
                vec = compute_basis_vector(phi)
                X_list.append(vec)
                y_list.append(label)

    X_tensor = torch.stack(X_list)
    y_tensor = torch.tensor(y_list, dtype=torch.long, device=device)

    dataset = TensorDataset(X_tensor, y_tensor)
    torch.save(dataset, cache_path)
    print(f"Saved dataset to: {cache_path}")

    return dataset

# Create data loaders
def get_data_loaders(batch_size: int = 512, limit: int | None = None, overwrite_cache: bool = False):
    train_ds = get_or_create_augmented_tensor_dataset(train=True, limit=limit, overwrite_cache=overwrite_cache)
    test_ds = get_or_create_augmented_tensor_dataset(train=False, limit=limit, overwrite_cache=overwrite_cache)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size)

    return train_loader, test_loader


In [19]:
# Configuration
SCALES: List[float] = [0.3 + 0.1 * i for i in range(13)]  # 0.3 to 1.5
ANGLES: List[int] = [0]
MAX_DEGREE = 6 # Degree of the basis functions

BASIS, DEGREE_INFO = generate_homogeneous_basis(MAX_DEGREE)

print(len(BASIS))

28


In [13]:
train_loader, test_loader = get_data_loaders(batch_size=512, limit=10000)

Creating augmented dataset for MNIST (train)...
Saved dataset to: ./data/scaled_mnist_limit10000_train_BaseDegree6/augmented_mnist_train.pt
Creating augmented dataset for MNIST (test)...
Saved dataset to: ./data/scaled_mnist_limit10000_test_BaseDegree6/augmented_mnist_test.pt


## Train util

In [22]:
import torch
from tqdm import trange
import matplotlib.pyplot as plt

def compute_accuracy(model, loader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            outputs = model(x)
            pred_labels = torch.argmax(outputs, dim=1)
            correct += (pred_labels == y).sum().item()
            total += y.size(0)
    return correct / total

# --- Training loop
def train_models(num_epochs=1000, early_stop_patience=200):
    best_hnn_acc, best_mlp_acc = 0.0, 0.0
    patience = 0

    train_loss_h, test_loss_h = [], []
    train_loss_m, test_loss_m = [], []
    train_acc_h, test_acc_h = [], []
    train_acc_m, test_acc_m = [], []

    for epoch in trange(num_epochs, desc="Training Epochs"):
        model_hnn.train()
        model_mlp.train()
        epoch_loss_h, epoch_loss_m = 0.0, 0.0

        for x_batch, y_batch in train_loader:
            # HNN
            opt_hnn.zero_grad()
            out_hnn = model_hnn(x_batch)
            loss_h = criterion(out_hnn, y_batch)
            loss_h.backward()
            opt_hnn.step()
            epoch_loss_h += loss_h.item()

            # MLP
            opt_mlp.zero_grad()
            out_mlp = model_mlp(x_batch)
            loss_m = criterion(out_mlp, y_batch)
            loss_m.backward()
            opt_mlp.step()
            epoch_loss_m += loss_m.item()

        # Eval phase
        def eval_model(model):
            model.eval()
            total_loss, correct, total = 0.0, 0, 0
            with torch.no_grad():
                for x, y in test_loader:
                    preds = model(x)
                    loss = criterion(preds, y)
                    total_loss += loss.item()
                    correct += (preds.argmax(dim=1) == y).sum().item()
                    total += y.size(0)
            return total_loss / len(test_loader), correct / total

        # Record losses/accuracies
        train_loss_h.append(epoch_loss_h / len(train_loader))
        train_loss_m.append(epoch_loss_m / len(train_loader))

        val_loss_h, acc_h = eval_model(model_hnn)
        val_loss_m, acc_m = eval_model(model_mlp)

        test_loss_h.append(val_loss_h)
        test_loss_m.append(val_loss_m)
        test_acc_h.append(acc_h)
        test_acc_m.append(acc_m)

        train_acc_h.append(compute_accuracy(model_hnn, train_loader))
        train_acc_m.append(compute_accuracy(model_mlp, train_loader))

        # Logging
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Epoch {epoch+1}: "
                  f"HNN Loss: {train_loss_h[-1]:.4f} / {val_loss_h:.4f}, "
                  f"Acc: {train_acc_h[-1]*100:.2f}% / {acc_h*100:.2f}% || "
                  f"MLP Loss: {train_loss_m[-1]:.4f} / {val_loss_m:.4f}, "
                  f"Acc: {train_acc_m[-1]*100:.2f}% / {acc_m*100:.2f}%")

        # Early stopping
        if acc_h > best_hnn_acc or acc_m > best_mlp_acc:
            best_hnn_acc = max(best_hnn_acc, acc_h)
            best_mlp_acc = max(best_mlp_acc, acc_m)
            patience = 0
        else:
            patience += 1
            if patience >= early_stop_patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    return {
        "train_loss_h": train_loss_h, "test_loss_h": test_loss_h,
        "train_loss_m": train_loss_m, "test_loss_m": test_loss_m,
        "train_acc_h": train_acc_h, "test_acc_h": test_acc_h,
        "train_acc_m": train_acc_m, "test_acc_m": test_acc_m
    }

# --- Plot results
def plot_metrics(results):
    epochs = range(1, len(results["train_loss_h"]) + 1)

    # Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, results["train_loss_h"], label="HNN Train")
    plt.plot(epochs, results["test_loss_h"], label="HNN Test")
    plt.plot(epochs, results["train_loss_m"], label="MLP Train", linestyle="--")
    plt.plot(epochs, results["test_loss_m"], label="MLP Test", linestyle="--")
    plt.xlabel("Epoch")
    plt.ylabel("CrossEntropy Loss")
    plt.title("Loss Curves")
    plt.legend()
    plt.grid(True)
    plt.show()

    # Accuracy
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, results["train_acc_h"], label="HNN Train")
    plt.plot(epochs, results["test_acc_h"], label="HNN Test")
    plt.plot(epochs, results["train_acc_m"], label="MLP Train", linestyle="--")
    plt.plot(epochs, results["test_acc_m"], label="MLP Test", linestyle="--")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Accuracy Curves")
    plt.legend()
    plt.grid(True)
    plt.show()



## Train

In [25]:
from hnn import HomogeneousNN
from hnn_utils import initialize_weights
import torch.nn as nn
import torch.optim as optim

r_list = torch.tensor(DEGREE_INFO, dtype=torch.float32, device=device)
nu = 0.0
Gd = torch.diag(r_list + 2)
input_dim = len(BASIS)
P = torch.eye(input_dim, device=device)

# --- Define models ---
hidden_layers = 5
hidden_dim = 10
output_dim = 10

# 1) HomogeneousNN
model_hnn = HomogeneousNN(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim,
                           P=P, Gd=Gd, nu=nu, hidden_layers = hidden_layers).to(device)

model_hnn.apply(initialize_weights)

# MLP de référence pour la comparaison
class SimpleMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, hidden_layers = hidden_layers):
        super().__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.Tanh()]
        for _ in range(hidden_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.Tanh()])
        layers.append(nn.Linear(hidden_dim, output_dim, bias=False))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Instantiate MLP model
model_mlp = SimpleMLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
model_mlp.apply(initialize_weights)

# --- Loss and optimizer
criterion = nn.CrossEntropyLoss()
opt_hnn = optim.Adam(model_hnn.parameters(), lr=1e-4, weight_decay=1e-4)
opt_mlp = optim.Adam(model_mlp.parameters(), lr=1e-4, weight_decay=1e-4)


In [26]:
# --- Run training
results = train_models(num_epochs=300)
plot_metrics(results)

Training Epochs:   0%|                          | 1/300 [00:08<42:14,  8.48s/it]

Epoch 1: HNN Loss: 2.3063 / 2.3020, Acc: 9.38% / 8.96% || MLP Loss: 2.3804 / 2.3200, Acc: 14.32% / 14.56%


Training Epochs:   3%|▊                        | 10/300 [01:23<40:13,  8.32s/it]

Epoch 10: HNN Loss: 2.0319 / 2.0174, Acc: 24.96% / 23.48% || MLP Loss: 2.0504 / 2.0610, Acc: 23.09% / 22.08%


Training Epochs:   7%|█▋                       | 20/300 [02:46<38:37,  8.28s/it]

Epoch 20: HNN Loss: 1.8367 / 1.8362, Acc: 34.57% / 33.25% || MLP Loss: 1.8829 / 1.9001, Acc: 33.34% / 32.57%


Training Epochs:  10%|██▌                      | 30/300 [04:08<36:52,  8.19s/it]

Epoch 30: HNN Loss: 1.5495 / 1.5589, Acc: 45.26% / 44.48% || MLP Loss: 1.7408 / 1.7686, Acc: 38.26% / 36.70%


Training Epochs:  13%|███▎                     | 40/300 [05:32<36:03,  8.32s/it]

Epoch 40: HNN Loss: 1.4270 / 1.4497, Acc: 49.49% / 48.74% || MLP Loss: 1.6739 / 1.7061, Acc: 40.16% / 38.57%


Training Epochs:  17%|████▏                    | 50/300 [06:56<35:11,  8.44s/it]

Epoch 50: HNN Loss: 1.3598 / 1.3892, Acc: 52.33% / 51.43% || MLP Loss: 1.6477 / 1.6821, Acc: 41.85% / 40.27%


Training Epochs:  20%|█████                    | 60/300 [08:18<32:57,  8.24s/it]

Epoch 60: HNN Loss: 1.3064 / 1.3356, Acc: 54.79% / 53.86% || MLP Loss: 1.6219 / 1.6593, Acc: 43.01% / 41.27%


Training Epochs:  21%|█████▎                   | 64/300 [08:54<32:51,  8.36s/it]


KeyboardInterrupt: 