In [None]:
# Cell 1: Import libraries and setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import json
from pathlib import Path

# Set device for M2 Mac
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

# Suppress warnings if needed
import warnings
warnings.filterwarnings('ignore')

# Cell 2: Grad-CAM Class
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_full_backward_hook(self.save_gradient)
    
    def save_activation(self, module, input, output):
        self.activations = output.detach()
    
    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def generate_cam(self, input_tensor, target_class=None):
        self.model.eval()
        
        # Forward pass
        output = self.model(input_tensor)
        
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Zero gradients
        self.model.zero_grad()
        
        # Backward pass for target class
        one_hot = torch.zeros_like(output)
        one_hot[0, target_class] = 1
        output.backward(gradient=one_hot)
        
        # Get gradients and activations
        gradients = self.gradients.cpu().numpy()[0]
        activations = self.activations.cpu().numpy()[0]
        
        # Global average pooling of gradients
        weights = np.mean(gradients, axis=(1, 2))
        
        # Weighted combination of activation maps
        cam = np.zeros(activations.shape[1:], dtype=np.float32)
        for i, w in enumerate(weights):
            cam += w * activations[i]
        
        # Apply ReLU
        cam = np.maximum(cam, 0)
        
        # Normalize
        if cam.max() > 0:
            cam = cam / cam.max()
        
        return cam, target_class, output
    
    def overlay_heatmap(self, original_image, cam, alpha=0.5):
        # Resize CAM to match original image size
        cam_resized = cv2.resize(cam, (original_image.width, original_image.height))
        
        # Convert to heatmap
        heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        heatmap_img = Image.fromarray(heatmap)
        
        # Blend with original image
        blended = Image.blend(original_image, heatmap_img, alpha)
        
        return blended, heatmap_img

# Cell 3: Load the trained model
def load_brain_tumor_model(model_path, class_names):
    """Load the trained brain tumor model"""
    try:
        # Create the same model architecture as during training
        model = models.resnet50(pretrained=False)
        
        # Use the EXACT same classifier structure as training
        in_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.2),
            nn.Linear(512, len(class_names))
        )
        
        # Load weights for MPS device
        checkpoint = torch.load(model_path, map_location=device)
        
        # Handle different checkpoint formats
        if isinstance(checkpoint, dict):
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint)
        else:
            # Direct state dict
            model.load_state_dict(checkpoint)
            
        model = model.to(device)
        model.eval()
        print("‚úÖ Model loaded successfully!")
        return model
        
    except Exception as e:
        print(f"‚ùå Error loading model: {e}")
        return None

# Define class names (adjust based on your dataset)
class_names = ['Glioma', 'Meningioma', 'Pituitary', 'No Tumor']

# Load the model
model_path = "/Users/kshitijverma/Downloads/best_model.h5"
model = load_brain_tumor_model(model_path, class_names)

if model is None:
    print("‚ùå Failed to load model!")
else:
    print("üéØ Model ready for Grad-CAM analysis!")

# Cell 4: Setup transforms and Grad-CAM
# Get the same transform used during training
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Setup Grad-CAM - use the last convolutional layer of ResNet50
target_layer = model.layer4[-1].conv3
grad_cam = GradCAM(model, target_layer)
print("‚úÖ Grad-CAM initialized!")

# Cell 5: Prediction function with Grad-CAM
def predict_with_gradcam(image_path, alpha=0.5):
    """Predict the class and generate Grad-CAM heatmap"""
    try:
        # Load and preprocess image
        original_image = Image.open(image_path).convert('RGB')
        input_tensor = transform(original_image)
        input_batch = input_tensor.unsqueeze(0).to(device)
        
        # Generate Grad-CAM
        cam, predicted_idx, output = grad_cam.generate_cam(input_batch)
        
        # Get prediction results
        probabilities = F.softmax(output[0], dim=0)
        confidence, _ = torch.max(probabilities, 0)
        
        predicted_class = class_names[predicted_idx]
        confidence_percent = confidence.item() * 100
        
        # Generate heatmap overlay
        heatmap_overlay, heatmap_only = grad_cam.overlay_heatmap(
            original_image, cam, alpha=alpha
        )
        
        return {
            'predicted_class': predicted_class,
            'confidence': confidence_percent,
            'probabilities': probabilities.cpu().detach().numpy(),
            'original_image': original_image,
            'heatmap_overlay': heatmap_overlay,
            'heatmap_only': heatmap_only,
            'cam': cam
        }
        
    except Exception as e:
        print(f"‚ùå Prediction error: {e}")
        return None

# Cell 6: Visualization function
def visualize_results(result, figsize=(15, 5)):
    """Visualize original image, heatmap, and overlay"""
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    
    # Original image
    axes[0].imshow(result['original_image'])
    axes[0].set_title(f"Original MRI Scan\nPredicted: {result['predicted_class']}\nConfidence: {result['confidence']:.2f}%")
    axes[0].axis('off')
    
    # Heatmap only
    axes[1].imshow(result['heatmap_only'])
    axes[1].set_title("Grad-CAM Heatmap\n(Red = High Attention)")
    axes[1].axis('off')
    
    # Overlay
    axes[2].imshow(result['heatmap_overlay'])
    axes[2].set_title("Heatmap Overlay\nTumor Region Highlighted")
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print confidence scores for all classes
    print("\nüìä Confidence Scores:")
    for i, (class_name, prob) in enumerate(zip(class_names, result['probabilities'])):
        confidence = prob * 100
        marker = "üéØ" if class_name == result['predicted_class'] else "  "
        print(f"  {marker} {class_name}: {confidence:.2f}%")
    
    # Medical advice
    if "No Tumor" not in result['predicted_class']:
        print(f"\n‚ö†Ô∏è  Medical Recommendation:")
        advice = {
            "Glioma": "Gliomas are tumors that occur in the brain and spinal cord. The red highlighted areas show potential tumor regions. Consult a neurologist immediately.",
            "Meningioma": "Meningiomas are tumors that arise from the meninges. The heatmap highlights suspicious regions. Consult a neurosurgeon.",
            "Pituitary": "Pituitary tumors affect the pituitary gland. The highlighted areas indicate regions of interest. Endocrine evaluation recommended."
        }
        print(f"  {advice.get(result['predicted_class'], 'Consult a medical professional for proper diagnosis.')}")
    else:
        print(f"\n‚úÖ No tumor detected - Brain appears healthy!")

# Cell 7: Interactive analysis function
def analyze_brain_mri(image_path, alpha=0.5):
    """Complete analysis of brain MRI with Grad-CAM"""
    print(f"üîç Analyzing: {Path(image_path).name}")
    print("=" * 50)
    
    result = predict_with_gradcam(image_path, alpha)
    
    if result is None:
        print("‚ùå Failed to analyze image!")
        return
    
    # Display results
    visualize_results(result)
    
    return result

# Alternative function with adjustable heatmap
def analyze_with_interactive_heatmap(image_path, alpha_values=[0.3, 0.5, 0.7]):
    """Show multiple heatmap intensities"""
    result = predict_with_gradcam(image_path)
    
    if result is None:
        print("‚ùå Failed to analyze image!")
        return
    
    fig, axes = plt.subplots(1, len(alpha_values) + 1, figsize=(15, 4))
    
    # Original image
    axes[0].imshow(result['original_image'])
    axes[0].set_title(f"Original\n{result['predicted_class']}\n{result['confidence']:.1f}%")
    axes[0].axis('off')
    
    # Different alpha values
    for i, alpha in enumerate(alpha_values):
        heatmap_overlay, _ = grad_cam.overlay_heatmap(result['original_image'], result['cam'], alpha=alpha)
        axes[i+1].imshow(heatmap_overlay)
        axes[i+1].set_title(f"Heatmap (Œ±={alpha})")
        axes[i+1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return result

# Cell 8: Example usage
# Example 1: Analyze a single image
image_path = "/path/to/your/brain_mri.jpg"  # Change this to your image path
result = analyze_brain_mri(image_path)

# Example 2: Analyze with different heatmap intensities
# result = analyze_with_interactive_heatmap(image_path)

# Example 3: Batch analyze multiple images
def analyze_multiple_images(image_folder):
    """Analyze all images in a folder"""
    folder_path = Path(image_folder)
    image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
    
    image_paths = []
    for ext in image_extensions:
        image_paths.extend(folder_path.glob(ext))
        image_paths.extend(folder_path.glob(ext.upper()))
    
    print(f"üìÅ Found {len(image_paths)} images")
    
    results = []
    for img_path in image_paths:
        print(f"\n{'='*40}")
        result = analyze_brain_mri(img_path)
        if result:
            results.append(result)
    
    return results

# Uncomment to analyze multiple images
# results = analyze_multiple_images("/path/to/your/image/folder")

# Cell 9: Utility functions
def save_results(result, output_dir="gradcam_results"):
    """Save the visualization results"""
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True)
    
    # Save original
    result['original_image'].save(output_path / "original.png")
    
    # Save heatmap overlay
    result['heatmap_overlay'].save(output_path / "heatmap_overlay.png")
    
    # Save heatmap only
    result['heatmap_only'].save(output_path / "heatmap_only.png")
    
    # Save prediction info
    info = {
        'predicted_class': result['predicted_class'],
        'confidence': result['confidence'],
        'probabilities': {class_names[i]: float(result['probabilities'][i]) for i in range(len(class_names))}
    }
    
    with open(output_path / "prediction_info.json", 'w') as f:
        json.dump(info, f, indent=2)
    
    print(f"üíæ Results saved to {output_path}/")

def compare_predictions(image_paths):
    """Compare predictions for multiple images"""
    fig, axes = plt.subplots(len(image_paths), 3, figsize=(15, 5*len(image_paths)))
    
    if len(image_paths) == 1:
        axes = [axes]
    
    for i, img_path in enumerate(image_paths):
        result = predict_with_gradcam(img_path)
        if result:
            # Original
            axes[i][0].imshow(result['original_image'])
            axes[i][0].set_title(f"Original\n{result['predicted_class']}\n{result['confidence']:.1f}%")
            axes[i][0].axis('off')
            
            # Heatmap
            axes[i][1].imshow(result['heatmap_only'])
            axes[i][1].set_title("Heatmap")
            axes[i][1].axis('off')
            
            # Overlay
            axes[i][2].imshow(result['heatmap_overlay'])
            axes[i][2].set_title("Overlay")
            axes[i][2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Cell 10: Quick test with sample workflow
print("üß† Brain Tumor Grad-CAM Analysis Ready!")
print("üìã Available functions:")
print("  1. analyze_brain_mri(image_path) - Single image analysis")
print("  2. analyze_with_interactive_heatmap(image_path) - Multiple heatmap intensities")
print("  3. analyze_multiple_images(folder_path) - Batch analysis")
print("  4. save_results(result) - Save visualization results")
print("\nüöÄ To get started, run:")
print("   result = analyze_brain_mri('/path/to/your/mri/image.jpg')")

# Test with a sample (replace with your actual image path)
# sample_image = "/Users/kshitijverma/Downloads/sample_mri.jpg"
# if Path(sample_image).exists():
#     analyze_brain_mri(sample_image)
# else:
#     print("‚ÑπÔ∏è  Replace 'sample_image' path with your actual MRI image path")