In [None]:
# =============================================================================
# Diagnostic du glaucome - Classification fundus (RIM-ONE + AIROGS)
# =============================================================================
# Pré-requis :
#   - Python 3.8+
#   - torch, torchvision, albumentations, tqdm, numpy
#   - Dossier fusionné : C:\Users\MH-CONFIG\Desktop\glaucoma_fused
#     avec sous-dossiers train/val/test chacun contenant 0/ et 1/
#
# Astuces performance Windows / CPU :
#   - num_workers=0
#   - batch_size 8-16
#   - MobileNetV3-Small plus rapide qu'EfficientNet-B0
# =============================================================================

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models
from torchvision.datasets import ImageFolder
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import numpy as np

# -----------------------------------------------------------------------------
# Wrapper pour Albumentations
# -----------------------------------------------------------------------------
class AlbumentationsTransform:
    """Wrapper pour appliquer Albumentations à ImageFolder"""
    def __init__(self, aug):
        self.aug = aug

    def __call__(self, image):
        image = np.array(image)  # PIL -> numpy
        augmented = self.aug(image=image)
        return augmented["image"]  # déjà un tensor grâce à ToTensorV2

# -----------------------------------------------------------------------------
# 1. Transformations
# -----------------------------------------------------------------------------
train_transform = AlbumentationsTransform(
    A.Compose([
        A.Resize(224, 224),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=30, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.GaussNoise(p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
)

val_transform = AlbumentationsTransform(
    A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
)

# -----------------------------------------------------------------------------
# 2. Datasets et DataLoaders
# -----------------------------------------------------------------------------
data_path = r"C:\Users\MH-CONFIG\Desktop\glaucoma_fused"

train_ds = ImageFolder(os.path.join(data_path, "train"), transform=train_transform)
val_ds   = ImageFolder(os.path.join(data_path, "val"), transform=val_transform)
test_ds  = ImageFolder(os.path.join(data_path, "test"), transform=val_transform)

train_loader = DataLoader(train_ds, batch_size=12, shuffle=True, num_workers=0, drop_last=True)
val_loader   = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=0)

# -----------------------------------------------------------------------------
# 3. Modèle corrigé (MobileNetV3-Small)
# -----------------------------------------------------------------------------
class GlaucomaModel(nn.Module):
    """Classification binaire glaucome / sain"""
    def __init__(self, model_name="mobilenet_v3_small", num_classes=2):
        super().__init__()

        if model_name == "efficientnet_b0":
            self.base = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
            in_features = self.base.classifier[1].in_features
            self.base.classifier = nn.Sequential(
                nn.Dropout(p=0.5),
                nn.Linear(in_features, num_classes)
            )

        elif model_name == "mobilenet_v3_small":
            self.base = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
            in_features = self.base.classifier[0].in_features  # <-- correction clé
            self.base.classifier = nn.Sequential(
                nn.Linear(in_features, 1024),
                nn.Hardswish(),
                nn.Dropout(p=0.5),
                nn.Linear(1024, num_classes)
            )
        else:
            raise ValueError("Modèle non supporté. Choisir 'efficientnet_b0' ou 'mobilenet_v3_small'")

    def forward(self, x):
        return self.base(x)

# -----------------------------------------------------------------------------
# 4. Fonction principale (protection Windows)
# -----------------------------------------------------------------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device utilisé : {device}")

    MODEL_NAME = "mobilenet_v3_small"
    model = GlaucomaModel(model_name=MODEL_NAME).to(device)

    # Paramètres entraînables
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Modèle : {MODEL_NAME} — Paramètres entraînables : {trainable_params:,}")

    # Loss, Optimizer, Scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

    # Debug DataLoader
    print("\nTest rapide du train_loader...")
    for i, (imgs, lbls) in enumerate(train_loader):
        print(f"Batch {i+1} chargé ({imgs.shape})")
        if i >= 2: break

    # Boucle d'entraînement
    epochs = 10
    best_val_acc = 0.0
    save_path = f"best_{MODEL_NAME}_glaucoma.pth"

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        train_loss /= train_total
        train_acc = 100. * train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0.0
        correct = total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_loss /= total
        val_acc = 100. * correct / total

        print(f"Epoch {epoch+1:2d} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

        # Scheduler
        scheduler.step(val_acc)

        # Sauvegarde si meilleur modèle
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), save_path)
            print(f"→ Nouveau meilleur modèle sauvegardé : {save_path}")

    print("\nEntraînement terminé ! Meilleure Val Acc :", best_val_acc)

# -----------------------------------------------------------------------------
# Lancement sécurisé
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    main()


Device utilisé : cpu
Modèle : mobilenet_v3_small — Paramètres entraînables : 1,519,906

Test rapide du train_loader...
Batch 1 chargé (torch.Size([12, 3, 224, 224]))
Batch 2 chargé (torch.Size([12, 3, 224, 224]))
Batch 3 chargé (torch.Size([12, 3, 224, 224]))


Epoch 1/40 [Train]: 100%|████████████████████████████████████████████████████████████| 694/694 [09:13<00:00,  1.25it/s]


Epoch  1 | Train Loss: 0.5427 | Train Acc: 71.42% | Val Loss: 0.3791 | Val Acc: 82.86%
→ Nouveau meilleur modèle sauvegardé : best_mobilenet_v3_small_glaucoma.pth


Epoch 2/40 [Train]: 100%|████████████████████████████████████████████████████████████| 694/694 [07:44<00:00,  1.49it/s]


Epoch  2 | Train Loss: 0.4406 | Train Acc: 78.55% | Val Loss: 0.3576 | Val Acc: 86.36%
→ Nouveau meilleur modèle sauvegardé : best_mobilenet_v3_small_glaucoma.pth


Epoch 3/40 [Train]: 100%|████████████████████████████████████████████████████████████| 694/694 [07:29<00:00,  1.54it/s]


Epoch  3 | Train Loss: 0.4108 | Train Acc: 81.38% | Val Loss: 0.2939 | Val Acc: 88.83%
→ Nouveau meilleur modèle sauvegardé : best_mobilenet_v3_small_glaucoma.pth


Epoch 4/40 [Train]: 100%|████████████████████████████████████████████████████████████| 694/694 [07:15<00:00,  1.59it/s]


Epoch  4 | Train Loss: 0.3825 | Train Acc: 82.23% | Val Loss: 0.2623 | Val Acc: 89.87%
→ Nouveau meilleur modèle sauvegardé : best_mobilenet_v3_small_glaucoma.pth


Epoch 5/40 [Train]: 100%|████████████████████████████████████████████████████████████| 694/694 [07:21<00:00,  1.57it/s]


Epoch  5 | Train Loss: 0.3698 | Train Acc: 83.56% | Val Loss: 0.2916 | Val Acc: 87.92%


Epoch 6/40 [Train]: 100%|████████████████████████████████████████████████████████████| 694/694 [07:38<00:00,  1.52it/s]


Epoch  6 | Train Loss: 0.3552 | Train Acc: 84.25% | Val Loss: 0.2777 | Val Acc: 89.48%


Epoch 7/40 [Train]: 100%|████████████████████████████████████████████████████████████| 694/694 [07:28<00:00,  1.55it/s]


Epoch  7 | Train Loss: 0.3364 | Train Acc: 85.24% | Val Loss: 0.2488 | Val Acc: 90.00%
→ Nouveau meilleur modèle sauvegardé : best_mobilenet_v3_small_glaucoma.pth


Epoch 8/40 [Train]: 100%|████████████████████████████████████████████████████████████| 694/694 [07:51<00:00,  1.47it/s]


Epoch  8 | Train Loss: 0.3304 | Train Acc: 85.55% | Val Loss: 0.2451 | Val Acc: 90.78%
→ Nouveau meilleur modèle sauvegardé : best_mobilenet_v3_small_glaucoma.pth


Epoch 9/40 [Train]: 100%|████████████████████████████████████████████████████████████| 694/694 [07:28<00:00,  1.55it/s]


Epoch  9 | Train Loss: 0.3175 | Train Acc: 85.57% | Val Loss: 0.2311 | Val Acc: 90.78%


Epoch 10/40 [Train]: 100%|███████████████████████████████████████████████████████████| 694/694 [07:24<00:00,  1.56it/s]


Epoch 10 | Train Loss: 0.3117 | Train Acc: 86.37% | Val Loss: 0.2672 | Val Acc: 88.57%


Epoch 11/40 [Train]: 100%|███████████████████████████████████████████████████████████| 694/694 [07:49<00:00,  1.48it/s]


Epoch 11 | Train Loss: 0.3138 | Train Acc: 86.30% | Val Loss: 0.2349 | Val Acc: 90.39%


Epoch 12/40 [Train]: 100%|███████████████████████████████████████████████████████████| 694/694 [07:38<00:00,  1.51it/s]


Epoch 12 | Train Loss: 0.2944 | Train Acc: 87.12% | Val Loss: 0.2637 | Val Acc: 89.87%


Epoch 13/40 [Train]:  19%|███████████▏                                               | 132/694 [01:24<06:26,  1.45it/s]