Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 294 additions & 0 deletions notebooks/06_explainability.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ClimateVision SHAP Explainability\n",
"\n",
"This notebook demonstrates how to use SHAP (SHapley Additive exPlanations) to understand\n",
"why the ClimateVision segmentation model makes specific predictions.\n",
"\n",
"**Author:** Linda Oraegbunam (@obielin) \n",
"**Module:** `src/climatevision/governance/explainability.py`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(0, '..')\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"\n",
"# ClimateVision imports\n",
"from climatevision.governance import explain_prediction, SHAPExplainer, get_band_contributions\n",
"from climatevision.inference.pipeline import _load_model\n",
"from climatevision.models import UNet"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Understanding SHAP for Segmentation\n",
"\n",
"SHAP values tell us how much each input feature (spectral band) contributed to the model's prediction.\n",
"For satellite imagery:\n",
"- **Positive SHAP**: Feature pushed prediction toward the target class\n",
"- **Negative SHAP**: Feature pushed prediction away from the target class\n",
"- **Magnitude**: Strength of the contribution"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the deforestation model\n",
"model, device = _load_model('deforestation')\n",
"print(f\"Model: {model.__class__.__name__}\")\n",
"print(f\"Input channels: {model.n_channels}\")\n",
"print(f\"Output classes: {model.n_classes}\")\n",
"print(f\"Device: {device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Create SHAP Explainer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize the explainer with background data\n",
"background = torch.zeros(1, model.n_channels, 64, 64).to(device)\n",
"explainer = SHAPExplainer(model, background_data=background, device=device)\n",
"print(\"SHAP Explainer initialized\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Generate Explanation for Sample Image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a synthetic forest-like image for demonstration\n",
"np.random.seed(42)\n",
"\n",
"# Simulate Sentinel-2 bands: Red, Green, Blue, NIR\n",
"# Forest typically has high NIR and low Red\n",
"h, w = 256, 256\n",
"red = np.random.normal(0.2, 0.1, (h, w)).clip(0, 1) # Low red reflectance\n",
"green = np.random.normal(0.3, 0.1, (h, w)).clip(0, 1)\n",
"blue = np.random.normal(0.25, 0.1, (h, w)).clip(0, 1)\n",
"nir = np.random.normal(0.7, 0.15, (h, w)).clip(0, 1) # High NIR for vegetation\n",
"\n",
"sample_image = np.stack([red, green, blue, nir], axis=0).astype(np.float32)\n",
"sample_tensor = torch.FloatTensor(sample_image).unsqueeze(0).to(device)\n",
"\n",
"print(f\"Sample image shape: {sample_image.shape}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate SHAP explanation\n",
"explanation = explainer.explain(sample_tensor, target_class=1) # Class 1 = Forest\n",
"\n",
"print(\"\\n=== Explanation Results ===\")\n",
"print(f\"Predicted class: {explanation['prediction']}\")\n",
"print(f\"Target class: {explanation['target_class']}\")\n",
"print(f\"Confidence: {explanation['confidence']:.4f}\")\n",
"print(f\"Explainer type: {explanation['explainer_type']}\")\n",
"print(f\"\\nBand contributions:\")\n",
"for band, importance in explanation['band_contributions'].items():\n",
" print(f\" {band}: {importance:.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Visualize Band Contributions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Plot band importance\n",
"band_names = ['Red (B04)', 'Green (B03)', 'Blue (B02)', 'NIR (B08)']\n",
"contributions = explanation['band_contributions']\n",
"importances = [contributions[f'band_{i}'] for i in range(len(band_names))]\n",
"\n",
"fig, ax = plt.subplots(figsize=(10, 6))\n",
"colors = ['#e74c3c', '#27ae60', '#3498db', '#9b59b6']\n",
"bars = ax.bar(band_names, importances, color=colors)\n",
"ax.set_ylabel('Relative Importance')\n",
"ax.set_title('Band Contributions to Forest Classification')\n",
"ax.set_ylim(0, max(importances) * 1.2)\n",
"\n",
"# Add value labels\n",
"for bar, imp in zip(bars, importances):\n",
" ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,\n",
" f'{imp:.3f}', ha='center', va='bottom', fontsize=10)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Spatial Importance Heatmap"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize spatial importance\n",
"spatial_importance = explanation['spatial_importance']\n",
"\n",
"fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
"\n",
"# Original RGB composite\n",
"rgb = np.stack([sample_image[0], sample_image[1], sample_image[2]], axis=-1)\n",
"rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)\n",
"axes[0].imshow(rgb)\n",
"axes[0].set_title('RGB Composite')\n",
"axes[0].axis('off')\n",
"\n",
"# SHAP importance heatmap\n",
"im = axes[1].imshow(spatial_importance, cmap='hot')\n",
"axes[1].set_title('SHAP Importance Heatmap')\n",
"axes[1].axis('off')\n",
"plt.colorbar(im, ax=axes[1], fraction=0.046)\n",
"\n",
"# Overlay\n",
"axes[2].imshow(rgb)\n",
"axes[2].imshow(spatial_importance, cmap='hot', alpha=0.5)\n",
"axes[2].set_title('RGB + SHAP Overlay')\n",
"axes[2].axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Compare Explanations Across Analysis Types"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Compare band importance across different analysis types\n",
"analysis_types = ['deforestation', 'ice_melting', 'flooding']\n",
"all_contributions = {}\n",
"\n",
"for atype in analysis_types:\n",
" try:\n",
" model, device = _load_model(atype)\n",
" explainer = SHAPExplainer(model, device=device)\n",
" \n",
" # Create appropriate test tensor\n",
" test_tensor = torch.randn(1, model.n_channels, 128, 128).to(device)\n",
" result = explainer.explain(test_tensor)\n",
" all_contributions[atype] = result['band_contributions']\n",
" print(f\"{atype}: {len(result['band_contributions'])} bands analyzed\")\n",
" except Exception as e:\n",
" print(f\"{atype}: Failed - {e}\")\n",
"\n",
"print(\"\\nComparison complete!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Using the High-Level API"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# For real usage with saved images:\n",
"# result = explain_prediction(\n",
"# model_path='models/unet_deforestation.pth',\n",
"# image_path='data/test/amazon_tile.tif',\n",
"# analysis_type='deforestation',\n",
"# save_heatmap=True\n",
"# )\n",
"# print(f\"Top bands: {result['top_bands']}\")\n",
"# print(f\"Heatmap saved to: {result['heatmap_path']}\")\n",
"\n",
"print(\"See explain_prediction() for file-based explanations\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"This notebook demonstrated:\n",
"1. **SHAPExplainer** - Core class for generating explanations\n",
"2. **Band contributions** - Which spectral bands drive predictions\n",
"3. **Spatial importance** - Which image regions matter most\n",
"4. **Visualization** - Heatmaps and bar charts for stakeholder communication\n",
"\n",
"For production use, call the `/api/explain` endpoint or use `explain_prediction()` directly."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ python-multipart>=0.0.5
mlflow>=2.1.0
optuna>=3.1.0

# Explainability & Governance
shap>=0.42.0

# Testing and Development
pytest>=7.0.0
pytest-cov>=3.0.0
Expand Down
Loading