# Exoplanet Detection - Visualization Dashboard

This notebook contains visualization tools for analyzing predictions from the exoplanet detection pipeline.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

In [None]:
def load_predictions(filename='predictions.csv'):
    """Load predictions from CSV file"""
    if not Path(filename).exists():
        print(f"Error: {filename} not found")
        return None
    
    df = pd.read_csv(filename)
    print(f"Loaded {len(df)} predictions from {filename}")
    return df

In [None]:
def visualize_predictions(filename='predictions.csv', save_path='predictions_visualization.png'):
    """Create comprehensive visualization of predictions"""
    df = load_predictions(filename)
    if df is None:
        return
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Exoplanet Detection Analysis Dashboard', fontsize=16, fontweight='bold')
    
    # 1. Planet probability histogram
    ax1 = axes[0, 0]
    ax1.hist(df['planet_probability'], bins=20, color='steelblue', edgecolor='black', alpha=0.7)
    ax1.axvline(0.5, color='red', linestyle='--', linewidth=2, label='Decision Threshold')
    ax1.set_xlabel('Planet Probability', fontsize=12)
    ax1.set_ylabel('Frequency', fontsize=12)
    ax1.set_title('Distribution of Planet Probabilities', fontsize=14, fontweight='bold')
    ax1.legend()
    ax1.grid(alpha=0.3)
    
    # 2. Prediction pie chart
    ax2 = axes[0, 1]
    prediction_counts = df['prediction'].value_counts()
    colors = ['#2ecc71' if p == 'Planet' else '#e74c3c' for p in prediction_counts.index]
    ax2.pie(prediction_counts.values, labels=prediction_counts.index, autopct='%1.1f%%',
            startangle=90, colors=colors, textprops={'fontsize': 12, 'fontweight': 'bold'})
    ax2.set_title('Planet vs Non-Planet Classifications', fontsize=14, fontweight='bold')
    
    # 3. Period vs Depth scatter
    ax3 = axes[1, 0]
    planet_mask = df['prediction'] == 'Planet'
    ax3.scatter(df[planet_mask]['period_days'], df[planet_mask]['depth'], 
                c='green', s=100, alpha=0.6, label='Planet', marker='o')
    ax3.scatter(df[~planet_mask]['period_days'], df[~planet_mask]['depth'], 
                c='red', s=100, alpha=0.6, label='Non-Planet', marker='x')
    ax3.set_xlabel('Orbital Period (days)', fontsize=12)
    ax3.set_ylabel('Transit Depth', fontsize=12)
    ax3.set_title('Period vs Transit Depth', fontsize=14, fontweight='bold')
    ax3.legend()
    ax3.grid(alpha=0.3)
    ax3.set_xscale('log')
    
    # 4. SNR vs Duration scatter
    ax4 = axes[1, 1]
    ax4.scatter(df[planet_mask]['snr'], df[planet_mask]['duration_days']*24, 
                c='green', s=100, alpha=0.6, label='Planet', marker='o')
    ax4.scatter(df[~planet_mask]['snr'], df[~planet_mask]['duration_days']*24, 
                c='red', s=100, alpha=0.6, label='Non-Planet', marker='x')
    ax4.axvline(7.0, color='orange', linestyle='--', linewidth=2, label='SDE Threshold')
    ax4.set_xlabel('Signal-to-Noise Ratio (SDE)', fontsize=12)
    ax4.set_ylabel('Transit Duration (hours)', fontsize=12)
    ax4.set_title('SNR vs Transit Duration', fontsize=14, fontweight='bold')
    ax4.legend()
    ax4.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✓ Visualization saved to {save_path}")
    plt.show()

In [None]:
def feature_correlation_heatmap(filename='predictions.csv'):
    """Create correlation heatmap of features"""
    df = load_predictions(filename)
    if df is None:
        return
    
    # Select numeric columns only
    numeric_cols = df.select_dtypes(include=[np.number]).columns
    feature_cols = [col for col in numeric_cols if col not in ['timestamp']]
    
    if len(feature_cols) == 0:
        print("No numeric features found for correlation")
        return
    
    corr_matrix = df[feature_cols].corr()
    
    plt.figure(figsize=(12, 10))
    plt.imshow(corr_matrix, cmap='coolwarm', aspect='auto', vmin=-1, vmax=1)
    plt.colorbar(label='Correlation Coefficient')
    plt.xticks(range(len(feature_cols)), feature_cols, rotation=45, ha='right')
    plt.yticks(range(len(feature_cols)), feature_cols)
    plt.title('Feature Correlation Heatmap', fontsize=16, fontweight='bold', pad=20)
    
    # Add correlation values
    for i in range(len(feature_cols)):
        for j in range(len(feature_cols)):
            plt.text(j, i, f'{corr_matrix.iloc[i, j]:.2f}',
                    ha='center', va='center', fontsize=8)
    
    plt.tight_layout()
    plt.savefig('feature_correlation.png', dpi=300, bbox_inches='tight')
    print("✓ Correlation heatmap saved to feature_correlation.png")
    plt.show()

In [None]:
def summary_statistics(filename='predictions.csv'):
    """Print summary statistics of predictions"""
    df = load_predictions(filename)
    if df is None:
        return
    
    print("\n" + "="*60)
    print("PREDICTION SUMMARY STATISTICS")
    print("="*60)
    
    print(f"\nTotal predictions: {len(df)}")
    print(f"Planets detected: {(df['prediction'] == 'Planet').sum()}")
    print(f"Non-planets: {(df['prediction'] == 'Non-Planet').sum()}")
    
    print("\nPlanet Probability Stats:")
    print(f"  Mean: {df['planet_probability'].mean():.3f}")
    print(f"  Median: {df['planet_probability'].median():.3f}")
    print(f"  Std Dev: {df['planet_probability'].std():.3f}")
    
    if 'period_days' in df.columns:
        print("\nOrbital Period Stats (days):")
        print(f"  Mean: {df['period_days'].mean():.2f}")
        print(f"  Median: {df['period_days'].median():.2f}")
        print(f"  Range: {df['period_days'].min():.2f} - {df['period_days'].max():.2f}")
    
    if 'snr' in df.columns:
        print("\nSNR Stats:")
        print(f"  Mean: {df['snr'].mean():.2f}")
        print(f"  Median: {df['snr'].median():.2f}")
        print(f"  Above threshold (≥7): {(df['snr'] >= 7).sum()} / {len(df)}")
    
    print("\n" + "="*60)

## Usage Examples

In [None]:
# Create main visualization dashboard
visualize_predictions('predictions.csv')

In [None]:
# Show correlation heatmap
feature_correlation_heatmap('predictions.csv')

In [None]:
# Print summary statistics
summary_statistics('predictions.csv')