Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
374 changes: 374 additions & 0 deletions notebooks/07_bias_audit.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ClimateVision Regional Bias Audit\n",
"\n",
"This notebook demonstrates how to evaluate model fairness across geographic regions.\n",
"Ensuring equitable predictions is critical for NGOs operating in different parts of the world.\n",
"\n",
"**Author:** Linda Oraegbunam (@obielin) \n",
"**Module:** `src/climatevision/governance/bias_audit.py`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(0, '..')\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"\n",
"from climatevision.governance import (\n",
" run_bias_audit,\n",
" BiasAuditor,\n",
" BiasReport,\n",
" check_fairness_gate,\n",
" SUPPORTED_REGIONS,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Understanding Regional Bias\n",
"\n",
"Climate models trained primarily on Amazon data may underperform on Congo Basin imagery due to:\n",
"- Different forest types and canopy structures\n",
"- Varying cloud patterns and seasonal effects\n",
"- Different satellite viewing angles and atmospheric conditions\n",
"\n",
"This audit ensures NGOs in all regions receive equally reliable predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# View supported regions\n",
"print(\"Supported Regions for Bias Audit:\")\n",
"print(\"=\" * 50)\n",
"for key, info in SUPPORTED_REGIONS.items():\n",
" print(f\"\\n{info['name']} ({key})\")\n",
" print(f\" Bounding Box: {info['bbox']}\")\n",
" print(f\" Description: {info['description']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Creating a Bias Auditor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create auditor with 85% fairness threshold\n",
"auditor = BiasAuditor(model=None, threshold=0.85)\n",
"\n",
"# Simulate regional prediction data\n",
"# In production, this would be real model outputs on test sets\n",
"np.random.seed(42)\n",
"\n",
"regions_data = {\n",
" 'amazon': {'accuracy': 0.92, 'forest_ratio': 0.70},\n",
" 'congo': {'accuracy': 0.85, 'forest_ratio': 0.65},\n",
" 'southeast_asia': {'accuracy': 0.88, 'forest_ratio': 0.55},\n",
"}\n",
"\n",
"for region, params in regions_data.items():\n",
" n_samples = 1000\n",
" \n",
" # Ground truth based on regional forest coverage\n",
" ground_truth = (np.random.random(n_samples) < params['forest_ratio']).astype(int)\n",
" \n",
" # Predictions based on regional accuracy\n",
" correct = np.random.random(n_samples) < params['accuracy']\n",
" predictions = np.where(correct, ground_truth, 1 - ground_truth)\n",
" \n",
" auditor.add_region_data(region, predictions, ground_truth)\n",
" print(f\"Added {n_samples} samples for {region}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Computing Fairness Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run full bias audit\n",
"report = auditor.run_audit(\n",
" metric='equalized_odds',\n",
" model_path='models/demo_model.pth',\n",
" model_version='v1.0-demo',\n",
" analysis_type='deforestation',\n",
")\n",
"\n",
"print(f\"Fairness Score: {report.fairness_score:.4f}\")\n",
"print(f\"Threshold: {report.threshold}\")\n",
"print(f\"Passed: {'✅' if report.passed else '❌'}\")\n",
"print(f\"\\nDisparity Regions: {report.disparity_regions or 'None'}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# View per-region metrics\n",
"print(\"Per-Region Metrics:\")\n",
"print(\"=\" * 60)\n",
"\n",
"for metrics in report.region_metrics:\n",
" print(f\"\\n{metrics.region_name} ({metrics.region}):\")\n",
" print(f\" Samples: {metrics.n_samples}\")\n",
" print(f\" IoU: {metrics.iou:.4f}\")\n",
" print(f\" F1: {metrics.f1:.4f}\")\n",
" print(f\" Precision: {metrics.precision:.4f}\")\n",
" print(f\" Recall: {metrics.recall:.4f}\")\n",
" print(f\" TPR: {metrics.true_positive_rate:.4f}\")\n",
" print(f\" FPR: {metrics.false_positive_rate:.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Visualizing Regional Disparities"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Prepare data for visualization\n",
"regions = [m.region_name for m in report.region_metrics]\n",
"ious = [m.iou for m in report.region_metrics]\n",
"f1s = [m.f1 for m in report.region_metrics]\n",
"tprs = [m.true_positive_rate for m in report.region_metrics]\n",
"\n",
"x = np.arange(len(regions))\n",
"width = 0.25\n",
"\n",
"fig, ax = plt.subplots(figsize=(12, 6))\n",
"\n",
"bars1 = ax.bar(x - width, ious, width, label='IoU', color='#3498db')\n",
"bars2 = ax.bar(x, f1s, width, label='F1 Score', color='#2ecc71')\n",
"bars3 = ax.bar(x + width, tprs, width, label='True Positive Rate', color='#e74c3c')\n",
"\n",
"ax.set_ylabel('Score')\n",
"ax.set_title('Model Performance by Region')\n",
"ax.set_xticks(x)\n",
"ax.set_xticklabels(regions)\n",
"ax.legend()\n",
"ax.set_ylim(0, 1.1)\n",
"ax.axhline(y=0.85, color='gray', linestyle='--', label='Threshold')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Radar chart for multi-metric comparison\n",
"from math import pi\n",
"\n",
"categories = ['IoU', 'F1', 'Precision', 'Recall', 'TPR']\n",
"N = len(categories)\n",
"\n",
"angles = [n / float(N) * 2 * pi for n in range(N)]\n",
"angles += angles[:1]\n",
"\n",
"fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))\n",
"\n",
"colors = ['#3498db', '#2ecc71', '#e74c3c']\n",
"for i, metrics in enumerate(report.region_metrics):\n",
" values = [metrics.iou, metrics.f1, metrics.precision, metrics.recall, metrics.true_positive_rate]\n",
" values += values[:1]\n",
" ax.plot(angles, values, 'o-', linewidth=2, label=metrics.region_name, color=colors[i % len(colors)])\n",
" ax.fill(angles, values, alpha=0.25, color=colors[i % len(colors)])\n",
"\n",
"ax.set_xticks(angles[:-1])\n",
"ax.set_xticklabels(categories)\n",
"ax.set_ylim(0, 1)\n",
"ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))\n",
"ax.set_title('Regional Performance Comparison', y=1.08)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Comparing Fairness Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Compare different fairness metrics\n",
"metrics_to_test = ['demographic_parity', 'equalized_odds', 'predictive_parity']\n",
"results = {}\n",
"\n",
"for metric in metrics_to_test:\n",
" report = auditor.run_audit(metric=metric)\n",
" results[metric] = {\n",
" 'score': report.fairness_score,\n",
" 'passed': report.passed,\n",
" 'disparity_regions': report.disparity_regions,\n",
" }\n",
"\n",
"print(\"Fairness Metrics Comparison:\")\n",
"print(\"=\" * 50)\n",
"for metric, result in results.items():\n",
" status = '✅' if result['passed'] else '❌'\n",
" print(f\"\\n{metric}:\")\n",
" print(f\" Score: {result['score']:.4f} {status}\")\n",
" if result['disparity_regions']:\n",
" print(f\" Disparity in: {', '.join(result['disparity_regions'])}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Using the High-Level API"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# For real usage with trained models:\n",
"# result = run_bias_audit(\n",
"# model_path='models/unet_deforestation.pth',\n",
"# regions=['amazon', 'congo', 'southeast_asia'],\n",
"# metric='equalized_odds',\n",
"# threshold=0.85,\n",
"# )\n",
"# \n",
"# print(f\"Score: {result['score']}\")\n",
"# print(f\"Passed: {result['passed']}\")\n",
"# print(f\"Report: {result['report_path']}\")\n",
"\n",
"print(\"See run_bias_audit() for production usage\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. CI/CD Integration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# CI gate function for automated checks\n",
"# This would be called in GitHub Actions or similar\n",
"\n",
"# passed = check_fairness_gate(\n",
"# model_path='models/best_model.pth',\n",
"# regions=['amazon', 'congo', 'southeast_asia'],\n",
"# threshold=0.85,\n",
"# )\n",
"# \n",
"# if not passed:\n",
"# sys.exit(1) # Fail the CI build\n",
"\n",
"print(\"Use check_fairness_gate() in CI/CD pipelines\")\n",
"print(\"Command: python scripts/audit_model.py --model models/best.pth --ci-gate\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Recommendations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get recommendations from the audit\n",
"print(\"Recommendations:\")\n",
"print(\"=\" * 50)\n",
"for rec in report.recommendations:\n",
" print(f\"\\n• {rec}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\n",
"\n",
"This notebook demonstrated:\n",
"\n",
"1. **BiasAuditor** - Core class for fairness evaluation\n",
"2. **Fairness Metrics** - Demographic parity, equalized odds, predictive parity\n",
"3. **Regional Analysis** - Per-region IoU, F1, precision, recall\n",
"4. **Visualization** - Bar charts and radar plots for stakeholder reports\n",
"5. **CI/CD Integration** - `check_fairness_gate()` for automated checks\n",
"\n",
"For production use:\n",
"- Run `python scripts/audit_model.py --model <path> --regions amazon,congo`\n",
"- Add `--ci-gate` flag to fail builds with poor fairness scores"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading