In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score, classification_report
import numpy as np
from mlxtend.plotting import plot_confusion_matrix
from tqdm import tqdm
import os
import math
from google.colab import drive

drive.mount('/content/drive')

# Path to the dataset and model
path_test_dataset = "/content/drive/MyDrive/TR_DIMA/test_set_reduit"
path_model = "/content/drive/MyDrive/TR_DIMA/Entrainement/best_model_logit_adjustment_0_8.pth"

# Transformations for the dataset
base_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),  # Convert images to PyTorch tensors
])

test_dataset = datasets.ImageFolder(root=path_test_dataset, transform=base_transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=16, pin_memory=True)
valid_classes = datasets.ImageFolder(root=path_test_dataset, transform=base_transform).classes

state_dict = torch.load(path_model)

# Classes du dataset test
num_classes = len(valid_classes)

# Créer l'instance du modèle
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(state_dict)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

y_true = []
y_pred = []
misclassified_images = []  # <-- Liste pour stocker les erreurs
count = 0

with torch.no_grad():
    test_loader_tqdm = tqdm(test_loader, desc="Test", leave=False)
    for inputs, labels in test_loader_tqdm:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        # Obtenir les prédictions
        _, predicted = torch.max(outputs, 1)

        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

        # Stocker les erreurs
        for img, true_label, pred_label in zip(inputs, labels, predicted):
            if true_label != pred_label:
                misclassified_images.append((img.cpu(), true_label.cpu().item(), pred_label.cpu().item()))

# Générer la matrice de confusion
conf_matrix = confusion_matrix(y_true, y_pred)

# Normaliser la matrice pour obtenir des probabilités
conf_mat = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]

liste = conf_matrix.sum(axis=1)
sorted_indices = sorted(range(len(liste)), key=lambda i: liste[i])[::-1]
conf_matrix = conf_matrix[sorted_indices][:, sorted_indices]
valid_classes2 = [valid_classes[i] for i in sorted_indices]

print(classification_report(y_true, y_pred, target_names=valid_classes))

# Affichage de la matrice de confusion
fig, ax = plot_confusion_matrix(conf_mat=conf_matrix,
                                colorbar=True,
                                show_absolute=True,
                                show_normed=True,
                                class_names=valid_classes2,
                                figsize=(16,16))
plt.title("Confusion Matrix for training with two stages training post classifier")
plt.savefig("ConfusionMatrix_plancton.png")
plt.show()

# ==================================================
# === Partie pour AFFICHER les erreurs maintenant ===
# ==================================================

# Limite à N images si besoin
max_images = len(misclassified_images)  # facultatif : limite à 20 images par ex

# Dimensions de la grille
cols = 4
rows = math.ceil(max_images / cols)

fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
axes = axes.flatten()

# Affichage des images mal classées
for ax, (img, true_label, pred_label) in zip(axes, misclassified_images[:max_images]):
    ax.imshow(img.permute(1, 2, 0))  # .cpu() est déjà fait lors de l'enregistrement
    ax.set_title(f'True: {valid_classes[true_label]}\nPred: {valid_classes[pred_label]}')
    ax.axis('off')

# Supprime les axes vides s’il y en a
for i in range(len(misclassified_images[:max_images]), len(axes)):
    axes[i].axis('off')

plt.tight_layout()
plt.show()




Output hidden; open in https://colab.research.google.com to view.