# üóëÔ∏è Trash-Buddy Model Interpretability & Explainability

## Overview
This notebook handles the seventh step of the Trash-Buddy pipeline: **Model Interpretability & Explainability**. We analyze and visualize how the model makes predictions by:

- Grad-CAM visualizations (Class Activation Maps)
- Feature importance analysis
- SHAP values for model explanations
- Attention maps
- Misclassification analysis with explanations

---

## üìä Prerequisites

From previous steps, we have:
- **Trained model** saved as checkpoint (from Step 3)
- **Test dataset** for analysis (from Step 2)
- **Label classes** and encoders (from Step 2)
- **Evaluation results** (from Step 4)

---

## üéØ Objectives
1. Load trained model and sample images
2. Generate Grad-CAM visualizations
3. Analyze feature importance
4. Compute SHAP values for explanations
5. Visualize attention maps
6. Analyze misclassifications with explanations


In [None]:
# Import necessary libraries
import os
import json
import pandas as pd
import numpy as np
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from sklearn.preprocessing import LabelEncoder
import warnings
warnings.filterwarnings('ignore')

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.autograd import Variable

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
print(f"Using device: {device}")

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print(" Libraries imported successfully!")


## üìÅ Load Model and Data

Load the trained model and prepare sample images for analysis.


In [None]:
# Define paths
models_dir = Path('models')
processed_data_dir = Path('processed_data')

# Load model checkpoint
model_checkpoint_path = list(models_dir.glob('best_model_*.pth'))[0]
checkpoint = torch.load(model_checkpoint_path, map_location=device, weights_only=False)

# Get model configuration
model_config = checkpoint.get('config', {})
MODEL_NAME = model_config.get('MODEL_NAME','resnet50')
IMAGE_SIZE = model_config.get('IMAGE_SIZE', 224)
IMAGENET_MEAN = model_config.get('IMAGENET_MEAN', [0.485, 0.456, 0.406])
IMAGENET_STD = model_config.get('IMAGENET_STD', [0.229, 0.224, 0.225])

# Load label classes
label_classes = np.load(processed_data_dir /'label_classes.npy', allow_pickle=True)
NUM_CLASSES = len(label_classes)

# Create label encoder
label_encoder = LabelEncoder()
label_encoder.classes_ = label_classes

# Load test split
df_test = pd.read_csv(processed_data_dir /'test_split.csv')

print("=" * 80)
print("DATA LOADING")
print("=" * 80)
print(f"\n Model loaded: {MODEL_NAME}")
print(f" Trained for {checkpoint['epoch']} epochs")
print(f" Validation Accuracy: {checkpoint['val_acc']:.2f}%")
print(f"\n Test set loaded: {len(df_test):,} images")
print(f"\n Label classes loaded: {NUM_CLASSES} classes")


## üîß Model Architecture Recreation

Recreate the model architecture and load weights.


In [None]:
# Function to create model architecture
def create_model(model_name='resnet50', num_classes=18, pretrained=False):
"""Create model architecture"""
 try:
 if model_name =='resnet50':
 weights = models.ResNet50_Weights.DEFAULT if pretrained else None
 model = models.resnet50(weights=weights)
 num_features = model.fc.in_features
 model.fc = nn.Linear(num_features, num_classes)
 elif model_name =='efficientnet_b0':
 weights = models.EfficientNet_B0_Weights.DEFAULT if pretrained else None
 model = models.efficientnet_b0(weights=weights)
 num_features = model.classifier[1].in_features
 model.classifier[1] = nn.Linear(num_features, num_classes)
 elif model_name =='mobilenet_v2':
 weights = models.MobileNet_V2_Weights.DEFAULT if pretrained else None
 model = models.mobilenet_v2(weights=weights)
 num_features = model.classifier[1].in_features
 model.classifier[1] = nn.Linear(num_features, num_classes)
 else:
 raise ValueError(f"Unknown model: {model_name}")
 except AttributeError:
 # Fallback for older torchvision versions
 if model_name =='resnet50':
 model = models.resnet50(pretrained=pretrained)
 model.fc = nn.Linear(model.fc.in_features, num_classes)
 elif model_name =='efficientnet_b0':
 model = models.efficientnet_b0(pretrained=pretrained)
 model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
 elif model_name =='mobilenet_v2':
 model = models.mobilenet_v2(pretrained=pretrained)
 model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
 return model

# Create and load model
model = create_model(MODEL_NAME, NUM_CLASSES, pretrained=False)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

print(f" Model architecture recreated and weights loaded")
print(f" Model: {MODEL_NAME}")
print(f" Number of classes: {NUM_CLASSES}")


## üéØ Grad-CAM Visualizations

Generate Grad-CAM (Gradient-weighted Class Activation Mapping) visualizations to see which parts of the image the model focuses on.


In [None]:
# Grad-CAM implementation
class GradCAM:
 def __init__(self, model, target_layer):
 self.model = model
 self.target_layer = target_layer
 self.gradients = None
 self.activations = None
 
 # Register hooks
 self.target_layer.register_forward_hook(self.save_activation)
 self.target_layer.register_full_backward_hook(self.save_gradient)
 
 def save_activation(self, module, input, output):
 self.activations = output
 
 def save_gradient(self, module, grad_input, grad_output):
 self.gradients = grad_output[0]
 
 def generate_cam(self, input_image, class_idx=None):
 # Forward pass
 output = self.model(input_image)
 
 if class_idx is None:
 class_idx = output.argmax(dim=1)
 
 # Backward pass
 self.model.zero_grad()
 class_loss = output[0, class_idx]
 class_loss.backward()
 
 # Calculate weights
 gradients = self.gradients[0]
 activations = self.activations[0]
 weights = torch.mean(gradients, dim=(1, 2), keepdim=True)
 
 # Generate CAM
 cam = torch.sum(weights * activations, dim=0)
 cam = F.relu(cam)
 cam = cam /(cam.max() + 1e-8) # Normalize
 
 return cam.cpu().numpy(), class_idx.item()

# Get target layer based on model architecture
if MODEL_NAME =='resnet50':
 target_layer = model.layer4[-1].conv3
elif MODEL_NAME =='efficientnet_b0':
 target_layer = model.features[-1]
elif MODEL_NAME =='mobilenet_v2':
 target_layer = model.features[-1]
else:
 target_layer = list(model.children())[-2] # Fallback

gradcam = GradCAM(model, target_layer)

print(" Grad-CAM class initialized")
print(f" Target layer: {target_layer}")


In [None]:
# Visualize Grad-CAM for sample images
def visualize_gradcam(image_path, model, gradcam, transform, label_classes, device):
"""Generate and visualize Grad-CAM for an image"""
 # Load and preprocess image
 img = Image.open(image_path).convert('RGB')
 original_img = np.array(img)
 
 # Transform image
 input_tensor = transform(img).unsqueeze(0).to(device)
 
 # Get prediction
 with torch.no_grad():
 output = model(input_tensor)
 probs = F.softmax(output, dim=1)
 pred_class = output.argmax(dim=1).item()
 confidence = probs[0, pred_class].item()
 
 # Generate CAM
 cam, _ = gradcam.generate_cam(input_tensor, pred_class)
 
 # Resize CAM to original image size
 cam_resized = cv2.resize(cam,(original_img.shape[1], original_img.shape[0]))
 cam_resized = np.uint8(255 * cam_resized)
 heatmap = cv2.applyColorMap(cam_resized, cv2.COLORMAP_JET)
 heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
 
 # Overlay heatmap on original image
 overlayed = cv2.addWeighted(original_img, 0.6, heatmap, 0.4, 0)
 
 return original_img, overlayed, pred_class, confidence

# Test transforms
test_transform = transforms.Compose([
 transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
 transforms.ToTensor(),
 transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# Select sample images for visualization
num_samples = 6
sample_indices = np.random.choice(len(df_test), num_samples, replace=False)
sample_images = df_test.iloc[sample_indices]

print("=" * 80)
print("GRAD-CAM VISUALIZATION")
print("=" * 80)
print(f"\n Generating Grad-CAM for {num_samples} sample images...")

# Create visualization
fig, axes = plt.subplots(num_samples, 2, figsize=(14, 4*num_samples))
fig.suptitle('Grad-CAM Visualizations', fontsize=18, fontweight='bold', y=0.995)

for idx,(_, row) in enumerate(sample_images.iterrows()):
 img_path = row['image_path']
 true_label = row['subcategory']
 
 try:
 original, overlayed, pred_class, confidence = visualize_gradcam(
 img_path, model, gradcam, test_transform, label_classes, device
 )
 
 pred_label = label_classes[pred_class]
 
 # Original image
 axes[idx, 0].imshow(original)
 axes[idx, 0].axis('off')
 axes[idx, 0].set_title(f'Original\\nTrue: {true_label}', fontsize=10, fontweight='bold')
 
 # Grad-CAM overlay
 axes[idx, 1].imshow(overlayed)
 axes[idx, 1].axis('off')
 color ='green' if pred_label == true_label else'red'
 axes[idx, 1].set_title(f'Grad-CAM\\nPred: {pred_label}({confidence*100:.1f}%)\\nTrue: {true_label}', 
 fontsize=10, fontweight='bold', color=color)
 except Exception as e:
 axes[idx, 0].text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center')
 axes[idx, 0].axis('off')
 axes[idx, 1].axis('off')

plt.tight_layout()
plt.show()

print(" Grad-CAM visualizations generated")


## üìã Summary

### Key Findings
- **Grad-CAM Visualizations**: Show which image regions the model focuses on
- **Feature Importance**: Identify important features for classification
- **Model Explanations**: Understand model decision-making process

### Next Steps
1. **SHAP Integration**: Install and use SHAP library for detailed explanations
2. **Feature Importance**: Analyze layer-wise feature importance
3. **Misclassification Analysis**: Deep dive into model errors with explanations

---

## üéØ Notes

### Interpretability Techniques
- **Grad-CAM**: Visualizes attention regions in images
- **SHAP Values**: Provides feature-level explanations
- **Feature Importance**: Identifies important model components

### Usage
- Use Grad-CAM to understand model focus areas
- Analyze misclassifications with visual explanations
- Validate model behavior on edge cases
