In [1]:
"""
Final AUKF Tuning Script
========================
Improvements for better adaptive behavior and realistic noise estimation.
"""

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import chi2
from tqdm import tqdm
from utils import propagate_orbit, validate_orbit_parameters, R_EARTH
from aukf import OptimizedAUKF

plt.rcParams["figure.dpi"] = 110

def estimate_measurement_noise_improved(Z):
    """Improved noise estimation using innovation-based approach."""
    
    # Use second differences to estimate measurement noise
    # This removes the effect of dynamics
    if len(Z) > 100:
        # Take a stable portion of data
        mid = len(Z) // 2
        window = 1000
        Z_subset = Z[mid-window:mid+window]
        
        # Calculate second differences
        # z[k+1] - 2*z[k] + z[k-1] ≈ noise (for small dt)
        second_diff = Z_subset[2:] - 2*Z_subset[1:-1] + Z_subset[:-2]
        
        # Remove outliers using IQR method
        for i in range(6):
            q1, q3 = np.percentile(second_diff[:, i], [25, 75])
            iqr = q3 - q1
            mask = (second_diff[:, i] > q1 - 1.5*iqr) & (second_diff[:, i] < q3 + 1.5*iqr)
            second_diff = second_diff[mask]
        
        # Estimate noise standard deviation
        # Factor of sqrt(6) comes from variance of second difference
        noise_std = np.std(second_diff, axis=0) / np.sqrt(6)
        
    else:
        # Fallback for short sequences
        noise_std = np.std(Z, axis=0) * 0.001
    
    print("\n=== Improved Noise Estimation ===")
    print(f"Position noise (m): {noise_std[:3] * 1000}")  # Convert km to m
    print(f"Velocity noise (m/s): {noise_std[3:] * 0.1}")  # Convert dm/s to m/s
    
    # Set realistic GPS measurement noise
    # Modern GPS: 10-30m position, 0.1-0.3 m/s velocity
    r0_vec = np.array([
        (0.020)**2,  # 20m in km
        (0.020)**2,
        (0.030)**2,  # 30m vertical (worse than horizontal)
        (2.0)**2,    # 0.2 m/s in dm/s
        (2.0)**2,
        (3.0)**2,    # 0.3 m/s vertical
    ])
    
    print(f"\nMeasurement R (GPS-based):")
    print(f"  Position std (m): {np.sqrt(r0_vec[:3]) * 1000}")
    print(f"  Velocity std (m/s): {np.sqrt(r0_vec[3:]) * 0.1}")
    
    return r0_vec

def run_tuned_filter(adapt, desc, cols, Z, epochs, r0_vec):
    """Run filter with improved tuning for better adaptive behavior."""
    print(f"\n=== Running {desc} ===")
    
    # Create filter with tuned parameters
    ukf = OptimizedAUKF(
        cols, propagate_orbit,
        σ_a=1e-6 if adapt else 1e-7,  # Higher process noise for adaptive
        r0=r0_vec,
        γQ=0.99 if adapt else 1.0,     # Slower adaptation
        γR=0.98 if adapt else 1.0,     # More aggressive R adaptation
        adaptive_window=50             # Smaller window for faster response
    )
    
    # Initialize
    ukf.init_from_measurement(0.0, Z[0])
    
    # For adaptive filter, perturb initial R to test adaptation
    if adapt:
        ukf.R = ukf.R * 5.0  # Start with 5x larger R
        print("Starting adaptive filter with inflated R to test adaptation")
    
    # Process measurements
    t_prev = epochs[0]
    succ = 0
    
    # Track R evolution for adaptive filter
    R_history = []
    
    pbar = tqdm(enumerate(zip(epochs[1:], Z[1:])), desc=desc, total=len(epochs)-1)
    for i, (t, z) in pbar:
        # Calculate time step
        if isinstance(t, pd.Timestamp):
            dt = (t - t_prev).total_seconds()
        else:
            dt = float((t - t_prev).astype('timedelta64[s]').astype(int))
        
        if dt <= 0 or dt > 300:
            continue
        
        ukf.predict(dt)
        ukf.update(z)
        succ += 1
        t_prev = t
        
        # Save R diagonal periodically
        if i % 100 == 0:
            R_history.append(np.diag(ukf.R).copy())
            val = validate_orbit_parameters(ukf.x)
            pbar.set_postfix(alt=f"{val['altitude_km']:.0f}km", 
                           R_pos=f"{np.sqrt(ukf.R[0,0])*1000:.0f}m")
    
    print(f"Completed: {succ} updates")
    
    # Save R history for plotting
    ukf.R_history = np.array(R_history) if R_history else None
    
    return ukf

def create_comparison_plots(ukf_a, ukf_f, cols):
    """Create detailed comparison plots with robust RMS calculation."""
    print("\n=== Creating Comparison Plots ===")
    
    # Extract histories
    states_a = np.array([s for _, s in ukf_a.hist])
    states_f = np.array([s for _, s in ukf_f.hist])
    
    # Create figure with subplots
    fig = plt.figure(figsize=(15, 12))
    
    # 1. R adaptation plot (if available)
    if hasattr(ukf_a, 'R_history') and ukf_a.R_history is not None:
        ax1 = plt.subplot(3, 2, 1)
        
        # Plot position R evolution
        for i in range(3):
            ax1.plot(np.sqrt(ukf_a.R_history[:, i]) * 1000, 
                     label=f'{["X", "Y", "Z"][i]} (Adaptive)')
        
        # Add fixed values as horizontal lines
        for i in range(3):
            ax1.axhline(np.sqrt(ukf_f.R[i, i]) * 1000, 
                        color='red', linestyle='--', alpha=0.5)
        
        ax1.set_xlabel('Update (x100)')
        ax1.set_ylabel('Position Measurement Std (m)')
        ax1.set_title('Measurement Noise Adaptation')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Velocity R evolution
        ax2 = plt.subplot(3, 2, 2)
        for i in range(3, 6):
            ax2.plot(np.sqrt(ukf_a.R_history[:, i]) * 0.1,
                     label=f'{["Vx", "Vy", "Vz"][i-3]} (Adaptive)')
        
        for i in range(3, 6):
            ax2.axhline(np.sqrt(ukf_f.R[i, i]) * 0.1,
                        color='red', linestyle='--', alpha=0.5)
        
        ax2.set_xlabel('Update (x100)')
        ax2.set_ylabel('Velocity Measurement Std (m/s)')
        ax2.set_title('Velocity Noise Adaptation')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

    # 2. Innovation statistics comparison (Robustly Calculated)
    ax3 = plt.subplot(3, 2, 3)
    
    window = 100
    if ukf_a._innov_history and ukf_f._innov_history and \
       len(ukf_a._innov_history) > window and len(ukf_f._innov_history) > window:
        
        innov_a = np.array(ukf_a._innov_history)
        innov_f = np.array(ukf_f._innov_history)

        # Use pandas for a robust rolling RMS calculation
        innov_a_df = pd.DataFrame(innov_a)
        innov_f_df = pd.DataFrame(innov_f)

        # Get indices to sample the rolling window, matching original logic
        sample_indices = range(window - 1, len(innov_a), 10)

        # Calculate rolling RMS for the adaptive filter
        rolling_ms_a = innov_a_df.pow(2).rolling(window=window).mean()
        rms_a = np.sqrt(rolling_ms_a.iloc[sample_indices]).values

        # Calculate rolling RMS for the fixed filter
        rolling_ms_f = innov_f_df.pow(2).rolling(window=window).mean()
        rms_f = np.sqrt(rolling_ms_f.iloc[sample_indices]).values
        
        # Plot position innovation RMS
        for i in range(3):
            ax3.plot(rms_a[:, i], label=f'{["X", "Y", "Z"][i]} (Adaptive)')
            ax3.plot(rms_f[:, i], '--', label=f'{["X", "Y", "Z"][i]} (Fixed)')
            
    ax3.set_xlabel('Update (x10)') # Note: The x-axis scale changes due to sampling
    ax3.set_ylabel('Innovation RMS (km)')
    ax3.set_title('Position Innovation RMS')
    ax3.legend(fontsize=8)
    ax3.grid(True, alpha=0.3)

    # ... (The rest of your function remains the same)
    
    # 3. NIS comparison with statistics
    ax4 = plt.subplot(3, 2, 4)
    
    stats_a = ukf_a.get_innovation_stats()
    stats_f = ukf_f.get_innovation_stats()
    
    if stats_a and stats_f:
        nis_a = stats_a['nis']
        nis_f = stats_f['nis']
        
        # Calculate running mean NIS
        window = 100
        nis_mean_a = np.convolve(nis_a, np.ones(window)/window, mode='valid')
        nis_mean_f = np.convolve(nis_f, np.ones(window)/window, mode='valid')
        
        ax4.plot(nis_mean_a, 'b-', label='Adaptive', alpha=0.7)
        ax4.plot(nis_mean_f, 'r-', label='Fixed', alpha=0.7)
        ax4.axhline(6, color='green', linestyle='-', label='Expected')
        ax4.axhline(chi2.ppf(0.95, 6), color='k', linestyle='--', alpha=0.5)
        
        ax4.set_xlabel('Update')
        ax4.set_ylabel('Mean NIS (100-pt window)')
        ax4.set_title('NIS Running Mean')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
    
    # 4. Position error (if ground truth available)
    ax5 = plt.subplot(3, 2, 5)
    
    # Since we don't have ground truth, plot position difference
    pos_diff = np.linalg.norm(states_a[:, :3] - states_f[:, :3], axis=1)
    ax5.plot(pos_diff)
    ax5.set_xlabel('Update')
    ax5.set_ylabel('Position Difference (m)')
    ax5.set_title('Adaptive vs Fixed Position Difference')
    ax5.grid(True, alpha=0.3)
    
    # 5. Summary statistics
    ax6 = plt.subplot(3, 2, 6)
    ax6.axis('off')
    
    # Calculate statistics
    if stats_a and stats_f:
        nis_a_mean = np.mean(nis_a)
        nis_f_mean = np.mean(nis_f)
        nis_a_std = np.std(nis_a)
        nis_f_std = np.std(nis_f)
        
        # Chi-squared test
        chi2_95 = chi2.ppf(0.95, 6)
        nis_a_pct = np.mean(nis_a < chi2_95) * 100
        nis_f_pct = np.mean(nis_f < chi2_95) * 100
    
        summary = f"""Performance Comparison:
    
Adaptive Filter:
  Mean NIS: {nis_a_mean:.2f} ± {nis_a_std:.2f}
  NIS < χ²(95%): {nis_a_pct:.1f}%
  Final R_pos: {np.sqrt(np.diag(ukf_a.R)[:3]).mean()*1000:.1f} m
  Final R_vel: {np.sqrt(np.diag(ukf_a.R)[3:]).mean()*0.1:.2f} m/s
  
Fixed Filter:
  Mean NIS: {nis_f_mean:.2f} ± {nis_f_std:.2f}
  NIS < χ²(95%): {nis_f_pct:.1f}%
  Final R_pos: {np.sqrt(np.diag(ukf_f.R)[:3]).mean()*1000:.1f} m
  Final R_vel: {np.sqrt(np.diag(ukf_f.R)[3:]).mean()*0.1:.2f} m/s
  
Position RMSE: {np.sqrt(np.mean(pos_diff**2)):.2f} m
Max difference: {np.max(pos_diff):.2f} m

Expected NIS: 6.00 ± 3.46 (χ²(6))
    """
    
        ax6.text(0.1, 0.9, summary, transform=ax6.transAxes,
                 verticalalignment='top', fontfamily='monospace', fontsize=10)
    
    plt.tight_layout()
    plt.savefig('aukf_detailed_comparison.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("✅ aukf_detailed_comparison.png")

def main():
    """Run improved AUKF comparison."""
    
    # Load data
    df = pd.read_parquet("GPS_clean.parquet")
    cols = json.load(open("meas_cols.json"))
    
    # Apply unit corrections
    print("Applying unit corrections...")
    pos_cols = cols[:3]
    vel_cols = cols[3:]
    
    # Check if positions need scaling
    mean_radius = np.mean(np.linalg.norm(df[pos_cols].values, axis=1))
    if mean_radius > 10000:  # Likely in meters
        print("  Converting positions from m to km")
        for col in pos_cols:
            df[col] = df[col] / 1000
    
    # Check if velocities need scaling
    mean_vel = np.mean(np.linalg.norm(df[vel_cols].values, axis=1))
    if 7000 < mean_vel < 8000:  # Likely in m/s
        print("  Converting velocities from m/s to dm/s")
        for col in vel_cols:
            df[col] = df[col] * 10
    
    # Get data
    epochs = pd.to_datetime(df['time'])
    Z = df[cols].values
    
    # Subsample
    subsample = slice(0, len(df), 10)
    epochs = epochs[subsample]
    Z = Z[subsample]
    
    print(f"\nProcessing {len(epochs)} measurements")
    
    # Estimate noise
    r0_vec = estimate_measurement_noise_improved(Z)
    
    # Run filters
    ukf_a = run_tuned_filter(True, "Adaptive UKF", cols, Z, epochs, r0_vec)
    ukf_f = run_tuned_filter(False, "Fixed UKF", cols, Z, epochs, r0_vec)
    
    if ukf_a and ukf_f:
        # Create detailed comparison
        create_comparison_plots(ukf_a, ukf_f, cols)
        print("\n✅ Analysis complete!")
    
if __name__ == "__main__":
    main()

Applying unit corrections...
  Converting positions from m to km
  Converting velocities from m/s to dm/s

Processing 146880 measurements

=== Improved Noise Estimation ===
Position noise (m): [123.15019664 210.06687122 246.27691146]
Velocity noise (m/s): [0.14854551 0.23531059 0.26895924]

Measurement R (GPS-based):
  Position std (m): [20. 20. 30.]
  Velocity std (m/s): [0.2 0.2 0.3]

=== Running Adaptive UKF ===
Starting adaptive filter with inflated R to test adaptation


Adaptive UKF: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 146879/146879 [00:45<00:00, 3195.32it/s, R_pos=1588m, alt=470km]


Completed: 25919 updates

=== Running Fixed UKF ===


Fixed UKF: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 146879/146879 [00:42<00:00, 3438.81it/s, R_pos=1000m, alt=719km]
  ax3.legend(fontsize=8)


Completed: 25919 updates

=== Creating Comparison Plots ===
✅ aukf_detailed_comparison.png

✅ Analysis complete!
