<a href="https://colab.research.google.com/github/LaurentTits/ResponsibleTrainingDeepLearning/blob/main/SimpleTraining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Vérification de la sélection du GPU**

In [None]:
!nvidia-smi

# **1. Installation de pytorch lightning**

In [None]:
!pip install pytorch_lightning torchsummary pytorch_bench captum

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.5.0.post0-py3-none-any.whl.metadata (21 kB)
Collecting pytorch_bench
  Downloading pytorch_bench-0.1.2-py3-none-any.whl.metadata (3.5 kB)
Collecting captum
  Downloading captum-0.7.0-py3-none-any.whl.metadata (26 kB)
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.6.2-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.14.0-py3-none-any.whl.metadata (5.6 kB)
Collecting torchprofile (from pytorch_bench)
  Downloading torchprofile-0.0.4-py3-none-any.whl.metadata (303 bytes)
Collecting codecarbon (from pytorch_bench)
  Downloading codecarbon-2.8.3-py3-none-any.whl.metadata (8.7 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.1.0->pytorch_lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from

# **2. Chargement des librairies**

In [None]:
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchvision import transforms
from torchmetrics.classification import MulticlassConfusionMatrix, Accuracy
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchsummary import summary
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from torchvision import transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset, Subset
from pytorch_bench import benchmark
import numpy as np
import os
import tarfile
import json
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from google.colab import drive
import random
import wandb
import seaborn as sns

# **1. Google drive**

In [None]:
# Mount Google Drive
drive.mount('/content/drive')

# Define dataset paths
dataset_folder = '/content/drive/My Drive/ResponsibleTraining/Datasets'
images_tar_path = os.path.join(dataset_folder, 'ILSVRC2012_img_val.tar')
bbox_tgz_path = os.path.join(dataset_folder, 'ILSVRC2012_bbox_val_v3.tgz')
extract_path = '/content/ILSVRC2012_val'
bbox_extract_path = '/content/ILSVRC2012_bbox'
bbox_annotation_path = f'{bbox_extract_path}/val'

def extract_tar_files(file_tar_path, mode, extract_path):
  """Extracts tar/tgz files if not already extracted."""
  if not os.path.exists(extract_path):
      os.makedirs(extract_path, exist_ok=True)
      with tarfile.open(file_tar_path, mode) as tar:
          tar.extractall(path=extract_path)

  # Count extracted elements
  num_files = len(os.listdir(extract_path))
  print(f"Total extracted elements (file or folder) in the extract path: {num_files}")

# Extract images if not already extracted
extract_tar_files(images_tar_path, 'r', extract_path)
# Get sorted list of image files
image_files = sorted([f for f in os.listdir(extract_path) if f.endswith('.JPEG')])

# Extract bounding boxes if not already extracted
extract_tar_files(bbox_tgz_path, 'r:gz', bbox_extract_path)

# Load json with class labels for imagenet
with open(os.path.join(dataset_folder, 'imagenet_class_index.json'), 'r') as f:
    class_mapping = json.load(f)

# Define the directory to save the best model
models_folder = '/content/drive/My Drive/ResponsibleTraining/Models'
os.makedirs(models_folder, exist_ok=True)

# Convert WordNet ID to readable class name
def get_class_name(wnid):
    for key, value in class_mapping.items():
        if value[0] == wnid:
            return value[1].replace('_', ' ')
    return 'Unknown'

Connexion à wandb

In [None]:
with open(os.path.join(dataset_folder, 'wandb_key.txt'), 'r') as f:
    wandb_key = f.read().strip()

!pip install wandb
!wandb login {wandb_key}

# Initialize wandb
wandb.init(project="ResponsibleTraining", config={
    "learning_rate": Learning_rate,
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "model": "mobilenet_v2"
})

In [None]:
# Display first 9 images with bounding boxes (3x3 grid)
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
axes = axes.ravel()

for i in range(9):
    image_path = os.path.join(extract_path, image_files[i])
    image_name = os.path.splitext(image_files[i])[0]  # Remove .JPEG
    annotation_path = os.path.join(bbox_annotation_path, f'{image_name}.xml')

    if not os.path.exists(annotation_path):
        print(f'Annotation not found for {image_files[i]}')
        continue

    # Load image
    img = Image.open(image_path)

    # Parse XML annotation
    tree = ET.parse(annotation_path)
    root = tree.getroot()

    # Get class name
    wnid = root.find('object/name').text if root.find('object/name') is not None else 'Unknown'
    class_name = get_class_name(wnid)

    # Plot image
    axes[i].imshow(img)
    axes[i].set_title(f'{image_files[i]}\nClass: {class_name}')

    # Draw bounding boxes
    for obj in root.findall('object'):
        bbox = obj.find('bndbox')
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)

        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='r', facecolor='none')
        axes[i].add_patch(rect)

    axes[i].axis('off')

plt.tight_layout()
plt.show()


# **3. Paramètres**

In [None]:
NUM_CLASSES = 1000
BATCH_SIZE=32 #@param [1,2,4,8,16,32,64,128] {type:"raw"}
EPOCHS=5 #@param [1,5, 10,20,50,100,200] {type:"raw"}
Learning_rate = 0.001 #@param [0.1, 0.01,0.02,0.05,0.001,0.002,0.005] {type:"raw"}
Train_split = 0.8  # @param [0.7, 0.8, 0.9] {type:"raw"}
# Percentage of remaining data allocated to test
Test_ratio = 0.5  # @param [0.3, 0.4, 0.5, 0.6, 0.7] {type:"raw"}
Val_split = (1 - Train_split) * (1 - Test_ratio)
Test_split = 1 - Train_split - Val_split
Img_size = 224 #@param [224,299] {type:"raw"}
Accelerator= "auto" #@param ["cpu","gpu","auto"]
num_workers = 4 #@param [1,2,4,8,16] {type:"raw"}
# Weight for localization loss
lambda_loc = 0.1  #@param [0.1, 0.2, 0.5, 0.9, 1] {type:"raw"}
# size of the dataset (we don't take the full dataset for faster training)
subsample_size = 1000 #@param [100,1000,5000,10000,50000] {type:"raw"}
DATA_DIR="."
LOG_DIR="logs/"

# **4. Création de la classe pour le modèle**

In [None]:
class CNNSimpleModel(pl.LightningModule):
    def __init__(self, model, num_classes=NUM_CLASSES, lambda_loc=1.0):
        super().__init__()
        # self.model = model
        # Store original model reference
        self.original_model = model

        # Modify classifier in-place
        in_features = self.original_model.model.classifier[1].in_features
        self.original_model.model.classifier[1] = nn.Linear(in_features, num_classes)

        # Register modified model as submodule
        self.model = self.original_model.model


        self.num_classes = num_classes
        self.lambda_loc = lambda_loc
        # Initialize confusion matrix metric
        self.confusion_matrix = MulticlassConfusionMatrix(num_classes=self.num_classes)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=self.num_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=self.num_classes)
        self.train_accuracy = Accuracy(task="multiclass", num_classes=self.num_classes)

    def forward(self, x):
        # return self.model(x)
        logits, features = self.model(x)
        return logits, features

    def bbox_to_mask(self, bboxes, height=Img_size, width=Img_size):
        """
        Convert bounding box coordinates to a binary mask.
        Args:
            bboxes: Tensor of shape [batch_size, 4] containing bounding box coordinates.
            height: Height of the mask.
            width: Width of the mask.
        Returns:
            mask: Tensor of shape [batch_size, height, width] containing binary masks.
        """
        batch_size = bboxes.size(0)
        mask = torch.zeros(batch_size, height, width, device=bboxes.device)
        for i in range(batch_size):
            xmin, ymin, xmax, ymax = bboxes[i]
            # Convert bounding box coordinates to integers
            xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
            mask[i, ymin:ymax, xmin:xmax] = 1.0
        return mask

    def localization_loss(self, features, gt_bboxes):
        """
        Compute localization loss using feature maps and bounding box masks.
        Args:
            features: Tensor of shape [batch_size, channels, height, width].
            gt_bboxes: Tensor of shape [batch_size, 4] containing bounding box coordinates.
        Returns:
            loc_loss: Localization loss.
        """
        # Convert bounding boxes to binary masks
        gt_masks = self.bbox_to_mask(gt_bboxes)
        # Resize features to match mask dimensions
        features_resized = F.interpolate(features, size=(Img_size, Img_size), mode='bilinear', align_corners=False)
        # Compute localization loss
        loc_loss = F.l1_loss(features_resized.mean(dim=1), gt_masks, reduction='mean')
        return loc_loss

    def training_step(self, batch, batch_idx):
        images, labels, bboxes = batch
        logits, features = self(images)
        loss_cls = F.cross_entropy(logits, labels)
        loss_loc = self.localization_loss(features, bboxes)
        loss_total = loss_cls + self.lambda_loc * loss_loc
        acc = self.train_accuracy(logits.argmax(dim=1), labels)

        # Log training loss
        self.log_dict({'train_loss_cls':loss_cls,'train_loss_loc':loss_loc,'train_loss':loss_total,"train_acc":acc}, on_step=True,prog_bar=True,logger=True, on_epoch=True)
        return loss_total

    def on_train_epoch_end(self):
        self.train_accuracy.reset()

    def validation_step(self, batch, batch_idx):
        images, labels, bboxes = batch
        logits, features = self(images)
        loss_cls = F.cross_entropy(logits, labels)
        loss_loc = self.localization_loss(features, bboxes)
        loss_total = loss_cls + self.lambda_loc * loss_loc
        acc = self.val_accuracy(logits.argmax(dim=1), labels)

        # Log val loss
        self.log_dict({'val_loss_cls':loss_cls,'val_loss_loc':loss_loc,'val_loss':loss_total,"val_acc":acc}, on_step=True,prog_bar=True,logger=True, on_epoch=True)
        return loss_total

    def on_validation_epoch_end(self):
        self.val_accuracy.reset()

    def test_step(self, batch, batch_idx):
        images, labels, bboxes = batch
        print(f"Processing batch {batch_idx+1}")
        logits, features = self(images)
        loss_cls = F.cross_entropy(logits, labels)
        loss_loc = self.localization_loss(features, bboxes)
        loss_total = loss_cls + self.lambda_loc * loss_loc
        acc = self.test_accuracy(logits.argmax(dim=1), labels)

        print(f"Predictions: {logits.argmax(dim=1)[:5]}")  # Print first 5 predictions

        # Log test loss
        self.log_dict({'test_loss_cls':loss_cls,'test_loss_loc':loss_loc,'test_loss':loss_total,"test_acc":acc}, prog_bar=True, on_step=False, on_epoch=True)
        self.confusion_matrix.update(logits.argmax(dim=1), labels)

        #return {"loss": loss_total, "acc": acc}
        return loss_total

    def on_test_end(self):
        print("Test finished!")
        self.test_accuracy.reset()

        # Compute confusion matrix
        print("Generating confusion matrix...")
        cm = self.confusion_matrix.compute().cpu()

        # Sum rows and columns to find most frequent classes
        row_sums = cm.sum(dim=1)
        top_n_indices = row_sums.argsort(descending=True)[:20]  # Get top 20 classes

        # Extract subset of confusion matrix
        cm_subset = cm[top_n_indices][:, top_n_indices]

        # Plot subset
        fig, ax = plt.subplots(figsize=(12, 10))
        sns.heatmap(cm_subset.numpy(), annot=True, fmt="d", ax=ax)
        ax.set_title("Confusion Matrix (Top 20 Classes)")
        plt.show()

        self.confusion_matrix.reset()

    def configure_optimizers(self):
        # return torch.optim.SGD(self.parameters(), lr=0.01)
        optimizer = optim.Adam(self.parameters(), lr=Learning_rate)
        return optimizer

#**5. Création du premier modèle**

In [None]:
# Model MobileNet
class MobilenetModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.num_classes = num_classes
        self.model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
        #self.model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        # Freeze all layers except the last convolutional layer
        for param in self.model.features.parameters():
            param.requires_grad = False
        for param in self.model.features[-1].parameters():
            param.requires_grad = True

        # Replace the classifier head
        self.model.classifier[1] = nn.Linear(self.model.last_channel, num_classes)

    def forward(self, x):
        # Extract feature maps from the last convolutional layer
        features = self.model.features(x)
        # Global average pooling
        pooled_features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
        # Classifier
        logits = self.model.classifier(pooled_features)
        return logits, features

In [None]:
mn1 = MobilenetModel()
mn1.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
summary(mn1,(3,Img_size,Img_size))

# **7. Création des jeux de données d'entraînement, validation et test "Data Loaders"** #

In [None]:
# Custom dataset class
class ImageNetDataset(Dataset):
    def __init__(self, image_folder, annotation_folder, transform=None):
        self.image_folder = image_folder
        self.annotation_folder = annotation_folder
        self.image_files = sorted([f for f in os.listdir(image_folder) if f.endswith('.JPEG')])
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_name = self.image_files[idx]
        image_path = os.path.join(self.image_folder, image_name)
        annotation_path = os.path.join(self.annotation_folder, f'{os.path.splitext(image_name)[0]}.xml')

        img = Image.open(image_path).convert("RGB")
        class_idx = -1
        bbox = None

        if os.path.exists(annotation_path):
            tree = ET.parse(annotation_path)
            root = tree.getroot()
            wnid = root.find('object/name').text if root.find('object/name') is not None else 'Unknown'

            # Convert class name to class index
            for key, value in class_mapping.items():
                if value[1].replace('_', ' ') == get_class_name(wnid):
                    class_idx = int(key)
                    break

            obj = root.find('object/bndbox')
            if obj is not None:
                bbox = [int(obj.find(tag).text) for tag in ['xmin', 'ymin', 'xmax', 'ymax']]
                # Scale bounding box to resized image dimensions
                bbox = [
                    bbox[0] * 224 // img.width,
                    bbox[1] * 224 // img.height,
                    bbox[2] * 224 // img.width,
                    bbox[3] * 224 // img.height
                ]

        if self.transform:
            img = self.transform(img)

        return img, torch.tensor(class_idx, dtype=torch.long), torch.tensor(bbox) if bbox else torch.zeros(4)

# Wrapper to apply different transforms to subsets
class TransformDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        img, class_idx, bbox = self.subset[idx]
        if self.transform:
            img = self.transform(img)
        return img, class_idx, bbox

def create_data_loaders(image_path, bb_path, batch_size=BATCH_SIZE,
                         img_size=Img_size, num_workers=num_workers):
    # Define transforms
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    test_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create main dataset
    main_dataset = ImageNetDataset(image_path, bb_path, transform=None)

    # Subsample main dataset
    subsample_indices = random.sample(range(len(main_dataset)), subsample_size)
    dataset = Subset(main_dataset, subsample_indices)

    # Split indices
    dataset_size = len(dataset)
    indices = np.arange(dataset_size)
    np.random.shuffle(indices)

    train_end = int(Train_split * dataset_size)
    val_end = train_end + int(Val_split * dataset_size)

    train_indices, val_indices, test_indices = indices[:train_end], indices[train_end:val_end], indices[val_end:]

    # Create subsets with transforms
    train_subset = TransformDataset(Subset(dataset, train_indices), transform=train_transform)
    val_subset = TransformDataset(Subset(dataset, val_indices), transform=train_transform)
    test_subset = TransformDataset(Subset(dataset, test_indices), transform=test_transform)

    # Create data loaders
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader

# **8. Définir les hyper-paramètres, EarlyStopping, Checkpoints** #

In [None]:
# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(extract_path, bbox_annotation_path, BATCH_SIZE, Img_size, num_workers)

# Initialize model
model = CNNSimpleModel(mn1)

# Setup callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=models_folder, #Save the best model in Google Drive
    filename='best-checkpoint-mobilenet', #filename='best-checkpoint-mobilenet-{epoch:02d}-{val_loss:.2f}'
    save_top_k=1,
    mode='min'
)
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    mode='min'
)

# Initialize WandbLogger
wandb_logger = WandbLogger(project="ResponsibleTraining")

# Initialize WandbLogger
csv_logger = CSVLogger(LOG_DIR, name="cnn", version='')

# Initialize Trainer
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator=Accelerator,
    log_every_n_steps=1,
    devices=1,
    logger=[wandb_logger, csv_logger],
    callbacks=[checkpoint_callback, early_stop_callback],
)

#**6. Lancement de l'entraintement**

In [None]:
trainer.fit(model, train_loader, val_loader)

# **10. Evaluer le modèle** ##

In [None]:
# trainer.test(model, dataloaders=test_loader)
best_model_path = os.path.join(models_folder, "best-checkpoint-mobilenet.ckpt")
trainer.test(dataloaders=test_loader,ckpt_path=best_model_path)

# **11 Exporter le modèle en .jit**

In [None]:
best_model = CNNSimpleModel.load_from_checkpoint(best_model_path, model=MobilenetModel())
best_model.eval()  # Set the model to evaluation mode

jit_model = best_model.to_torchscript()
jit_save_path = os.path.join(models_folder, "best-checkpoint-mobilenet_jit.pth")

torch.jit.save(jit_model, jit_save_path)


# **11. Afficher les courbes d'entrainement avec la fonction "plot_metrics"** ##

In [None]:
def plot_metrics(log_folder):
  import pandas as pd
  import matplotlib.pyplot as plt

  # Load the CSV file generated by CSVLogger
  df = pd.read_csv(f'{LOG_DIR}/{log_folder}/metrics.csv')
  train_df = df[df['train_loss_epoch'].notna()]
  val_df = df[df['val_loss_epoch'].notna()]

  # Plot training loss
  plt.plot(train_df['epoch'], train_df['train_loss_epoch'], label='Train Loss')
  plt.plot(val_df['epoch'], val_df['val_loss_epoch'], label='Validation Loss')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.title('Training & Validation Loss')
  plt.legend()
  plt.grid(True)
  plt.show()

  # Plot training accuracy
  plt.plot(train_df['epoch'], train_df['train_acc_epoch'], label='Train Acc')
  plt.plot(val_df['epoch'], val_df['val_acc_epoch'], label='Val Acc')
  plt.xlabel('Epoch')
  plt.ylabel('Accuracy')
  plt.title('Training & Validation Accuracy')
  plt.legend()
  plt.grid(True)
  plt.show()

In [None]:
plot_metrics("cnn")

# **12. Evaluation du modèle selon différentes métriques de notre librarie "Benchmark"** #

In [None]:
example_input = torch.randn(1, 3, Img_size, Img_size)
results = benchmark(best_model, example_input)

# log to wandb
wandb.log({
    "benchmark": results
})
wandb.finish()


# **13. Tester le modèle avec une image de test de votre choix**

In [None]:
from captum.attr import LayerGradCam
from pytorch_lightning import LightningModule
import cv2

def get_last_conv_layer(model: LightningModule):
    """Find the last convolutional layer in the model"""
    last_conv = None
    for module in model.modules():
        if isinstance(module, torch.nn.Conv2d):
            last_conv = module
    if last_conv is None:
        raise ValueError("No convolutional layer found in model")
    return last_conv

def generate_gradcam(model, input_tensor, target_class, last_conv_layer):
    """Generate Grad-CAM heatmap using Captum"""
    gradcam = LayerGradCam(model, last_conv_layer)
    attribution = gradcam.attribute(input_tensor, target=target_class, relu_attributions=True)
    return attribution[0].cpu().detach().numpy()

def generate_cam(features, weights, class_idx):
    """Generate Class Activation Map (CAM)"""
    class_weights = weights[class_idx]
    cam = torch.matmul(class_weights, features.view(features.size(1), -1))
    cam = cam.view(features.size(2), features.size(3))
    cam = (cam - cam.min()) / (cam.max() - cam.min())
    return cam.cpu().numpy()

def visualize_results(image, bbox, cam, gradcam, predicted_class, true_class, class_mapping):
    """Visualize results with 3 subplots"""
    # Denormalize image
    image = image.permute(1, 2, 0).cpu().numpy()
    image = (image * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
    image = np.clip(image, 0, 1)

    # Create figure
    fig, axes = plt.subplots(1, 3, figsize=(24, 8))

    # Original image with bbox
    axes[0].imshow(image)
    if bbox.sum() > 0:
        xmin, ymin, xmax, ymax = bbox
        rect = plt.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin,
                           linewidth=2, edgecolor='r', facecolor='none')
        axes[0].add_patch(rect)
    axes[0].set_title(f"True: {class_mapping[str(true_class)][1]}")
    axes[0].axis('off')

    # CAM overlay
    axes[1].imshow(image)
    axes[1].imshow(cam, cmap='jet', alpha=0.5)
    axes[1].set_title(f"Predicted: {class_mapping[str(predicted_class)][1]}\nCAM")
    axes[1].axis('off')

    # Grad-CAM overlay
    axes[2].imshow(image)
    axes[2].imshow(gradcam, cmap='jet', alpha=0.5)
    axes[2].set_title(f"Predicted: {class_mapping[str(predicted_class)][1]}\nGrad-CAM")
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

# Get random sample from dataset
dataset = ImageNetDataset(extract_path, bbox_annotation_path, transform=None)
random_idx = random.randint(0, len(dataset)-1)
image, true_class, bbox = dataset[random_idx]

# Preprocess image
transform = transforms.Compose([
    transforms.Resize((Img_size, Img_size)),
    transforms.ToTensor(),
])
input_tensor = transform(image).unsqueeze(0).to(best_model.device)

# Get predictions and features
best_model.eval()
with torch.no_grad():
    logits, features = best_model(input_tensor)
    probs = torch.nn.functional.softmax(logits, dim=1)
    predicted_class = torch.argmax(probs).item()

# Generate CAM
weights = best_model.model.classifier[1].weight.data
cam = generate_cam(features, weights, predicted_class)

# Generate Grad-CAM
last_conv_layer = get_last_conv_layer(best_model.model)
gradcam = generate_gradcam(best_model.model, input_tensor, predicted_class, last_conv_layer)

# Resize Grad-CAM to match image size
gradcam = cv2.resize(gradcam[0], (Img_size, Img_size))
gradcam = np.maximum(gradcam, 0)
gradcam = gradcam / gradcam.max()

# Visualize results
visualize_results(image, bbox, cam, gradcam, predicted_class, true_class, class_mapping)

# Print probabilities
print("Prediction probabilities:")
for idx, prob in enumerate(probs[0]):
    print(f"{class_mapping[str(idx)][1]:<25} {prob.item()*100:.2f}%")
