# Importations

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import os
import pathlib
import math
import cv2
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, array_to_img

#Augmentation des images
import random
from tensorflow.keras.layers import RandomZoom, RandomRotation, RandomContrast, Rescaling, Resizing, RandomBrightness

#Importation pour le modèle ResNet50V2 et l'encodage des labels
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.applications.resnet_v2 import preprocess_input
from tensorflow.keras.utils import to_categorical, load_img

# Importations pour la construiction du modèle
from collections import defaultdict
from tensorflow.keras.models import Model
from tensorflow.keras.layers import RandomGrayscale
from tensorflow.keras.layers import Lambda
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import ReLU
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dropout

# Importations pour évaluation des performances
from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score, recall_score
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

# Importation de l'utilitaire d'image de keras.utils
from tensorflow.keras.utils import image_dataset_from_directory


# Mobile Net

In [None]:
import os
import numpy as np
import pickle
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.mobilenet import preprocess_input
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.metrics import classification_report, accuracy_score, recall_score, confusion_matrix
# === Paramètres ===
base_dir = r"C:\Users\romai\OneDrive\Tiedostot\Projet COVID\DATASET_V2V2"
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')
test_dir = os.path.join(base_dir, 'test')
batch_size = 16
img_size = (224, 224)
epochs = 10
# === Générateurs d'images ===
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=15,
    zoom_range=0.1,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    brightness_range=[0.8, 1.2]
)
val_test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True,
    seed=42
)
val_generator = val_test_datagen.flow_from_directory(
    val_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False
)
test_generator = val_test_datagen.flow_from_directory(
    test_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False
)
print("Classes détectées :", train_generator.class_indices)
# === Poids de classes ===
class_weights = {
    0: 3.0,  # COVID
    1: 1.0,
    2: 1.0,
    3: 1.0
}
# === Construction du modèle ===
base_model = MobileNet(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = True
for layer in base_model.layers[:-30]:
    layer.trainable = False
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
predictions = Dense(train_generator.num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)
checkpoint_cb = ModelCheckpoint(
    filepath="mobilenet_3111.keras",
    monitor="val_loss",
    save_best_only=True,
)
callbacks = [
    EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
    checkpoint_cb
]
history = model.fit(
    train_generator,
    epochs=epochs,
    validation_data=val_generator,
    callbacks=callbacks,
    class_weight=class_weights
)
# === Évaluation
best_model = load_model("mobilenet_3111.keras")
test_generator.reset()
test_pred_probs = best_model.predict(test_generator)
test_preds = np.argmax(test_pred_probs, axis=1)
y_test_true = test_generator.classes
class_names = list(test_generator.class_indices.keys())
# === Résultats
test_acc = accuracy_score(y_test_true, test_preds)
test_recall = recall_score(y_test_true, test_preds, average='macro')
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Recall (macro): {test_recall:.4f}")
print("Test Classification Report:")
print(classification_report(y_test_true, test_preds, target_names=class_names))
cm = confusion_matrix(y_test_true, test_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_names, yticklabels=class_names, cmap="Greens")
plt.title("Matrice de confusion - Test")
plt.xlabel("Prédit")
plt.ylabel("Réel")
plt.tight_layout()
plt.show()

# EfficientNet

In [None]:
# Chemins des dossiers

DATA_DIR = r'C:\Users\Morvan\Documents\DATA\projet_covid\COVID-PROJET\DATASETS\DATASET V2V2  - équilibré -augmenté_split'
train_dir = DATA_DIR + '/train'
val_dir = DATA_DIR + '/val'
test_dir = DATA_DIR + '/test'

# Paramètres
batch_size = 32
img_height = 224
img_width = 224

# Chargement des datasets
train_ds = image_dataset_from_directory(
    train_dir,
    labels='inferred',
    label_mode='categorical',
    color_mode='grayscale',
    image_size=(img_height, img_width),
    batch_size=batch_size
)

val_ds = image_dataset_from_directory(
    val_dir,
    labels='inferred',
    label_mode='categorical',
    color_mode='grayscale',
    image_size=(img_height, img_width),
    batch_size=batch_size
)

test_ds = image_dataset_from_directory(
    test_dir,
    labels='inferred',
    label_mode='categorical',
    color_mode='grayscale',
    image_size=(img_height, img_width),
    batch_size=batch_size,
    shuffle=False
)
classes_names = train_ds.class_names
num_classes = len(classes_names)

# Préparation des datasets (préchargement et préfetch)
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Conversion N&B -> 3 canaux pour EfficientNet (qui attend des images en 3 canaux)
def convert_grayscale_to_rgb(image, label):
    image = tf.image.grayscale_to_rgb(image)
    return image, label

train_ds = train_ds.map(convert_grayscale_to_rgb)
val_ds = val_ds.map(convert_grayscale_to_rgb)
test_ds = test_ds.map(convert_grayscale_to_rgb)

# Construction du modèle EfficientNet
base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(img_height, img_width, 3))
base_model.trainable = False  # on freeze le modèle pour commencer

for layer in base_model.layers:
    layer.trainable = False  # on freeze les couches du modèle de base

for layer in base_model.layers[-20:]:    
    layer.trainable = True  # on défreeze les 3 dernières couches

# Ajout de la tête de classification
x = base_model.output
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(num_classes, activation='softmax')(x)
model = models.Model(inputs=base_model.input, outputs=x)

# Compilation du modèle
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Résumé du modèle
model_summary = model.summary()
print(model_summary)

# Poids de classes (à adapter selon ton jeu de données ou préférences)
# Exemple : favoriser la détection de la classe 0 et sous-pondérer la 3
class_weights = {
    0: 3.0,     # COVID
    1: 1.0,     #Lung_Opacity
    2: 1.0,     # Normal
    3: 1.0      # Viral Pneumonia
}

timestr = time.strftime("%Y%m%d-%H%M%S")
weight_code = ''
for value in class_weights.values():
    weight_code += str(int(value))  #mise en forme x111 pour l'enregistrement

# Entraînement
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=20,
    class_weight=class_weights
)

model_name = f'efficientnet_{timestr}_{weight_code}'
model_path=f'models/{model_name}'
os.makedirs(model_path, exist_ok=True)


# Affichage des courbes de performance
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(acc))

plt.figure(figsize=(12, 5))

# Courbe Accuracy
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Train Accuracy')
plt.plot(epochs_range, val_acc, label='Val Accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

# Courbe Loss
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Train Loss')
plt.plot(epochs_range, val_loss, label='Val Loss')
plt.legend(loc='upper right')
plt.title('Loss')

plt.savefig(f"{model_path}/training_history.png")
plt.show()

# ResNet

In [None]:
# Chemin d'accès des datasets
TRAINING_DIR = r"C:\Users\ambro\Desktop\Datasets\DATASET V2V2  - balanced -augmented_split\train"
VALIDATION_DIR = r"C:\Users\ambro\Desktop\Datasets\DATASET V2V2  - balanced -augmented_split\val"
TEST_DIR = r"C:\Users\ambro\Desktop\Datasets\DATASET V2V2  - balanced -augmented_split\test"

## Création utilitaire d'images _ training
train_ds = image_dataset_from_directory(TRAINING_DIR, # Chemin vers le répertoire contenant les images
                                        seed = 42, # Germe aléatoire
                                        batch_size = 64, # Taille des lots d'itération
                                        image_size = (224,224),
                                        labels='inferred',
                                        label_mode='categorical',
                                        color_mode='grayscale'
                                       )


## Création utilitaire d'images _ validation
val_ds = image_dataset_from_directory(VALIDATION_DIR, # Chemin vers le répertoire contenant les images
                                        seed = 42, # Germe aléatoire
                                        batch_size = 64, # Taille des lots d'itération
                                        image_size = (224,224),
                                        labels='inferred',
                                        label_mode='categorical',
                                        color_mode='grayscale'
                                       )

## Création utilitaire d'images _ test
test_ds = image_dataset_from_directory(TEST_DIR, # Chemin vers le répertoire contenant les images
                                        seed = 42, # Germe aléatoire
                                        batch_size = 64, # Taille des lots d'itération
                                        image_size = (224,224),
                                        labels='inferred',
                                        label_mode='categorical',
                                        color_mode='grayscale'
                                       )

class_names = train_ds.class_names
num_classes = len(class_names)

# Préchargement et préfetch
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Conversion de "Grayscale" en 3 canaux pour ResNet50_V2
def convert_grayscale_to_rgb(image, label):
    image = tf.image.grayscale_to_rgb(image)
    return image, label

train_ds = train_ds.map(convert_grayscale_to_rgb)
val_ds = val_ds.map(convert_grayscale_to_rgb)
test_ds = test_ds.map(convert_grayscale_to_rgb)


## Instanciation des callbacks

# a) Arrêt prématuré
from tensorflow.keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(monitor = 'val_loss',       # Métrique surveillée
                               min_delta = 0.01,           # Changement minimum de la métrique surveillée
                               patience = 5,               # Nombre d'epochs sans amélioration pour arrêt
                               mode = 'min',               # L'entraînement s'arrête quand 'val_loss' cesse de décroître
                               verbose = 1,                # Affichage de l'epoch d'arrêt
                               restore_best_weights = True # Restauration des meilleurs poids (après arrêt)
                              )


## b) Learning_rate réduit
from tensorflow.keras.callbacks import ReduceLROnPlateau
reduce_learning_rate = ReduceLROnPlateau(monitor = 'val_loss',       # Métrique surveillée
                                         min_delta = 0.01,           # Changement minimum de la métrique surveillée
                                         patience = 3,               # Nombre d'epochs sans amélioration pour arrêt
                                         factor = 0.1,               # Learning_rate divisé par 10 (multiplié par 0.1)
                                         cooldown = 4,               # Pause de 4 epochs entre 2 cycles
                                         verbose = 1,                # Affichage de l'epoch d'arrêt
                                         restore_best_weights = True # Restauration des meilleurs poids (après arrêt)
                                        )


## c) Sauvegarde auto du modèle durant l'entraînement
from tensorflow.keras.callbacks import ModelCheckpoint
save0 = ModelCheckpoint("resnet_model_2V2_1111_20epochs.keras",
                       save_best_only=True, # Only saves when the model is considered the "best"
                       monitor='val_loss', # The metric name to monitor
                       mode='min' # Overwrite elder model when 'val_loss' is min
                      )
save1 = ModelCheckpoint("resnet_model_2V2_1111_20epochs.h5",
                       save_best_only=True, # Only saves when the model is considered the "best"
                       monitor='val_loss', # The metric name to monitor
                       mode='min' # Overwrite elder model when 'val_loss' is min
                      )


## Structure générale du modèle

# Charger le modèle ResNet50V2 sans la partie supérieure
base_model = ResNet50V2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# On "freeze" le modèle de base
base_model.trainable = False

# "Freezer" les couches du modèle de base
for layer in base_model.layers:
    layer.trainable = False

# "Défreezer" les couches du modèle de base
for layer in base_model.layers[-15:]:
    layer.trainable = True

# Construction du modèle (partie top)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.2)(x)
x = Flatten()(x)
x = Dense(1024, activation = 'relu')(x)
x = Dropout(rate = 0.2)(x) # Couche de Dropout pour éviter l'overfitting
x = Dense(512, activation = 'relu')(x)
x = Dropout(rate = 0.2)(x) # Couche de Dropout pour éviter l'overfitting
outputs = Dense(num_classes, activation='softmax')(x) # 4 classes de radiographies

# Instanciation du modèle
model = Model(inputs = base_model.input, outputs = outputs)



## Compilation du modèle

# Création fct de perte
opt = tf.keras.optimizers.Adam(learning_rate= 0.001)

# Compilation
resnet_model.compile(loss = 'categorical_crossentropy',
                     optimizer = opt,
                     metrics = ['accuracy']
                     )



# Résumé du modèle
model_summary = model.summary()
print(model_summary)

# Poids de classes (à adapter selon ton jeu de données ou préférences)
# Exemple : favoriser la détection de la classe 0 et sous-pondérer la 3
class_weights = {
    0: 3.0,     # COVID
    1: 1.0,     #Lung_Opacity
    2: 1.0,     # Normal
    3: 1.0      # Viral Pneumonia
}

timestr = time.strftime("%Y%m%d-%H%M%S")
weight_code = ''
for value in class_weights.values():
    weight_code += str(int(value))  #mise en forme x111 pour l'enregistrement

model_name = f'resnet50V2_{timestr}_{weight_code}'
model_path=f'models/{model_name}'

## Instanciation callbacks

# a) Arrêt prématuré
from tensorflow.keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(monitor = 'val_loss',       # Métrique surveillée
                               min_delta = 0.01,           # Changement minimum de la métrique surveillée
                               patience = 5,               # Nombre d'epochs sans amélioration pour arrêt
                               mode = 'min',               # L'entraînement s'arrête quand 'val_loss' cesse de décroître
                               verbose = 1,                # Affichage de l'epoch d'arrêt
                               restore_best_weights = True # Restauration des meilleurs poids (après arrêt)
                              )


## b) Learning_rate réduit
from tensorflow.keras.callbacks import ReduceLROnPlateau
reduce_learning_rate = ReduceLROnPlateau(monitor = 'val_loss',       # Métrique surveillée
                                         min_delta = 0.01,           # Changement minimum de la métrique surveillée
                                         patience = 3,               # Nombre d'epochs sans amélioration pour arrêt
                                         factor = 0.1,               # Learning_rate divisé par 10 (multiplié par 0.1)
                                         cooldown = 4,               # Pause de 4 epochs entre 2 cycles
                                         verbose = 1,                # Affichage de l'epoch d'arrêt
                                         restore_best_weights = True # Restauration des meilleurs poids (après arrêt)
                                        )


## c) Sauvegarde auto du modèle durant l'entraînement
from keras import callbacks
from tensorflow.keras.callbacks import ModelCheckpoint
save0 = ModelCheckpoint(f"{model_name}.keras",
                       save_best_only=True, # Only saves when the model is considered the "best"
                       monitor='val_loss', # The metric name to monitor
                       mode='min' # Overwrite elder model when 'val_loss' is min
                      )
save1 = ModelCheckpoint(f"{model_name}.h5",
                       save_best_only=True, # Only saves when the model is considered the "best"
                       monitor='val_loss', # The metric name to monitor
                       mode='min' # Overwrite elder model when 'val_loss' is min
                      )


resnet_history = model.fit(train_ds,
                                  epochs=20,
                                  validation_data=val_ds,
                                  class_weight=class_weights,
                                  callbacks = [early_stopping,
                                               reduce_learning_rate,
                                               save0, save1])
###### On n'utilise pas de batch_size, car déjà inclus dans les générateurs ######

## Extractions des valeurs de perte train/val
train_loss = resnet_history.history["loss"]
val_loss = resnet_history.history["val_loss"]

train_mae = resnet_history.history["accuracy"]
val_mae = resnet_history.history["val_accuracy"]




## Affichage de la fonction de perte
plt.figure(figsize=(20, 8))

# Tracer la perte MSE
plt.subplot(121)
plt.plot(train_loss)
plt.plot(val_loss)
plt.title('sparse_categorical_crossentropy par époque')
plt.ylabel('sparse_categorical_crossentropy')
plt.xlabel('Époque')
plt.legend(['Entraînement', 'Validation'], loc='best')

# Tracer l'erreur absolue moyenne (MAE)
plt.subplot(122)
plt.plot(train_mae)
plt.plot(val_mae)
plt.title('Accuracy par époque')
plt.ylabel('Accuracy')
plt.xlabel('Époque')
plt.legend(['Entraînement', 'Validation'], loc='best')

plt.show()


# VGG

In [None]:
#Librairie d'utilitaires
import os
import numpy as np
import random
import datetime
import tensorflow as tf
from tensorflow.keras.applications import VGG19
from tensorflow.keras.optimizers import Adam, Adamax
from tensorflow.keras.layers import GlobalAveragePooling2D, Dropout, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard
from sklearn.metrics import classification_report, accuracy_score, recall_score, confusion_matrix, f1_score
from collections import Counter
from tensorflow.keras import Input
import seaborn as sns
import matplotlib.pyplot as plt
import shap
import pickle
# Définir les catégories et les paramètres
categories = ["COVID", "Normal", "Lung_Opacity", "Viral_Pneumonia"] # Les quatres classes du jeu de données
base_dir = ... # Chemain d'accès au dossiers des images
# Paramètres
img_size = (224, 224)
batch_size = 32
seed = 42
AUTOTUNE = tf.data.AUTOTUNE
# Chargement des datasets
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(base_dir, "train"),
    labels='inferred',
    label_mode='categorical',
    batch_size=batch_size,
    image_size=img_size,
    seed=seed,
    shuffle=True
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(base_dir, "val"),
    labels='inferred',
    label_mode='categorical',
    batch_size=batch_size,
    image_size=img_size,
    seed=seed,
    shuffle=False
)
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(base_dir, "test"),
    labels='inferred',
    label_mode='categorical',
    batch_size=batch_size,
    image_size=img_size,
    seed=seed,
    shuffle=False
)
#  Appliquer le prétraitement VGG19
preprocess_input = tf.keras.applications.vgg19.preprocess_input
train_ds = train_ds.map(lambda x, y: (preprocess_input(x), y)).prefetch(AUTOTUNE)
val_ds = val_ds.map(lambda x, y: (preprocess_input(x), y)).prefetch(AUTOTUNE)
test_ds = test_ds.map(lambda x, y: (preprocess_input(x), y)).prefetch(AUTOTUNE)
# Mapping des classes et ponderation, pour application de poids d'entraienment différent entre les classes
class_mapping = {name: idx for idx, name in enumerate(categories)}
# Poids à appliquer aux différentes classes
class_weights = {
    class_mapping['Viral_Pneumonia']: 1.0,
    class_mapping['Lung_Opacity']: 1.0,
    class_mapping['Normal']: 1.0,
    class_mapping['COVID']: 1.0
}
# Création de la fonction du modèle avec VGG19
def build_model(input_shape=(224, 224, 3), num_classes=len(categories)):
    base_model = VGG19(weights='imagenet', include_top=False, input_shape=input_shape, pooling="max")
    # Gèle toutes les couches du modèle de base
    for layer in base_model.layers:
        layer.trainable = False
    # Déverrouille uniquement la couche 'block5_conv4'
    for layer in base_model.layers:
        if layer.name == 'block5_conv4':
            layer.trainable = True
            print(f"Déverrouillage de : {layer.name}")
    x = base_model.output # permet la lecture des couches pour une analyse d'interprétabilité
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=outputs)
    # Recompile le modèle après modification des couches entraînables
    model.compile(optimizer=Adamax(learning_rate=0.0001),  # Learning rate plus bas pour le fine-tuning
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model
# Fonction des callback
def get_callbacks():
    early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, min_lr=1e-6, verbose=1)
    checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, verbose=1)
    log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=1)
    return [early_stop, reduce_lr, checkpoint, tensorboard]
#  Entraînement du modèle
model = build_model()
callbacks = get_callbacks()
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=30,
    callbacks=callbacks,
    class_weight=class_weights
)

# Interprétabilité Grad-CAM

In [None]:
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    grad_model = tf.keras.models.Model(
        [model.inputs], 
        [model.get_layer(last_conv_layer_name).output, model.output]
    )

    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(predictions[0])
        class_channel = predictions[:, pred_index]

    grads = tape.gradient(class_channel, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    conv_outputs = conv_outputs[0]
    heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    heatmap = np.maximum(heatmap, 0) / np.max(heatmap)
    return heatmap


# Dé-unbatch et convertit en liste
all_test_data = list(test_ds.unbatch().as_numpy_iterator())

# Tire 100 échantillons au hasard (ou moins si moins de 100 dans le dataset)
selected_samples = random.sample(all_test_data, k=min(100, len(all_test_data)))

# Sépare images et labels
test_images = [img for img, label in selected_samples]
test_labels = [np.argmax(label) for img, label in selected_samples]

test_images = np.array(test_images)
test_labels = np.array(test_labels)

# Prédictions sur ces images
pred_probs = model.predict(test_images)
pred_labels = np.argmax(pred_probs, axis=1)

# Trouver 2 bonnes et 2 mauvaises
correct_idx = np.where(pred_labels == test_labels)[0]
incorrect_idx = np.where(pred_labels != test_labels)[0]


selected_correct = np.random.choice(correct_idx, size=min(2, len(correct_idx)), replace=False)
selected_incorrect = np.random.choice(incorrect_idx, size=min(2, len(incorrect_idx)), replace=False)
# Fusion
selected_idx = np.concatenate([selected_correct, selected_incorrect])

# Grad-CAM et affichage
last_conv_layer_name = 'conv5_block3_3_conv'  # pour ResNet50V2

fig, axes = plt.subplots(4, 2, figsize=(10, 16))
axes = axes.flatten()

for i, idx in enumerate(selected_idx):
    img_array = np.expand_dims(test_images[idx], axis=0)

    heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name)

    # Superposer heatmap à l’image
    img = test_images[idx].astype('uint8')
    heatmap_resized = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap_colored = cv2.applyColorMap(np.uint8(255*heatmap_resized), cv2.COLORMAP_JET)
    superimposed_img = cv2.addWeighted(cv2.cvtColor(img, cv2.COLOR_RGB2BGR), 0.6, heatmap_colored, 0.4, 0)

    
    ax = axes[2*i]
    ax.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
    ax.set_title(f"Réel: {class_names[test_labels[idx]]}, Prédit: {class_names[pred_labels[idx]]}")
    ax.axis('off')

    ax = axes[2*i+1]
    probs = pred_probs[idx]
    classes = [f'Classe {i}' for i in range(len(probs))]
    ax.bar(class_names, probs, color='skyblue')
    ax.set_title(f'Probabilités pour l\'image {idx}')
    ax.set_ylabel('Probabilité')

# if not os.path.exists(f"{model_path}/gradcam"):
#     os.makedirs(f"{model_path}/gradcam")
# timestamp = time.strftime("%Y%m%d-%H%M%S")
# plt.savefig(f"{model_path}/gradcam/gradcam_{timestamp}.png")