In [None]:
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as T
import torch.onnx
import onnx

In [3]:
train_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=T.ToTensor(),
)
test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=T.ToTensor(),
)

In [4]:
# DataLoader
train_loader = DataLoader(
    dataset=train_data,
    batch_size=64,
    shuffle=True,
    num_workers=4,
)
test_loader = DataLoader(
    dataset=test_data,
    batch_size=64,
    shuffle=False,
    num_workers=4,
)

In [9]:
class ConvMLP(nn.Module):
    def __init__(self):
        super(ConvMLP, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x



In [6]:
device = torch.device('mps' if torch.mps.is_available() else 'cpu') #x10 -> 30
model = MLP().to(device)
criterion = nn.CrossEntropyLoss() #fonction d'erreur
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) #Agent qui met à jour les poids du modèle

In [7]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        ypred = model(images)  # Forward pass
        loss = criterion(ypred, labels)  # Compute loss

        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
        optimizer.zero_grad()  # Reset gradients

def test(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            ypred = model(images)  # Forward pass
            loss = criterion(ypred, labels)  # Compute loss

            total_loss += loss.item()
            preds = ypred.argmax(dim=1)
            correct += (preds == labels).sum().item()
    avg_loss = total_loss / len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.4f}")

In [10]:
for epoch in range(5):
    print(f"Epoch {epoch+1}")
    train(model, train_loader, criterion, optimizer, device)
    test(model, test_loader, criterion, device)

Epoch 1
Test Loss: 0.0826, Test Accuracy: 0.9771
Epoch 2
Test Loss: 0.0779, Test Accuracy: 0.9774
Epoch 3
Test Loss: 0.0802, Test Accuracy: 0.9757
Epoch 4
Test Loss: 0.0840, Test Accuracy: 0.9777
Epoch 5
Test Loss: 0.0974, Test Accuracy: 0.9732


In [11]:
model.eval()

MLP(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)

In [19]:
dummy_input = torch.randn(1, 1, 28, 28).to(device)
torch.onnx.export(
    model, dummy_input, "mnist.onnx",
    input_names=["input"], output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    opset_version=11
)


In [18]:
%pip install onnx

Collecting onnx
  Downloading onnx-1.18.0-cp313-cp313-macosx_12_0_universal2.whl.metadata (6.9 kB)
Downloading onnx-1.18.0-cp313-cp313-macosx_12_0_universal2.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: onnx
Successfully installed onnx-1.18.0
Note: you may need to restart the kernel to use updated packages.
