In [2]:
# -*- coding: utf-8 -*-
"""
Batch Analysis of Multiple Sunspot Prediction Models - Complete Feature Repair + Comprehensive Fit Results
Function: Loads position and velocity data simultaneously, fixes feature matching issues, and generates comprehensive CSV results.
     (v7: Spectral plot changed to 1x3 subplots: Base SSN | Fitted Values | Residuals)
"""

import pandas as pd
import numpy as np
import joblib
import os
import matplotlib.pyplot as plt
from scipy import signal
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from pybaselines.whittaker import asls
import warnings
import logging
from datetime import datetime

# --- User Configuration ---
MODEL_DIR = '../../results/05_p_m_a_model/p_model_4'
MODEL_FILES = [
    '12stars_Ridge_CV-R2_0.6880_OOT-SMOOTH-R2_0.4587_OOT-RAW-R2_0.3326_Params_alpha_0.1000.joblib',
    '15stars_Ridge_CV-R2_0.7291_OOT-SMOOTH-R2_0.4643_OOT-RAW-R2_0.3510_Params_alpha_0.0010.joblib',
    '21stars_Ridge_CV-R2_0.7281_OOT-SMOOTH-R2_0.5876_OOT-RAW-R2_0.4431_Params_alpha_0.1339.joblib',
    '25stars_Ridge_CV-R2_0.6885_OOT-SMOOTH-R2_0.6117_OOT-RAW-R2_0.4448_Params_alpha_0.3329.joblib'
]
MODEL_LABELS = ['M8+2', 'M8+3', 'M0+3', 'M0+2']
DATA_DIR = '../../data/ready'

# Expecting CSV with columns: ['Day', 'SSN']
SSN_FILE = os.path.join(DATA_DIR, 'ssn_daily_1849_2025.csv')

# Expecting CSV with columns: ['Year', 'Month', 'SSN'] (Previously '年', '月', '黑子数')
SIDC_MONTHLY_FILE = os.path.join(DATA_DIR, 'ssn_smoothed_monthly_1749_2025.csv')

# Expecting Parquet with columns: ['date', 'SSB_x', 'SSB_y', 'SSB_z'...]
PLANET_POSITION_FILE = os.path.join(DATA_DIR, '781_planets_dwarfs_asteroids_xyz.parquet')

# Expecting Parquet with columns: ['date', 'SSB_vx', 'SSB_vy', 'SSB_vz'...]
PLANET_VELOCITY_FILE = os.path.join(DATA_DIR, '781_planets_dwarfs_asteroids_velocity.parquet')

OUTPUT_DIR = '../../results/05_p_m_a_model/p_model_4/residual'

# --- Configuration End ---

# --- Helper Functions ---
def setup_logging():
    """Configure logging."""
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    log_filename = os.path.join(OUTPUT_DIR, f'analysis_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt')
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_filename, encoding='utf-8'),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

def load_planet_data(logger):
    """Load complete planet position and velocity data."""
    logger.info("Loading planet position and velocity data...")
    
    try:
        # Load position data
        df_position = pd.read_parquet(PLANET_POSITION_FILE)
        df_position = df_position.set_index(pd.to_datetime(df_position['date'])).drop('date', axis=1).sort_index()
        
        # Load velocity data
        df_velocity = pd.read_parquet(PLANET_VELOCITY_FILE)
        df_velocity = df_velocity.set_index(pd.to_datetime(df_velocity['date'])).drop('date', axis=1).sort_index()
        
        # Merge
        if not df_position.index.equals(df_velocity.index):
            logger.warning("Position and velocity file date indices do not match perfectly; taking intersection.")
        
        df_combined = df_position.join(df_velocity, how='inner').sort_index()
        logger.info(f"Data merge complete: Position={len(df_position)}, Velocity={len(df_velocity)}, Combined={len(df_combined)}")
            
        return df_combined
        
    except FileNotFoundError as e:
        logger.error(f"Data file not found: {e}")
        raise
    except Exception as e:
        logger.error(f"Error loading planet data: {e}")
        raise

def load_sunspot_data(logger):
    """Load sunspot data."""
    logger.info(f"Loading sunspot data: {SSN_FILE}")
    
    try:
        # Load CSV without parsing dates first to inspect columns
        df = pd.read_csv(SSN_FILE)
        
        # Standardize all column names to lowercase (e.g., 'Date' -> 'date', 'SSN' -> 'ssn')
        df.columns = [c.lower().strip() for c in df.columns]
        
        # Identify the date column
        if 'date' in df.columns:
            date_col = 'date'
        elif 'day' in df.columns:
            date_col = 'day'
        else:
            # If neither found, show error with available columns
            raise ValueError(f"Missing date column. Found columns: {list(df.columns)}")
            
        # Identify the SSN value column
        if 'ssn' in df.columns:
            val_col = 'ssn'
        elif 'sunspot' in df.columns:
            val_col = 'sunspot'
        else:
             # Fallback: assume the second column is values if 'ssn' is missing
             if len(df.columns) >= 2:
                 val_col = df.columns[1]
                 logger.warning(f"Column 'ssn' not found. Using 2nd column '{val_col}' as SSN values.")
             else:
                 raise ValueError(f"Missing SSN column. Found columns: {list(df.columns)}")

        # Parse dates and set index
        df[date_col] = pd.to_datetime(df[date_col])
        df_sunspot_raw = df.set_index(date_col)[val_col].asfreq('D').fillna(0)
        
        # Rename Series to standard 'SSN' for downstream compatibility
        df_sunspot_raw.name = 'SSN'
        
        return df_sunspot_raw

    except Exception as e:
        logger.error(f"Error loading sunspot data: {e}")
        raise

def load_sidc_interpolated_data(logger, file_path, full_date_range):
    """Load SIDC monthly smoothed data and interpolate to daily."""
    logger.info(f"Loading SIDC monthly smoothed data: {file_path}")
    try:
        monthly_df = pd.read_csv(file_path)
        
        # Updated to English column names: 'Year', 'Month', 'SSN'
        # Parse dates
        monthly_df['date'] = pd.to_datetime(
            monthly_df['Year'].astype(int).astype(str) + '-' + 
            monthly_df['Month'].astype(int).astype(str) + '-01'
        )
        monthly_df.set_index('date', inplace=True)
        monthly_df.rename(columns={'SSN': 'smoothed_number'}, inplace=True)
        
        # Filter invalid data
        monthly_df = monthly_df[monthly_df['smoothed_number'] != -1]
        
        # Resample to daily and interpolate linearly
        daily_sidc = monthly_df['smoothed_number'].resample('D').interpolate(method='linear')
        
        # Reindex to match the full planet data date range
        daily_sidc = daily_sidc.reindex(full_date_range).rename('SIDC_SSN')
        
        logger.info("SIDC monthly smoothed data loaded and interpolated to daily.")
        return daily_sidc
        
    except FileNotFoundError:
        logger.error(f"SIDC monthly data file not found: {file_path}. Skipping this column.")
        return None
    except KeyError as e:
        logger.error(f"Column name error in SIDC file: {e}. Please ensure columns are ['Year', 'Month', 'SSN'].")
        return None
    except Exception as e:
        logger.error(f"Error loading SIDC monthly data: {e}. Skipping this column.")
        return None

def get_smoothed_sunspots(raw_sunspot_series, logger):
    """Calculate smoothed sunspot trend line."""
    logger.info("Calculating SSN smoothed trend line...")
    OPTIMAL_LAMBDA = 7e7
    raw_values = raw_sunspot_series.values
    smoothed_values, _ = asls(raw_values, lam=OPTIMAL_LAMBDA, p=0.5)
    logger.info("SSN smoothed trend line calculation complete.")
    return pd.Series(smoothed_values, index=raw_sunspot_series.index)

def calculate_psd_welch(data_series, fs=1.0, nperseg=365*22):
    """Calculate Power Spectral Density using Welch's method."""
    # Ensure no NaN
    valid_data_series = data_series.dropna()
    if valid_data_series.empty:
        return np.array([]), np.array([])
        
    sig = valid_data_series.values - valid_data_series.mean()
    frequencies, psd = signal.welch(sig, fs=fs, nperseg=nperseg)
    
    valid_mask = frequencies > 0
    if not np.any(valid_mask):
        return np.array([]), np.array([])
        
    periods_days = 1 / frequencies[valid_mask]
    power_spectral_density = psd[valid_mask]
    
    sort_idx = np.argsort(periods_days)
    return periods_days[sort_idx], power_spectral_density[sort_idx]

def find_main_periods(periods, power, top_n=100, min_period=10, max_period=20000):
    """
    Find main periods from the power spectrum.
    (Sorted by power descending, i.e., significance)
    """
    mask = (periods >= min_period) & (periods <= max_period)
    periods_valid = periods[mask]
    power_valid = power[mask]
    
    if len(periods_valid) == 0:
        return np.full(top_n, np.nan)
    
    # --- Sort by power (significance) ---
    idx = np.argsort(power_valid)[-top_n:][::-1]
    main_periods = periods_valid[idx]
    
    if len(main_periods) < top_n:
        main_periods = np.pad(main_periods, (0, top_n - len(main_periods)), constant_values=np.nan)
    
    return main_periods

def enhanced_spectral_analysis(residuals_series, label, logger):
    """
    Enhanced spectral analysis (for residuals).
    (Returns periods, psd, stats for reuse)
    """
    logger.info(f"Performing spectral analysis on Residuals for {label}...")
    periods, psd = calculate_psd_welch(residuals_series, nperseg=365*22)
    
    main_periods = find_main_periods(periods, psd, top_n=10)
    valid_periods = main_periods[~np.isnan(main_periods)]
    
    stats = {
        'model': label,
        'total_variance': residuals_series.var(),
        'spectral_peak_frequency': periods[np.argmax(psd)] if len(periods) > 0 else np.nan,
        'top_periods': valid_periods,
        'mean_period': np.mean(valid_periods) if len(valid_periods) > 0 else np.nan,
        'period_std': np.std(valid_periods) if len(valid_periods) > 0 else np.nan,
    }
    
    logger.info(f"{label} Residual spectral analysis complete: {len(valid_periods)} valid periods found.")
    return periods, psd, stats

def calculate_model_metrics(all_residuals_ts, df_sunspot_raw, logger):
    """Calculate model performance metrics."""
    logger.info("Calculating model performance metrics...")
    metrics = []
    
    for label, residuals in all_residuals_ts.items():
        aligned_data = pd.DataFrame({
            'actual': df_sunspot_raw,
            'residual': residuals
        }).dropna()
        
        if len(aligned_data) == 0:
            logger.warning(f"Model {label} has no valid aligned data, skipping metrics.")
            continue
            
        actual = aligned_data['actual']
        residual = aligned_data['residual']
        
        mse = np.mean(residual**2)
        rmse = np.sqrt(mse)
        mae = np.mean(np.abs(residual))
        
        model_metrics = {
            'Model': label,
            'RMSE': rmse,
            'MAE': mae,
            'MSE': mse,
            'Mean_Residual': np.mean(residual),
            'Std_Residual': np.std(residual),
            'Max_Residual': np.max(np.abs(residual)),
            'R2_vs_actual': max(0, 1 - (np.sum(residual**2) / np.sum((actual - actual.mean())**2))),
            'Data_Points': len(aligned_data),
        }
        metrics.append(model_metrics)
        
        logger.info(f"{label}: RMSE={rmse:.3f}, MAE={mae:.3f}, R²={model_metrics['R2_vs_actual']:.3f}")
    
    return pd.DataFrame(metrics)

def extract_model_info(saved_data, label, logger):
    """Extract model information from saved data."""
    pipeline = saved_data.get('model_pipeline')
    features = saved_data.get('features')
    
    if pipeline is None:
        logger.error(f"Model {label} missing 'model_pipeline'")
        return None, None, None
    
    if features is None:
        logger.error(f"Model {label} missing 'features'")
        return None, None, None
    
    model_info = {
        'Model': label, 
        'cv_r2_score': saved_data.get('cv_r2_score'),
        'oot_r2_score': saved_data.get('oot_r2_score'),
        'best_params': saved_data.get('best_params'),
        'star_count': saved_data.get('star_count'),
        'dimension_mode': saved_data.get('dimension_mode')
    }
    
    logger.info(f"{label} - Feature count: {len(features)}, Stars: {model_info['star_count']}, Mode: {model_info['dimension_mode']}")
    logger.info(f"{label} - CV R²: {model_info.get('cv_r2_score', 'N/A'):.4f}, OOT R²: {model_info.get('oot_r2_score', 'N/A'):.4f}")
    
    return pipeline, features, model_info

# --- Main Analysis Flow ---
def main_analysis():
    """Main analysis function."""
    logger = setup_logging()
    logger.info(f"Starting batch analysis of {len(MODEL_FILES)} models...")
    
    try:
        # Load base data
        df_sunspot_raw = load_sunspot_data(logger)
        df_planet_all = load_planet_data(logger)
        
        # Define global date index
        full_date_range = df_planet_all.index
        
        # Calculate smoothed trend line
        df_sunspot_smooth = get_smoothed_sunspots(df_sunspot_raw, logger)

        # Load SIDC interpolated data
        df_sidc_daily = load_sidc_interpolated_data(logger, SIDC_MONTHLY_FILE, full_date_range)

        # Prepare result storage
        all_results_for_csv = {} # Stores fitted values and residuals for final CSV
        all_residuals_ts = {} # Stores valid residuals for statistics and spectrum
        all_spectral_stats = []
        all_model_info_list = [] # Stores model info dictionaries
        all_psd_results = {} # Stores Residual PSD results
        all_psd_fits_results = {} # (v7) Stores Fitted Value PSD results

        # Loop through each model
        for model_file, label in zip(MODEL_FILES, MODEL_LABELS):
            logger.info(f"--- Processing Model: {label} ({model_file}) ---")
            
            try:
                # Load model
                model_path = os.path.join(MODEL_DIR, model_file)
                logger.info(f"Loading model file: {model_path}")
                
                if not os.path.exists(model_path):
                    logger.error(f"Model file does not exist: {model_path}")
                    continue
                    
                saved_data = joblib.load(model_path)
                pipeline, features, model_info = extract_model_info(saved_data, label, logger)
                
                if pipeline is None:
                    continue
                    
                all_model_info_list.append(model_info)

                # Check features
                missing_features = set(features) - set(df_planet_all.columns)
                if missing_features:
                    logger.error(f"{label} is missing {len(missing_features)} features. Model requires all original features.")
                    logger.error(f"Example missing features: {list(missing_features)[:5]}")
                    logger.error(f"Skipping model: {model_file}")
                    continue 
                
                logger.info(f"{label} using all {len(features)} features for prediction.")
                X_full = df_planet_all[features].copy()
                
                # Predict
                logger.info("Performing full prediction...")
                ssn_pred = pipeline.predict(X_full)
                ssn_pred_series = pd.Series(ssn_pred, index=X_full.index, name=f'Fit_SSN_{label}')
                
                logger.info(f"Calculating residuals and storing results for {label}...")
                
                # Calculate Residuals
                df_temp_calc = pd.DataFrame({'pred': ssn_pred_series, 'actual': df_sunspot_raw})
                residual_series = df_temp_calc['actual'] - df_temp_calc['pred']
                residual_series.name = f'Residual_{label}'

                # Store for comprehensive CSV
                all_results_for_csv[label] = pd.concat([ssn_pred_series, residual_series], axis=1)
                
                # Extract valid residuals on 'df_sunspot_raw' original index
                valid_residuals = residual_series.reindex(df_sunspot_raw.index).dropna()
                all_residuals_ts[label] = valid_residuals

                # Spectral Analysis (Residuals)
                periods, psd, spectral_stats = enhanced_spectral_analysis(valid_residuals, label, logger)
                all_spectral_stats.append(spectral_stats)
                all_psd_results[label] = (periods, psd) 

                # (v7) Calculate Fit Spectrum
                try:
                    logger.info(f"Calculating spectrum for Fitted Values of {label}...")
                    periods_fit, psd_fit = calculate_psd_welch(ssn_pred_series.dropna(), nperseg=365*22)
                    if len(periods_fit) > 0:
                        all_psd_fits_results[label] = (periods_fit, psd_fit)
                except Exception as e:
                    logger.error(f"Error calculating Fit spectrum for {label}: {e}")
                
                # Cleanup
                del X_full, ssn_pred, ssn_pred_series, residual_series, df_temp_calc, valid_residuals

            except Exception as e:
                logger.error(f"Error processing model {model_file}: {e}", exc_info=True)
                continue

        if not all_residuals_ts:
            logger.error("No models successfully processed. Analysis aborted.")
            return

        # --- Generate Comprehensive Results CSV ---
        
        logger.info("Generating comprehensive fitted results CSV...")
        df_complete_results = pd.DataFrame(index=full_date_range)

        # Base Data
        df_complete_results['Raw_SSN'] = df_sunspot_raw
        df_complete_results['Smoothed_SSN'] = df_sunspot_smooth
        if df_sidc_daily is not None:
            df_complete_results['SIDC_SSN'] = df_sidc_daily

        # Add fits and residuals for all models
        for label in MODEL_LABELS:
            if label in all_results_for_csv:
                df_complete_results = df_complete_results.join(all_results_for_csv[label])

        # Save complete results
        csv_complete_filename = os.path.join(OUTPUT_DIR, 'Summary_Fit_Results.csv')
        df_complete_results.to_csv(csv_complete_filename, index=True, index_label='Date', encoding='utf-8-sig')
        logger.info(f"Saved complete results: {csv_complete_filename}")
        logger.info(f"File contains {len(df_complete_results.columns)} columns: {list(df_complete_results.columns)}")

        # --- Merge Model Info and Metrics ---
        logger.info("Merging model info and metrics...")
        try:
            model_info_df = pd.DataFrame(all_model_info_list)
            metrics_df = calculate_model_metrics(all_residuals_ts, df_sunspot_raw, logger)
            
            if not model_info_df.empty and not metrics_df.empty:
                combined_stats_df = pd.merge(model_info_df, metrics_df, on='Model', how='outer')
            elif not model_info_df.empty:
                combined_stats_df = model_info_df
            elif not metrics_df.empty:
                combined_stats_df = metrics_df
            else:
                combined_stats_df = pd.DataFrame() 
                
            if not combined_stats_df.empty:
                stats_filename = os.path.join(OUTPUT_DIR, 'Model_Comprehensive_Stats.csv')
                combined_stats_df.to_csv(stats_filename, index=False, encoding='utf-8-sig')
                logger.info(f"Saved combined stats: {stats_filename}")
            else:
                logger.warning("Model info and metrics are empty. No stats file generated.")

        except Exception as e:
            logger.error(f"Error merging stats: {e}", exc_info=True)
        
        # Save Spectral Stats (Residual based)
        spectral_stats_df = pd.DataFrame(all_spectral_stats)
        spectral_stats_filename = os.path.join(OUTPUT_DIR, 'Spectral_Stats.csv')
        spectral_stats_df.to_csv(spectral_stats_filename, index=False, encoding='utf-8-sig')
        logger.info(f"Saved spectral stats: {spectral_stats_filename}")

        # --- Spectral Analysis Visualization (v7) ---
        logger.info("Plotting 1x3 Spectral Analysis (Welch)...")
        
        # 1. Calculate Base SSN Spectrum
        logger.info("Calculating spectrum for Raw, Smoothed, and SIDC SSN...")
        psd_results_base = {}
        try:
            periods_raw, psd_raw = calculate_psd_welch(df_sunspot_raw)
            if len(periods_raw) > 0:
                psd_results_base['Raw_SSN'] = (periods_raw, psd_raw)
                
            periods_smooth, psd_smooth = calculate_psd_welch(df_sunspot_smooth)
            if len(periods_smooth) > 0:
                psd_results_base['Smoothed_SSN'] = (periods_smooth, psd_smooth)
            
            if df_sidc_daily is not None:
                periods_sidc, psd_sidc = calculate_psd_welch(df_sidc_daily)
                if len(periods_sidc) > 0:
                    psd_results_base['SIDC_SSN'] = (periods_sidc, psd_sidc)
                    
        except Exception as e:
            logger.error(f"Error calculating base SSN spectrum: {e}", exc_info=True)
        
        # 2. Create 1x3 Subplots
        fig_psd, axes = plt.subplots(1, 3, figsize=(24, 8), sharey=True)
        all_periods_dict = {} 
        
        # Styles
        base_styles = {
            'Raw_SSN': ('black', '--', 2.0),
            'Smoothed_SSN': ('gray', ':', 2.0),
            'SIDC_SSN': ('orange', '-.', 2.0)
        }
        colors = ['blue', 'red', 'green', 'purple'] 
        
        # --- 3. Left Plot (Base SSN) ---
        ax_left = axes[0]
        for label, (periods, psd) in psd_results_base.items():
            style = base_styles.get(label)
            ax_left.loglog(periods, psd, label=label.replace('_', ' '), color=style[0], linestyle=style[1], linewidth=style[2], alpha=0.9)
        
        ax_left.set_title('Base SSN Spectrum')
        ax_left.set_xlabel('Period (Days)', fontsize=12)
        ax_left.set_ylabel('Power Spectral Density', fontsize=12)
        ax_left.legend(fontsize=10)
        ax_left.grid(True, which="both", ls="--", alpha=0.7)

        # --- 4. Middle Plot (Fits) ---
        ax_mid = axes[1]
        for i, (label, (periods, psd)) in enumerate(all_psd_fits_results.items()):
            color = colors[i % len(colors)]
            ax_mid.loglog(periods, psd, label=f'Fit - {label}', alpha=0.8, color=color, linewidth=1.5)

        ax_mid.set_title('Model Fits Spectrum')
        ax_mid.set_xlabel('Period (Days)', fontsize=12)
        ax_mid.legend(fontsize=10)
        ax_mid.grid(True, which="both", ls="--", alpha=0.7)

        # --- 5. Right Plot (Residuals) ---
        ax_right = axes[2]
        for i, (label, (periods, psd)) in enumerate(all_psd_results.items()):
            color = colors[i % len(colors)]
            ax_right.loglog(periods, psd, label=f'Resid - {label}', alpha=0.8, color=color, linewidth=1.5)
            
            # Find Top 100 periods for residuals
            main_periods = find_main_periods(periods, psd, top_n=100, max_period=20000)
            all_periods_dict[f'Period_{label}'] = main_periods

        ax_right.set_title('Model Residuals Spectrum')
        ax_right.set_xlabel('Period (Days)', fontsize=12)
        ax_right.legend(fontsize=10)
        ax_right.grid(True, which="both", ls="--", alpha=0.7)
        
        # --- 6. Save Plot ---
        fig_psd.suptitle(f'Spectral Comparison 1x3 (Welch Method)', fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        
        plot_filename = os.path.join(OUTPUT_DIR, 'Comparison_Spectrum_1x3.png')
        plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
        plt.close(fig_psd)
        logger.info(f"Saved 1x3 Spectrum Plot: {plot_filename}")

        # Save Period Data (Residuals only)
        all_periods_df = pd.DataFrame({k: pd.Series(v) for k, v in all_periods_dict.items()})
        csv_periods_filename = os.path.join(OUTPUT_DIR, 'Comparison_Top100_Periods.csv')
        all_periods_df.to_csv(csv_periods_filename, index=False, encoding='utf-8-sig')
        logger.info(f"Saved Periods CSV: {csv_periods_filename}")

        # ACF/PACF Analysis
        logger.info("Plotting ACF/PACF...")
        num_models = len(all_residuals_ts)
        if num_models == 0:
            logger.warning("No models available for ACF/PACF analysis.")
        else:
            fig_acf, axes = plt.subplots(num_models, 2, figsize=(16, 5 * num_models), squeeze=False) 
            fig_acf.suptitle('ACF and PACF Analysis of Model Residuals', fontsize=20, y=1.02)
            
            if num_models == 1:
                fig_acf.set_size_inches(16, 5)
                
            for i, (label, residuals) in enumerate(all_residuals_ts.items()):
                plot_acf(residuals, lags=60, ax=axes[i, 0], title=f'ACF - {label}')
                axes[i, 0].set_xlabel('Lags')
                plot_pacf(residuals, lags=60, ax=axes[i, 1], title=f'PACF - {label}', method='ywm')
                axes[i, 1].set_xlabel('Lags')

            plt.tight_layout()
            acf_plot_filename = os.path.join(OUTPUT_DIR, 'Comparison_ACF_PACF.png')
            plt.savefig(acf_plot_filename, dpi=300, bbox_inches='tight')
            plt.close(fig_acf) 
            logger.info(f"Saved ACF/PACF Plot: {acf_plot_filename}")
        
        logger.info("Batch analysis complete!")
        
    except Exception as e:
        logger.error(f"Error during analysis: {e}", exc_info=True)
        raise

if __name__ == "__main__":
    # Filter warnings
    warnings.filterwarnings('ignore', category=RuntimeWarning)
    warnings.filterwarnings('ignore', category=UserWarning)
    warnings.filterwarnings('ignore', category=FutureWarning)
    
    # Fix matplotlib logging spam
    logging.getLogger('matplotlib').setLevel(logging.WARNING)
    
    main_analysis()

2025-12-29 17:02:05,024 - INFO - Starting batch analysis of 4 models...
2025-12-29 17:02:05,025 - INFO - Loading sunspot data: ../../data/ready\ssn_daily_1849_2025.csv
2025-12-29 17:02:05,057 - INFO - Loading planet position and velocity data...
2025-12-29 17:02:09,601 - INFO - Data merge complete: Position=73780, Velocity=73780, Combined=73780
2025-12-29 17:02:09,653 - INFO - Calculating SSN smoothed trend line...
2025-12-29 17:02:09,730 - INFO - SSN smoothed trend line calculation complete.
2025-12-29 17:02:09,731 - INFO - Loading SIDC monthly smoothed data: ../../data/ready\ssn_smoothed_monthly_1749_2025.csv
2025-12-29 17:02:09,771 - INFO - SIDC monthly smoothed data loaded and interpolated to daily.
2025-12-29 17:02:09,772 - INFO - --- Processing Model: M8+2 (12stars_Ridge_CV-R2_0.6880_OOT-SMOOTH-R2_0.4587_OOT-RAW-R2_0.3326_Params_alpha_0.1000.joblib) ---
2025-12-29 17:02:09,773 - INFO - Loading model file: ../../results/05_p_m_a_model/p_model_4\12stars_Ridge_CV-R2_0.6880_OOT-SMOOT