In [1]:
!pip install --no-index --find-links=/kaggle/input/ariel-2024-pqdm pqdm
!pip install catboost

Looking in links: /kaggle/input/ariel-2024-pqdm
Processing /kaggle/input/ariel-2024-pqdm/pqdm-0.2.0-py2.py3-none-any.whl
Processing /kaggle/input/ariel-2024-pqdm/bounded_pool_executor-0.0.3-py3-none-any.whl (from pqdm)
Installing collected packages: bounded-pool-executor, pqdm
Successfully installed bounded-pool-executor-0.0.3 pqdm-0.2.0


In [2]:
import os
import time
import numpy as np
import pandas as pd
from tqdm import tqdm
from pqdm.threads import pqdm
from astropy.stats import sigma_clip
from scipy.signal import savgol_filter
from scipy.optimize import minimize
from catboost import CatBoostRegressor, Pool
import pickle
import joblib

ROOT_PATH = "/kaggle/input/ariel-data-challenge-2025"
MODE = "test"
__t0 = time.perf_counter()

# =========================================================
# CONFIGURATION
# =========================================================

class Config:
    DATA_PATH = '/kaggle/input/ariel-data-challenge-2025'
    DATASET = "test"
    SCALE = 0.946
    SIGMA = 0.00056
    CUT_INF = 39
    CUT_SUP = 321
    
    SENSOR_CONFIG = {
        "AIRS-CH0": {
            "raw_shape": [11250, 32, 356],
            "calibrated_shape": [1, 32, CUT_SUP - CUT_INF],
            "linear_corr_shape": (6, 32, 356),
            "dt_pattern": (0.1, 4.5), 
            "binning": 30
        },
        "FGS1": {
            "raw_shape": [135000, 32, 32],
            "calibrated_shape": [1, 32, 32],
            "linear_corr_shape": (6, 32, 32),
            "dt_pattern": (0.1, 0.1),
            "binning": 30 * 12
        }
    }
    
    MODEL_PHASE_DETECTION_SLICE = slice(30, 140)
    MODEL_OPTIMIZATION_DELTA = 11
    MODEL_POLYNOMIAL_DEGREE = 3
    N_JOBS = 3
    
    # CatBoost parameters for wavelength prediction only
    CATBOOST_ITERATIONS = 1000
    CATBOOST_DEPTH = 6
    CATBOOST_LR = 0.03
    CATBOOST_L2_LEAF_REG = 3

# =========================================================
# PREPROCESSING MODULE
# =========================================================

class SignalProcessor:
    """Handles all data preprocessing and calibration"""
    
    def __init__(self, config):
        self.cfg = config
        self.adc_info = pd.read_csv(f"{self.cfg.DATA_PATH}/adc_info.csv")
        self.planet_ids = pd.read_csv(
            f'{self.cfg.DATA_PATH}/{self.cfg.DATASET}_star_info.csv', 
            index_col='planet_id'
        ).index.astype(int)

    def _apply_linear_corr(self, linear_corr, signal):
        """Apply linearity correction to signal"""
        coeffs = np.flip(linear_corr, axis=0)
        x = signal.astype(np.float64, copy=False)
        out = np.empty_like(x, dtype=np.float64)
        out[...] = coeffs[0]
        for k in range(1, coeffs.shape[0]):
            np.multiply(out, x, out=out)
            out += coeffs[k]
        return out.astype(signal.dtype, copy=False)

    def _calibrate_single_signal(self, planet_id, sensor):
        """Calibrate raw signal for a single planet and sensor"""
        sensor_cfg = self.cfg.SENSOR_CONFIG[sensor]
        
        # Load raw data
        signal = pd.read_parquet(
            f"{self.cfg.DATA_PATH}/{self.cfg.DATASET}/{planet_id}/{sensor}_signal_0.parquet"
        ).to_numpy()
        dark = pd.read_parquet(
            f"{self.cfg.DATA_PATH}/{self.cfg.DATASET}/{planet_id}/{sensor}_calibration_0/dark.parquet"
        ).to_numpy()
        dead = pd.read_parquet(
            f"{self.cfg.DATA_PATH}/{self.cfg.DATASET}/{planet_id}/{sensor}_calibration_0/dead.parquet"
        ).to_numpy()
        flat = pd.read_parquet(
            f"{self.cfg.DATA_PATH}/{self.cfg.DATASET}/{planet_id}/{sensor}_calibration_0/flat.parquet"
        ).to_numpy()
        linear_corr = pd.read_parquet(
            f"{self.cfg.DATA_PATH}/{self.cfg.DATASET}/{planet_id}/{sensor}_calibration_0/linear_corr.parquet"
        ).values.astype(np.float64).reshape(sensor_cfg["linear_corr_shape"])
        
        # Reshape and apply ADC corrections
        signal = signal.reshape(sensor_cfg["raw_shape"])
        gain = self.adc_info[f"{sensor}_adc_gain"].iloc[0]
        offset = self.adc_info[f"{sensor}_adc_offset"].iloc[0]
        signal = signal / gain + offset
        
        # Detect hot pixels
        hot = sigma_clip(dark, sigma=5, maxiters=5).mask
        
        # Apply sensor-specific cropping
        if sensor == "AIRS-CH0":
            signal = signal[:, :, self.cfg.CUT_INF : self.cfg.CUT_SUP]
            linear_corr = linear_corr[:, :, self.cfg.CUT_INF : self.cfg.CUT_SUP]
            dark = dark[:, self.cfg.CUT_INF : self.cfg.CUT_SUP]
            dead = dead[:, self.cfg.CUT_INF : self.cfg.CUT_SUP]
            flat = flat[:, self.cfg.CUT_INF : self.cfg.CUT_SUP]
            hot = hot[:, self.cfg.CUT_INF : self.cfg.CUT_SUP]
        
        if sensor == "FGS1":
            y0, y1, x0, x1 = 10, 22, 10, 22
            signal = signal[:, y0:y1, x0:x1]
            dark = dark[y0:y1, x0:x1]
            dead = dead[y0:y1, x0:x1]
            flat = flat[y0:y1, x0:x1]
            linear_corr = linear_corr[:, y0:y1, x0:x1]
            hot = hot[y0:y1, x0:x1]
        
        # Apply corrections
        np.maximum(signal, 0, out=signal)
        
        # Apply linearity correction
        if sensor == "FGS1":
            signal = self._apply_linear_corr(linear_corr, signal)
        elif sensor == "AIRS-CH0":
            sl = (slice(None), slice(10, 22), slice(None))
            signal[sl] = self._apply_linear_corr(linear_corr[:, 10:22, :], signal[sl])
        else:
            signal = self._apply_linear_corr(linear_corr, signal)
        
        # Dark subtraction with pattern
        base_dt, increment = sensor_cfg["dt_pattern"]
        even_scale = base_dt
        odd_scale = base_dt + increment
        signal[::2] -= dark * even_scale
        signal[1::2] -= dark * odd_scale
        
        return signal

    def _preprocess_calibrated_signal(self, calibrated_signal, sensor):
        """Preprocess calibrated signal with binning and filtering"""
        sensor_cfg = self.cfg.SENSOR_CONFIG[sensor]
        binning = sensor_cfg["binning"]
        
        # Extract ROI
        if sensor == "AIRS-CH0":
            signal_roi = calibrated_signal[:, 10:22, :]
        elif sensor == "FGS1":
            signal_roi = calibrated_signal[:, 10:22, 10:22]
            signal_roi = signal_roi.reshape(signal_roi.shape[0], -1)
        
        # CDS (Correlated Double Sampling)
        mean_signal = np.nanmean(signal_roi, axis=1)
        cds_signal = mean_signal[1::2] - mean_signal[0::2]
        
        # Binning
        n_bins = cds_signal.shape[0] // binning
        binned = np.array([
            cds_signal[j*binning : (j+1)*binning].mean(axis=0) 
            for j in range(n_bins)
        ])
        
        # Outlier clipping for AIRS-CH0
        if sensor == "AIRS-CH0":
            q_lo = np.nanpercentile(binned, 5.0, axis=1, keepdims=True)
            q_hi = np.nanpercentile(binned, 95.0, axis=1, keepdims=True)
            np.clip(binned, q_lo, q_hi, out=binned)
        
        # Reshape for FGS1
        if sensor == "FGS1":
            binned = binned.reshape((binned.shape[0], 1))
        
        # Variance-based weighting for AIRS-CH0
        if sensor == "AIRS-CH0":
            var = np.nanvar(binned, axis=0, ddof=1)
            med = np.nanmedian(var)
            safe_var = np.where(
                ~np.isfinite(var) | (var <= 0), 
                med if (np.isfinite(med) and med > 0) else 1.0, 
                var
            )
            w = 1.0 / safe_var
            lo, hi = np.nanpercentile(w, 5.0), np.nanpercentile(w, 95.0)
            if np.isfinite(lo) and np.isfinite(hi) and lo < hi:
                w = np.clip(w, lo, hi)
            M = binned.shape[1]
            s = np.nansum(w)
            if np.isfinite(s) and s > 0:
                w = w * (M / s)
            else:
                w = np.ones_like(w)
            binned *= w[None, :]
        
        return binned

    def _process_planet_sensor(self, args):
        """Process a single planet-sensor combination"""
        planet_id, sensor = args['planet_id'], args['sensor']
        calibrated = self._calibrate_single_signal(planet_id, sensor)
        preprocessed = self._preprocess_calibrated_signal(calibrated, sensor)
        return preprocessed

    def process_all_data(self):
        """Process all planets and sensors in parallel"""
        print("[PREPROCESSING] Processing FGS1 data...")
        args_fgs1 = [
            dict(planet_id=planet_id, sensor="FGS1") 
            for planet_id in self.planet_ids
        ]
        preprocessed_fgs1 = pqdm(
            args_fgs1, 
            self._process_planet_sensor, 
            n_jobs=self.cfg.N_JOBS
        )
        
        print("[PREPROCESSING] Processing AIRS-CH0 data...")
        args_airs_ch0 = [
            dict(planet_id=planet_id, sensor="AIRS-CH0") 
            for planet_id in self.planet_ids
        ]
        preprocessed_airs_ch0 = pqdm(
            args_airs_ch0, 
            self._process_planet_sensor, 
            n_jobs=self.cfg.N_JOBS
        )
        
        # Combine FGS1 and AIRS-CH0
        preprocessed_signal = np.concatenate([
            np.stack(preprocessed_fgs1), 
            np.stack(preprocessed_airs_ch0)
        ], axis=2)
        
        print(f"[PREPROCESSING] Complete. Shape: {preprocessed_signal.shape}")
        return preprocessed_signal

# =========================================================
# PHASE DETECTION UTILITIES
# =========================================================

def _phase_detector_signal(signal, cfg):
    """Phase detection helper function"""
    sl = cfg.MODEL_PHASE_DETECTION_SLICE
    min_idx = int(np.argmin(signal[sl])) + sl.start
    s1 = signal[:min_idx]; s2 = signal[min_idx:]
    if s1.size < 3 or s2.size < 3:
        return 0, len(signal) - 1
    g1 = np.gradient(s1); g1_max = np.max(g1) if np.size(g1) else 0.0
    g2 = np.gradient(s2); g2_max = np.max(g2) if np.size(g2) else 0.0
    if g1_max != 0: g1 /= g1_max
    if g2_max != 0: g2 /= g2_max
    phase1 = int(np.argmin(g1)); phase2 = int(np.argmax(g2)) + min_idx
    return phase1, phase2

def diff_recursion_measure(signal, tol=0.01, max_iter=12):
    """Measure signal complexity for polynomial degree selection"""
    lis = np.copy(signal)
    depth = 0
    while depth < max_iter and len(lis) > 2:
        diffs = np.diff(lis)
        mean = diffs.mean()
        dev = np.mean(np.abs(diffs - mean))
        depth += 1
        if dev <= tol:
            break
        lis = diffs
    return depth, dev

# =========================================================
# TRANSIT MODEL (FROM FIRST CODE)
# =========================================================

class TransitModel:
    """Physics-based transit depth model using polynomial fitting"""
    
    def __init__(self, config):
        self.cfg = config

    def _phase_detector(self, signal):
        """Detect transit ingress and egress phases"""
        search_slice = self.cfg.MODEL_PHASE_DETECTION_SLICE
        min_index = np.argmin(signal[search_slice]) + search_slice.start
        signal1 = signal[:min_index]
        signal2 = signal[min_index:]
        if len(signal1) < 3 or len(signal2) < 3:
            return max(min_index - 3, 0), min(min_index + 3, len(signal))
        grad1 = np.gradient(signal1)
        grad2 = np.gradient(signal2)
        grad1 /= grad1.max() if grad1.max() != 0 else 1
        grad2 /= grad2.max() if grad2.max() != 0 else 1
        phase1 = np.argmin(grad1)
        phase2 = np.argmax(grad2) + min_index
        return phase1, phase2

    def _objective_function(self, s, signal, phase1, phase2):
        """Objective function for transit depth optimization"""
        delta = self.cfg.MODEL_OPTIMIZATION_DELTA
        
        if phase1 - delta <= 0 or phase2 + delta >= len(signal) or phase2 - delta - (phase1 + delta) < 5:
            delta = 2
        
        y = np.concatenate([
            signal[: phase1 - delta],
            signal[phase1 + delta : phase2 - delta] * (1 + s),
            signal[phase2 + delta :]
        ])
        x = np.arange(len(y))
        best_fit_depth, deviation_level = diff_recursion_measure(y)
        degree = int(np.clip(3 + best_fit_depth // 50, 2, 4))
        
        coeffs = np.polyfit(x, y, deg=degree)
        poly = np.poly1d(coeffs)
        error = np.abs(poly(x) - y).mean()
        
        return error

    def predict(self, single_preprocessed_signal):
        """Predict transit depth for a single observation"""
        signal_1d = single_preprocessed_signal[:, 1:].mean(axis=1)
        signal_1d = savgol_filter(signal_1d, 23, 2)
        
        phase1, phase2 = self._phase_detector(signal_1d)
        delta = self.cfg.MODEL_OPTIMIZATION_DELTA
        phase1 = max(delta, phase1)
        phase2 = min(len(signal_1d) - delta - 1, phase2)
        
        if phase2 - phase1 < 3:
            return 0.0
        
        result = minimize(
            fun=self._objective_function, 
            x0=[0.0001],
            args=(signal_1d, phase1, phase2),
            method="Nelder-Mead"
        )
        
        return result.x[0] if result.success else 0.0

    def predict_all(self, preprocessed_data):
        """Predict transit depths for all observations (returns scaled predictions)"""
        predictions = [self.predict(signal) for signal in tqdm(preprocessed_data, desc="Transit Model")]
        return np.array(predictions) * self.cfg.SCALE

# =========================================================
# LOG MODEL ENSEMBLE ADJUSTMENT (FROM FIRST CODE)
# =========================================================

def log_preprocess(signal_2d, binning=3):
    """Preprocess signal for log model"""
    signal_1d = signal_2d[:, 1:].mean(axis=1)
    if binning > 1:
        n_bins = len(signal_1d) // binning
        signal_1d = np.array([signal_1d[i*binning:(i+1)*binning].mean() for i in range(n_bins)])
    noise_level = np.std(signal_1d)
    window = 23 if noise_level < 0.005 else 31
    signal_1d = savgol_filter(signal_1d, window, 2)
    return signal_1d

def detect(points, window_size=4, eps=1e-12, strength_threshold=0.75, symmetry_tolerance=1.1):
    """Detect transit features using symmetry analysis"""
    arr = np.array(points)
    flux = np.clip(arr[:, 1], eps, None)
    if len(flux) < (2 * window_size + 1):
        return None
    s = np.mean(flux)
    mad = np.median(np.abs(flux - s)) + eps
    global_log_ref = np.median(np.log(flux))
    cutoff = np.quantile(flux, (1-strength_threshold))
    flux_length = len(flux)
    window_length = 9
    smoothed = savgol_filter(flux, window_length=window_length, polyorder=2)
    residuals = flux - smoothed
    noise_level = np.std(residuals)
    noise_level_and_mad=(1-(noise_level/mad+eps))*mad
    if noise_level_and_mad<0:
        noise_level_and_mad=noise_level
    max_offset = window_size
    sym_gaps = []
    for center in range(window_size, len(flux) - window_size):
        sym_gaps = []
        if flux[center] > cutoff:
            continue
        left = flux[center - window_size:center]
        right = flux[center + 1:center + 1 + window_size]
        if flux[center] >= left.mean() or flux[center] >= right.mean():
            continue
        for offset in range(1, max_offset + 1):
            lhs_val = flux[center - offset]
            rhs_val = flux[center + offset]
            sym_gaps.append(abs(np.log(lhs_val/flux[center]) - np.log(rhs_val/flux[center])))
        symmetry_gap = np.mean(sym_gaps)
        if symmetry_gap > symmetry_tolerance * mad:
            continue
        symmetry_score = 1.0 - np.clip(symmetry_gap / (mad * symmetry_tolerance), 0, 1)
        local_log_strength = np.mean(np.log(np.clip([flux[center], *left, *right], eps, None)))
        if local_log_strength < global_log_ref * strength_threshold:
            continue
        return center, flux[center], symmetry_score
    return None

def log_model_ensemble_adjust(preprocessed_data, transit_predictions, sigma,
                              log_threshold=0.75, window_size=4,
                              min_relative_dip=5e-4, max_injection=0.015):
    """Adjust predictions using log-space transit detection"""
    adjusted = transit_predictions.copy()
    adjusted_sigma = sigma.copy()
    for i, signal in enumerate(preprocessed_data):
        flux = signal[:, 1:].mean(axis=1)
        points = np.stack([np.arange(len(flux)), flux], axis=1)
        result = detect(points, window_size=window_size, strength_threshold=log_threshold)
        if result is None:
            continue
        center_idx, center_flux, symmetry_score = result
        w = window_size
        left = flux[max(0, center_idx - w):center_idx]
        right = flux[center_idx + 1:center_idx + 1 + w]
        surrounding = np.concatenate([left, right])
        baseline = np.median(surrounding) if surrounding.size else np.median(flux)
        dip = max(0.0, baseline - center_flux)
        relative_dip = dip / max(baseline, 1e-12)
        if relative_dip < min_relative_dip:
            continue
        inj = np.clip(dip * symmetry_score * 0.7, 1e-5, max_injection)
        blend_factor = 0.5
        adjusted[i] = (1 - blend_factor) * adjusted[i] + blend_factor * inj
        adjusted_sigma[i] = np.clip(adjusted_sigma[i] * (1.0 - 0.2 * symmetry_score), 1e-6, 0.1)
        adjusted[i] = np.clip(adjusted[i], 1e-5, max_injection)
    return adjusted, adjusted_sigma

# =========================================================
# FEATURE ENGINEERING
# =========================================================

def extract_features(preprocessed_data, star_info_df, config):
    """Extract comprehensive features from preprocessed data for CatBoost"""
    features_list = []
    
    for i, signal in enumerate(tqdm(preprocessed_data, desc="Extracting features")):
        # FGS1 features (column 0)
        fgs_signal = signal[:, 0]
        
        # AIRS features (columns 1+)
        airs_signal = signal[:, 1:].mean(axis=1)
        airs_signal_smooth = savgol_filter(airs_signal, 23, 2)
        
        # Phase detection
        phase1, phase2 = _phase_detector_signal(airs_signal_smooth, config)
        
        # Statistical features
        features = {
            # FGS1 features
            'fgs_mean': np.nanmean(fgs_signal),
            'fgs_std': np.nanstd(fgs_signal),
            'fgs_min': np.nanmin(fgs_signal),
            'fgs_max': np.nanmax(fgs_signal),
            'fgs_median': np.nanmedian(fgs_signal),
            'fgs_range': np.nanmax(fgs_signal) - np.nanmin(fgs_signal),
            'fgs_q25': np.nanpercentile(fgs_signal, 25),
            'fgs_q75': np.nanpercentile(fgs_signal, 75),
            'fgs_iqr': np.nanpercentile(fgs_signal, 75) - np.nanpercentile(fgs_signal, 25),
            
            # AIRS features
            'airs_mean': np.nanmean(airs_signal),
            'airs_std': np.nanstd(airs_signal),
            'airs_min': np.nanmin(airs_signal),
            'airs_max': np.nanmax(airs_signal),
            'airs_median': np.nanmedian(airs_signal),
            'airs_range': np.nanmax(airs_signal) - np.nanmin(airs_signal),
            'airs_q25': np.nanpercentile(airs_signal, 25),
            'airs_q75': np.nanpercentile(airs_signal, 75),
            'airs_iqr': np.nanpercentile(airs_signal, 75) - np.nanpercentile(airs_signal, 25),
            
            # Smoothed AIRS features
            'airs_smooth_min': np.nanmin(airs_signal_smooth),
            'airs_smooth_max': np.nanmax(airs_signal_smooth),
            'airs_smooth_std': np.nanstd(airs_signal_smooth),
            'airs_smooth_mean': np.nanmean(airs_signal_smooth),
            
            # Transit-like features
            'transit_depth_estimate': np.nanmax(airs_signal) - np.nanmin(airs_signal_smooth),
            'min_position': np.argmin(airs_signal_smooth) / len(airs_signal_smooth),
            'phase1_normalized': phase1 / len(airs_signal_smooth),
            'phase2_normalized': phase2 / len(airs_signal_smooth),
            'transit_duration': (phase2 - phase1) / len(airs_signal_smooth),
            
            # Gradient features
            'airs_grad_mean': np.nanmean(np.abs(np.gradient(airs_signal_smooth))),
            'airs_grad_max': np.nanmax(np.abs(np.gradient(airs_signal_smooth))),
            'airs_grad_std': np.nanstd(np.gradient(airs_signal_smooth)),
            
            # Spectral features (variance across wavelengths)
            'spectral_variance': np.nanvar(signal[:, 1:], axis=1).mean(),
            'spectral_mean': np.nanmean(signal[:, 1:], axis=0).mean(),
            'spectral_std': np.nanstd(signal[:, 1:], axis=0).mean(),
            
            # Out-of-transit vs in-transit features
            'oot_mean': np.nanmean(np.concatenate([airs_signal[:phase1], airs_signal[phase2:]])) if phase1 > 0 and phase2 < len(airs_signal) else np.nanmean(airs_signal),
            'in_transit_mean': np.nanmean(airs_signal[phase1:phase2]) if phase2 > phase1 else np.nanmean(airs_signal),
            'oot_std': np.nanstd(np.concatenate([airs_signal[:phase1], airs_signal[phase2:]])) if phase1 > 0 and phase2 < len(airs_signal) else np.nanstd(airs_signal),
            'in_transit_std': np.nanstd(airs_signal[phase1:phase2]) if phase2 > phase1 else np.nanstd(airs_signal),
        }
        
        features_list.append(features)
    
    features_df = pd.DataFrame(features_list)
    
    # Add stellar parameters
    if star_info_df is not None:
        features_df = pd.concat([features_df, star_info_df.reset_index(drop=True)], axis=1)
    
    return features_df

# =========================================================
# CATBOOST WAVELENGTH MODEL
# =========================================================

class CatBoostWavelengthPredictor:
    """CatBoost model for predicting wavelength-dependent transit depths"""
    
    def __init__(self, config, n_wavelengths=282):
        self.cfg = config
        self.n_wavelengths = n_wavelengths
        self.models = []
        
    def train(self, X_train, y_train, X_val=None, y_val=None):
        """Train separate CatBoost model for each wavelength"""
        print(f"[CATBOOST-WAVELENGTH] Training {self.n_wavelengths} wavelength models...")
        
        for i in tqdm(range(self.n_wavelengths), desc="Training wavelengths"):
            model = CatBoostRegressor(
                iterations=self.cfg.CATBOOST_ITERATIONS,
                depth=self.cfg.CATBOOST_DEPTH,
                learning_rate=self.cfg.CATBOOST_LR,
                l2_leaf_reg=self.cfg.CATBOOST_L2_LEAF_REG,
                loss_function='RMSE',
                random_seed=42 + i,
                verbose=False
            )
            
            if X_val is not None and y_val is not None:
                eval_set = Pool(X_val, y_val[:, i])
                model.fit(X_train, y_train[:, i], eval_set=eval_set, verbose=False)
            else:
                model.fit(X_train, y_train[:, i], verbose=False)
            
            self.models.append(model)
        
        print("[CATBOOST-WAVELENGTH] Training complete!")
    
    def predict(self, X):
        """Predict transit depths for all wavelengths"""
        if not self.models:
            raise ValueError("Models not trained or loaded!")
            
        predictions = np.zeros((len(X), self.n_wavelengths))
        
        for i, model in enumerate(tqdm(self.models, desc="Predicting wavelengths")):
            predictions[:, i] = model.predict(X)
        
        # Apply smoothing
        if predictions.shape[1] >= 5:
            predictions_smooth = savgol_filter(
                predictions, 
                window_length=min(13, predictions.shape[1] // 2 * 2 - 1), 
                polyorder=2, 
                axis=1
            )
            predictions = 0.7 * predictions + 0.3 * predictions_smooth
        
        return predictions
    
    def save(self, filepath):
        """Save all models"""
        joblib.dump(self.models, filepath)
        print(f"[CATBOOST-WAVELENGTH] Models saved to {filepath}")
    
    def load(self, filepath):
        """Load all models"""
        self.models = joblib.load(filepath)
        print(f"[CATBOOST-WAVELENGTH] Models loaded from {filepath}")

# =========================================================
# SIGMA ESTIMATION
# =========================================================

def estimate_sigma_fgs(preprocessed_data, cfg):
    """Estimate FGS1 uncertainties"""
    sig_rel = []
    delta = cfg.MODEL_OPTIMIZATION_DELTA
    eps = 1e-12
    for single in preprocessed_data:
        air_white = savgol_filter(single[:, 1:].mean(axis=1), 20, 2)
        p1, p2 = _phase_detector_signal(air_white, cfg)
        p1 = max(delta, p1)
        p2 = min(len(air_white) - delta - 1, p2)
        fgs = single[:, 0]
        oot = (fgs[: p1 - delta] if p1 - delta > 0 else np.empty(0, fgs.dtype))
        if p2 + delta < fgs.size:
            oot = np.concatenate([oot, fgs[p2 + delta :]])
        inn = fgs[p1 + delta : max(p1 + delta, p2 - delta)]
        if oot.size == 0 or inn.size == 0:
            sig_rel.append(np.nan); continue
        n_oot, n_in = len(oot), len(inn)
        var_oot = np.nanvar(oot, ddof=1)
        var_in  = np.nanvar(inn, ddof=1)
        oot_mean = float(np.nanmean(oot)) if np.isfinite(np.nanmean(oot)) else float(np.nanmean(fgs))
        sigma_rel = np.sqrt(var_oot / max(n_oot,1) + var_in / max(n_in,1)) / max(oot_mean, eps)
        sig_rel.append(sigma_rel)
    s = np.asarray(sig_rel, dtype=float)
    mask = np.isfinite(s) & (s > 0)
    med = float(np.nanmedian(s[mask])) if mask.any() else 1.0
    k = np.ones_like(s)
    if med > 0 and np.isfinite(med):
        k[mask] = np.sqrt(s[mask] / med)
    k = np.clip(k, 0.85, 1.30)
    sigma_fgs = k * cfg.SIGMA * 1.04
    return sigma_fgs

def estimate_sigma_air(preprocessed_data, cfg):
    """Estimate AIRS uncertainties"""
    sig_rel = []
    delta = cfg.MODEL_OPTIMIZATION_DELTA
    eps = 1e-12
    for single in preprocessed_data:
        white = np.nanmean(single[:, 1:], axis=1)
        white_s = savgol_filter(white, 20, 2)
        p1, p2 = _phase_detector_signal(white_s, cfg)
        p1 = max(delta, p1)
        p2 = min(len(white) - delta - 1, p2)
        oot_left = white[: p1 - delta] if p1 - delta > 0 else np.empty(0, white.dtype)
        oot_right = white[p2 + delta :] if (p2 + delta) < white.size else np.empty(0, white.dtype)
        oot = np.concatenate([oot_left, oot_right]) if (oot_left.size + oot_right.size) else oot_left
        inn = white[p1 + delta : max(p1 + delta, p2 - delta)]
        if oot.size == 0 or inn.size == 0:
            sig_rel.append(np.nan); continue
        n_oot, n_in = len(oot), len(inn)
        var_oot = np.nanvar(oot, ddof=1)
        var_in  = np.nanvar(inn, ddof=1)
        oot_mean = float(np.nanmean(oot)) if np.isfinite(np.nanmean(oot)) else float(np.nanmean(white))
        sigma_rel = np.sqrt(var_oot / max(n_oot,1) + var_in / max(n_in,1)) / max(oot_mean, eps)
        sig_rel.append(sigma_rel)
    s = np.asarray(sig_rel, dtype=float)
    mask = np.isfinite(s) & (s > 0)
    med = float(np.nanmedian(s[mask])) if mask.any() else 1.0
    k = np.ones_like(s)
    if med > 0 and np.isfinite(med):
        k[mask] = np.sqrt(s[mask] / med)
    k = np.clip(k, 0.92, 1.22)
    sigma_air = k * cfg.SIGMA * 1.04
    return sigma_air

# =========================================================
# SUBMISSION GENERATOR
# =========================================================

class SubmissionGenerator:
    """Generate submission file"""
    
    def __init__(self, config):
        self.cfg = config
        self.sample_submission = pd.read_csv(
            "/kaggle/input/ariel-data-challenge-2025/sample_submission.csv", 
            index_col="planet_id"
        )

    def create(self, predictions1, predictions, sigma_fgs=None, sigma_air=None):
        planet_ids = self.sample_submission.index
        n_planets = len(planet_ids)
        n_mu = self.sample_submission.shape[1] // 2
        
        predictions1 = np.nan_to_num(predictions1, nan=0.0, posinf=0.0, neginf=0.0)
        predictions = np.nan_to_num(predictions, nan=0.0, posinf=0.0, neginf=0.0)
        
        if predictions.ndim > 1:
            predictions = predictions.flatten()
        if len(predictions) < n_planets:
            predictions = np.pad(predictions, (0, n_planets - len(predictions)), constant_values=0.0)
        elif len(predictions) > n_planets:
            predictions = predictions[:n_planets]
        
        if sigma_fgs is not None:
            sigma_fgs = np.nan_to_num(sigma_fgs, nan=self.cfg.SIGMA, posinf=self.cfg.SIGMA, neginf=self.cfg.SIGMA)
            if len(sigma_fgs) < n_planets:
                sigma_fgs = np.pad(sigma_fgs, (0, n_planets - len(sigma_fgs)), constant_values=self.cfg.SIGMA)
            elif len(sigma_fgs) > n_planets:
                sigma_fgs = sigma_fgs[:n_planets]
        
        if sigma_air is not None:
            sigma_air = np.nan_to_num(sigma_air, nan=self.cfg.SIGMA, posinf=self.cfg.SIGMA, neginf=self.cfg.SIGMA)
            if len(sigma_air) < n_planets:
                sigma_air = np.pad(sigma_air, (0, n_planets - len(sigma_air)), constant_values=self.cfg.SIGMA)
            elif len(sigma_air) > n_planets:
                sigma_air = sigma_air[:n_planets]
        
        if predictions1.shape[0] < n_planets:
            pad_rows = n_planets - predictions1.shape[0]
            predictions1 = np.vstack([predictions1, np.zeros((pad_rows, predictions1.shape[1]))])
        elif predictions1.shape[0] > n_planets:
            predictions1 = predictions1[:n_planets]
        
        if predictions1.shape[1] < n_mu:
            pad_cols = n_mu - predictions1.shape[1]
            predictions1 = np.hstack([predictions1, np.zeros((predictions1.shape[0], pad_cols))])
        elif predictions1.shape[1] > n_mu:
            predictions1 = predictions1[:, :n_mu]
        
        mu = np.zeros((n_planets, n_mu), dtype=float)
        mu[:, 0] = np.clip(predictions, 0, None)
        mu[:, 1:] = np.clip(predictions1[:, 1:], 0, None)
        
        sigmas = np.full((n_planets, n_mu), self.cfg.SIGMA, dtype=float)
        if sigma_fgs is not None:
            sigmas[:, 0] = np.clip(sigma_fgs, 1e-6, 0.1)
        if sigma_air is not None:
            sigmas[:, 1:] = np.clip(sigma_air[:, None], 1e-6, 0.1)
        
        submission_df = pd.DataFrame(
            np.concatenate([mu, sigmas], axis=1), 
            columns=self.sample_submission.columns, 
            index=planet_ids
        )
        submission_df = submission_df.replace([np.inf, -np.inf], 1e-5).fillna(1e-5)
        submission_df = submission_df.clip(lower=0.0)
        submission_df.to_csv("submission.csv", index_label="planet_id")
        print(f"[INFO] submission.csv written successfully with shape: {submission_df.shape}")
        return submission_df

# =========================================================
# MAIN EXECUTION
# =========================================================

if __name__ == "__main__":
    config = Config()
    
    # Step 1: Preprocessing
    print("="*60)
    print("STEP 1: PREPROCESSING")
    print("="*60)
    signal_processor = SignalProcessor(config)
    preprocessed_data = signal_processor.process_all_data()
    
    # Step 2: Load star info
    print("\n" + "="*60)
    print("STEP 2: LOADING STELLAR PARAMETERS")
    print("="*60)
    star_info = pd.read_csv(f"{ROOT_PATH}/{MODE}_star_info.csv")
    star_info["planet_id"] = star_info["planet_id"].astype(int)
    planet_ids = star_info["planet_id"].tolist()
    star_info_indexed = star_info.set_index("planet_id")
    
    # Step 3: Transit Depth Prediction with TransitModel
    print("\n" + "="*60)
    print("STEP 3: TRANSIT DEPTH PREDICTION (TRANSIT MODEL)")
    print("="*60)
    model = TransitModel(config)
    predictions = model.predict_all(preprocessed_data)  # Already scaled with SCALE factor
    print(f"[INFO] Transit depth predictions: min={predictions.min():.6f}, max={predictions.max():.6f}, mean={predictions.mean():.6f}")
    
    # Step 4: Estimate uncertainties
    print("\n" + "="*60)
    print("STEP 4: UNCERTAINTY ESTIMATION")
    print("="*60)
    sigma_fgs_vec = estimate_sigma_fgs(preprocessed_data, config)
    sigma_air_vec = estimate_sigma_air(preprocessed_data, config)
    print(f"[INFO] Sigma FGS: min={sigma_fgs_vec.min():.6f}, max={sigma_fgs_vec.max():.6f}")
    print(f"[INFO] Sigma AIR: min={sigma_air_vec.min():.6f}, max={sigma_air_vec.max():.6f}")
    
    # Step 5: Apply log model ensemble adjustment
    print("\n" + "="*60)
    print("STEP 5: LOG MODEL ENSEMBLE ADJUSTMENT")
    print("="*60)
    predictions_corrected = predictions.copy()
    sigma_air_corrected = sigma_air_vec.copy()
    
    # Use median-based mask like working version
    mask = predictions < (np.median(predictions) * 0.1)
    print(f"[INFO] Adjusting {mask.sum()} low-confidence predictions")
    
    if mask.any():
        adjusted, adjusted_sigma = log_model_ensemble_adjust(
            preprocessed_data[mask], 
            predictions[mask], 
            sigma_air_vec[mask], 
            log_threshold=0.75, 
            window_size=4
        )
        alpha = 0.5
        predictions_corrected[mask] = alpha * adjusted + (1 - alpha) * predictions[mask]
        sigma_air_corrected[mask] = adjusted_sigma
    
    print(f"[INFO] Corrected predictions: min={predictions_corrected.min():.6f}, max={predictions_corrected.max():.6f}")
    
    # Step 6: Feature engineering for wavelength prediction
    print("\n" + "="*60)
    print("STEP 6: FEATURE ENGINEERING FOR WAVELENGTH PREDICTION")
    print("="*60)
    features_df = extract_features(preprocessed_data, star_info_indexed, config)
    
    # Add transit depth prediction to features
    features_df['transit_depth_pred'] = predictions_corrected
    print(f"[INFO] Features extracted. Shape: {features_df.shape}")
    
    # Step 7: Wavelength prediction with CatBoost
    print("\n" + "="*60)
    print("STEP 7: WAVELENGTH PREDICTION (CATBOOST)")
    print("="*60)
    
    # Check if pre-trained wavelength models exist
    wavelength_model_path = "/kaggle/input/catboost-wavelength-models/catboost_wavelength_models.pkl"
    
    wavelength_model = CatBoostWavelengthPredictor(config, n_wavelengths=282)
    
    try:
        wavelength_model.load(wavelength_model_path)
        predictions1 = wavelength_model.predict(features_df)
        print(f"[INFO] Wavelength predictions shape: {predictions1.shape}")
        print(f"[INFO] Wavelength predictions: min={predictions1.min():.6f}, max={predictions1.max():.6f}")
    except:
        print("[WARNING] Pre-trained wavelength models not found!")
        print("[INFO] Using fallback: simple wavelength estimation based on transit depth")
        
        # Fallback: Create wavelength predictions based on transit depth
        n_wavelengths = 282
        predictions1 = np.zeros((len(predictions_corrected), n_wavelengths))
        
        for i in range(len(predictions_corrected)):
            # Base depth from transit prediction
            base_depth = predictions_corrected[i]
            
            # Create wavelength-dependent variation (simplified model)
            # Real exoplanet spectra have wavelength-dependent features
            wavelength_variation = np.random.normal(1.0, 0.05, n_wavelengths)
            wavelength_variation = savgol_filter(wavelength_variation, 
                                                window_length=15, 
                                                polyorder=2)
            
            predictions1[i, :] = base_depth * wavelength_variation
        
        predictions1 = np.clip(predictions1, 0, None)
        print(f"[INFO] Fallback wavelength predictions created: {predictions1.shape}")
    
    # Convert predictions1 from absolute to relative (like in first code)
    predictions1 /= 10000  # Scale back if needed
    
    # Step 8: Create submission
    print("\n" + "="*60)
    print("STEP 8: GENERATING SUBMISSION")
    print("="*60)
    
    sigma_fgs_vec = np.nan_to_num(sigma_fgs_vec, nan=config.SIGMA, 
                                   posinf=config.SIGMA, neginf=config.SIGMA)
    sigma_air_corrected = np.nan_to_num(sigma_air_corrected, nan=config.SIGMA, 
                                        posinf=config.SIGMA, neginf=config.SIGMA)
    
    submission_generator = SubmissionGenerator(config)
    submission = submission_generator.create(
        predictions1, 
        predictions_corrected, 
        sigma_fgs=sigma_fgs_vec, 
        sigma_air=sigma_air_corrected
    )
    
    __t1 = time.perf_counter()
    elapsed = __t1 - __t0
    print(f"\n[TIMING] Total runtime: {elapsed:.2f} s ({elapsed/60:.2f} min)")
    print("\n" + "="*60)
    print("SUBMISSION PREVIEW")
    print("="*60)
    print(submission.head(10))

STEP 1: PREPROCESSING
[PREPROCESSING] Processing FGS1 data...


QUEUEING TASKS | :   0%|          | 0/1 [00:00<?, ?it/s]

PROCESSING TASKS | :   0%|          | 0/1 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/1 [00:00<?, ?it/s]

[PREPROCESSING] Processing AIRS-CH0 data...


QUEUEING TASKS | :   0%|          | 0/1 [00:00<?, ?it/s]

PROCESSING TASKS | :   0%|          | 0/1 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/1 [00:00<?, ?it/s]

[PREPROCESSING] Complete. Shape: (1, 187, 283)

STEP 2: LOADING STELLAR PARAMETERS

STEP 3: TRANSIT DEPTH PREDICTION (TRANSIT MODEL)


Transit Model: 100%|██████████| 1/1 [00:00<00:00, 11.36it/s]


[INFO] Transit depth predictions: min=0.015916, max=0.015916, mean=0.015916

STEP 4: UNCERTAINTY ESTIMATION
[INFO] Sigma FGS: min=0.000582, max=0.000582
[INFO] Sigma AIR: min=0.000582, max=0.000582

STEP 5: LOG MODEL ENSEMBLE ADJUSTMENT
[INFO] Adjusting 0 low-confidence predictions
[INFO] Corrected predictions: min=0.015916, max=0.015916

STEP 6: FEATURE ENGINEERING FOR WAVELENGTH PREDICTION


Extracting features: 100%|██████████| 1/1 [00:00<00:00, 135.24it/s]

[INFO] Features extracted. Shape: (1, 46)

STEP 7: WAVELENGTH PREDICTION (CATBOOST)
[INFO] Using fallback: simple wavelength estimation based on transit depth
[INFO] Fallback wavelength predictions created: (1, 282)

STEP 8: GENERATING SUBMISSION
[INFO] submission.csv written successfully with shape: (1, 566)

[TIMING] Total runtime: 9.32 s (0.16 min)

SUBMISSION PREVIEW
               wl_1      wl_2      wl_3      wl_4      wl_5      wl_6  \
planet_id                                                               
1103775    0.015916  0.000002  0.000002  0.000002  0.000002  0.000002   

               wl_7      wl_8      wl_9     wl_10  ...  sigma_274  sigma_275  \
planet_id                                          ...                         
1103775    0.000002  0.000002  0.000002  0.000002  ...   0.000582   0.000582   

           sigma_276  sigma_277  sigma_278  sigma_279  sigma_280  sigma_281  \
planet_id                                                                     
1103775


