In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from torch.utils.tensorboard import SummaryWriter

In [3]:
# 1. Préparation des données
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.MNIST(root="data", train=True, download=True, transform=transform)
test_data = datasets.MNIST(root="data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

In [None]:
# 2. Modèles
class MNISTMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(),              # [batch,1,28,28] -> [batch,784]
            nn.Linear(28*28, 512),     # 784 -> 512
            nn.ReLU(),                 # activation
            nn.Dropout(0.2),           # régularisation
            nn.Linear(512, 256),       # 512 -> 256
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),       # 256 -> 128
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 10)         # 128 -> 10 (10 classes)
        )
    def forward(self, x):
        return self.model(x)            # sortie : 10 scores (logits)

class MNISTCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), # image 1x28x28 -> 32x28x28
            nn.ReLU(),
            nn.MaxPool2d(2),                # 32x28x28 -> 32x14x14
            nn.Conv2d(32, 64, 3, padding=1),# 32x14x14 -> 64x14x14
            nn.ReLU(),
            nn.MaxPool2d(2)                 # 64x14x14 -> 64x7x7
        )
        self.fc = nn.Sequential(
            nn.Flatten(),                   # 64x7x7 -> 3136
            nn.Linear(64*7*7, 128),         # 3136 -> 128
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 10)              # 128 -> 10 (10 classes)
        )
    def forward(self, x):
        x = self.conv(x)
        return self.fc(x)                   # sortie : 10 scores (logits)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [5]:
# 3. Entraînement
def train(model, loader, optimizer, loss_fn, device, epoch, prefix):
    model.train()
    running_loss = 0.0
    for batch_idx, (X, y) in enumerate(loader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 100 == 0:
            n_iter = epoch * len(loader) + batch_idx
            writer.add_scalar(f"Loss/{prefix}_train", loss.item(), n_iter)
    return running_loss / len(loader)

def test(model, loader, device, epoch, prefix):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            correct += (pred.argmax(1) == y).sum().item()
            total += y.size(0)
    acc = correct / total
    writer.add_scalar(f"Accuracy/{prefix}_test", acc, epoch)
    return acc


In [6]:
# 2.5 TensorBoard writer
writer = SummaryWriter("runs/mnist_experiment")

In [7]:
# MLP
epochs = 10
mlp = MNISTMLP().to(device)
optimizer_mlp = optim.Adam(mlp.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(epochs):
    train(mlp, train_loader, optimizer_mlp, loss_fn, device, epoch, "MLP")
    acc = test(mlp, test_loader, device, epoch, "MLP")
    print(f"MLP Epoch {epoch+1}: Test accuracy = {acc:.4f}")
    if epoch == 0 or epoch == epochs-1:
        mlp.eval()
        images, labels = next(iter(test_loader))
        images = images.to(device)
        outputs = mlp(images)
        preds = outputs.argmax(1)
        img_grid = utils.make_grid(images[:16].cpu(), nrow=4, normalize=True)
        writer.add_image('MNIST MLP Images', img_grid, epoch)
        writer.add_text('MLP Predictions', str(preds[:16].cpu().numpy()), epoch)
        writer.add_text('MLP Labels', str(labels[:16].cpu().numpy()), epoch)

MLP Epoch 1: Test accuracy = 0.9331
MLP Epoch 2: Test accuracy = 0.9611
MLP Epoch 3: Test accuracy = 0.9647
MLP Epoch 4: Test accuracy = 0.9697
MLP Epoch 5: Test accuracy = 0.9723
MLP Epoch 6: Test accuracy = 0.9776
MLP Epoch 7: Test accuracy = 0.9755
MLP Epoch 8: Test accuracy = 0.9772
MLP Epoch 9: Test accuracy = 0.9750
MLP Epoch 10: Test accuracy = 0.9780


In [11]:
# CNN
cnn = MNISTCNN().to(device)
optimizer_cnn = optim.Adam(cnn.parameters(), lr=1e-3)
for epoch in range(epochs):
    train(cnn, train_loader, optimizer_cnn, loss_fn, device, epoch, "CNN")
    acc = test(cnn, test_loader, device, epoch, "CNN")
    print(f"CNN Epoch {epoch+1}: Test accuracy = {acc:.4f}")
    if epoch == 0 or epoch == epochs-1:
        cnn.eval()
        images, labels = next(iter(test_loader))
        images = images.to(device)
        outputs = cnn(images)
        preds = outputs.argmax(1)
        img_grid = utils.make_grid(images[:16].cpu(), nrow=4, normalize=True)
        writer.add_image('MNIST CNN Images', img_grid, epoch)
        writer.add_text('CNN Predictions', str(preds[:16].cpu().numpy()), epoch)
        writer.add_text('CNN Labels', str(labels[:16].cpu().numpy()), epoch)
writer.close()

CNN Epoch 1: Test accuracy = 0.9834
CNN Epoch 2: Test accuracy = 0.9866
CNN Epoch 3: Test accuracy = 0.9888
CNN Epoch 4: Test accuracy = 0.9902
CNN Epoch 5: Test accuracy = 0.9908
CNN Epoch 6: Test accuracy = 0.9910
CNN Epoch 7: Test accuracy = 0.9907
CNN Epoch 8: Test accuracy = 0.9923
CNN Epoch 9: Test accuracy = 0.9914
CNN Epoch 10: Test accuracy = 0.9911


In [None]:
# 4. Sauvegarde des modèles
torch.save(mlp.state_dict(), "mnistmlp.pth")
torch.save(cnn.state_dict(), "mnistcnn.pth")

In [12]:
# 5. Export ONNX
mlp.eval()
cnn.eval()
dummy = torch.randn(1, 1, 28, 28, device='cpu')
torch.onnx.export(mlp, dummy, "mnistmlp.onnx", input_names=["input"], output_names=["output"], opset_version=13)
torch.onnx.export(cnn, dummy, "mnistcnn.onnx", input_names=["input"], output_names=["output"], opset_version=13)