# Predictions Visualization

Visualize ice predictions and create beautiful plots:
- Visualize model predictions
- Compare with ground truth
- Create time series animations
- Generate prediction maps

In [None]:
import sys
sys.path.append('../training')

import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.colors import ListedColormap
import seaborn as sns
from PIL import Image
from IPython.display import HTML

from train_ice_classifier import IceDataset, IceClassifier

print("Imports successful!")

## 1. Load Model and Data

In [None]:
# Load model
MODEL_PATH = '../models/ice_classifier_resnet50.pth'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = IceClassifier(num_classes=3, pretrained=False)
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(DEVICE)
model.eval()

# Load test dataset
test_dataset = IceDataset('../data/processed', 'test')

print(f"✅ Model loaded")
print(f"✅ Test dataset: {len(test_dataset)} samples")

## 2. Visualize Single Prediction

In [None]:
def visualize_prediction(model, dataset, idx=0):
    """Visualize prediction with confidence scores"""
    image, true_label = dataset[idx]
    
    # Predict
    with torch.no_grad():
        output = model(image.unsqueeze(0).to(DEVICE))
        probabilities = torch.softmax(output, dim=1)[0]
        predicted = torch.argmax(probabilities).item()
    
    class_names = ['Open Water', 'Thin Ice', 'Thick Ice']
    colors = ['#3498db', '#e67e22', '#2ecc71']
    
    # Create figure
    fig = plt.figure(figsize=(16, 6))
    gs = fig.add_gridspec(1, 3, width_ratios=[1.2, 1, 1])
    
    # Image
    ax1 = fig.add_subplot(gs[0])
    img_display = image.permute(1, 2, 0).numpy()
    ax1.imshow(img_display)
    ax1.set_title('Satellite Image', fontsize=14, fontweight='bold', pad=15)
    ax1.axis('off')
    
    # Prediction bars
    ax2 = fig.add_subplot(gs[1])
    probs = probabilities.cpu().numpy() * 100
    bars = ax2.barh(class_names, probs, color=colors, alpha=0.7, edgecolor='black')
    bars[predicted].set_alpha(1.0)
    bars[predicted].set_linewidth(3)
    ax2.set_xlabel('Confidence (%)', fontsize=12)
    ax2.set_title('Predictions', fontsize=14, fontweight='bold', pad=15)
    ax2.set_xlim(0, 100)
    ax2.grid(axis='x', alpha=0.3)
    
    # Add value labels
    for i, (bar, prob) in enumerate(zip(bars, probs)):
        ax2.text(prob + 2, bar.get_y() + bar.get_height()/2, 
                f'{prob:.1f}%', va='center', fontsize=10)
    
    # Result summary
    ax3 = fig.add_subplot(gs[2])
    ax3.axis('off')
    
    result_text = f"""Result Summary
    
True Label:
{class_names[true_label]}

Predicted:
{class_names[predicted]}

Confidence:
{probs[predicted]:.1f}%

Status:
{'✅ Correct' if predicted == true_label else '❌ Wrong'}
"""
    
    ax3.text(0.1, 0.5, result_text, fontsize=12, 
            verticalalignment='center', family='monospace',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))
    
    plt.tight_layout()
    plt.show()

# Visualize multiple samples
for i in range(min(3, len(test_dataset))):
    visualize_prediction(model, test_dataset, i)

## 3. Batch Prediction Grid

In [None]:
def create_prediction_grid(model, dataset, num_samples=9):
    """Create grid of predictions"""
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    axes = axes.flatten()
    
    class_names = ['Open Water', 'Thin Ice', 'Thick Ice']
    
    for i in range(min(num_samples, len(dataset))):
        image, true_label = dataset[i]
        
        # Predict
        with torch.no_grad():
            output = model(image.unsqueeze(0).to(DEVICE))
            probabilities = torch.softmax(output, dim=1)[0]
            predicted = torch.argmax(probabilities).item()
            confidence = probabilities[predicted].item() * 100
        
        # Display
        img_display = image.permute(1, 2, 0).numpy()
        axes[i].imshow(img_display)
        
        # Title with color coding
        color = 'green' if predicted == true_label else 'red'
        axes[i].set_title(f"True: {class_names[true_label]}\nPred: {class_names[predicted]} ({confidence:.0f}%)",
                         fontsize=10, color=color, fontweight='bold')
        axes[i].axis('off')
    
    plt.suptitle('Batch Predictions', fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.savefig('../models/prediction_grid.png', dpi=150, bbox_inches='tight')
    plt.show()

create_prediction_grid(model, test_dataset)

## 4. Time Series Simulation

In [None]:
def create_time_series_plot(model, dataset, num_timesteps=10):
    """Simulate time series of ice predictions"""
    fig, axes = plt.subplots(2, 5, figsize=(18, 8))
    axes = axes.flatten()
    
    class_names = ['Open Water', 'Thin Ice', 'Thick Ice']
    
    for i in range(min(num_timesteps, len(dataset))):
        image, _ = dataset[i]
        
        # Predict
        with torch.no_grad():
            output = model(image.unsqueeze(0).to(DEVICE))
            probabilities = torch.softmax(output, dim=1)[0]
            predicted = torch.argmax(probabilities).item()
            confidence = probabilities[predicted].item() * 100
        
        # Display
        img_display = image.permute(1, 2, 0).numpy()
        axes[i].imshow(img_display)
        axes[i].set_title(f"Day {i+1}\n{class_names[predicted]} ({confidence:.0f}%)",
                         fontsize=9, fontweight='bold')
        axes[i].axis('off')
    
    plt.suptitle('Simulated 10-Day Ice Monitoring', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

create_time_series_plot(model, test_dataset)

## 5. Confidence Heatmap

In [None]:
def create_confidence_heatmap(model, dataset, num_samples=20):
    """Create heatmap of prediction confidences"""
    class_names = ['Open Water', 'Thin Ice', 'Thick Ice']
    confidences = np.zeros((num_samples, 3))
    
    for i in range(min(num_samples, len(dataset))):
        image, _ = dataset[i]
        
        with torch.no_grad():
            output = model(image.unsqueeze(0).to(DEVICE))
            probabilities = torch.softmax(output, dim=1)[0]
            confidences[i] = probabilities.cpu().numpy() * 100
    
    # Plot heatmap
    plt.figure(figsize=(10, 12))
    sns.heatmap(confidences, annot=True, fmt='.1f', cmap='RdYlGn',
               xticklabels=class_names, yticklabels=[f'Sample {i+1}' for i in range(num_samples)],
               cbar_kws={'label': 'Confidence (%)'})
    plt.title('Prediction Confidence Heatmap', fontsize=14, fontweight='bold', pad=15)
    plt.xlabel('Ice Type', fontsize=12)
    plt.ylabel('Sample', fontsize=12)
    plt.tight_layout()
    plt.savefig('../models/confidence_heatmap.png', dpi=150, bbox_inches='tight')
    plt.show()

create_confidence_heatmap(model, test_dataset)

## 6. Export Predictions for Analysis

In [None]:
import json

def export_predictions(model, dataset, output_file='../models/predictions.json'):
    """Export all predictions to JSON"""
    class_names = ['Open Water', 'Thin Ice', 'Thick Ice']
    results = []
    
    for i in range(len(dataset)):
        image, true_label = dataset[i]
        
        with torch.no_grad():
            output = model(image.unsqueeze(0).to(DEVICE))
            probabilities = torch.softmax(output, dim=1)[0].cpu().numpy()
            predicted = np.argmax(probabilities)
        
        results.append({
            'sample_id': i,
            'true_label': class_names[true_label],
            'predicted_label': class_names[predicted],
            'correct': bool(predicted == true_label),
            'confidences': {
                class_names[j]: float(probabilities[j] * 100)
                for j in range(len(class_names))
            }
        })
    
    # Save to JSON
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"✅ Exported {len(results)} predictions to: {output_file}")
    
    # Calculate accuracy
    correct = sum(1 for r in results if r['correct'])
    accuracy = correct / len(results) * 100
    print(f"   Overall Accuracy: {accuracy:.2f}%")

export_predictions(model, test_dataset)

## Summary

Predictions visualization complete!

**Outputs Created**:
- Individual prediction visualizations
- Prediction grid
- Time series simulation
- Confidence heatmap
- Exported predictions (JSON)

**Next Steps**:
- Analyze prediction patterns
- Check `05_satellite_data_analysis.ipynb`
- Deploy model to production