# Entraînement du Modèle Smart Recycle avec PyTorch & MPS (ResNet50)

Ce notebook permet d'entraîner un modèle de classification d'images (ResNet50) sur le dataset TrashNet.
Il est optimisé pour les Mac Apple Silicon (M1/M2/M3) en utilisant l'accélération MPS (Metal Performance Shaders).

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
import time
import os
import copy

# Vérification du device (MPS pour Mac, CUDA pour NVIDIA, CPU sinon)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Utilisation du device : {device}")

Utilisation du device : mps


## 1. Préparation des Données

In [2]:
# Transformations : Augmentation des données pour la robustesse + Normalisation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), # Zoom moins agressif
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),      # Rotation +/- 15 degrés
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Variations d'éclairage
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = '../data'  # On suppose que le notebook est dans 01_IA_LAB/notebooks/
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

print(f"Classes : {class_names}")
print(f"Images Train : {dataset_sizes['train']}")
print(f"Images Val : {dataset_sizes['val']}")

Classes : ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
Images Train : 2019
Images Val : 508


## 2. Chargement du Modèle (Transfer Learning)

In [3]:
model_ft = models.resnet50(pretrained=True)

# On gèle les poids (optionnel, selon stratégie)
# for param in model_ft.parameters():
#     param.requires_grad = False

# Remplacement de la dernière couche fully connected
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /Users/ramadane/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100.0%


## 3. Boucle d'Entraînement

In [4]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.float() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model

In [5]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=15) # Augmentation à 15 epochs

Epoch 0/14
----------
train Loss: 1.2068 Acc: 0.5632
val Loss: 0.6180 Acc: 0.8091

Epoch 1/14
----------
train Loss: 0.5627 Acc: 0.8053
val Loss: 0.4171 Acc: 0.8780

Epoch 2/14
----------
train Loss: 0.3756 Acc: 0.8757
val Loss: 0.3402 Acc: 0.8996

Epoch 3/14
----------
train Loss: 0.2548 Acc: 0.9183
val Loss: 0.3402 Acc: 0.8917

Epoch 4/14
----------
train Loss: 0.2081 Acc: 0.9297
val Loss: 0.2692 Acc: 0.9193

Epoch 5/14
----------
train Loss: 0.1614 Acc: 0.9505
val Loss: 0.2871 Acc: 0.9114

Epoch 6/14
----------
train Loss: 0.1182 Acc: 0.9673
val Loss: 0.2586 Acc: 0.9173

Epoch 7/14
----------
train Loss: 0.0990 Acc: 0.9718
val Loss: 0.2378 Acc: 0.9213

Epoch 8/14
----------
train Loss: 0.0964 Acc: 0.9723
val Loss: 0.2394 Acc: 0.9232

Epoch 9/14
----------
train Loss: 0.0751 Acc: 0.9792
val Loss: 0.2567 Acc: 0.9252

Epoch 10/14
----------
train Loss: 0.0927 Acc: 0.9752
val Loss: 0.2359 Acc: 0.9252

Epoch 11/14
----------
train Loss: 0.0723 Acc: 0.9837
val Loss: 0.2722 Acc: 0.9193

Ep

## 4. Sauvegarde du Modèle

In [6]:
import os

save_path = '../models/waste_model.pth'
torch.save(model_ft.state_dict(), save_path)
print(f"Modèle sauvegardé dans {save_path}")

Modèle sauvegardé dans ../models/waste_model.pth
