In [6]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch.onnx

# Dataset
training_data = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.MNIST(root="data", train=False, download=True, transform=ToTensor())

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)

# Modèle simple
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = SimpleNN()

# Entraînement rapide (1 epoch)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            print(f"Loss: {loss.item():.4f}")

train(train_dataloader, model, loss_fn, optimizer)

# Export ONNX
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(
    model, dummy_input,
    "mnist_model.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=11
)
print("Modèle exporté en mnist_model.onnx")


Loss: 2.3159
Loss: 2.1214
Loss: 1.8669
Loss: 1.4860
Loss: 1.1755
Loss: 0.9900
Loss: 0.7872
Loss: 0.7310
Loss: 0.5367
Loss: 0.4734
Modèle exporté en mnist_model.onnx


In [7]:
import base64
import io
from torchvision import datasets
from torchvision.transforms import ToTensor, ToPILImage

# Charger dataset test MNIST
test_data = datasets.MNIST(root="data", train=False, download=True, transform=ToTensor())

# Récupérer une image et label
img_tensor, label = test_data[0]  # image 0
print("Label réel:", label)

# Convertir en image PIL
img_pil = ToPILImage()(img_tensor)

# Sauvegarder en PNG dans un buffer
buffer = io.BytesIO()
img_pil.save(buffer, format="PNG")
img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')

# Générer code HTML img base64 à copier
print(f'<img id="mnist-img" src="data:image/png;base64,{img_base64}" alt="MNIST test image"/>')


Label réel: 7

