# Sharks from Space - Enhanced Visualization Demo

This notebook demonstrates the enhanced visualization capabilities of the Sharks from Space project, including:

- Time series visualization of habitat predictions
- Multi-model comparison
- Interactive map generation
- Statistical analysis of predictions

## Prerequisites

Make sure you have run the enhanced pipeline:
```bash
make all-enhanced  # For full pipeline with enhanced visualization
# or
make demo-enhanced  # For demo with time series
```


In [None]:
import os
import sys
import json
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
from PIL import Image
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

# Add src to path
sys.path.append('../src')
from utils import load_config, setup_logging

# Set up plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Load configuration
config = load_config('../config/params.yaml')
logger = setup_logging()

print("Enhanced Visualization Demo - Sharks from Space")
print("=" * 50)


## 1. Discover Available Prediction Files


In [None]:
def discover_prediction_files(data_dir='../web/data'):
    """Discover available prediction files and organize by model and date."""
    files = {
        'xgboost': [],
        'random_forest': [],
        'lightgbm': []
    }
    
    # Look for PNG files
    png_pattern = os.path.join(data_dir, 'habitat_prob_*.png')
    for file_path in glob.glob(png_pattern):
        filename = os.path.basename(file_path)
        # Parse filename: habitat_prob_{model}_{date}.png
        parts = filename.replace('habitat_prob_', '').replace('.png', '').split('_')
        if len(parts) >= 2:
            model = parts[0]
            date_str = parts[1]
            if model in files:
                files[model].append((date_str, file_path))
    
    # Sort by date
    for model in files:
        files[model].sort(key=lambda x: x[0])
    
    return files

# Discover files
prediction_files = discover_prediction_files()
available_models = [model for model, file_list in prediction_files.items() if file_list]

print(f"Available models: {available_models}")
for model in available_models:
    print(f"{model}: {len(prediction_files[model])} prediction files")
    if prediction_files[model]:
        dates = [item[0] for item in prediction_files[model]]
        print(f"  Date range: {dates[0]} to {dates[-1]}")


## 2. Load and Analyze Prediction Metadata


In [None]:
# Load metadata if available
metadata_path = '../web/data/prediction_metadata.json'
if os.path.exists(metadata_path):
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    print("Prediction Metadata:")
    print(json.dumps(metadata, indent=2))
else:
    print("No metadata file found")
    metadata = None


## 3. Visualize Time Series of Predictions


In [None]:
def create_time_series_plot(files, model='xgboost'):
    """Create a time series plot of prediction files."""
    if model not in files or not files[model]:
        print(f"No files found for model: {model}")
        return
    
    dates = [item[0] for item in files[model]]
    file_paths = [item[1] for item in files[model]]
    
    # Convert dates to datetime
    date_objects = [datetime.strptime(date_str, '%Y%m%d') for date_str in dates]
    
    # Create subplot for time series
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle(f'Habitat Prediction Time Series - {model.title()}', fontsize=16)
    
    # Plot 1: Sample images over time
    sample_indices = np.linspace(0, len(file_paths)-1, 4, dtype=int)
    for i, idx in enumerate(sample_indices):
        ax = axes[i//2, i%2]
        img = Image.open(file_paths[idx])
        ax.imshow(img)
        ax.set_title(f'{dates[idx]}')
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Create time series visualization for first available model
if available_models:
    create_time_series_plot(prediction_files, available_models[0])


## 4. Multi-Model Comparison


In [None]:
def compare_models(files):
    """Compare predictions across different models."""
    if len(available_models) < 2:
        print("Need at least 2 models for comparison")
        return
    
    # Find common dates across models
    all_dates = set()
    for model in available_models:
        dates = set(item[0] for item in files[model])
        all_dates.update(dates)
    
    common_dates = sorted(list(all_dates))
    print(f"Found {len(common_dates)} unique dates across all models")
    
    # Create comparison plot
    fig, axes = plt.subplots(len(available_models), 1, figsize=(12, 4*len(available_models)))
    if len(available_models) == 1:
        axes = [axes]
    
    fig.suptitle('Multi-Model Comparison', fontsize=16)
    
    for i, model in enumerate(available_models):
        ax = axes[i]
        model_dates = [item[0] for item in files[model]]
        model_files = [item[1] for item in files[model]]
        
        # Plot available dates for this model
        date_objects = [datetime.strptime(date_str, '%Y%m%d') for date_str in model_dates]
        ax.scatter(date_objects, [1]*len(date_objects), label=f'{model} predictions', alpha=0.7)
        ax.set_title(f'{model.title()} Model - Available Predictions')
        ax.set_ylabel('Model')
        ax.set_xlabel('Date')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Compare models
compare_models(prediction_files)


## 5. Interactive Map Generation
