In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# --- Dataset MNIST ---
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_loader = DataLoader(datasets.MNIST(root='./data', train=True, download=True, transform=transform), batch_size=32, shuffle=True)

# --- Modèle cible : un MLP simple ---
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.net(x)

# --- Learned Optimizer : un petit LSTM qui apprend les deltas ---
class LearnedOptimizer(nn.Module):
    def __init__(self, hidden_size=20):
        super().__init__()
        self.lstm = nn.LSTMCell(1, hidden_size)
        self.linear = nn.Linear(hidden_size, 1)

    def forward(self, grad, state):
        grad = grad.view(-1, 1)
        hx, cx = state
        hx, cx = self.lstm(grad, (hx, cx))
        delta = self.linear(hx)
        return delta.view(-1), (hx, cx)

    def init_state(self, num_params):
        device = next(self.parameters()).device
        return (torch.zeros(num_params, self.lstm.hidden_size, device=device),
                torch.zeros(num_params, self.lstm.hidden_size, device=device))

# --- Meta-Training : entraîner le learned optimizer ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
meta_optimizer = LearnedOptimizer().to(device)
meta_optim = torch.optim.Adam(meta_optimizer.parameters(), lr=1e-3)

for meta_epoch in range(5):  # Nombre d'époques meta
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        # Réinitialiser le modèle cible
        model = MLP().to(device)
        params = [p for p in model.parameters()]
        flat_params = torch.cat([p.data.view(-1) for p in params])

        state = meta_optimizer.init_state(flat_params.numel())

        # Inner loop : optimiser le MLP avec le learned optimizer
        for inner_step in range(20):  # Nombre d'étapes d'optimisation interne
            model.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            grads = torch.autograd.grad(loss, params, create_graph=True)
            grads_flat = torch.cat([g.view(-1) for g in grads])

            # Update avec le learned optimizer
            delta, state = meta_optimizer(grads_flat.detach(), state)
            idx = 0
            for p in params:
                numel = p.numel()
                p.data = p.data - delta[idx:idx+numel].view_as(p)
                idx += numel

        # Outer loss = perte finale après 20 steps
        final_output = model(data)
        final_loss = F.cross_entropy(final_output, target)

        # Backprop à travers l'optimizer
        meta_optim.zero_grad()
        final_loss.backward()
        meta_optim.step()

        if batch_idx % 100 == 0:
            print(f"Meta Epoch {meta_epoch}, Batch {batch_idx}, Final Loss: {final_loss.item():.4f}")

print("✅ Entraînement du learned optimizer terminé.")


100%|██████████| 9.91M/9.91M [00:00<00:00, 43.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.19MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 10.6MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.4MB/s]


Meta Epoch 0, Batch 0, Final Loss: 17494.3086
Meta Epoch 0, Batch 100, Final Loss: 29273.3789
Meta Epoch 0, Batch 200, Final Loss: 35057.8945
Meta Epoch 0, Batch 300, Final Loss: 33405.4297
Meta Epoch 0, Batch 400, Final Loss: 37830.6172
Meta Epoch 0, Batch 500, Final Loss: 24990.5508
Meta Epoch 0, Batch 600, Final Loss: 20747.6191
Meta Epoch 0, Batch 700, Final Loss: 36504.0938
Meta Epoch 0, Batch 800, Final Loss: 37557.1250
Meta Epoch 0, Batch 900, Final Loss: 45898.2891
Meta Epoch 0, Batch 1000, Final Loss: 39327.5078
Meta Epoch 0, Batch 1100, Final Loss: 41892.1680
Meta Epoch 0, Batch 1200, Final Loss: 43841.2383
Meta Epoch 0, Batch 1300, Final Loss: 32619.2402
Meta Epoch 0, Batch 1400, Final Loss: 22824.2227
Meta Epoch 0, Batch 1500, Final Loss: 31394.1523
Meta Epoch 0, Batch 1600, Final Loss: 22481.7812
Meta Epoch 0, Batch 1700, Final Loss: 37762.8086
Meta Epoch 0, Batch 1800, Final Loss: 37023.3203
Meta Epoch 1, Batch 0, Final Loss: 32645.6836
Meta Epoch 1, Batch 100, Final Loss

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# --- Dataset MNIST ---
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_loader = DataLoader(datasets.MNIST(root='./data', train=True, download=True, transform=transform), batch_size=16, shuffle=True)

# --- MLP model ---
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 64),  # Reduced hidden size
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.net(x)

# --- Learned Optimizer ---
class LearnedOptimizer(nn.Module):
    def __init__(self, hidden_size=5):  # Smaller hidden size
        super().__init__()
        self.lstm = nn.LSTMCell(1, hidden_size)
        self.linear = nn.Linear(hidden_size, 1)

    def forward(self, grad, state):
        grad = grad.view(-1, 1)
        hx, cx = state
        hx, cx = self.lstm(grad, (hx, cx))
        delta = self.linear(hx)
        return delta.view(-1), (hx, cx)

    def init_state(self, num_params):
        # Store state on CPU to save GPU memory
        device = torch.device('cpu')
        return (torch.zeros(num_params, self.lstm.hidden_size, device=device),
                torch.zeros(num_params, self.lstm.hidden_size, device=device))

# --- Meta-training the optimizer ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
meta_optimizer = LearnedOptimizer().to(device)
meta_optim = torch.optim.Adam(meta_optimizer.parameters(), lr=1e-3)

for meta_epoch in range(3):  # Fewer meta-epochs
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        model = MLP().to(device)
        params = [p for p in model.parameters()]
        flat_params = torch.cat([p.data.view(-1) for p in params])

        state = meta_optimizer.init_state(flat_params.numel())

        for inner_step in range(10):  # Shorter unroll for memory
            model.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            grads = torch.autograd.grad(loss, params, create_graph=True)
            grads_flat = torch.cat([g.view(-1) for g in grads])

            # Update on CPU, send back to GPU
            delta, state = meta_optimizer(grads_flat.detach().cpu(), state)
            delta = delta.to(device)

            idx = 0
            for p in params:
                numel = p.numel()
                p.data = p.data - delta[idx:idx+numel].view_as(p)
                idx += numel

        # Outer loss = loss after unroll
        final_output = model(data)
        final_loss = F.cross_entropy(final_output, target)

        meta_optim.zero_grad()
        final_loss.backward()
        meta_optim.step()

        if batch_idx % 100 == 0:
            print(f"Meta Epoch {meta_epoch}, Batch {batch_idx}, Final Loss: {final_loss.item():.4f}")

print("✅ Meta-training of learned optimizer complete.")

# --- Evaluation: MNIST Accuracy ---
def evaluate_accuracy(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == target).sum().item()
            total += target.size(0)
    return correct / total

test_loader_mnist = DataLoader(datasets.MNIST(root='./data', train=False, transform=transform), batch_size=128, shuffle=False)

# Evaluate the trained MLP with the learned optimizer (on MNIST)
test_model = MLP().to(device)
params = [p for p in test_model.parameters()]
flat_params = torch.cat([p.data.view(-1) for p in params])
state = meta_optimizer.init_state(flat_params.numel())

# Run a few optimization steps on MNIST test data
for batch_idx, (data, target) in enumerate(test_loader_mnist):
    data, target = data.to(device), target.to(device)

    for inner_step in range(10):
        test_model.zero_grad()
        output = test_model(data)
        loss = F.cross_entropy(output, target)
        grads = torch.autograd.grad(loss, params, create_graph=False)
        grads_flat = torch.cat([g.view(-1) for g in grads])

        delta, state = meta_optimizer(grads_flat.detach().cpu(), state)
        delta = delta.to(device)

        idx = 0
        for p in params:
            numel = p.numel()
            p.data = p.data - delta[idx:idx+numel].view_as(p)
            idx += numel

    if batch_idx > 20:  # Just test a few batches
        break

accuracy_mnist = evaluate_accuracy(test_model, test_loader_mnist)
print(f"✅ Accuracy on MNIST with learned optimizer: {accuracy_mnist * 100:.2f}%")


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [6]:
# === Évaluation sur MNIST ===

# DataLoader MNIST Test Set
test_loader = DataLoader(datasets.MNIST(root='./data', train=False, transform=transform), batch_size=128, shuffle=False)

# Nouveau modèle à optimiser
test_model = MLP().to(device)
params = [p for p in test_model.parameters()]
flat_params = torch.cat([p.data.view(-1) for p in params])
state = meta_optimizer.init_state(flat_params.numel())

# Inner loop d'optimisation sur le test set
for batch_idx, (data, target) in enumerate(test_loader):
    data, target = data.to(device), target.to(device)

    for inner_step in range(20):  # Nombre d'étapes d'optimisation avec le learned optimizer
        test_model.zero_grad()
        output = test_model(data)
        loss = F.cross_entropy(output, target)
        grads = torch.autograd.grad(loss, params, create_graph=False)
        grads_flat = torch.cat([g.view(-1) for g in grads])

        delta, state = meta_optimizer(grads_flat.detach(), state)
        idx = 0
        for p in params:
            numel = p.numel()
            p.data = p.data - delta[idx:idx+numel].view_as(p)
            idx += numel

    if batch_idx > 30:  # Évaluer sur seulement quelques batches pour éviter d'exploser la RAM
        break

# Évaluer l'accuracy finale
def evaluate_accuracy(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == target).sum().item()
            total += target.size(0)
    return correct / total

accuracy = evaluate_accuracy(test_model, test_loader)
print(f"✅ Accuracy du modèle optimisé par le learned optimizer: {accuracy * 100:.2f}%")


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)