# AI Model Evaluation

This notebook evaluates the `WasteClassifier` model trained on the mix of Realwaste and Trashnet datasets. 
It includes visualizations of predictions with Grad-CAM heatmaps, confusion matrices, and accuracy metrics.

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import os
import random
import sys

# Fix path to import model from parent directory
current_dir = os.path.dirname(os.path.abspath(''))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

from utils.gradcam import GradCAM

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

ModuleNotFoundError: No module named 'utils'

: 

## 1. Model Definition and Loading
We reconstruct the model architecture used during training and load the best saved weights.

In [None]:
class WasteClassifier(nn.Module):
    def __init__(self, num_classes=9, pretrained=False):
        super().__init__()
        # usage of weights instead of pretrained=True to avoid potential warnings if libraries updated
        self.backbone = models.resnet18(weights=None) 
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.backbone(x)

# Initialize model
model = WasteClassifier(num_classes=9)
model = model.to(device)

# Load weights
weights_path = '../pretrained/best_waste_model.pth'
if os.path.exists(weights_path):
    model.load_state_dict(torch.load(weights_path, map_location=device))
    print("Model weights loaded successfully!")
else:
    print(f"Warning: {weights_path} not found. Please ensure the model file exists.")
    
model.eval();

## 2. Dataset and Transforms
We define the validation transforms and dataset classes.

In [None]:
# Defines classes based on directory structure (Alphabetical order by default in ImageFolder)
data_dir = "../dataset"
classes = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
print(f"Classes found: {classes}")

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])

# Create dataset reference
full_dataset = ImageFolder(root=data_dir, transform=val_transforms)

# note: the original split was random. 
# For specific evaluation, we can just grab a random subset or use the whole thing if small enough.
# Here we'll take a subset for quicker demonstration of evaluation.
indices = list(range(len(full_dataset)))
random.shuffle(indices)
eval_subset_indices = indices[:500] # Evaluate on 500 images
eval_dataset = Subset(full_dataset, eval_subset_indices)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False)

## 3. Visual Predictions with Grad-CAM
Let's see the model in action by predicting on random images and overlaying Grad-CAM heatmaps.

In [None]:
def denormalize(tensor):
    tensor = tensor.clone().detach().cpu()
    for t, m, s in zip(tensor, IMAGENET_MEAN, IMAGENET_STD):
        t.mul_(s).add_(m)
    tensor = torch.clamp(tensor, 0, 1)
    return tensor.permute(1, 2, 0).numpy()

def visualize_predictions(dataset, model, num_images=6):
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    indices = random.sample(range(len(dataset)), num_images)
    
    # Initialize GradCAM
    target_layer = model.backbone.layer4[-1]
    grad_cam = GradCAM(model, target_layer)
    
    for i, idx in enumerate(indices):
        image, label = dataset[idx]
        
        # Prediction
        input_tensor = image.unsqueeze(0).to(device)
        
        # Generate CAM
        cam = grad_cam.generate_cam(input_tensor)
        
        with torch.no_grad():
            output = model(input_tensor)
            _, predicted = torch.max(output, 1)
            predicted_idx = predicted.item()
            
        # Plotting
        ax = axes[i]
        img_display_np = denormalize(image)
        img_display_pil = Image.fromarray((img_display_np * 255).astype(np.uint8))
        
        # Overlay
        overlayed = grad_cam.overlay_heatmap(img_display_pil, cam, alpha=0.5)
        
        ax.imshow(overlayed)
        
        color = 'green' if predicted_idx == label else 'red'
        ax.set_title(f"True: {classes[label]}\nPred: {classes[predicted_idx]}", color=color, fontsize=12, fontweight='bold')
        ax.axis('off')
        
    plt.tight_layout()
    plt.show()

visualize_predictions(full_dataset, model)

## 4. Quantitative Evaluation
Confusion Matrix and Classification Report.

In [None]:
def evaluate_model(model, loader):
    all_preds = []
    all_labels = []
    
    print("Evaluating...")
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
            
    return np.array(all_labels), np.array(all_preds)

y_true, y_pred = evaluate_model(model, eval_loader)

# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

# Classification Report
print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=classes))