In [None]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
from google.colab import drive
from torch.utils.data import random_split
from torchvision import models
from torch import nn
from torch.optim import Adam
import numpy as np
from sklearn.model_selection import KFold


drive.mount('/content/drive')


Mounted at /content/drive


In [None]:

transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
])


# Créer le dataset avec ImageFolder
train_dataset = datasets.ImageFolder(root="/content/drive/MyDrive/TR_DIMA/training_set_reduit", transform=transform)

image_paths = np.array([s[0] for s in train_dataset.samples])
labels = np.array(train_dataset.targets)

print(train_dataset.classes)


valid_classes = train_dataset.classes


['Brachionus', 'Ceratiums', 'Closterium', 'Cyclopoides', 'Daphnies', 'Daphnies bb', 'Hexarthra', 'Keratella cochlearis', 'Keratella cochlearis vides', 'Nauplius de copepodes', 'Polyarthra', 'Pompholyx', 'X']


In [None]:
from tqdm import tqdm

# Paramètres d'entraînement
num_epochs = 5
num_folds = 5
num_classes = len(valid_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)
error_paths = []

for fold, (train_idx, test_idx) in enumerate(kf.split(image_paths)):
    print(f"Fold {fold+1}/{num_folds}")

    train_subset = Subset(train_dataset, train_idx)
    test_subset = Subset(train_dataset, test_idx)

    train_loader = DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
    test_loader = DataLoader(test_subset, batch_size=1, shuffle=False, num_workers=16, pin_memory=True)

    model = models.resnet50(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=1e-5, weight_decay=1e-4)

    # Entraînement du modèle
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

        for images, labels in train_loader_tqdm:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Statistiques
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Mise à jour de tqdm
            train_loader_tqdm.set_postfix(loss=loss.item(), acc=100 * correct / total)

        # Calcul des métriques finales
        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = 100 * correct / total
        print(f"\n[Epoch {epoch+1}/{num_epochs}] Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_accuracy:.2f}%")

    # Évaluation et enregistrement des erreurs
    model.eval()
    with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        predicted = torch.argmax(outputs, dim=1)

        for j in range(images.size(0)):  # taille du batch
            if predicted[j].item() != labels[j].item() and labels[j].item() == 12:
                error_paths.append(image_paths[test_idx[i * test_loader.batch_size + j]])


    print(f"Images suspectes dans le fold {fold+1}: {len(error_paths)}")

# Sauvegarde des erreurs
txt_path = "/content/drive/MyDrive/TR_DIMA/erreurs_5_folds.txt"
with open(txt_path, "w") as f:
    for path in error_paths:
        f.write(path + "\n")

print(f"Liste des images suspectes enregistrée dans {txt_path}")


Fold 1/5





[Epoch 1/5] Train Loss: 0.6888 | Train Acc: 83.89%





[Epoch 2/5] Train Loss: 0.1831 | Train Acc: 95.09%





[Epoch 3/5] Train Loss: 0.1126 | Train Acc: 96.94%





[Epoch 4/5] Train Loss: 0.0738 | Train Acc: 98.07%


                                                                                   


[Epoch 5/5] Train Loss: 0.0468 | Train Acc: 98.89%




Images mal classées dans le fold 1: 462
Fold 2/5





[Epoch 1/5] Train Loss: 0.7099 | Train Acc: 82.95%





[Epoch 2/5] Train Loss: 0.1895 | Train Acc: 95.00%





[Epoch 3/5] Train Loss: 0.1165 | Train Acc: 96.97%





[Epoch 4/5] Train Loss: 0.0778 | Train Acc: 98.00%


                                                                                   


[Epoch 5/5] Train Loss: 0.0499 | Train Acc: 98.85%




Images mal classées dans le fold 2: 914
Fold 3/5





[Epoch 1/5] Train Loss: 0.6627 | Train Acc: 85.79%





[Epoch 2/5] Train Loss: 0.1724 | Train Acc: 95.31%





[Epoch 3/5] Train Loss: 0.1084 | Train Acc: 97.03%





[Epoch 4/5] Train Loss: 0.0712 | Train Acc: 98.11%


                                                                                   


[Epoch 5/5] Train Loss: 0.0476 | Train Acc: 98.83%




Images mal classées dans le fold 3: 1388
Fold 4/5





[Epoch 1/5] Train Loss: 0.6851 | Train Acc: 83.93%





[Epoch 2/5] Train Loss: 0.1836 | Train Acc: 95.06%





[Epoch 3/5] Train Loss: 0.1168 | Train Acc: 96.84%





[Epoch 4/5] Train Loss: 0.0778 | Train Acc: 97.95%


                                                                                   


[Epoch 5/5] Train Loss: 0.0506 | Train Acc: 98.82%




Images mal classées dans le fold 4: 1835
Fold 5/5





[Epoch 1/5] Train Loss: 0.6482 | Train Acc: 85.03%





[Epoch 2/5] Train Loss: 0.1798 | Train Acc: 95.21%





[Epoch 3/5] Train Loss: 0.1135 | Train Acc: 96.81%





[Epoch 4/5] Train Loss: 0.0732 | Train Acc: 98.04%


                                                                                   


[Epoch 5/5] Train Loss: 0.0490 | Train Acc: 98.81%




Images mal classées dans le fold 5: 2291
Liste des images mal classées enregistrée dans /content/drive/MyDrive/TR_DIMA/erreurs_5_folds.txt
