# Image Dehazing Pipeline Demo

This notebook demonstrates the complete image dehazing pipeline for Indian winter conditions.

## Features:
- Multiple dehazing models (AOD-Net, DehazeNet, MSBDN)
- Haze generation for testing
- Comprehensive evaluation metrics
- Hallucination detection
- Interactive web dashboard

## 1. Setup and Imports

In [None]:
# Install required packages if needed
# !pip install -r requirements.txt

import sys
import os
from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Add src to path
sys.path.append(str(Path().absolute().parent / "src"))

from src.inference import DehazeInferencePipeline
from src.haze_generator import HazeGenerator
from evaluation.metrics import DehazeMetrics
from evaluation.evaluator import DehazeEvaluator

print("✓ All imports successful")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 2. Initialize Models

In [None]:
# Initialize inference pipeline
pipeline = DehazeInferencePipeline()
print("Loading models...")
pipeline.initialize_models()

# Get model information
model_info = pipeline.get_model_info()
print(f"Available models: {model_info['available_models']}")
print(f"Device: {model_info['device']}")

## 3. Haze Generation (Bonus Feature)

In [None]:
# Initialize haze generator
haze_generator = HazeGenerator()

# Load a clear image (replace with your image path)
clear_image_path = "../data/sample_clear.jpg"  # Replace with actual path

# For demo, create a sample image if not available
if not os.path.exists(clear_image_path):
    print("Creating sample clear image...")
    sample_image = Image.new('RGB', (512, 512), color='skyblue')
    # Add some simple content
    import numpy as np
    img_array = np.array(sample_image)
    # Add some shapes
    img_array[100:200, 100:200] = [255, 0, 0]  # Red square
    img_array[300:400, 300:400] = [0, 255, 0]  # Green square
    sample_image = Image.fromarray(img_array)
    sample_image.save(clear_image_path)
    print(f"Sample image saved to {clear_image_path}")

# Generate different types of haze
haze_types = ['light', 'moderate', 'heavy', 'extreme']
hazy_images = {}

for haze_type in haze_types:
    # Load clear image
    clear_img = Image.open(clear_image_path).convert('RGB')
    clear_array = np.array(clear_img)
    
    # Generate haze
    hazy_array = haze_generator.generate_composite_haze(clear_array, haze_type)
    hazy_images[haze_type] = Image.fromarray(hazy_array)
    
    # Save hazy image
    hazy_path = f"../data/hazy_{haze_type}.jpg"
    hazy_images[haze_type].save(hazy_path)
    print(f"Generated {haze_type} haze: {hazy_path}")

# Display results
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
axes[0].imshow(clear_img)
axes[0].set_title('Clear Image')
axes[0].axis('off')

for i, haze_type in enumerate(haze_types):
    axes[i+1].imshow(hazy_images[haze_type])
    axes[i+1].set_title(f'{haze_type.capitalize()} Haze')
    axes[i+1].axis('off')

plt.tight_layout()
plt.show()

## 4. Image Dehazing

In [None]:
# Select a hazy image for dehazing
hazy_image_path = "../data/hazy_moderate.jpg"  # Use generated haze

# Test all models
models = model_info['available_models']
dehazed_results = {}

print("Dehazing with different models...")
for model_name in models:
    print(f"Processing with {model_name}...")
    
    output_path = f"../data/dehazed_{model_name}.jpg"
    result = pipeline.dehaze_single_image(hazy_image_path, model_name, output_path)
    
    if result['success']:
        dehazed_results[model_name] = {
            'image': Image.open(output_path),
            'result': result
        }
        print(f"  ✓ Success - Time: {result['processing_time']:.2f}s")
        if result['metrics'].get('psnr'):
            print(f"    PSNR: {result['metrics']['psnr']:.2f} dB")
        if result['metrics'].get('ssim'):
            print(f"    SSIM: {result['metrics']['ssim']:.4f}")
    else:
        print(f"  ✗ Failed: {result['error']}")

# Display comparison
if dehazed_results:
    fig, axes = plt.subplots(1, len(models) + 2, figsize=(15, 3))
    
    # Original clear
    axes[0].imshow(clear_img)
    axes[0].set_title('Original Clear')
    axes[0].axis('off')
    
    # Hazy
    hazy_img = Image.open(hazy_image_path)
    axes[1].imshow(hazy_img)
    axes[1].set_title('Hazy Input')
    axes[1].axis('off')
    
    # Dehazed results
    for i, model_name in enumerate(models):
        if model_name in dehazed_results:
            axes[i+2].imshow(dehazed_results[model_name]['image'])
            axes[i+2].set_title(f'{model_name.upper()}')
            axes[i+2].axis('off')
    
    plt.tight_layout()
    plt.show()

## 5. Metrics Evaluation

In [None]:
# Initialize metrics calculator
metrics_calc = DehazeMetrics()

# Calculate comprehensive metrics for each model
from torchvision import transforms

clear_tensor = transforms.ToTensor()(clear_img)
hazy_tensor = transforms.ToTensor()(hazy_img)

print("Metrics Comparison:")
print("-" * 60)
print(f"{'Model':<12} {'PSNR':<8} {'SSIM':<8} {'MAE':<8} {'Time(s)':<8}")
print("-" * 60)

# Hazy baseline metrics
hazy_metrics = metrics_calc.calculate_all_metrics(hazy_tensor, clear_tensor)
print(f"{'Hazy':<12} {hazy_metrics['psnr']:<8.2f} {hazy_metrics['ssim']:<8.4f} {hazy_metrics['mae']:<8.4f} {'N/A':<8}")

# Model results
for model_name, data in dehazed_results.items():
    dehazed_tensor = transforms.ToTensor()(data['image'])
    metrics = metrics_calc.calculate_all_metrics(dehazed_tensor, clear_tensor)
    time_val = data['result']['processing_time']
    
    print(f"{model_name:<12} {metrics['psnr']:<8.2f} {metrics['ssim']:<8.4f} {metrics['mae']:<8.4f} {time_val:<8.3f}")

print()

# Visual comparison of metrics
if len(dehazed_results) > 1:
    models_list = list(dehazed_results.keys())
    psnr_values = []
    ssim_values = []
    
    for model_name in models_list:
        dehazed_tensor = transforms.ToTensor()(dehazed_results[model_name]['image'])
        metrics = metrics_calc.calculate_all_metrics(dehazed_tensor, clear_tensor)
        psnr_values.append(metrics['psnr'])
        ssim_values.append(metrics['ssim'])
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.bar(models_list, psnr_values)
    ax1.set_title('PSNR Comparison')
    ax1.set_ylabel('PSNR (dB)')
    
    ax2.bar(models_list, ssim_values)
    ax2.set_title('SSIM Comparison')
    ax2.set_ylabel('SSIM')
    
    plt.tight_layout()
    plt.show()

## 6. Hallucination Detection

In [None]:
from evaluation.metrics import HallucinationDetector

# Initialize hallucination detector
hallucination_detector = HallucinationDetector()

print("Hallucination Analysis:")
print("-" * 50)

for model_name, data in dehazed_results.items():
    dehazed_tensor = transforms.ToTensor()(data['image'])
    
    # Detect hallucination indicators
    indicators = hallucination_detector.detect_structure_changes(
        hazy_tensor, dehazed_tensor, clear_tensor
    )
    
    print(f"\n{model_name.upper()}:")
    print(f"  Edge Preservation Improvement: {indicators['edge_preservation_improvement']:.4f}")
    print(f"  Texture Similarity Change: {indicators['texture_similarity_change']:.4f}")
    print(f"  Color Consistency: {indicators['color_consistency']:.4f}")
    print(f"  Hallucination Score: {indicators['hallucination_score']:.4f} (lower is better)")
    
    # Interpretation
    score = indicators['hallucination_score']
    if score < 0.3:
        assessment = "Low hallucination risk"
    elif score < 0.6:
        assessment = "Moderate hallucination risk"
    else:
        assessment = "High hallucination risk"
    
    print(f"  Assessment: {assessment}")

## 7. Web Dashboard Integration

In [None]:
# Information about web dashboard
print("Web Dashboard Information:")
print("-" * 30)
print("To start the interactive web dashboard:")
print()
print("1. From command line:")
print("   python main.py web")
print()
print("2. Or directly with Streamlit:")
print("   streamlit run web/app.py")
print()
print("3. Then open: http://localhost:8501")
print()
print("Features:")
- Upload hazy images
- Choose dehazing model
- Real-time processing
- Download results
- Model comparison
- Batch processing
- Metrics display")

## 8. Complete Pipeline Demo

In [None]:
# Complete demo: Clear -> Haze -> Dehaze
print("Complete Pipeline Demo")
print("=" * 40)

# Step 1: Start with clear image
print("Step 1: Original clear image")
display(clear_img)

# Step 2: Add haze
print("\nStep 2: Adding moderate haze...")
hazy_array = haze_generator.generate_composite_haze(np.array(clear_img), 'moderate')
hazy_demo = Image.fromarray(hazy_array)
display(hazy_demo)

# Step 3: Dehaze with best model
print("\nStep 3: Dehazing with best performing model...")

# Find best model based on PSNR
best_model = None
best_psnr = -float('inf')

for model_name, data in dehazed_results.items():
    dehazed_tensor = transforms.ToTensor()(data['image'])
    metrics = metrics_calc.calculate_all_metrics(dehazed_tensor, clear_tensor)
    if metrics['psnr'] > best_psnr:
        best_psnr = metrics['psnr']
        best_model = model_name

if best_model:
    print(f"Best model: {best_model} (PSNR: {best_psnr:.2f} dB)")
    display(dehazed_results[best_model]['image'])
    
    # Final metrics
    final_metrics = metrics_calc.calculate_all_metrics(
        transforms.ToTensor()(dehazed_results[best_model]['image']), 
        clear_tensor
    )
    
    print(f"\nFinal Results:")
    print(f"  PSNR: {final_metrics['psnr']:.2f} dB")
    print(f"  SSIM: {final_metrics['ssim']:.4f}")
    print(f"  MAE: {final_metrics['mae']:.4f}")
    print(f"  Processing time: {dehazed_results[best_model]['result']['processing_time']:.3f}s")
else:
    print("No successful dehazing results found")

print("\n" + "=" * 40)
print("Demo completed successfully!")
print("\nNext steps:")
print("1. Try with your own images")
print("2. Experiment with different haze types")
print("3. Use the web dashboard for interactive processing")
print("4. Run full evaluation on NTIRE datasets")