In [1]:
# --- CELLULE DE CONFIGURATION ROBUSTE ---
import os

# 1. Trouver automatiquement le nom du dossier
# Kaggle met les donn√©es dans /kaggle/input/
base_path = '/kaggle/input'
found_path = None

# On cherche un dossier qui contient 'garbage'
for root, dirs, files in os.walk(base_path):
    if 'garbage_classification' in dirs:
        # C'est souvent /kaggle/input/garbage-classification/garbage_classification
        found_path = os.path.join(root, 'garbage_classification')
        break
    elif 'Garbage classification' in dirs:
        # Parfois avec une majuscule
        found_path = os.path.join(root, 'Garbage classification')
        break

if found_path:
    print(f"‚úÖ Dataset trouv√© ici : {found_path}")
    DATA_PATH = found_path
else:
    # Si l'auto-d√©tection √©choue, on liste ce qu'il y a pour vous aider √† corriger
    print("‚ùå Dataset introuvable automatiquement. Voici ce qu'il y a dans /kaggle/input :")
    print(os.listdir(base_path))
    # Mettez le chemin manuellement ici si besoin :
    DATA_PATH = '/kaggle/input/garbage-classification/Garbage classification' 

# 2. V√©rification des classes
try:
    classes = os.listdir(DATA_PATH)
    print(f"üìÇ Classes d√©tect√©es ({len(classes)}) : {classes}")
    # Doit afficher : ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
except:
    print("Erreur : Le chemin est incorrect.")

‚úÖ Dataset trouv√© ici : /kaggle/input/garbage-classification/garbage_classification
üìÇ Classes d√©tect√©es (12) : ['metal', 'white-glass', 'biological', 'paper', 'brown-glass', 'battery', 'trash', 'cardboard', 'shoes', 'clothes', 'plastic', 'green-glass']


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split
import time
import copy

# 1. CONFIGURATION
DATA_PATH = '/kaggle/input/garbage-classification/garbage_classification' # Votre chemin confirm√©
BATCH_SIZE = 32
EPOCHS = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ Entra√Ænement lanc√© sur : {device}")

# 2. PR√âPARATION DES DONN√âES (+ Augmentation)
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15), # Un peu de rotation aide pour les d√©chets
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

full_dataset = datasets.ImageFolder(DATA_PATH, transform=train_transforms)
CLASSES = full_dataset.classes
print(f"üìã Liste des classes ({len(CLASSES)}) : {CLASSES}")

# IMPORTANT : Notez cette liste quelque part, vous en aurez besoin pour Flutter !
# Ordre probable : ['battery', 'biological', 'brown-glass', 'cardboard', 'clothes', 'green-glass', 'metal', 'paper', 'plastic', 'shoes', 'trash', 'white-glass']

# Split 80% train / 20% validation
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# On applique les transformations simples (sans rotation) pour la validation
val_dataset.dataset.transform = val_transforms 

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# 3. LE MOD√àLE (MobileNetV3 Large)
model = models.mobilenet_v3_large(weights='DEFAULT')

# Geler les couches (Feature Extractor)
for param in model.parameters():
    param.requires_grad = False

# Remplacer la t√™te (Classifier) pour 12 classes
num_ftrs = model.classifier[3].in_features
model.classifier[3] = nn.Linear(num_ftrs, len(CLASSES))

model = model.to(device)

# 4. BOUCLE D'ENTRA√éNEMENT
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("‚è≥ D√©but de l'entra√Ænement...")
since = time.time()

best_acc = 0.0

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Phase d'entra√Ænement
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = correct / total
    
    # Phase de validation (Rapide)
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    val_acc = val_correct / val_total
    
    print(f'Epoch {epoch+1}/{EPOCHS} | Train Acc: {epoch_acc:.4f} | Val Acc: {val_acc:.4f}')
    
    # On sauvegarde le meilleur mod√®le
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'best_waste_model.pth')

time_elapsed = time.time() - since
print(f'‚úÖ Entra√Ænement termin√© en {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Meilleure pr√©cision Val: {best_acc:.4f}')
print("‚¨áÔ∏è T√©l√©chargez 'best_waste_model.pth' dans la section Output !")

üöÄ Entra√Ænement lanc√© sur : cuda
üìã Liste des classes (12) : ['battery', 'biological', 'brown-glass', 'cardboard', 'clothes', 'green-glass', 'metal', 'paper', 'plastic', 'shoes', 'trash', 'white-glass']
Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-5c1a4163.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 21.1M/21.1M [00:00<00:00, 145MB/s]


‚è≥ D√©but de l'entra√Ænement...
Epoch 1/10 | Train Acc: 0.8564 | Val Acc: 0.9384
Epoch 2/10 | Train Acc: 0.9259 | Val Acc: 0.9426
Epoch 3/10 | Train Acc: 0.9376 | Val Acc: 0.9484
Epoch 4/10 | Train Acc: 0.9447 | Val Acc: 0.9504
Epoch 5/10 | Train Acc: 0.9463 | Val Acc: 0.9526
Epoch 6/10 | Train Acc: 0.9496 | Val Acc: 0.9533
Epoch 7/10 | Train Acc: 0.9550 | Val Acc: 0.9526
Epoch 8/10 | Train Acc: 0.9544 | Val Acc: 0.9526
Epoch 9/10 | Train Acc: 0.9567 | Val Acc: 0.9520
Epoch 10/10 | Train Acc: 0.9575 | Val Acc: 0.9549
‚úÖ Entra√Ænement termin√© en 8m 17s
Meilleure pr√©cision Val: 0.9549
‚¨áÔ∏è T√©l√©chargez 'best_waste_model.pth' dans la section Output !
