# Training — Digits classification with PyTorch (MLP)

## Goal:
Train a simple neural network to classify digits (0–9) from `sklearn.datasets.load_digits()`.

We will:
1. Load the dataset
2. Split into train/test sets (stratified)
3. Normalize pixel values (between 0 , 1)
4. Convert data to PyTorch tensors
5. Create DataLoaders (mini-batches)
6. Define a simple MLP model
7. Train the model with a training loop
8. Evaluate accuracy on the test set
9. Plot training curves (loss over epochs)

## Required libraries

Let's start by importing the libraries we need for:
- data handling
- model training

In [4]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

Define a seed for the random state for have the same data, and control on the device

In [6]:
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)

set_seed(42)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Device:", device)


Device: cpu
