In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
import requests
from io import BytesIO
import os
from pytorch_grad_cam import GradCAM, AblationCAM, ScoreCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image

In [None]:
class GradCAMAnalyzer:
    def __init__(self):
        # Load ResNet50 model trained on ImageNet
        self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.model.eval()
        
        # Define image preprocessing
        self.preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        # Define target layer (last convolutional layer)
        self.target_layers = [self.model.layer4[-1]]
        
        # Image URLs from the assignment
        self.image_urls = {
            'West_Highland_white_terrier': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n02098286_West_Highland_white_terrier.JPEG',
            'American_coot': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n02018207_American_coot.JPEG',
            'racer': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n04037443_racer.JPEG',
            'flamingo': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n02007558_flamingo.JPEG',
            'kite': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n01608432_kite.JPEG',
            'goldfish': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n01443537_goldfish.JPEG',
            'tiger_shark': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n01491361_tiger_shark.JPEG',
            'vulture': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n01616318_vulture.JPEG',
            'common_iguana': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n01677366_common_iguana.JPEG',
            'orange': 'https://github.com/EliSchwartz/imagenet-sample-images/raw/master/n07747607_orange.JPEG'
        }
        
        # Load ImageNet class labels
        self.class_labels = self.load_imagenet_labels()
        
    def load_imagenet_labels(self):
        """Load ImageNet class labels"""
        try:
            # Try to load from a common source
            url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
            response = requests.get(url)
            labels = response.text.strip().split('\n')
            return labels
        except:
            # Fallback: create a simple mapping for the first 1000 classes
            return [f"class_{i}" for i in range(1000)]
    
    def download_image(self, url):
        """Download image from URL"""
        try:
            response = requests.get(url)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content)).convert('RGB')
            return image
        except Exception as e:
            print(f"Error downloading image from {url}: {e}")
            return None
    
    def preprocess_image_for_cam(self, image):
        """Preprocess image for Grad-CAM"""
        # Resize image to 224x224 for consistency
        resized_image = image.resize((224, 224))
        
        # Convert PIL image to numpy array
        rgb_img = np.array(resized_image) / 255.0
        
        # Preprocess for model input
        input_tensor = self.preprocess(resized_image).unsqueeze(0)
        
        return rgb_img, input_tensor
    
    def get_top_prediction(self, input_tensor):
        """Get top prediction from the model"""
        with torch.no_grad():
            output = self.model(input_tensor)
            probabilities = F.softmax(output, dim=1)
            top_prob, top_class = torch.topk(probabilities, 1)
            
        return top_class.item(), top_prob.item()
    
    def generate_gradcam(self, image, method='gradcam'):
        """Generate CAM visualization using specified method"""
        rgb_img, input_tensor = self.preprocess_image_for_cam(image)
        
        # Get top prediction
        top_class, confidence = self.get_top_prediction(input_tensor)
        
        # Create target for the top predicted class
        targets = [ClassifierOutputTarget(top_class)]
        
        # Initialize the CAM method
        if method == 'gradcam':
            cam = GradCAM(model=self.model, target_layers=self.target_layers)
        elif method == 'ablationcam':
            cam = AblationCAM(model=self.model, target_layers=self.target_layers)
        elif method == 'scorecam':
            cam = ScoreCAM(model=self.model, target_layers=self.target_layers)
        else:
            raise ValueError(f"Unknown method: {method}")
        
        # Generate CAM
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
        grayscale_cam = grayscale_cam[0, :]  # Get the first (and only) image
        
        # Create visualization
        visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
        
        return visualization, grayscale_cam, top_class, confidence
    
    def analyze_single_image(self, image_name, save_dir='results'):
        """Analyze a single image with all three methods"""
        # Create save directory
        os.makedirs(save_dir, exist_ok=True)
        
        # Download image
        image_url = self.image_urls[image_name]
        image = self.download_image(image_url)
        
        if image is None:
            print(f"Failed to download {image_name}")
            return None
        
        print(f"Analyzing {image_name}...")
        
        # Create figure for comparison
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        fig.suptitle(f'CAM Analysis: {image_name}', fontsize=16)
        
        # Original image (resize for consistent display)
        display_image = image.resize((224, 224))
        axes[0, 0].imshow(display_image)
        axes[0, 0].set_title('Original Image')
        axes[0, 0].axis('off')
        
        methods = ['gradcam', 'ablationcam', 'scorecam']
        method_positions = [(0, 1), (1, 0), (1, 1)]
        
        results = {}
        
        for i, method in enumerate(methods):
            try:
                visualization, grayscale_cam, top_class, confidence = self.generate_gradcam(image, method)
                
                # Store results
                results[method] = {
                    'visualization': visualization,
                    'grayscale_cam': grayscale_cam,
                    'top_class': top_class,
                    'confidence': confidence,
                    'class_name': self.class_labels[top_class] if top_class < len(self.class_labels) else f"class_{top_class}"
                }
                
                # Plot
                row, col = method_positions[i]
                axes[row, col].imshow(visualization)
                axes[row, col].set_title(f'{method.upper()}\nPred: {results[method]["class_name"]}\nConf: {confidence:.3f}')
                axes[row, col].axis('off')
                
            except Exception as e:
                print(f"Error with {method} for {image_name}: {e}")
                row, col = method_positions[i]
                axes[row, col].text(0.5, 0.5, f'Error: {method}', 
                                  transform=axes[row, col].transAxes, 
                                  ha='center', va='center')
                axes[row, col].axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f'{image_name}_cam_analysis.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
        
        # Save individual CAM heatmaps
        for method in methods:
            if method in results:
                plt.figure(figsize=(8, 6))
                plt.imshow(results[method]['grayscale_cam'], cmap='jet')
                plt.colorbar()
                plt.title(f'{method.upper()} Heatmap: {image_name}')
                plt.axis('off')
                plt.savefig(os.path.join(save_dir, f'{image_name}_{method}_heatmap.png'), 
                           dpi=300, bbox_inches='tight')
                plt.close()
        
        return results
    
    def analyze_all_images(self, save_dir='results'):
        """Analyze all 10 images"""
        all_results = {}
        
        for image_name in self.image_urls.keys():
            results = self.analyze_single_image(image_name, save_dir)
            if results:
                all_results[image_name] = results
        
        # Generate summary report
        self.generate_summary_report(all_results, save_dir)
        
        return all_results
    
    def generate_summary_report(self, all_results, save_dir):
        """Generate a summary report of all analyses"""
        report_path = os.path.join(save_dir, 'gradcam_analysis_report.txt')
        
        with open(report_path, 'w') as f:
            f.write("Grad-CAM Analysis Report\n")
            f.write("=" * 50 + "\n\n")
            
            for image_name, results in all_results.items():
                f.write(f"Image: {image_name}\n")
                f.write("-" * 30 + "\n")
                
                for method in ['gradcam', 'ablationcam', 'scorecam']:
                    if method in results:
                        result = results[method]
                        f.write(f"{method.upper()}:\n")
                        f.write(f"  Predicted Class: {result['class_name']}\n")
                        f.write(f"  Confidence: {result['confidence']:.4f}\n")
                        f.write(f"  Class ID: {result['top_class']}\n")
                    else:
                        f.write(f"{method.upper()}: Error occurred\n")
                
                f.write("\n")
        
        print(f"Summary report saved to {report_path}")

In [None]:
# Usage example
if __name__ == "__main__":
    # Initialize analyzer
    analyzer = GradCAMAnalyzer()
    
    # Analyze all images
    print("Starting Grad-CAM analysis...")
    results = analyzer.analyze_all_images()
    
    # You can also analyze individual images
    # results = analyzer.analyze_single_image('goldfish')
    
    print("Analysis complete! Check the 'results' directory for outputs.")
    
    # Print summary
    print("\nSummary of Results:")
    for image_name, image_results in results.items():
        print(f"\n{image_name}:")
        for method in ['gradcam', 'ablationcam', 'scorecam']:
            if method in image_results:
                result = image_results[method]
                print(f"  {method.upper()}: {result['class_name']} (conf: {result['confidence']:.3f})")