In [5]:
import os
import torch
from typing import List, Callable
from torchvision.datasets import MNIST
from torchvision import transforms
import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationMode
from torch.utils.data import TensorDataset, DataLoader
from math import factorial
from itertools import product
from tqdm import trange

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Generate basis functions b_i(x, y) = x^p y^q / r! with p+q=r <= MAX_DEGREE
def generate_homogeneous_basis(max_degree: int):
    basis = []
    degrees = []  # (r = p + q)
    for r in range(max_degree + 1):
        for p in range(r + 1):
            q = r - p
            def monomial(p=p, q=q, r=r):
                return lambda X, Y: (X**p * Y**q) / factorial(r)
            basis.append(monomial())
            degrees.append(r)
    return basis, degrees

# Compute basis vector from image phi
def compute_basis_vector(phi: torch.Tensor) -> torch.Tensor:
    width, height = phi.shape
    k = torch.arange(width, device=phi.device) - width // 2
    j = torch.arange(height, device=phi.device) - height // 2
    X, Y = torch.meshgrid(k, j, indexing="ij")
    return torch.stack([(b(X, Y) * phi).sum() for b in BASIS])

# Precompute and cache dataset
def get_or_create_augmented_tensor_dataset(train: bool = True, limit = None, overwrite_cache: bool = False) -> TensorDataset:
    split_name = "train" if train else "test"
    limit_str = f"{limit}" if limit is not None else "all"
    CACHE_DIR = f"./data"
    os.makedirs(CACHE_DIR, exist_ok=True)
    cache_path = os.path.join(CACHE_DIR, f"scaled_mnist_limit{limit_str}_{split_name}_BaseDegree{MAX_DEGREE}.pt")

    if os.path.exists(cache_path) and not overwrite_cache:
        print(f"Loading cached dataset from: {cache_path}")
        dataset = torch.load(cache_path, weights_only=False)
        return TensorDataset(dataset.tensors[0].to(device), dataset.tensors[1].to(device))

    print(f"Creating augmented dataset for MNIST ({split_name})...")
    mnist_ds = MNIST(root="./data", train=train, download=True, transform=transforms.ToTensor())

    X_list, y_list = [], []

    num_samples = len(mnist_ds) if limit is None else min(limit, len(mnist_ds))
    total_iterations = num_samples * len(SCALES) * len(ANGLES)

    with trange(num_samples, desc=f"Augmenting {split_name}") as t:
        for idx in t:

            tensor_img, label = mnist_ds[idx]

            for scale in SCALES:
                for angle in ANGLES:
                    transformed = F.affine(
                        tensor_img,
                        angle=angle,
                        translate=(0, 0),
                        scale=scale,
                        shear=(0.0, 0.0),
                        interpolation=InterpolationMode.NEAREST,
                        fill=0.0,
                        center=None,
                    )
                    phi = transformed.squeeze(0).to(device)  # (28, 28)
                    vec = compute_basis_vector(phi)
                    X_list.append(vec)
                    y_list.append(label)

    X_tensor = torch.stack(X_list)
    y_tensor = torch.tensor(y_list, dtype=torch.long, device=device)

    dataset = TensorDataset(X_tensor, y_tensor)
    torch.save(dataset, cache_path)
    print(f"Saved dataset to: {cache_path}")

    return dataset


# Create data loaders
def get_data_loaders(batch_size: int = 512, limit = None, overwrite_cache: bool = False):
    train_ds = get_or_create_augmented_tensor_dataset(train=True, limit=limit, overwrite_cache=overwrite_cache)
    test_ds = get_or_create_augmented_tensor_dataset(train=False, limit=limit, overwrite_cache=overwrite_cache)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size)

    return train_loader, test_loader


cuda


In [6]:
# Configuration
SCALES: List[float] = [0.3 + 0.1 * i for i in range(13)]  # 0.3 to 1.5
ANGLES: List[int] = [0]
MAX_DEGREE = 8 # Degree of the basis functions

BASIS, DEGREE_INFO = generate_homogeneous_basis(MAX_DEGREE)

print(len(BASIS))

45


In [None]:
train_loader, test_loader = get_data_loaders(batch_size=4096, limit=1000)

Creating augmented dataset for MNIST (train)...


Augmenting train: 100%|██████████| 1000/1000 [00:35<00:00, 28.45it/s]


Saved dataset to: ./data/scaled_mnist_limit1000_train_BaseDegree8.pt
Creating augmented dataset for MNIST (test)...


Augmenting test:  60%|██████    | 603/1000 [00:21<00:13, 28.49it/s]

## Train util

In [10]:
import torch
from tqdm import trange
import matplotlib.pyplot as plt

def compute_accuracy(model, loader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            outputs = model(x)
            pred_labels = torch.argmax(outputs, dim=1)
            correct += (pred_labels == y).sum().item()
            total += y.size(0)
    return correct / total

# --- Training loop
def train_models(num_epochs=1000, early_stop_patience=200):
    best_hnn_acc, best_mlp_acc = 0.0, 0.0
    patience = 0

    train_loss_h, test_loss_h = [], []
    train_loss_m, test_loss_m = [], []
    train_acc_h, test_acc_h = [], []
    train_acc_m, test_acc_m = [], []

    for epoch in trange(num_epochs, desc="Training Epochs"):
        model_hnn.train()
        model_mlp.train()
        epoch_loss_h, epoch_loss_m = 0.0, 0.0

        for x_batch, y_batch in train_loader:
            # HNN
            opt_hnn.zero_grad()
            out_hnn = model_hnn(x_batch)
            loss_h = criterion(out_hnn, y_batch)
            loss_h.backward()
            opt_hnn.step()
            epoch_loss_h += loss_h.item()

            # MLP
            opt_mlp.zero_grad()
            out_mlp = model_mlp(x_batch)
            loss_m = criterion(out_mlp, y_batch)
            loss_m.backward()
            opt_mlp.step()
            epoch_loss_m += loss_m.item()

        # Eval phase
        def eval_model(model):
            model.eval()
            total_loss, correct, total = 0.0, 0, 0
            with torch.no_grad():
                for x, y in test_loader:
                    preds = model(x)
                    loss = criterion(preds, y)
                    total_loss += loss.item()
                    correct += (preds.argmax(dim=1) == y).sum().item()
                    total += y.size(0)
            return total_loss / len(test_loader), correct / total

        # Record losses/accuracies
        train_loss_h.append(epoch_loss_h / len(train_loader))
        train_loss_m.append(epoch_loss_m / len(train_loader))

        val_loss_h, acc_h = eval_model(model_hnn)
        val_loss_m, acc_m = eval_model(model_mlp)

        test_loss_h.append(val_loss_h)
        test_loss_m.append(val_loss_m)
        test_acc_h.append(acc_h)
        test_acc_m.append(acc_m)

        train_acc_h.append(compute_accuracy(model_hnn, train_loader))
        train_acc_m.append(compute_accuracy(model_mlp, train_loader))

        # Logging
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Epoch {epoch+1}: "
                  f"HNN Loss: {train_loss_h[-1]:.4f} / {val_loss_h:.4f}, "
                  f"Acc: {train_acc_h[-1]*100:.2f}% / {acc_h*100:.2f}% || "
                  f"MLP Loss: {train_loss_m[-1]:.4f} / {val_loss_m:.4f}, "
                  f"Acc: {train_acc_m[-1]*100:.2f}% / {acc_m*100:.2f}%")

        # Early stopping
        if acc_h > best_hnn_acc or acc_m > best_mlp_acc:
            best_hnn_acc = max(best_hnn_acc, acc_h)
            best_mlp_acc = max(best_mlp_acc, acc_m)
            patience = 0
        else:
            patience += 1
            if patience >= early_stop_patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    return {
        "train_loss_h": train_loss_h, "test_loss_h": test_loss_h,
        "train_loss_m": train_loss_m, "test_loss_m": test_loss_m,
        "train_acc_h": train_acc_h, "test_acc_h": test_acc_h,
        "train_acc_m": train_acc_m, "test_acc_m": test_acc_m
    }

# --- Plot results
def plot_metrics(results):
    epochs = range(1, len(results["train_loss_h"]) + 1)

    # Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, results["train_loss_h"], label="HNN Train")
    plt.plot(epochs, results["test_loss_h"], label="HNN Test")
    plt.plot(epochs, results["train_loss_m"], label="MLP Train", linestyle="--")
    plt.plot(epochs, results["test_loss_m"], label="MLP Test", linestyle="--")
    plt.xlabel("Epoch")
    plt.ylabel("CrossEntropy Loss")
    plt.title("Loss Curves")
    plt.legend()
    plt.grid(True)
    plt.show()

    # Accuracy
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, results["train_acc_h"], label="HNN Train")
    plt.plot(epochs, results["test_acc_h"], label="HNN Test")
    plt.plot(epochs, results["train_acc_m"], label="MLP Train", linestyle="--")
    plt.plot(epochs, results["test_acc_m"], label="MLP Test", linestyle="--")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Accuracy Curves")
    plt.legend()
    plt.grid(True)
    plt.show()



## Train

In [15]:
from hnn import HomogeneousNN
from hnn_utils import initialize_weights
import torch.nn as nn
import torch.optim as optim

r_list = torch.tensor(DEGREE_INFO, dtype=torch.float32, device=device)
nu = 0.0
Gd = torch.diag(r_list + 2)
input_dim = len(BASIS)
P = torch.eye(input_dim, device=device)

# --- Define models ---
hidden_layers = 5
hidden_dim = 10
output_dim = 10

# 1) HomogeneousNN
model_hnn = HomogeneousNN(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim,
                           P=P, Gd=Gd, nu=nu, hidden_layers = hidden_layers).to(device)

model_hnn.apply(initialize_weights)

# MLP de référence pour la comparaison
class SimpleMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, hidden_layers = hidden_layers):
        super().__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.Tanh()]
        for _ in range(hidden_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.Tanh()])
        layers.append(nn.Linear(hidden_dim, output_dim, bias=False))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Instantiate MLP model
model_mlp = SimpleMLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
model_mlp.apply(initialize_weights)

# --- Loss and optimizer
criterion = nn.CrossEntropyLoss()
opt_hnn = optim.Adam(model_hnn.parameters(), lr=1e-4, weight_decay=1e-4)
opt_mlp = optim.Adam(model_mlp.parameters(), lr=1e-4, weight_decay=1e-4)


In [16]:
# --- Run training
results = train_models(num_epochs=1000)
plot_metrics(results)

Training Epochs:   0%|          | 1/1000 [00:08<2:29:24,  8.97s/it]

Epoch 1: HNN Loss: 2.3165 / 2.3142, Acc: 9.82% / 9.99% || MLP Loss: 2.4151 / 2.4161, Acc: 10.85% / 10.47%


Training Epochs:   1%|          | 10/1000 [01:29<2:28:03,  8.97s/it]

Epoch 10: HNN Loss: 2.3004 / 2.3020, Acc: 10.36% / 9.79% || MLP Loss: 2.2682 / 2.2684, Acc: 17.87% / 16.73%


Training Epochs:   2%|▏         | 20/1000 [02:59<2:26:39,  8.98s/it]

Epoch 20: HNN Loss: 2.2896 / 2.2905, Acc: 14.07% / 13.81% || MLP Loss: 2.1626 / 2.1627, Acc: 20.55% / 19.27%


Training Epochs:   3%|▎         | 30/1000 [04:29<2:25:08,  8.98s/it]

Epoch 30: HNN Loss: 2.2383 / 2.2328, Acc: 22.43% / 21.76% || MLP Loss: 2.0845 / 2.0877, Acc: 24.42% / 23.16%


Training Epochs:   4%|▍         | 40/1000 [05:59<2:24:05,  9.01s/it]

Epoch 40: HNN Loss: 2.1272 / 2.1153, Acc: 21.78% / 21.13% || MLP Loss: 2.0216 / 2.0271, Acc: 26.50% / 25.20%


Training Epochs:   5%|▌         | 50/1000 [07:29<2:22:14,  8.98s/it]

Epoch 50: HNN Loss: 2.0792 / 2.0668, Acc: 22.84% / 22.01% || MLP Loss: 1.9683 / 1.9753, Acc: 27.73% / 27.30%


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f0574b146d0>>
Traceback (most recent call last):
  File "/home/elanzera/HNN/venv/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
Training Epochs:   5%|▌         | 53/1000 [08:00<2:23:07,  9.07s/it]


KeyboardInterrupt: 