# XGBOOST Forecast Implementation Stage

In [2]:
import numpy as np
from matplotlib.dates import DayLocator, MonthLocator, DateFormatter
from matplotlib.ticker import AutoLocator
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd
import statsmodels.api as sm
from scipy.signal import savgol_filter
from statsmodels.tsa.stattools import adfuller
import warnings
import pickle
warnings.filterwarnings("ignore")
from xgboost import XGBRegressor
import seaborn as sns
from sklearn.metrics import mean_absolute_error
import optuna
from optuna.visualization import plot_param_importances

# Import Clean Data

In [13]:
orbital_file = r"C:\Users\Suare\satellite_anomaly_project\data\cleaned\all_satellite_orbitals.csv"
maneuver_file = r"C:\Users\Suare\satellite_anomaly_project\data\cleaned\cleaned_maneuvers.csv"
df_orbital = pd.read_csv(orbital_file)
df_maneuver = pd.read_csv(maneuver_file)

In [15]:
#Filter the dataset with the satellites selected for the report
df_orbital_filtered = df_orbital[
    df_orbital['satellite_name'].isin(["Fengyun_2E", "Fengyun_2F","Fengyun_4A", "Jason_3", "Sentinel_3A","Sentinel_3B","CryoSat_2"])]

In [17]:
#Global Parameters
SCALING_FACTOR = 1e7
MAX_LAG = 15
N_TRIALS = 200
SATELLITES = df_orbital_filtered['satellite_name'].unique()
VARIABLES = ["Brouwer mean motion", "eccentricity", "argument of perigee", "right ascension", "inclination", "mean anomaly"]
SIGMAS = [10, 20]

# Plots

In [63]:
def plot_precision_recall_comparison_by_scenario_arima(
    pr_data_arima,
    satellite,
    variable,
    match_days_pair=(2, 3),
    figsize=(16, 7)
):
    """
    Plot precision-recall curves for ARIMA scenarios with colorblind-friendly colors
    and improved scenario descriptions.
    """
    import matplotlib.pyplot as plt

    # Define ARIMA scenario strings exactly as stored in results
    # Handle both string and integer scenario formats
    scenarios = [
        "Scenario 1",            # Scenario 1 might be stored as integer
        "Scenario 2",            # Scenario 2 might be stored as integer
        "Scenario 3 (MAD σ=10)", # Scenario 3 stored with full name 
        "Scenario 3 (MAD σ=20)", # Scenario 3 stored with full name
        4,                       # Scenario 4 might be stored as integer
        "Scenario 5 (win=7)",    # Scenario 5 stored with full name
        "Scenario 5 (win=9)"     # Scenario 5 stored with full name
    ]
    
    # Improved, concise scenario descriptions
    scenario_labels = [
        "Base Model",           # Scenario 1: Standard preprocessing, stationarity-based differencing
        "Forced Diff",          # Scenario 2: Always applies differencing  
        "MAD σ=10",            # Scenario 3: Outlier removal with 10-sigma threshold
        "MAD σ=20",            # Scenario 3: Outlier removal with 20-sigma threshold
        "Log Transform",        # Scenario 4: Log transformation
        "Smooth w=7",          # Scenario 5: Savitzky-Golay smoothing, window=7
        "Smooth w=9"           # Scenario 5: Savitzky-Golay smoothing, window=9
    ]
    
    # Truly colorblind-friendly, professional color palette
    # Avoids red-green combinations entirely and uses high-contrast colors
    colors = [
        '#0d0d0d',  # Deep blue (Base Model)
        '#ff7f0e',  # Orange (Forced Diff) - blue/orange is colorblind-safe
        '#ed1c24',  # Purple (MAD σ=10)
        '#ffa1a8',  # Light purple (MAD σ=20) - related to MAD σ=10 but distinguishable
        '#a2efe7',  # Brown (Log Transform) 
        '#4b51e9',  # Cyan (Smooth w=7)
        '#929af4'   # Light cyan (Smooth w=9) - related to w=7 but distinguishable
    ]

    fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=True)

    legend_lines = []
    legend_labels = []

    for ax_idx, match_days in enumerate(match_days_pair):
        ax = axes[ax_idx]
        ax.set_title(f"Matching Window: {match_days} Days", fontsize=14, fontweight='bold', pad=20)
        ax.set_xlabel("Recall", fontsize=12, fontweight='bold')
        if ax_idx == 0:
            ax.set_ylabel("Precision", fontsize=12, fontweight='bold')

        # Reference lines
        ax.axhline(0.5, color='lightgrey', linestyle='--', linewidth=1, alpha=0.7)
        ax.axvline(0.5, color='lightgrey', linestyle='--', linewidth=1, alpha=0.7)
        ax.fill_between([0.5, 1.0], 0.5, 1.0, color='lightblue', alpha=0.15)

        for scenario, scenario_label, color in zip(scenarios, scenario_labels, colors):
            found_match = False
            for entry in pr_data_arima:
                if (
                    entry["satellite"] == satellite and
                    entry["variable"] == variable and
                    entry["scenario"] == scenario and
                    entry["matching_max_days"] == match_days
                ):
                    found_match = True
                    # Different line styles for scenario variations to enhance distinction
                    scenario_str = str(scenario)
                    if "MAD σ=20" in scenario_str or "win=9" in scenario_str:
                        linestyle = '--'  # Dashed for the "second" variation
                        alpha = 0.8
                    else:
                        linestyle = '-'   # Solid for main scenarios and first variations
                        alpha = 0.9
                    
                    line = ax.plot(
                        entry["recall"],
                        entry["precision"],
                        label=scenario_label,
                        linewidth=2,
                        color=color,
                        linestyle=linestyle,
                        alpha=alpha
                        #marker='o',
                        #markersize=3,
                        #markevery=5
                    )[0]

                    if ax_idx == 0:
                        legend_lines.append(line)
                        legend_labels.append(scenario_label)
                    break
            
            
        ax.grid(True, linestyle='--', alpha=0.3, linewidth=0.5)
        ax.set_xlim(-0.02, 1.02)
        ax.set_ylim(-0.02, 1.02)
        ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
        ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
        ax.tick_params(labelsize=10)

    # Enhanced legend with better positioning
    legend = fig.legend(
        legend_lines,
        legend_labels,
        title="XGBoost Preprocessing Scenarios",
        title_fontsize=12,
        loc='upper center',
        bbox_to_anchor=(0.5, 0.96),
        ncol=min(4, len(scenarios)),  # Adaptive column count
        frameon=True,
        fancybox=True,
        shadow=True,
        fontsize=10,
        columnspacing=1.8,
        handletextpad=0.8
    )
    legend.get_title().set_fontweight('bold')

    fig.suptitle(
        f"XGBoost Precision-Recall Curves: {satellite} – {variable}",
        fontsize=16,
        fontweight='bold',
        y=0.99
    )

    plt.tight_layout(rect=[0, 0.02, 1, 0.89])
    plt.show()

# 1. Optimization

In [19]:
optuna.logging.set_verbosity(optuna.logging.WARNING)

# Define scenarios
SCENARIOS = {
    "Scenario 1": {"differencing": False, "mad": None, "description": "Baseline"},
    "Scenario 2": {"differencing": True, "mad": None, "description": "Differencing"},
    "Scenario 3 (MAD σ=10)": {"differencing": False, "mad": 10, "description": "MAD σ=10"},
    "Scenario 3 (MAD σ=20)": {"differencing": False, "mad": 20, "description": "MAD σ=20"}
}

def apply_mad(series_mm, satellite, variable, sigma_threshold):
    series_with_nan = series_mm.copy()
    
    if satellite == "CryoSat_2" and variable == "Brouwer mean motion":
        breakpoint_date = pd.Timestamp('2020-07-17')
        regime1_mask = series_mm.index < breakpoint_date
        regime2_mask = series_mm.index >= breakpoint_date
        outliers_all = []
        
        for regime_mask in [regime1_mask, regime2_mask]:
            regime_data = series_mm[regime_mask]
            if len(regime_data) == 0: 
                continue
            median_val = regime_data.median()
            sigma = (regime_data - median_val).abs().median() * 1.4826
            sigma = regime_data.std() if pd.isna(sigma) or sigma == 0 else sigma
            
            outlier_mask_regime = (regime_data - median_val).abs() > sigma_threshold * sigma
            regime_outliers = regime_data[outlier_mask_regime]
            
            # Replace outliers with NaN in the full series
            series_with_nan.loc[regime_mask & (series_mm.index.isin(regime_outliers.index))] = np.nan
            outliers_all.append(regime_outliers)
            
        outliers = pd.concat(outliers_all).sort_index() if any(len(o) > 0 for o in outliers_all) else pd.Series(dtype=float)
    else:
        median_val = series_mm.median()
        sigma = (series_mm - median_val).abs().median() * 1.4826
        sigma = series_mm.std() if pd.isna(sigma) or sigma == 0 else sigma
        outlier_mask = (series_mm - median_val).abs() > sigma_threshold * sigma
        outliers = series_mm[outlier_mask]
        
        # Replace outliers with NaN
        series_with_nan.loc[outlier_mask] = np.nan
        
    return series_with_nan, outliers

def preprocess_series(df, satellite, variable, scenario_key):
    df = df[df['satellite_name'] == satellite][['epoch', variable]].copy()
    df['epoch'] = pd.to_datetime(df['epoch'])
    df.set_index('epoch', inplace=True)

    scenario = SCENARIOS[scenario_key]

    if scenario["mad"] is not None:
        df_mm = (df - df.mean()) * SCALING_FACTOR
        cleaned_series, outliers = apply_mad(df_mm[variable], satellite, variable, scenario["mad"])
        df_clean = cleaned_series.to_frame('value')
        df_clean = df_clean.resample('D').mean()
        df_clean['value'] = df_clean['value'].bfill()
    else:
        df_clean = df.resample('D').mean()
        df_clean['value'] = df_clean[variable].bfill()
        if scenario["differencing"]:
            df_clean['value'] = df_clean['value'].diff()
            df_clean = df_clean.dropna(subset=['value'])
        df_clean['value'] = (df_clean['value'] - df_clean['value'].mean()) * SCALING_FACTOR
        outliers = None

    return df_clean, outliers

def prepare_lagged_df(df, lags):
    df = df.copy()
    for lag in lags:
        df[f'lag_{lag}'] = df['value'].shift(lag)
    df = df.dropna(subset=['value'])
    return df.iloc[max(lags):]

def objective(trial, df):
    lags = [i for i in range(1, MAX_LAG + 1) if trial.suggest_categorical(f'lag_{i}', [0, 1])]
    if len(lags) < 2:
        raise optuna.exceptions.TrialPruned()
    df_lagged = prepare_lagged_df(df.copy(), lags)
    X = df_lagged[[f'lag_{lag}' for lag in lags]]
    y = df_lagged['value']
    split = int(len(X) * 0.8)
    X_train, X_val = X[:split], X[split:]
    y_train, y_val = y[:split], y[split:]

    model = XGBRegressor(
        max_depth=trial.suggest_int("max_depth", 2, 6),
        learning_rate=trial.suggest_float("learning_rate", 0.01, 0.2, log=True),
        gamma=trial.suggest_float("gamma", 0, 5),
        reg_alpha=trial.suggest_float("reg_alpha", 0, 3),
        reg_lambda=trial.suggest_float("reg_lambda", 0, 3),
        colsample_bytree=trial.suggest_float("colsample_bytree", 0.6, 1.0),
        n_estimators=trial.suggest_int("n_estimators", 50, 300),
        eval_metric='mae',
        early_stopping_rounds=10,
        random_state=42,
        verbosity=0
    )
    model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)
    y_pred = model.predict(X_val)
    return mean_absolute_error(y_val, y_pred)

# Main optimization loop
results = []
SATELLITES = df_orbital_filtered['satellite_name'].unique()
VARIABLES = ["Brouwer mean motion", "eccentricity", "argument of perigee", "right ascension", "inclination", "mean anomaly"]
total_combinations = len(SATELLITES) * len(VARIABLES) * len(SCENARIOS)

count = 0
for sat in SATELLITES:
    for var in VARIABLES:
        for scen_key in SCENARIOS:
            count += 1
            print(f"[{count}/{total_combinations}] {sat} | {var} | {scen_key}")

            # Preprocess data for this combination
            df_prepared, outliers = preprocess_series(df_orbital_filtered, sat, var, scen_key)

            # Run Optuna optimization
            study = optuna.create_study(direction="minimize")
            study.optimize(lambda trial: objective(trial, df_prepared), n_trials=N_TRIALS)

            # Extract results
            best_trial = study.best_trial
            lags_used = [i for i in range(1, MAX_LAG + 1) if best_trial.params.get(f'lag_{i}', 0) == 1]

            # Get parameter importance (FANOVA evaluator for stability)
            try:
                param_importance = optuna.importance.get_param_importances(
                    study, 
                    evaluator=optuna.importance.FanovaImportanceEvaluator()
                )
            except:
                param_importance = {}

            # Store results
            results.append({
                "satellite": sat,
                "variable": var,
                "scenario": scen_key,
                "scenario_description": SCENARIOS[scen_key]["description"],
                "best_mae": best_trial.value,
                "lags": lags_used,
                "n_lags": len(lags_used),
                "params": best_trial.params,
                "importance": param_importance,
                "n_trials_completed": len([t for t in study.trials if t.state.name == 'COMPLETE']),
                "study": study,
                "outliers": outliers
            })

df_xgb_optimized_results3 = pd.DataFrame(results)
print("\n All XGBoost optimizations complete.")

[1/168] CryoSat_2 | Brouwer mean motion | Scenario 1
[2/168] CryoSat_2 | Brouwer mean motion | Scenario 2
[3/168] CryoSat_2 | Brouwer mean motion | Scenario 3 (MAD σ=10)
[4/168] CryoSat_2 | Brouwer mean motion | Scenario 3 (MAD σ=20)
[5/168] CryoSat_2 | eccentricity | Scenario 1
[6/168] CryoSat_2 | eccentricity | Scenario 2
[7/168] CryoSat_2 | eccentricity | Scenario 3 (MAD σ=10)
[8/168] CryoSat_2 | eccentricity | Scenario 3 (MAD σ=20)
[9/168] CryoSat_2 | argument of perigee | Scenario 1
[10/168] CryoSat_2 | argument of perigee | Scenario 2
[11/168] CryoSat_2 | argument of perigee | Scenario 3 (MAD σ=10)
[12/168] CryoSat_2 | argument of perigee | Scenario 3 (MAD σ=20)
[13/168] CryoSat_2 | right ascension | Scenario 1
[14/168] CryoSat_2 | right ascension | Scenario 2
[15/168] CryoSat_2 | right ascension | Scenario 3 (MAD σ=10)
[16/168] CryoSat_2 | right ascension | Scenario 3 (MAD σ=20)
[17/168] CryoSat_2 | inclination | Scenario 1
[18/168] CryoSat_2 | inclination | Scenario 2
[19/168] 

In [21]:
#Save results to file
df_xgb_optimized_results3.to_pickle('xgboost_optimization_results3.pkl')

# Retrain Model and Full Forecasting of Orbital Elements

In [27]:
# Retraining with updated scenario handling
retraining_results = []

print(f"Starting retraining for {len(df_xgb_optimized_results3)} models...")

for idx, row in df_xgb_optimized_results3.iterrows():
    satellite = row["satellite"]
    variable = row["variable"]
    scenario = row["scenario"]
    lags = row["lags"]
    best_params = row["params"]
    
    print(f"[{idx+1}/{len(df_xgb_optimized_results3)}] Retraining: {satellite} | {variable} | {scenario}")
    
    # Preprocess full dataset for retraining
    df_prepared, outliers = preprocess_series(df_orbital_filtered, satellite, variable, scenario)
    
    # Create lag features
    df_lagged = prepare_lagged_df(df_prepared, lags)
    X_all = df_lagged[[f'lag_{lag}' for lag in lags]]
    y_all = df_lagged['value']
    
    # Retrain full model
    model = XGBRegressor(
        max_depth=best_params['max_depth'],
        learning_rate=best_params['learning_rate'],
        gamma=best_params['gamma'],
        reg_alpha=best_params['reg_alpha'],
        reg_lambda=best_params['reg_lambda'],
        colsample_bytree=best_params['colsample_bytree'],
        n_estimators=best_params['n_estimators'],
        random_state=42,
        verbosity=0
    )
    model.fit(X_all, y_all)
    
    # Forecast and compute residuals
    predictions = model.predict(X_all)
    residuals = y_all - predictions
    mae = mean_absolute_error(y_all, predictions)
    
    # Feature importances
    try:
        weight_importance = model.get_booster().get_score(importance_type='weight')
    except:
        weight_importance = {}
    try:
        gain_importance = model.get_booster().get_score(importance_type='gain')
    except:
        gain_importance = {}
    try:
        cover_importance = model.get_booster().get_score(importance_type='cover')
    except:
        cover_importance = {}
    
    # Pack results
    result_df = pd.DataFrame({
        "epoch": X_all.index,
        "predicted": predictions,
        "original": y_all.values,
        "residuals": residuals
    })
    
    retraining_results.append({
        "satellite": satellite,
        "variable": variable,
        "scenario": scenario,
        "lags": lags,
        "mae": mae,
        "result_df": result_df,
        "weight_importance": weight_importance,
        "gain_importance": gain_importance,
        "cover_importance": cover_importance,
        "model": model,
        "outliers": outliers  # stored for scenario 3 variants
    })

print(f"\nRetraining complete! Processed {len(retraining_results)} models.")

Starting retraining for 168 models...
[1/168] Retraining: CryoSat_2 | Brouwer mean motion | Scenario 1
[2/168] Retraining: CryoSat_2 | Brouwer mean motion | Scenario 2
[3/168] Retraining: CryoSat_2 | Brouwer mean motion | Scenario 3 (MAD σ=10)
[4/168] Retraining: CryoSat_2 | Brouwer mean motion | Scenario 3 (MAD σ=20)
[5/168] Retraining: CryoSat_2 | eccentricity | Scenario 1
[6/168] Retraining: CryoSat_2 | eccentricity | Scenario 2
[7/168] Retraining: CryoSat_2 | eccentricity | Scenario 3 (MAD σ=10)
[8/168] Retraining: CryoSat_2 | eccentricity | Scenario 3 (MAD σ=20)
[9/168] Retraining: CryoSat_2 | argument of perigee | Scenario 1
[10/168] Retraining: CryoSat_2 | argument of perigee | Scenario 2
[11/168] Retraining: CryoSat_2 | argument of perigee | Scenario 3 (MAD σ=10)
[12/168] Retraining: CryoSat_2 | argument of perigee | Scenario 3 (MAD σ=20)
[13/168] Retraining: CryoSat_2 | right ascension | Scenario 1
[14/168] Retraining: CryoSat_2 | right ascension | Scenario 2
[15/168] Retraini

In [29]:
# Save the main results
with open('xgboost_retraining_results3.pkl', 'wb') as f:
    pickle.dump(retraining_results, f)

In [25]:
df_retrain_summary

Unnamed: 0,satellite,variable,scenario,lags,n_lags,mae
0,CryoSat_2,Brouwer mean motion,Scenario 1,"[1, 2, 5, 7, 8]",5,0.639958
1,CryoSat_2,Brouwer mean motion,Scenario 2,"[1, 6, 7]",3,0.324364
2,CryoSat_2,Brouwer mean motion,Scenario 3 (MAD σ=10),"[1, 2, 6]",3,0.643459
3,CryoSat_2,Brouwer mean motion,Scenario 3 (MAD σ=20),"[1, 2, 4, 7, 11, 12]",6,0.349658
4,CryoSat_2,eccentricity,Scenario 1,"[1, 2, 3]",3,51.604325
...,...,...,...,...,...,...
163,Sentinel_3B,inclination,Scenario 3 (MAD σ=20),"[1, 5, 6, 10, 11, 13, 15]",7,21.485159
164,Sentinel_3B,mean anomaly,Scenario 1,"[1, 2, 3, 8, 9, 11, 12, 15]",8,79374.695763
165,Sentinel_3B,mean anomaly,Scenario 2,"[1, 3, 6, 7, 9, 10, 11, 14]",8,24088.416074
166,Sentinel_3B,mean anomaly,Scenario 3 (MAD σ=10),"[1, 2, 5, 6, 9, 10, 12, 13, 14, 15]",10,101054.428138


# Evaluation Precision vs Recall Curves

In [33]:
ground_truth_dict = {}

for r in retraining_results:
    satellite = r['satellite']
    variable = r['variable']
    residuals_index = r['result_df'].index

    start_date = residuals_index.min()
    end_date = residuals_index.max()

    satellite_maneuvers = df_maneuver[df_maneuver['OrbitalKeyName'] == satellite]

    # Convert to datetime
    satellite_maneuvers['start_date'] = pd.to_datetime(satellite_maneuvers['start_date'], errors='coerce')

    valid_maneuvers = satellite_maneuvers[
        (satellite_maneuvers['start_date'] >= start_date) &
        (satellite_maneuvers['start_date'] <= end_date)
    ]['start_date'].dropna().sort_values()

    ground_truth_dict[(satellite, variable)] = valid_maneuvers.tolist()

In [35]:
def convert_timestamp_series_to_epoch(series):
    return (
        (series - pd.Timestamp(year=1970, month=1, day=1)) // pd.Timedelta(seconds=1)
    ).values
    
def compute_simple_matching_precision_recall_for_one_threshold(
    matching_max_days,
    threshold,
    series_ground_truth_manoeuvre_timestamps,
    series_predictions,
):
    """
    :param matching_max_days
    :param threshold
    :param series_ground_truth_manoeuvre_timestamps
    :param series_predictions: The index of this series should be the timestamps of the predictions.
    :return: (precision, recall)

   Computes the precision and recall at one anomaly threshold.

   Does this using an implementation of the framework proposed by Zhao:
   Zhao, L. (2021). Event prediction in the big data era: A systematic survey. ACM Computing Surveys (CSUR), 54(5), 1-37.
   https://doi.org/10.1145/3450287

   The method matches each manoeuvre prediction with the closest ground-truth manoeuvre, if it is within a time window.

   Predictions with a match are then true positives and those without a match are false positives. Ground-truth manoeuvres
   with no matching prediction are counted as false negatives.
   """

    matching_max_distance_seconds = pd.Timedelta(days=matching_max_days).total_seconds()

    dict_predictions_to_ground_truth = {}
    dict_ground_truth_to_predictions = {}

    manoeuvre_timestamps_seconds = convert_timestamp_series_to_epoch(series_ground_truth_manoeuvre_timestamps)
    pred_time_stamps_seconds = convert_timestamp_series_to_epoch(series_predictions.index)
    predictions = series_predictions.to_numpy()

    for i in range(predictions.shape[0]):
        if predictions[i] >= threshold:
            left_index = np.searchsorted(
                manoeuvre_timestamps_seconds, pred_time_stamps_seconds[i]
            )

            if left_index != 0:
                left_index -= 1

            index_of_closest = left_index

            if (left_index < series_ground_truth_manoeuvre_timestamps.shape[0] - 1) and (
                abs(manoeuvre_timestamps_seconds[left_index] - pred_time_stamps_seconds[i])
                > abs(manoeuvre_timestamps_seconds[left_index + 1] - pred_time_stamps_seconds[i])
            ):
                index_of_closest = left_index + 1

            diff = abs(manoeuvre_timestamps_seconds[index_of_closest] - pred_time_stamps_seconds[i])

            if diff < matching_max_distance_seconds:
                dict_predictions_to_ground_truth[i] = (
                    index_of_closest,
                    diff,
                )
                if index_of_closest in dict_ground_truth_to_predictions:
                    dict_ground_truth_to_predictions[index_of_closest].append(i)
                else:
                    dict_ground_truth_to_predictions[index_of_closest] = [i]

    positive_prediction_indices = np.argwhere(predictions >= threshold)[:, 0]
    list_false_positives = [
        pred_ind for pred_ind in positive_prediction_indices if pred_ind not in dict_predictions_to_ground_truth.keys()
    ]
    list_false_negatives = [
        true_ind for true_ind in np.arange(0, len(series_ground_truth_manoeuvre_timestamps))
        if true_ind not in dict_ground_truth_to_predictions.keys()
    ]

    tp = len(dict_ground_truth_to_predictions)
    fp = len(list_false_positives)
    fn = len(list_false_negatives)


    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0

    return precision, recall, tp, fp, fn

In [37]:
def evaluate_all_thresholds_with_flags_xgboost(
    results,
    ground_truth_dict,
    min_precision_best_recall=0.6
):
    """
    Evaluate XGBoost results for precision-recall analysis
    
    Args:
        results: retraining_results from XGBoost training
        ground_truth_dict: Dictionary with structure {satellite: [timestamps]}
        min_precision_best_recall: Minimum precision threshold for best recall strategy
    """
    summary = []
    pr_data = []

    for r in results:
        satellite = r['satellite']
        variable = r['variable']
        scenario = r['scenario']
        
        # Access residuals from result_df
        residuals_series = r['result_df']['residuals']
        abs_resid = residuals_series.abs()
        
        # Create series for predictions with proper index
        series_predictions = pd.Series(abs_resid.values, index=abs_resid.index)

        # Skip if we don't have ground truth for this satellite-variable combination
        if (satellite, variable) not in ground_truth_dict:
            print(f" No ground truth for satellite-variable: ({satellite}, {variable})")
            continue
            
        # Ground truth uses (satellite, variable) tuple as key
        filtered_ground_truth = pd.Series(ground_truth_dict[(satellite, variable)])

         # Define threshold range
        thresholds = series_predictions[
            (series_predictions > series_predictions.quantile(0.5)) &
            (series_predictions < series_predictions.quantile(0.99))
        ].unique()
        thresholds = sorted(thresholds, reverse=True)

        # Define matching days based on variable
        matching_days_list = [1, 2, 3, 4] if variable == "Brouwer mean motion" else [3]

        for match_days in matching_days_list:
            print(f"Processing: {satellite} - {variable} - {scenario} - MatchDays={match_days}")

            precision_list = []
            recall_list = []
            threshold_list = []
            tp_list = []
            fp_list = []
            fn_list = []

            best_f1 = 0
            best_f2 = 0
            best_recall_over_min_precision = 0

            best_row_f1 = {}
            best_row_f2 = {}
            best_row_recall_strategy = {}

            for threshold in thresholds:
                precision, recall, tp, fp, fn = compute_simple_matching_precision_recall_for_one_threshold(
                    matching_max_days=match_days,
                    threshold=threshold,
                    series_ground_truth_manoeuvre_timestamps=filtered_ground_truth,
                    series_predictions=series_predictions
                )

                f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
                f2 = (5 * precision * recall) / ((4 * precision) + recall) if (precision + recall) > 0 else 0

                precision_list.append(precision)
                recall_list.append(recall)
                threshold_list.append(threshold)
                tp_list.append(tp)
                fp_list.append(fp)
                fn_list.append(fn)

                if f1 > best_f1:
                    best_f1 = f1
                    best_row_f1 = {
                        "threshold_f1": threshold,
                        "recall_f1": recall,
                        "precision_f1": precision,
                        "f1_score": f1,
                        "tp_f1": tp,
                        "fp_f1": fp,
                        "fn_f1": fn
                    }

                if f2 > best_f2:
                    best_f2 = f2
                    best_row_f2 = {
                        "threshold_f2": threshold,
                        "recall_f2": recall,
                        "precision_f2": precision,
                        "f2_score": f2,
                        "tp_f2": tp,
                        "fp_f2": fp,
                        "fn_f2": fn
                    }

                if precision >= min_precision_best_recall and recall > best_recall_over_min_precision:
                    best_recall_over_min_precision = recall
                    best_row_recall_strategy = {
                        "threshold_best_recall_strategy": threshold,
                        "recall_best_recall_strategy": recall,
                        "precision_best_recall_strategy": precision,
                        "tp_best_recall_strategy": tp,
                        "fp_best_recall_strategy": fp,
                        "fn_best_recall_strategy": fn
                    }

            if best_row_f1 and best_row_f2:
                combined_row = {
                    "satellite": satellite,
                    "variable": variable,
                    "scenario": scenario,
                    "matching_max_days": match_days,
                    **best_row_f1,
                    **best_row_f2,
                    **(best_row_recall_strategy if best_row_recall_strategy else {
                        "threshold_best_recall_strategy": None,
                        "recall_best_recall_strategy": None,
                        "precision_best_recall_strategy": None,
                        "tp_best_recall_strategy": None,
                        "fp_best_recall_strategy": None,
                        "fn_best_recall_strategy": None
                    })
                }
                summary.append(combined_row)

            pr_data.append({
                "satellite": satellite,
                "variable": variable,
                "scenario": scenario,
                "matching_max_days": match_days,
                "thresholds": threshold_list,
                "precision": precision_list,
                "recall": recall_list,
                "tp": tp_list,
                "fp": fp_list,
                "fn": fn_list
            })

    df_summary = pd.DataFrame(summary)
    return df_summary, pr_data

In [39]:
# Usage:
pr_xgb_summary, pr_data_xgb = evaluate_all_thresholds_with_flags_xgboost(
    results=retraining_results,
    ground_truth_dict=ground_truth_dict
)

Processing: CryoSat_2 - Brouwer mean motion - Scenario 1 - MatchDays=1
Processing: CryoSat_2 - Brouwer mean motion - Scenario 1 - MatchDays=2
Processing: CryoSat_2 - Brouwer mean motion - Scenario 1 - MatchDays=3
Processing: CryoSat_2 - Brouwer mean motion - Scenario 1 - MatchDays=4
Processing: CryoSat_2 - Brouwer mean motion - Scenario 2 - MatchDays=1
Processing: CryoSat_2 - Brouwer mean motion - Scenario 2 - MatchDays=2
Processing: CryoSat_2 - Brouwer mean motion - Scenario 2 - MatchDays=3
Processing: CryoSat_2 - Brouwer mean motion - Scenario 2 - MatchDays=4
Processing: CryoSat_2 - Brouwer mean motion - Scenario 3 (MAD σ=10) - MatchDays=1
Processing: CryoSat_2 - Brouwer mean motion - Scenario 3 (MAD σ=10) - MatchDays=2
Processing: CryoSat_2 - Brouwer mean motion - Scenario 3 (MAD σ=10) - MatchDays=3
Processing: CryoSat_2 - Brouwer mean motion - Scenario 3 (MAD σ=10) - MatchDays=4
Processing: CryoSat_2 - Brouwer mean motion - Scenario 3 (MAD σ=20) - MatchDays=1
Processing: CryoSat_2 

In [41]:
# Save the main results
with open('xgb_pr_data3.pkl', 'wb') as f:
    pickle.dump(pr_data_xgb, f)

In [117]:
# Save the main results
with open('xgb_pr_summary3.pkl', 'wb') as f:
    pickle.dump(pr_xgb_summary, f)

In [119]:
pr_xgb_summary.to_csv('xgb_metrics_pr3.csv', index=False)