# SwellSight Wave Analysis - Inference and Evaluation

## Interactive Inference Interface

In [None]:
class WaveInferenceEngine:
    """Interactive inference engine for wave analysis."""
    
    def __init__(self, model_path, device='cuda'):
        self.device = device
        self.model = None
        self.transform = transforms.Compose([
            transforms.Resize((768, 768)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Load model
        self.load_model(model_path)
    
    def load_model(self, model_path):
        """Load trained model from checkpoint."""
        checkpoint = torch.load(model_path, map_location=self.device)
        
        # Recreate model
        config = checkpoint['config']
        self.model = WaveAnalysisModel(config).to(self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        
        print(f"‚úÖ Model loaded from {model_path}")
        print(f"üìä Validation loss: {checkpoint['val_loss']['total_loss']:.4f}")
    
    def predict_image(self, image):
        """Predict wave parameters from image."""
        # Preprocess image
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        
        # Inference
        with torch.no_grad():
            outputs = self.model(image_tensor)
        
        # Process outputs
        height = outputs['height'].cpu().item()
        wave_type_probs = outputs['wave_type_probs'].cpu().numpy()[0]
        direction_probs = outputs['direction_probs'].cpu().numpy()[0]
        
        # Get predictions
        wave_type_idx = np.argmax(wave_type_probs)
        direction_idx = np.argmax(direction_probs)
        
        return {
            'height_meters': height,
            'wave_type': WAVE_TYPES[wave_type_idx],
            'direction': DIRECTIONS[direction_idx],
            'wave_type_confidence': wave_type_probs[wave_type_idx],
            'direction_confidence': direction_probs[direction_idx],
            'wave_type_probs': {WAVE_TYPES[i]: prob for i, prob in enumerate(wave_type_probs)},
            'direction_probs': {DIRECTIONS[i]: prob for i, prob in enumerate(direction_probs)}
        }
    
    def visualize_prediction(self, image, prediction):
        """Visualize prediction results."""
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        # Original image
        axes[0].imshow(image)
        axes[0].set_title('Input Image')
        axes[0].axis('off')
        
        # Wave type probabilities
        wave_types = list(prediction['wave_type_probs'].keys())
        wave_probs = list(prediction['wave_type_probs'].values())
        
        bars1 = axes[1].bar(wave_types, wave_probs, color='skyblue')
        axes[1].set_title('Wave Type Probabilities')
        axes[1].set_ylabel('Probability')
        axes[1].tick_params(axis='x', rotation=45)
        
        # Highlight predicted class
        max_idx = np.argmax(wave_probs)
        bars1[max_idx].set_color('orange')
        
        # Direction probabilities
        directions = list(prediction['direction_probs'].keys())
        dir_probs = list(prediction['direction_probs'].values())
        
        bars2 = axes[2].bar(directions, dir_probs, color='lightgreen')
        axes[2].set_title('Direction Probabilities')
        axes[2].set_ylabel('Probability')
        
        # Highlight predicted class
        max_idx = np.argmax(dir_probs)
        bars2[max_idx].set_color('red')
        
        # Add prediction text
        plt.suptitle(
            f"Prediction: {prediction['height_meters']:.2f}m, {prediction['wave_type']}, {prediction['direction']}\n"
            f"Confidence: Type {prediction['wave_type_confidence']:.3f}, Direction {prediction['direction_confidence']:.3f}",
            fontsize=14, fontweight='bold'
        )
        
        plt.tight_layout()
        plt.show()

# Load the trained model
inference_engine = WaveInferenceEngine(checkpoints_dir / 'best_model.pth', device=device)

In [None]:
# Interactive demo with synthetic samples
def demo_inference(num_samples=5):
    """Demo inference on synthetic samples."""
    print("üåä SwellSight Wave Analysis Demo")
    print("=" * 40)
    
    # Generate test samples
    test_generator = SyntheticDataGenerator()
    
    for i in range(num_samples):
        print(f"\nüì∏ Sample {i+1}/{num_samples}")
        
        # Generate sample
        image, true_labels, params = test_generator.generate_sample()
        
        # Make prediction
        prediction = inference_engine.predict_image(image)
        
        # Show results
        print(f"Ground Truth: {params['height_meters']:.2f}m, {params['wave_type']}, {params['direction']}")
        print(f"Prediction:   {prediction['height_meters']:.2f}m, {prediction['wave_type']}, {prediction['direction']}")
        
        # Calculate errors
        height_error = abs(prediction['height_meters'] - params['height_meters'])
        type_correct = prediction['wave_type'] == params['wave_type']
        dir_correct = prediction['direction'] == params['direction']
        
        print(f"Height Error: {height_error:.3f}m")
        print(f"Type Correct: {'‚úÖ' if type_correct else '‚ùå'}")
        print(f"Direction Correct: {'‚úÖ' if dir_correct else '‚ùå'}")
        
        # Visualize
        inference_engine.visualize_prediction(image, prediction)

# Run demo
demo_inference(3)

## Model Evaluation and Metrics

In [None]:
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import seaborn as sns

def evaluate_model(model, test_loader, device):
    """Comprehensive model evaluation."""
    model.eval()
    
    all_height_preds = []
    all_height_true = []
    all_type_preds = []
    all_type_true = []
    all_dir_preds = []
    all_dir_true = []
    
    with torch.no_grad():
        for images, targets in tqdm(test_loader, desc='Evaluating'):
            images = images.to(device)
            targets = {k: v.to(device) for k, v in targets.items()}
            
            outputs = model(images)
            
            # Collect predictions
            all_height_preds.extend(outputs['height'].cpu().numpy())
            all_height_true.extend(targets['height'].cpu().numpy())
            
            all_type_preds.extend(outputs['wave_type_logits'].argmax(dim=1).cpu().numpy())
            all_type_true.extend(targets['wave_type'].cpu().numpy())
            
            all_dir_preds.extend(outputs['direction_logits'].argmax(dim=1).cpu().numpy())
            all_dir_true.extend(targets['direction'].cpu().numpy())
    
    # Convert to numpy arrays
    height_preds = np.array(all_height_preds)
    height_true = np.array(all_height_true)
    type_preds = np.array(all_type_preds)
    type_true = np.array(all_type_true)
    dir_preds = np.array(all_dir_preds)
    dir_true = np.array(all_dir_true)
    
    # Calculate metrics
    metrics = {
        'height_mae': np.mean(np.abs(height_preds - height_true)),
        'height_rmse': np.sqrt(np.mean((height_preds - height_true) ** 2)),
        'type_accuracy': accuracy_score(type_true, type_preds),
        'type_f1': f1_score(type_true, type_preds, average='weighted'),
        'direction_accuracy': accuracy_score(dir_true, dir_preds),
        'direction_f1': f1_score(dir_true, dir_preds, average='weighted')
    }
    
    return metrics, {
        'height_preds': height_preds,
        'height_true': height_true,
        'type_preds': type_preds,
        'type_true': type_true,
        'dir_preds': dir_preds,
        'dir_true': dir_true
    }

def plot_evaluation_results(metrics, predictions):
    """Plot comprehensive evaluation results."""
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    
    # Height regression scatter plot
    axes[0, 0].scatter(predictions['height_true'], predictions['height_preds'], alpha=0.6)
    axes[0, 0].plot([0, 4], [0, 4], 'r--', label='Perfect prediction')
    axes[0, 0].set_xlabel('True Height (m)')
    axes[0, 0].set_ylabel('Predicted Height (m)')
    axes[0, 0].set_title(f'Height Prediction\nMAE: {metrics["height_mae"]:.3f}m, RMSE: {metrics["height_rmse"]:.3f}m')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Height error distribution
    height_errors = predictions['height_preds'] - predictions['height_true']
    axes[0, 1].hist(height_errors, bins=30, alpha=0.7, color='skyblue')
    axes[0, 1].axvline(0, color='red', linestyle='--', label='Perfect prediction')
    axes[0, 1].set_xlabel('Prediction Error (m)')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title('Height Prediction Error Distribution')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Wave type confusion matrix
    type_cm = confusion_matrix(predictions['type_true'], predictions['type_preds'])
    sns.heatmap(type_cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=WAVE_TYPES, yticklabels=WAVE_TYPES, ax=axes[0, 2])
    axes[0, 2].set_title(f'Wave Type Confusion Matrix\nAccuracy: {metrics["type_accuracy"]:.3f}, F1: {metrics["type_f1"]:.3f}')
    axes[0, 2].set_xlabel('Predicted')
    axes[0, 2].set_ylabel('True')
    
    # Direction confusion matrix
    dir_cm = confusion_matrix(predictions['dir_true'], predictions['dir_preds'])
    sns.heatmap(dir_cm, annot=True, fmt='d', cmap='Greens',
                xticklabels=DIRECTIONS, yticklabels=DIRECTIONS, ax=axes[1, 0])
    axes[1, 0].set_title(f'Direction Confusion Matrix\nAccuracy: {metrics["direction_accuracy"]:.3f}, F1: {metrics["direction_f1"]:.3f}')
    axes[1, 0].set_xlabel('Predicted')
    axes[1, 0].set_ylabel('True')
    
    # Metrics summary
    axes[1, 1].axis('off')
    metrics_text = f"""
    üìä Model Performance Summary
    
    Height Regression:
    ‚Ä¢ MAE: {metrics['height_mae']:.3f} meters
    ‚Ä¢ RMSE: {metrics['height_rmse']:.3f} meters
    
    Wave Type Classification:
    ‚Ä¢ Accuracy: {metrics['type_accuracy']:.3f}
    ‚Ä¢ F1-Score: {metrics['type_f1']:.3f}
    
    Direction Classification:
    ‚Ä¢ Accuracy: {metrics['direction_accuracy']:.3f}
    ‚Ä¢ F1-Score: {metrics['direction_f1']:.3f}
    """
    axes[1, 1].text(0.1, 0.5, metrics_text, fontsize=12, verticalalignment='center')
    
    # Performance by height range
    height_ranges = [(0, 1), (1, 2), (2, 3), (3, 4)]
    range_maes = []
    range_labels = []
    
    for low, high in height_ranges:
        mask = (predictions['height_true'] >= low) & (predictions['height_true'] < high)
        if mask.sum() > 0:
            range_mae = np.mean(np.abs(predictions['height_preds'][mask] - predictions['height_true'][mask]))
            range_maes.append(range_mae)
            range_labels.append(f'{low}-{high}m')
    
    axes[1, 2].bar(range_labels, range_maes, color='coral')
    axes[1, 2].set_xlabel('Height Range')
    axes[1, 2].set_ylabel('MAE (meters)')
    axes[1, 2].set_title('Height Prediction Error by Range')
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Evaluate on validation set
print("üîç Evaluating model performance...")
metrics, predictions = evaluate_model(model, val_loader, device)
plot_evaluation_results(metrics, predictions)