# Explainable AI for Lung Cancer Classification
## Complete Pipeline Demo

This notebook demonstrates the complete end-to-end pipeline:
1. Load a CT scan image
2. Run classification using ResNet-50
3. Generate Grad-CAM heatmap
4. Generate RAG-based explanation

---

## 1. Setup and Imports

In [None]:
# Add project root to path
import sys
sys.path.insert(0, '..')

# Standard imports
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# PyTorch
import torch

# Project imports
from src.utils.config import Config
from src.utils.helpers import set_seed, get_device
from src.pipeline import ExplainablePipeline, create_demo_visualization

# Set random seed for reproducibility
set_seed(42)

# Get device
device = get_device()

# Load configuration
config = Config()

print("Setup complete!")

## 2. Initialize the Pipeline

The `ExplainablePipeline` class combines:
- ResNet-50 classifier
- Grad-CAM visualization
- RAG-based explanation generator

In [None]:
# Initialize the explainable pipeline
# If you have a trained model, provide the checkpoint path
checkpoint_path = os.path.join(config.checkpoint_dir, "best_model.pth")

pipeline = ExplainablePipeline(
    checkpoint_path=checkpoint_path if os.path.exists(checkpoint_path) else None,
    config=config,
    device=device
)

## 3. Find a Sample Image

Let's find a CT scan image from the dataset to analyze.

In [None]:
# Find sample images from the dataset
sample_images = {}

for class_name in config.class_names:
    class_dir = Path(config.dataset_dir) / class_name
    if class_dir.exists():
        images = list(class_dir.glob("*.png")) + \
                list(class_dir.glob("*.jpg")) + \
                list(class_dir.glob("*.jpeg"))
        if images:
            sample_images[class_name] = images[:3]  # Get up to 3 samples per class
            print(f"{class_name}: {len(images)} images found")

if not sample_images:
    print("\n⚠️ No images found in dataset folder.")
    print("Please add images to the dataset/ subfolders.")
else:
    print(f"\n✓ Found images in {len(sample_images)} classes")

## 4. Run Prediction on a Sample Image

Select an image and run the complete pipeline.

In [None]:
# Select a sample image (change this to analyze different images)
if sample_images:
    # Get first available image
    first_class = list(sample_images.keys())[0]
    image_path = str(sample_images[first_class][0])
    print(f"Selected image: {image_path}")
    print(f"True class: {first_class}")
else:
    # If no images, you can manually set a path
    image_path = None
    print("Set image_path manually to an image file")

In [None]:
# Run prediction
if image_path and os.path.exists(image_path):
    result = pipeline.predict(image_path)
    
    print("\n" + "=" * 50)
    print("PREDICTION RESULT")
    print("=" * 50)
    print(f"Predicted Class: {result.predicted_class.replace('_', ' ').title()}")
    print(f"Confidence: {result.confidence * 100:.2f}%")
    print("\nAll Probabilities:")
    for cls, prob in sorted(result.all_probabilities.items(), key=lambda x: -x[1]):
        print(f"  {cls:25s}: {prob*100:5.1f}%")
else:
    print("No image to process. Please add images to the dataset.")

## 5. Visualize Grad-CAM Heatmap

The Grad-CAM heatmap shows which regions the model focused on to make its prediction.

In [None]:
if 'result' in dir():
    # Create visualization
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    axes[0].imshow(result.original_image)
    axes[0].set_title('Original CT Image', fontsize=12)
    axes[0].axis('off')
    
    # Heatmap
    im = axes[1].imshow(result.heatmap, cmap='jet', vmin=0, vmax=1)
    axes[1].set_title('Grad-CAM Heatmap', fontsize=12)
    axes[1].axis('off')
    plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
    
    # Overlay
    axes[2].imshow(result.overlay)
    axes[2].set_title('Overlay', fontsize=12)
    axes[2].axis('off')
    
    plt.suptitle(f'Prediction: {result.predicted_class.replace("_", " ").title()} ({result.confidence*100:.1f}%)', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
else:
    print("Run prediction first.")

## 6. RAG-Based Explanation

The explanation combines:
- **Visual Evidence**: What the model focused on (from Grad-CAM)
- **Medical Context**: Relevant medical knowledge (from knowledge base)

In [None]:
if 'result' in dir():
    # Print the full explanation
    result.print_explanation()
else:
    print("Run prediction first.")

## 7. Comprehensive Demo Visualization

Create a single figure with all information combined.

In [None]:
if 'result' in dir():
    # Create comprehensive visualization
    fig = create_demo_visualization(
        result,
        save_path=os.path.join(config.results_dir, "notebook_demo.png")
    )
    plt.show()
else:
    print("Run prediction first.")

## 8. Batch Processing (Multiple Images)

Process multiple images at once.

In [None]:
# Collect all sample images
all_images = []
for class_name, images in sample_images.items():
    all_images.extend([str(img) for img in images[:2]])  # 2 per class

print(f"Processing {len(all_images)} images...")

if all_images:
    results = pipeline.predict_batch(all_images[:8])  # Limit to 8 for demo
    
    # Create summary
    print("\n" + "=" * 60)
    print("BATCH RESULTS")
    print("=" * 60)
    for r in results:
        print(f"{Path(r.image_path).name:30s} → {r.predicted_class:25s} ({r.confidence*100:.1f}%)")
else:
    print("No images to process.")

## 9. Explore the Knowledge Base

See what medical knowledge is available for each cancer type.

In [None]:
from src.rag.knowledge_base import MedicalKnowledgeBase

kb = MedicalKnowledgeBase()

# Show knowledge for each class
for class_name in config.class_names:
    print(f"\n{'='*60}")
    print(f"KNOWLEDGE: {class_name.upper()}")
    print(f"{'='*60}")
    
    entries = kb.get_class_knowledge(class_name)
    for entry in entries[:2]:  # Show first 2 entries
        print(f"\n• {entry['content'][:200]}...")
        print(f"  Source: {entry['source']}")

---

## Summary

This notebook demonstrated the complete **Explainable AI for Lung Cancer Classification** pipeline:

| Component | Purpose |
|-----------|--------|
| ResNet-50 | Classification of CT images into 4 classes |
| Grad-CAM | Visual explanation of model attention |
| XAI→Text | Converts visual attention to textual description |
| Knowledge Base | Stores curated medical facts |
| RAG Pipeline | Retrieves relevant knowledge for explanation |

### Key Points for Viva:
1. **Transfer Learning**: Uses ImageNet weights for better generalization
2. **Explainability**: Grad-CAM shows WHERE, RAG explains WHY
3. **Novel Contribution**: Bridging visual XAI to textual explanations
4. **Citable**: All medical knowledge has sources