1. Write OOP
2. Write Type casting

In [2]:
import pandas as pd

In [4]:
# data_akr = pd.read_csv('../../data/fogg_akr_burst_list_2000_2004.csv')
# data_akr.head()

In [None]:
import sys
print("Python executable:", sys.executable)
print("Python version:", sys.version)

Python executable: c:\Users\Sudipta\AppData\Local\Programs\Python\Python312\python.exe
Python version: 3.12.0 (tags/v3.12.0:0fb18b0, Oct  2 2023, 13:03:39) [MSC v.1935 64 bit (AMD64)]


In [None]:
from typing import Tuple, List, Dict, Optional, Union, Any
from typing import Literal
import numpy as np
import pandas as pd
import xarray as xr
from datetime import datetime
from pathlib import Path

Grid Creation Functions

In [3]:
from typing import Tuple, Optional
import numpy as np
import xarray as xr

def create_3d_grid(
    x_range: Tuple[float, float] = (-15.0, 15.0),
    y_range: Tuple[float, float] = (-15.0, 15.0),
    z_range: Tuple[float, float] = (-10.0, 10.0),
    bin_size: float = 0.5
) -> xr.Dataset:
    """
    Create empty 3D grid around Earth.
    
    Parameters
    ----------
    x_range : Tuple[float, float]
        (min, max) X coordinates in Earth radii
    y_range : Tuple[float, float]
        (min, max) Y coordinates in Earth radii
    z_range : Tuple[float, float]
        (min, max) Z coordinates in Earth radii
    bin_size : float
        Size of each grid cell in Earth radii
    
    Returns
    -------
    xr.Dataset
        3D grid with coordinates and empty data arrays
    
    Examples
    --------
    >>> grid = create_3d_grid(bin_size=1.0)
    >>> print(grid.dims)
    """
    
    # Type checking (optional but helpful during development)
    if not isinstance(x_range, tuple) or len(x_range) != 2:
        raise TypeError(f"x_range must be tuple of 2 floats, got {type(x_range)}")
    if not isinstance(bin_size, (int, float)) or bin_size <= 0:
        raise ValueError(f"bin_size must be positive number, got {bin_size}")
    
    # Create bin edges
    x_bins: np.ndarray = np.arange(x_range[0], x_range[1] + bin_size, bin_size)
    y_bins: np.ndarray = np.arange(y_range[0], y_range[1] + bin_size, bin_size)
    z_bins: np.ndarray = np.arange(z_range[0], z_range[1] + bin_size, bin_size)
    
    # Create bin centers
    x_centers: np.ndarray = (x_bins[:-1] + x_bins[1:]) / 2
    y_centers: np.ndarray = (y_bins[:-1] + y_bins[1:]) / 2
    z_centers: np.ndarray = (z_bins[:-1] + z_bins[1:]) / 2
    
    # Get dimensions
    n_x: int = len(x_centers)
    n_y: int = len(y_centers)
    n_z: int = len(z_centers)
    
    # Initialize arrays with proper dtype
    shape: Tuple[int, int, int] = (n_x, n_y, n_z)
    residence_time: np.ndarray = np.zeros(shape, dtype=np.float64)
    burst_count: np.ndarray = np.zeros(shape, dtype=np.int32)
    burst_time: np.ndarray = np.zeros(shape, dtype=np.float64)
    probability: np.ndarray = np.zeros(shape, dtype=np.float64)
    
    # Create Dataset
    grid: xr.Dataset = xr.Dataset(
        data_vars={
            'residence_time': (['x', 'y', 'z'], residence_time,
                              {'units': 'seconds', 'dtype': 'float64'}),
            'burst_count': (['x', 'y', 'z'], burst_count,
                           {'units': 'count', 'dtype': 'int32'}),
            'burst_time': (['x', 'y', 'z'], burst_time,
                          {'units': 'seconds', 'dtype': 'float64'}),
            'probability': (['x', 'y', 'z'], probability,
                          {'units': 'percent', 'dtype': 'float64'}),
        },
        coords={
            'x': (['x'], x_centers, {'units': 'R_E', 'dtype': 'float64'}),
            'y': (['y'], y_centers, {'units': 'R_E', 'dtype': 'float64'}),
            'z': (['z'], z_centers, {'units': 'R_E', 'dtype': 'float64'}),
        },
        attrs={
            'coordinate_system': 'GSE',
            'units': 'Earth_radii',
            'bin_size': float(bin_size),
            'description': 'AKR detection probability grid'
        }
    )
    
    return grid

Data Loading Functions

In [None]:
from typing import Union, Optional, List
from pathlib import Path
import pandas as pd

def load_burst_catalog(
    filepath: Union[str, Path],
    parse_dates: Optional[List[str]] = None,
    encoding: str = 'utf-8'
) -> pd.DataFrame:
    """
    Load AKR burst catalog from CSV file.
    
    Parameters
    ----------
    filepath : str or Path
        Path to CSV file
    parse_dates : List[str], optional
        Column names to parse as datetime
    encoding : str, default 'utf-8'
        File encoding
    
    Returns
    -------
    pd.DataFrame
        Burst catalog with parsed columns
    
    Raises
    ------
    FileNotFoundError
        If file does not exist
    ValueError
        If file format is invalid
    """
    
    # Convert to Path object for type safety
    filepath = Path(filepath)
    
    if not filepath.exists():
        raise FileNotFoundError(f"File not found: {filepath}")
    
    if parse_dates is None:
        parse_dates = ['STIME', 'ETIME']
    
    try:
        df: pd.DataFrame = pd.read_csv(
            filepath,
            parse_dates=parse_dates,
            encoding=encoding
        )
    except Exception as e:
        raise ValueError(f"Error reading CSV: {e}")
    
    # Validate required columns
    required_cols: List[str] = ['STIME', 'ETIME', 'X_GSE', 'Y_GSE', 'Z_GSE']
    missing_cols: List[str] = [col for col in required_cols if col not in df.columns]
    
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    return df


def parse_comma_separated_column(
    series: pd.Series,
    dtype: type = float
) -> pd.Series:
    """
    Parse comma-separated string values into lists of specified type.
    
    Parameters
    ----------
    series : pd.Series
        Series with comma-separated strings
    dtype : type, default float
        Target data type (float, int, str)
    
    Returns
    -------
    pd.Series
        Series with lists of parsed values
    
    Examples
    --------
    >>> df['x_positions'] = parse_comma_separated_column(df['X_GSE'], dtype=float)
    """
    
    def parse_value(x: str) -> List[Union[float, int, str]]:
        """Parse single comma-separated string."""
        if pd.isna(x) or x == '':
            return []
        
        values: List[str] = x.split(',')
        parsed: List[Union[float, int, str]] = []
        
        for val in values:
            val = val.strip()
            if val.lower() in ['nan', 'nat', '']:
                continue
            try:
                if dtype == float:
                    parsed.append(float(val))
                elif dtype == int:
                    parsed.append(int(float(val)))  # Handle "5.0" strings
                else:
                    parsed.append(str(val))
            except (ValueError, TypeError):
                continue
        
        return parsed
    
    return series.apply(parse_value)

Position and Coordinate Functions

In [None]:
from typing import Tuple, Optional
import numpy as np

def find_grid_indices(
    x: float,
    y: float,
    z: float,
    x_bins: np.ndarray,
    y_bins: np.ndarray,
    z_bins: np.ndarray
) -> Tuple[int, int, int]:
    """
    Find grid indices for a given position.
    
    Parameters
    ----------
    x, y, z : float
        Position coordinates in R_E
    x_bins, y_bins, z_bins : np.ndarray
        Bin edges for each dimension
    
    Returns
    -------
    Tuple[int, int, int]
        Grid indices (i, j, k)
    
    Raises
    ------
    ValueError
        If position is outside grid bounds
    """
    
    # Type casting to ensure float
    x = float(x)
    y = float(y)
    z = float(z)
    
    # Find bin indices
    i: int = int(np.digitize(x, x_bins) - 1)
    j: int = int(np.digitize(y, y_bins) - 1)
    k: int = int(np.digitize(z, z_bins) - 1)
    
    # Validate indices
    if not (0 <= i < len(x_bins) - 1):
        raise ValueError(f"X position {x} outside grid bounds")
    if not (0 <= j < len(y_bins) - 1):
        raise ValueError(f"Y position {y} outside grid bounds")
    if not (0 <= k < len(z_bins) - 1):
        raise ValueError(f"Z position {z} outside grid bounds")
    
    return (i, j, k)


def calculate_mean_position(
    positions: List[float]
) -> float:
    """
    Calculate mean position, handling empty lists.
    
    Parameters
    ----------
    positions : List[float]
        List of position values
    
    Returns
    -------
    float
        Mean position, or NaN if list is empty
    """
    
    if not positions or len(positions) == 0:
        return np.nan
    
    # Ensure all values are floats
    positions_array: np.ndarray = np.array(positions, dtype=np.float64)
    
    # Remove NaN values
    valid_positions: np.ndarray = positions_array[~np.isnan(positions_array)]
    
    if len(valid_positions) == 0:
        return np.nan
    
    return float(np.mean(valid_positions))

Grid Filling Functions

In [None]:
from typing import Dict, Any
import numpy as np
import pandas as pd
import xarray as xr

def fill_grid_with_bursts(
    grid: xr.Dataset,
    bursts: pd.DataFrame,
    time_col: str = 'duration_minutes'
) -> xr.Dataset:
    """
    Fill grid with burst statistics.
    
    Parameters
    ----------
    grid : xr.Dataset
        Empty or partially filled grid
    bursts : pd.DataFrame
        Burst catalog with position and time info
    time_col : str, default 'duration_minutes'
        Column name for burst duration
    
    Returns
    -------
    xr.Dataset
        Grid filled with burst statistics
    
    Notes
    -----
    Modifies grid in place and returns reference
    """
    
    # Type validation
    if not isinstance(grid, xr.Dataset):
        raise TypeError(f"grid must be xr.Dataset, got {type(grid)}")
    if not isinstance(bursts, pd.DataFrame):
        raise TypeError(f"bursts must be pd.DataFrame, got {type(bursts)}")
    
    # Get bin arrays
    x_bins: np.ndarray = grid.x.values
    y_bins: np.ndarray = grid.y.values
    z_bins: np.ndarray = grid.z.values
    
    # Ensure proper dtypes
    burst_count: np.ndarray = grid['burst_count'].values.astype(np.int32)
    burst_time: np.ndarray = grid['burst_time'].values.astype(np.float64)
    
    # Iterate through bursts
    for idx, burst in bursts.iterrows():
        try:
            # Extract position (with type casting)
            x: float = float(burst['x_mean'])
            y: float = float(burst['y_mean'])
            z: float = float(burst['z_mean'])
            duration: float = float(burst[time_col])
            
            # Skip invalid positions
            if np.isnan(x) or np.isnan(y) or np.isnan(z) or np.isnan(duration):
                continue
            
            # Find grid cell
            i, j, k = find_grid_indices(x, y, z, x_bins, y_bins, z_bins)
            
            # Increment counters
            burst_count[i, j, k] += 1
            burst_time[i, j, k] += duration * 60  # Convert minutes to seconds
            
        except (ValueError, KeyError, IndexError) as e:
            # Skip problematic bursts
            print(f"Warning: Skipping burst {idx}: {e}")
            continue
    
    # Update grid (ensures proper dtype)
    grid['burst_count'].values = burst_count
    grid['burst_time'].values = burst_time
    
    return grid


def calculate_probabilities(
    grid: xr.Dataset,
    min_residence_time: float = 60.0
) -> xr.Dataset:
    """
    Calculate detection probabilities from burst and residence time.
    
    Parameters
    ----------
    grid : xr.Dataset
        Grid with residence_time and burst_time filled
    min_residence_time : float, default 60.0
        Minimum residence time (seconds) to calculate probability
    
    Returns
    -------
    xr.Dataset
        Grid with probability calculated
    """
    
    # Type casting for safety
    residence: np.ndarray = grid['residence_time'].values.astype(np.float64)
    burst_time: np.ndarray = grid['burst_time'].values.astype(np.float64)
    
    # Calculate probability (avoid division by zero)
    probability: np.ndarray = np.zeros_like(residence, dtype=np.float64)
    
    mask: np.ndarray = residence >= min_residence_time
    probability[mask] = (burst_time[mask] / residence[mask]) * 100.0
    
    # Clip to valid range
    probability = np.clip(probability, 0.0, 100.0)
    
    # Update grid
    grid['probability'].values = probability.astype(np.float64)
    
    return grid

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import plotly.graph_objects as go
import xarray as xr

# ==============================================================================
# STEP 1: Create 3D Grid Structure
# ==============================================================================

def create_3d_grid(x_range=(-15, 15), y_range=(-15, 15), z_range=(-10, 10),
                   bin_size=0.5):
    """
    Create empty 3D grid around Earth
    
    Parameters:
    -----------
    x_range, y_range, z_range : tuple
        (min, max) in Earth radii
    bin_size : float
        Size of each grid cell in Earth radii
    
    Returns:
    --------
    grid : xarray.Dataset
        3D grid with coordinates
    """
    
    # Create bin edges
    x_bins = np.arange(x_range[0], x_range[1] + bin_size, bin_size)
    y_bins = np.arange(y_range[0], y_range[1] + bin_size, bin_size)
    z_bins = np.arange(z_range[0], z_range[1] + bin_size, bin_size)
    
    # Create bin centers (for plotting)
    x_centers = (x_bins[:-1] + x_bins[1:]) / 2
    y_centers = (y_bins[:-1] + y_bins[1:]) / 2
    z_centers = (z_bins[:-1] + z_bins[1:]) / 2
    
    # Create empty arrays
    n_x = len(x_centers)
    n_y = len(y_centers)
    n_z = len(z_centers)
    
    # Initialize with zeros
    residence_time = np.zeros((n_x, n_y, n_z))
    burst_count = np.zeros((n_x, n_y, n_z))
    burst_time = np.zeros((n_x, n_y, n_z))
    probability = np.zeros((n_x, n_y, n_z))
    
    # Create Xarray Dataset (keeps everything organized!)
    grid = xr.Dataset(
        data_vars={
            'residence_time': (['x', 'y', 'z'], residence_time,
                              {'units': 'seconds', 'long_name': 'Time spent in cell'}),
            'burst_count': (['x', 'y', 'z'], burst_count,
                           {'units': 'count', 'long_name': 'Number of bursts'}),
            'burst_time': (['x', 'y', 'z'], burst_time,
                          {'units': 'seconds', 'long_name': 'Total burst time'}),
            'probability': (['x', 'y', 'z'], probability,
                          {'units': 'percent', 'long_name': 'Detection probability'}),
        },
        coords={
            'x': x_centers,
            'y': y_centers,
            'z': z_centers,
        },
        attrs={
            'coordinate_system': 'GSE',
            'units': 'Earth_radii',
            'bin_size': bin_size,
            'description': 'AKR detection probability grid'
        }
    )
    
    print(f"Created 3D grid:")
    print(f"  X range: {x_range[0]} to {x_range[1]} R_E ({n_x} bins)")
    print(f"  Y range: {y_range[0]} to {y_range[1]} R_E ({n_y} bins)")
    print(f"  Z range: {z_range[0]} to {z_range[1]} R_E ({n_z} bins)")
    print(f"  Total cells: {n_x * n_y * n_z:,}")
    print(f"  Bin size: {bin_size} R_E")
    
    return grid

# Create your grid
grid = create_3d_grid(
    x_range=(-15, 15),
    y_range=(-15, 15), 
    z_range=(-10, 10),
    bin_size=1.0  # Start with 1 R_E for testing
)

# ==============================================================================
# STEP 2: Fill with Random Data (for testing visualization)
# ==============================================================================

def fill_with_random_data(grid, seed=42):
    """
    Fill grid with random data for testing
    
    Strategy: Make probability higher on nightside and near auroral regions
    """
    np.random.seed(seed)
    
    # Get coordinates
    X, Y, Z = np.meshgrid(grid.x, grid.y, grid.z, indexing='ij')
    
    # Calculate distance from Earth
    R = np.sqrt(X**2 + Y**2 + Z**2)
    
    # Create realistic-looking probability distribution
    # Higher probability:
    #   - On nightside (X < 0)
    #   - At auroral latitudes (higher Z)
    #   - Within magnetosphere (R < 10)
    
    # Nightside enhancement
    nightside = np.exp(-X / 5)  # Higher on X < 0 side
    
    # Auroral latitude enhancement (higher at Z = ±5)
    auroral = np.exp(-((np.abs(Z) - 5) / 2)**2)
    
    # Distance decay (probability drops outside magnetosphere)
    distance = np.exp(-(R - 6) / 3)
    
    # Combine factors with random noise
    probability = 30 * nightside * auroral * distance + np.random.rand(*X.shape) * 10
    
    # Clip to 0-100%
    probability = np.clip(probability, 0, 100)
    
    # Set to zero very close to Earth (< 3 R_E) and far away (> 12 R_E)
    probability[R < 3] = 0
    probability[R > 12] = 0
    
    # Fill grid
    grid['probability'].values = probability
    
    # Generate corresponding random data for other variables
    grid['residence_time'].values = np.random.uniform(1000, 10000, probability.shape)
    grid['burst_count'].values = np.random.poisson(probability / 10, probability.shape)
    grid['burst_time'].values = grid['burst_count'].values * np.random.uniform(600, 3600, probability.shape)
    
    print("Filled grid with random data")
    print(f"  Probability range: {probability.min():.1f} - {probability.max():.1f}%")
    print(f"  Mean probability: {probability.mean():.1f}%")
    
    return grid

# Fill with random data
grid = fill_with_random_data(grid)

# ==============================================================================
# STEP 3: Visualization Option 1 - PyVista (BEST for scientific viz)
# ==============================================================================

try:
    import pyvista as pv
    
    def visualize_with_pyvista(grid, threshold=10):
        """
        Create interactive 3D visualization with PyVista
        
        Parameters:
        -----------
        grid : xarray.Dataset
            Grid data
        threshold : float
            Only show cells with probability > threshold
        """
        
        # Create structured grid
        x = grid.x.values
        y = grid.y.values
        z = grid.z.values
        
        # Create meshgrid
        X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
        
        # Create PyVista structured grid
        mesh = pv.StructuredGrid(X, Y, Z)
        
        # Add probability data
        mesh['probability'] = grid['probability'].values.flatten(order='F')
        
        # Threshold (only show high-probability regions)
        thresholded = mesh.threshold(threshold, scalars='probability')
        
        # Create plotter
        plotter = pv.Plotter(window_size=[1200, 800])
        
        # Add probability data
        plotter.add_mesh(
            thresholded,
            scalars='probability',
            cmap='hot',
            opacity=0.6,
            show_scalar_bar=True,
            scalar_bar_args={
                'title': 'AKR Detection\nProbability (%)',
                'vertical': True,
            }
        )
        
        # Add Earth sphere
        earth = pv.Sphere(radius=1.0, center=(0, 0, 0))
        plotter.add_mesh(earth, color='lightblue', opacity=0.8, label='Earth')
        
        # Add coordinate axes
        plotter.add_axes(
            xlabel='X (GSE) [R_E]',
            ylabel='Y (GSE) [R_E]',
            zlabel='Z (GSE) [R_E]',
            line_width=3,
            color='black'
        )
        
        # Add Sun direction arrow
        sun_arrow = pv.Arrow(
            start=(15, 0, 0),
            direction=(5, 0, 0),
            scale=3
        )
        plotter.add_mesh(sun_arrow, color='yellow', label='To Sun')
        
        # Add text
        plotter.add_text(
            'AKR Detection Probability Map\n(Random Test Data)',
            position='upper_edge',
            font_size=14,
            color='black'
        )
        
        # Set camera
        plotter.camera_position = 'xz'
        plotter.camera.zoom(1.2)
        
        # Show
        plotter.show()
        
        # Save screenshot
        plotter.screenshot('akr_grid_pyvista.png')
        
    # Create visualization
    visualize_with_pyvista(grid, threshold=15)
    
except ImportError:
    print("PyVista not installed. Install with: pip install pyvista")

# ==============================================================================
# STEP 4: Visualization Option 2 - Plotly (Interactive web-based)
# ==============================================================================

def visualize_with_plotly(grid, threshold=10):
    """
    Create interactive 3D visualization with Plotly
    """
    
    # Get high-probability cells
    prob = grid['probability'].values
    mask = prob > threshold
    
    # Get coordinates of high-probability cells
    X, Y, Z = np.meshgrid(grid.x, grid.y, grid.z, indexing='ij')
    
    x_flat = X[mask]
    y_flat = Y[mask]
    z_flat = Z[mask]
    prob_flat = prob[mask]
    
    # Create figure
    fig = go.Figure()
    
    # Add probability points as 3D scatter
    fig.add_trace(go.Scatter3d(
        x=x_flat,
        y=y_flat,
        z=z_flat,
        mode='markers',
        marker=dict(
            size=3,
            color=prob_flat,
            colorscale='Hot',
            colorbar=dict(title='Probability (%)', x=1.1),
            opacity=0.6,
            showscale=True
        ),
        name='AKR Probability',
        text=[f'Prob: {p:.1f}%<br>X: {x:.1f}<br>Y: {y:.1f}<br>Z: {z:.1f}' 
              for x, y, z, p in zip(x_flat, y_flat, z_flat, prob_flat)],
        hovertemplate='%{text}<extra></extra>'
    ))
    
    # Add Earth sphere
    u = np.linspace(0, 2 * np.pi, 50)
    v = np.linspace(0, np.pi, 50)
    x_sphere = np.outer(np.cos(u), np.sin(v))
    y_sphere = np.outer(np.sin(u), np.sin(v))
    z_sphere = np.outer(np.ones(np.size(u)), np.cos(v))
    
    fig.add_trace(go.Surface(
        x=x_sphere,
        y=y_sphere,
        z=z_sphere,
        colorscale=[[0, 'lightblue'], [1, 'lightblue']],
        showscale=False,
        opacity=0.8,
        name='Earth'
    ))
    
    # Add Sun direction arrow (as a line)
    fig.add_trace(go.Scatter3d(
        x=[12, 18],
        y=[0, 0],
        z=[0, 0],
        mode='lines+text',
        line=dict(color='yellow', width=10),
        text=['', '☀ Sun'],
        textposition='top center',
        textfont=dict(size=20),
        name='Sun Direction',
        showlegend=False
    ))
    
    # Update layout
    fig.update_layout(
        title='AKR Detection Probability Map (Random Test Data)',
        scene=dict(
            xaxis_title='X (GSE) [R_E]',
            yaxis_title='Y (GSE) [R_E]',
            zaxis_title='Z (GSE) [R_E]',
            xaxis=dict(range=[-15, 15]),
            yaxis=dict(range=[-15, 15]),
            zaxis=dict(range=[-10, 10]),
            aspectmode='manual',
            aspectratio=dict(x=1, y=1, z=0.7)
        ),
        width=1000,
        height=800,
        showlegend=True
    )
    
    # Show
    fig.show()
    
    # Save as HTML
    fig.write_html('akr_grid_plotly.html')
    print("Saved interactive plot to: akr_grid_plotly.html")

# Create visualization
visualize_with_plotly(grid, threshold=15)

# ==============================================================================
# STEP 5: 2D Slice Views (Simple matplotlib)
# ==============================================================================

def visualize_2d_slices(grid):
    """
    Create 2D slice views (top, side, front)
    """
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    
    # Get middle slices
    x_mid = len(grid.x) // 2
    y_mid = len(grid.y) // 2
    z_mid = len(grid.z) // 2
    
    prob = grid['probability'].values
    
    # 1. Top-down view (X-Y plane, Z=0)
    ax = axes[0, 0]
    im = ax.pcolormesh(
        grid.x, grid.y,
        prob[:, :, z_mid].T,
        cmap='hot',
        vmin=0, vmax=80,
        shading='auto'
    )
    ax.set_xlabel('X (GSE) [R_E]')
    ax.set_ylabel('Y (GSE) [R_E]')
    ax.set_title('Top-Down View (Z = 0 plane)')
    ax.set_aspect('equal')
    
    # Add Earth circle
    circle = plt.Circle((0, 0), 1, color='lightblue', zorder=10, alpha=0.8)
    ax.add_patch(circle)
    
    # Add Sun arrow
    ax.arrow(12, 0, 3, 0, head_width=0.8, head_length=0.5, 
             fc='yellow', ec='orange', linewidth=2, zorder=5)
    ax.text(16, 1, '☀ Sun', fontsize=14, color='orange')
    
    plt.colorbar(im, ax=ax, label='Probability (%)')
    
    # 2. Side view (X-Z plane, Y=0)
    ax = axes[0, 1]
    im = ax.pcolormesh(
        grid.x, grid.z,
        prob[:, y_mid, :].T,
        cmap='hot',
        vmin=0, vmax=80,
        shading='auto'
    )
    ax.set_xlabel('X (GSE) [R_E]')
    ax.set_ylabel('Z (GSE) [R_E]')
    ax.set_title('Side View (Y = 0 plane)')
    ax.set_aspect('equal')
    
    # Add Earth circle
    circle = plt.Circle((0, 0), 1, color='lightblue', zorder=10, alpha=0.8)
    ax.add_patch(circle)
    
    plt.colorbar(im, ax=ax, label='Probability (%)')
    
    # 3. Front view (Y-Z plane, X=0)
    ax = axes[1, 0]
    im = ax.pcolormesh(
        grid.y, grid.z,
        prob[x_mid, :, :].T,
        cmap='hot',
        vmin=0, vmax=80,
        shading='auto'
    )
    ax.set_xlabel('Y (GSE) [R_E]')
    ax.set_ylabel('Z (GSE) [R_E]')
    ax.set_title('Front View (X = 0 plane)')
    ax.set_aspect('equal')
    
    # Add Earth circle
    circle = plt.Circle((0, 0), 1, color='lightblue', zorder=10, alpha=0.8)
    ax.add_patch(circle)
    
    plt.colorbar(im, ax=ax, label='Probability (%)')
    
    # 4. Histogram of probabilities
    ax = axes[1, 1]
    ax.hist(prob.flatten(), bins=50, edgecolor='black', alpha=0.7)
    ax.set_xlabel('Probability (%)')
    ax.set_ylabel('Number of Grid Cells')
    ax.set_title('Probability Distribution')
    ax.axvline(prob.mean(), color='red', linestyle='--', 
               label=f'Mean: {prob.mean():.1f}%')
    ax.legend()
    ax.grid(alpha=0.3)
    
    plt.suptitle('AKR Detection Probability - 2D Slices (Random Test Data)', 
                 fontsize=16, y=0.995)
    plt.tight_layout()
    plt.savefig('akr_grid_slices.png', dpi=300, bbox_inches='tight')
    plt.show()

# Create 2D slices
visualize_2d_slices(grid)

# ==============================================================================
# STEP 6: Add Satellite Trajectory (example)
# ==============================================================================

def add_satellite_trajectory(fig, trajectory_data=None):
    """
    Add satellite trajectory to Plotly figure
    
    Parameters:
    -----------
    fig : plotly figure
    trajectory_data : DataFrame with columns ['x', 'y', 'z', 'time']
    """
    
    if trajectory_data is None:
        # Create example trajectory (Wind-like orbit)
        t = np.linspace(0, 2*np.pi, 200)
        x_traj = 5 * np.cos(t) + 3
        y_traj = 5 * np.sin(t)
        z_traj = 2 * np.sin(2*t)
    else:
        x_traj = trajectory_data['x'].values
        y_traj = trajectory_data['y'].values
        z_traj = trajectory_data['z'].values
    
    # Add trajectory line
    fig.add_trace(go.Scatter3d(
        x=x_traj,
        y=y_traj,
        z=z_traj,
        mode='lines',
        line=dict(color='cyan', width=4),
        name='Satellite Orbit'
    ))
    
    # Add start/end points
    fig.add_trace(go.Scatter3d(
        x=[x_traj[0]],
        y=[y_traj[0]],
        z=[z_traj[0]],
        mode='markers',
        marker=dict(size=8, color='green', symbol='diamond'),
        name='Start'
    ))
    
    fig.add_trace(go.Scatter3d(
        x=[x_traj[-1]],
        y=[y_traj[-1]],
        z=[z_traj[-1]],
        mode='markers',
        marker=dict(size=8, color='red', symbol='x'),
        name='End'
    ))
    
    return fig

# ==============================================================================
# STEP 7: Save/Load Grid
# ==============================================================================

# Save grid to NetCDF
grid.to_netcdf('akr_grid_test.nc')
print("Saved grid to: akr_grid_test.nc")

# Load back
loaded_grid = xr.open_dataset('akr_grid_test.nc')
print("Loaded grid successfully")

# Also save as CSV (less efficient but human-readable)
# Flatten to 2D table
df = grid.to_dataframe().reset_index()
df.to_csv('akr_grid_test.csv', index=False)
print("Saved grid to: akr_grid_test.csv")

print("\n" + "="*60)
print("SUMMARY")
print("="*60)
print(f"Created {grid.x.size}x{grid.y.size}x{grid.z.size} grid")
print(f"Total cells: {grid.x.size * grid.y.size * grid.z.size:,}")
print(f"Filled with random data for testing")
print(f"Generated visualizations:")
print("  - 2D slices: akr_grid_slices.png")
print("  - Interactive 3D: akr_grid_plotly.html")
print("  - Saved data: akr_grid_test.nc")