# **1. Importer les libraries**

In [1]:
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
from tf_explain.core.grad_cam import GradCAM
from keras.models import load_model
import os
import glob




# **2. Définition fonction**

In [2]:
def get_images_in_folder(folder_path, extensions=['jpg', 'jpeg', 'png', 'gif']):
    # Create a list to store the image file paths
    image_files = []

    # Use glob to match files with specified extensions
    for extension in extensions:
        pattern = os.path.join(folder_path, f'*.{extension}')
        image_files.extend(glob.glob(pattern))

    return image_files

# **3. Charger votre modèle**

In [3]:
# Load pretrained model or your own
model = load_model("../4 - Modele/Groupe9_DB3_VGG16_30_16.h5")





# **4. Définir les classes à expliquer**

In [4]:
classNames = ["fire", "no_fire", "start_fire"] 

# **5. Expliquer le modèle avec la méthode XAI "GradCAM"**

In [5]:
for className in classNames:
  input_path = f"../2 - Traitement temps reel/images/{className}"
  output_path = f"images/{className}"
  for image_path in get_images_in_folder(input_path):
      # Load to the correct format and predict the current image
      img0 = tf.keras.preprocessing.image.load_img(image_path, target_size=(model.input_shape[1], model.input_shape[2]))
      img = tf.keras.preprocessing.image.img_to_array(img0)

      data = ([img], None)
      xy=np.expand_dims(img,axis=0)
      myTuple=(xy,None)

      # Start explainer
      explainer = GradCAM()
      grid = explainer.explain(myTuple, model, class_index=1)
      temp = np.concatenate((img,grid),axis=1)
      fig, axs = plt.subplots(1, 2)
      axs[0].imshow(img.astype(np.uint8))
      axs[0].set_title("input")
      axs[1].imshow(grid.astype(np.uint8))
      axs[1].set_title("XAI")

      fig.set_figheight(15)
      fig.set_figwidth(15)
        
      # Save the figure with a unique filename
      output_filename = os.path.join(output_path, f"{os.path.basename(image_path)}_XAI.png")
      plt.savefig(output_filename)

      # Close the figure to avoid memory leaks
      plt.close()