# Model 4: XGB SpecType

This notebook contains Model 4 for the MALLORN challenge.

This model is the first one that performed very well. It was enough to put me in around 130th / 925 participants.
The biggest change is using SpecType (train-only metadata) to generate features that can also be computed for the test set.

A Kaggle user in a discussion post pointed out focusing on TDE vs SN/AGN which are values in SpecType. Since SpecType is not available at test time, I train a separate model to predict SpecType, and then use its predicted probabilities as additional features in the main TDE classifier.

1) Train a multiclass LightGBM model to predict `SpecTypeGroup`:
   - TDE
   - AGN
   - SNIa
   - SNother
   - Other

2) Generate OOF predicted probabilities for the train set:
   - Each training object only gets probabilities from a model that did not train on its group-split fold.

3) Fit the multiclass model on full train and predict probabilities for test.

4) Append these probabilities as features:
   - `p_spec_<class>` for each class
   - `spec_entropy` as a confidence / ambiguity signal

This gives the main classifier extra information about “what kind of transient this looks like” using only lightcurve-derived features.

In [2]:
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, average_precision_score
from sklearn.model_selection import StratifiedGroupKFold
from xgboost import XGBClassifier
from extinction import fitzpatrick99
from lightgbm import LGBMClassifier
import lightgbm as lgb
import optuna

## Constants / Configuration

- `PRE_BASE_FRAC`: fraction of early-time points used to estimate baseline before peak
- `MIN_BAND_POINTS`: minimum points needed for certain per-band features
- `PEAK_SIGMA_K`: how strict a “significant peak” must be relative to noise
- `REBRIGHT_FRAC`: what fraction of amplitude counts as rebrightening
- `EPS`: numerical stability


In [3]:
FILTERS = ["u", "g", "r", "i", "z", "y"]

EFF_WL_AA = {
    "u": 3641.0,
    "g": 4704.0,
    "r": 6155.0,
    "i": 7504.0,
    "z": 8695.0,
    "y": 10056.0,
}

R_V = 3.1

PRE_BASE_FRAC = 0.20
MIN_BAND_POINTS = 5
PEAK_SIGMA_K = 3.0
REBRIGHT_FRAC = 0.30
EPS = 1e-8


In [None]:
def safe_float(x, default=np.nan):
    try:
        if x is None:
            return default
        x = float(x)
        if np.isnan(x):
            return default
        return x
    except Exception:
        return default


def trapz_safe(y, x):
    if hasattr(np, "trapezoid"):
        return float(np.trapezoid(y, x))
    y = np.asarray(y)
    x = np.asarray(x)
    return float(np.sum((x[1:] - x[:-1]) * (y[1:] + y[:-1]) * 0.5))


def median_abs_dev(x):
    x = np.asarray(x)
    if len(x) == 0:
        return np.nan
    med = np.median(x)
    return float(np.median(np.abs(x - med)))


def iqr(x):
    x = np.asarray(x)
    if len(x) < 2:
        return np.nan
    q75, q25 = np.percentile(x, [75, 25])
    return float(q75 - q25)


def skewness(x):
    x = np.asarray(x)
    n = len(x)
    if n < 3:
        return np.nan
    mu = np.mean(x)
    s = np.std(x)
    if s < 1e-12:
        return 0.0
    m3 = np.mean((x - mu) ** 3)
    return float(m3 / (s ** 3))


def kurtosis_excess(x):
    x = np.asarray(x)
    n = len(x)
    if n < 4:
        return np.nan
    mu = np.mean(x)
    s = np.std(x)
    if s < 1e-12:
        return 0.0
    m4 = np.mean((x - mu) ** 4)
    return float(m4 / (s ** 4) - 3.0)


def von_neumann_eta(x):
    x = np.asarray(x)
    n = len(x)
    if n < 3:
        return np.nan
    v = np.var(x)
    if v < 1e-12:
        return 0.0
    dif = np.diff(x)
    return float(np.mean(dif ** 2) / v)


def max_slope(t, f):
    t = np.asarray(t)
    f = np.asarray(f)
    if len(t) < 3:
        return np.nan
    dt = np.diff(t)
    df = np.diff(f)
    good = dt > 0
    if not np.any(good):
        return np.nan
    slopes = df[good] / dt[good]
    return float(np.max(np.abs(slopes)))


def median_abs_slope(t, f):
    t = np.asarray(t)
    f = np.asarray(f)
    if len(t) < 3:
        return np.nan
    dt = np.diff(t)
    df = np.diff(f)
    good = dt > 0
    if not np.any(good):
        return np.nan
    slopes = df[good] / dt[good]
    return float(np.median(np.abs(slopes)))


def linear_slope(t, f):
    t = np.asarray(t)
    f = np.asarray(f)
    if len(t) < 3:
        return np.nan
    try:
        a, b = np.polyfit(t, f, 1)
        return float(a)
    except Exception:
        return np.nan


def chi2_to_constant(f, ferr):
    f = np.asarray(f)
    ferr = np.asarray(ferr)
    n = len(f)
    if n < 3:
        return np.nan
    mu = np.median(f)
    denom = (ferr + EPS) ** 2
    chi2 = np.sum((f - mu) ** 2 / denom)
    dof = max(1, n - 1)
    return float(chi2 / dof)


def interp_flux_at_time(tb, fb, t0):
    tb = np.asarray(tb)
    fb = np.asarray(fb)
    if len(tb) < 2:
        return np.nan
    if (t0 < tb.min()) or (t0 > tb.max()):
        return np.nan
    return float(np.interp(t0, tb, fb))


def fractional_variability(f, ferr):
    """
    Noise-corrected intrinsic variability:
    F_var = sqrt(max(0, S^2 - mean(err^2))) / |mean(f)|
    """
    f = np.asarray(f, float)
    ferr = np.asarray(ferr, float)
    n = len(f)
    if n < 3:
        return np.nan

    mu = np.mean(f)
    if np.abs(mu) < 1e-8:
        return np.nan

    s2 = np.var(f, ddof=1)
    mean_err2 = np.mean(ferr**2)
    excess = max(0.0, s2 - mean_err2)
    return float(np.sqrt(excess) / np.abs(mu))

def stetson_J_consecutive(t, f, ferr):
    """
    Stetson J using consecutive pairs (always exists if n>=4).
    """
    t = np.asarray(t)
    f = np.asarray(f)
    ferr = np.asarray(ferr)
    n = len(t)
    if n < 4:
        return np.nan

    mu = np.mean(f)
    scale = np.sqrt(n / max(1, n - 1))
    delta = scale * (f - mu) / (ferr + EPS)

    vals = []
    for i in range(n - 1):
        P = delta[i] * delta[i + 1]
        vals.append(np.sign(P) * np.sqrt(np.abs(P)))

    return float(np.mean(vals))

def pre_peak_baseline(tb, fb, eb, frac=PRE_BASE_FRAC):
    """
    baseline from earliest fraction of times (robust).
    """
    tb = np.asarray(tb)
    fb = np.asarray(fb)
    eb = np.asarray(eb)
    n = len(tb)
    if n < 3:
        return np.nan, np.nan, np.nan

    k = max(2, int(np.ceil(frac * n)))
    k = min(k, n)

    base = float(np.median(fb[:k]))
    mad_pre = median_abs_dev(fb[:k])
    mederr_pre = float(np.median(eb[:k])) if k > 0 else np.nan
    return base, mad_pre, mederr_pre


def count_significant_peaks(tb, fb, eb, baseline_pre, k_sigma=PEAK_SIGMA_K):
    """
    Simple local-maximum peak count above baseline_pre + k_sigma * median_err_pre.
    """
    tb = np.asarray(tb)
    fb = np.asarray(fb)
    eb = np.asarray(eb)
    n = len(fb)
    if n < 5:
        return 0

    mederr = float(np.median(eb)) if np.isfinite(np.median(eb)) else 0.0
    thresh = baseline_pre + k_sigma * mederr

    peaks = 0
    for i in range(1, n - 1):
        if (fb[i] > fb[i - 1]) and (fb[i] > fb[i + 1]) and (fb[i] > thresh):
            peaks += 1
    return int(peaks)


def postpeak_monotonicity(tb, fb, pidx):
    """
    fraction of negative slopes after peak (monotone decline score).
    """
    tb = np.asarray(tb)
    fb = np.asarray(fb)
    if pidx is None or pidx >= len(fb) - 2:
        return np.nan

    t2 = tb[pidx:]
    f2 = fb[pidx:]
    if len(f2) < 3:
        return np.nan

    dt = np.diff(t2)
    df = np.diff(f2)
    good = dt > 0
    if not np.any(good):
        return np.nan

    frac_neg = float(np.mean((df[good] / dt[good]) < 0))
    return frac_neg


def count_rebrighten(tb, fb, baseline_pre, amp, pidx, frac=REBRIGHT_FRAC):
    """
    Count how often post-peak rises above baseline_pre + frac*amp after having dropped below it.
    """
    if pidx is None or pidx >= len(fb) - 2:
        return 0

    level = baseline_pre + frac * amp
    post = fb[pidx:]
    if len(post) < 3:
        return 0

    above = post > level
    crossings = np.sum((~above[:-1]) & (above[1:]))
    return int(crossings)


def fall_time_to_level(tb, fb, baseline_pre, amp, pidx, frac):
    """
    t_fallX: time from peak to first time flux <= baseline_pre + frac*amp
    using only decay segment.
    """
    if amp <= 0 or pidx is None:
        return np.nan

    level = baseline_pre + frac * amp
    t_dec = tb[pidx:]
    f_dec = fb[pidx:]
    if len(f_dec) < 2:
        return np.nan

    idx = np.where(f_dec <= level)[0]
    if len(idx) == 0:
        return np.nan
    return float(t_dec[idx[0]] - t_dec[0])


def decay_powerlaw_fit(tb, fb, baseline_pre, pidx, tmax=300.0):
    """
    Fit log(f-baseline) = a + b*log(dt) on post-peak points.
    Returns slope b, r2, npts.
    """
    if pidx is None or pidx >= len(fb) - 3:
        return np.nan, np.nan, 0

    t0 = tb[pidx]
    t_dec = tb[pidx:]
    f_dec = fb[pidx:]

    dt = t_dec - t0
    m = (dt > 0.0) & (dt <= tmax)
    dt = dt[m]
    fd = f_dec[m] - baseline_pre

    # must be positive
    m2 = fd > 0.0
    dt = dt[m2]
    fd = fd[m2]

    if len(dt) < 4:
        return np.nan, np.nan, int(len(dt))

    x = np.log(dt + EPS)
    y = np.log(fd + EPS)

    # linear fit
    try:
        b, a = np.polyfit(x, y, 1)
    except Exception:
        return np.nan, np.nan, int(len(dt))

    yhat = a + b * x
    ss_res = float(np.sum((y - yhat) ** 2))
    ss_tot = float(np.sum((y - np.mean(y)) ** 2)) + EPS
    r2 = 1.0 - ss_res / ss_tot

    return float(b), float(r2), int(len(dt))

In [7]:
def deextinct_band(flux, flux_err, ebv, band, r_v=R_V):
    if ebv is None or (isinstance(ebv, float) and np.isnan(ebv)):
        return flux, flux_err, 0.0

    A_V = float(ebv) * float(r_v)
    wave = np.array([EFF_WL_AA[band]], dtype=float)  # Angstrom
    A_lambda = float(fitzpatrick99(wave, A_V, r_v=r_v, unit="aa")[0])  # mag

    fac = 10.0 ** (0.4 * A_lambda)
    return flux * fac, flux_err * fac, A_lambda


def deextinct_lightcurve(lc, ebv):
    flux = lc["Flux"].to_numpy().astype(float)
    ferr = lc["Flux_err"].to_numpy().astype(float)
    filt = lc["Filter"].to_numpy()

    flux_corr = flux.copy()
    ferr_corr = ferr.copy()

    for b in FILTERS:
        m = (filt == b)
        if not np.any(m):
            continue
        flux_corr[m], ferr_corr[m], _ = deextinct_band(flux_corr[m], ferr_corr[m], ebv, b)

    return flux_corr, ferr_corr


## Global features

These are computed using all observations across all bands for a given object.  
They summarize time coverage, brightness distribution, cadence, variability, and context (redshift + dust + redshift uncertainty).

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `n_obs` | Total number of observations across all filters | Captures overall sampling density and how well-measured the object is |
| `total_time_obs` | Observed-frame time baseline: `max(t_rel) - min(t_rel)` | Separates short transients vs long events and measures overall monitoring duration |
| `total_time_rest` | Rest-frame time baseline: `total_time_obs / (1+z)` | Removes time dilation so the model compares intrinsic evolution speed across redshifts |

### Flux distribution (dust-corrected `flux_corr`)

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `flux_mean` | Mean corrected flux | Measures average intrinsic brightness level (sensitive to sustained high flux) |
| `flux_median` | Median corrected flux | Robust typical brightness baseline (less sensitive to one-off spikes) |
| `flux_std` | Standard deviation of corrected flux | Captures variability strength (high = more change over time) |
| `flux_min` | Minimum corrected flux | Captures deep fades / dips / negative excursions from noise-subtraction artifacts |
| `flux_max` | Maximum corrected flux | Captures peak brightness or flare intensity (key transient signature) |
| `flux_mad` | Median absolute deviation of corrected flux | Robust variability estimate that doesn’t get bullied by outliers |
| `flux_iqr` | Interquartile range of corrected flux | Another robust variability measure (spread of the middle 50%) |
| `flux_skew` | Skewness of corrected flux distribution | Detects asymmetric lightcurves (fast rise / slow decay vs vice versa) |
| `flux_kurt_excess` | Excess kurtosis of corrected flux distribution | Detects heavy tails/spiky behavior from rare bursts or sharp transients |
| `flux_p5` | 5th percentile of corrected flux | Robust low-end level (less sensitive than min) |
| `flux_p25` | 25th percentile of corrected flux | Lower-quartile level |
| `flux_p75` | 75th percentile of corrected flux | Upper-quartile level |
| `flux_p95` | 95th percentile of corrected flux | Robust high-end level (less sensitive than max) |
| `robust_amp_global` | Robust amplitude: `flux_p95 - flux_p5` | Outlier-resistant variability scale, often better than max-min |
| `neg_flux_frac` | Fraction of corrected flux values `< 0` | Flags noise-dominated objects or weak detections where measurements hover around zero |

### SNR (using corrected errors `err_corr`)

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `snr_median` | Median SNR where `snr = \|flux_corr\| / (err_corr + EPS)` | Typical detection quality (separates clean signals from noisy junk) |
| `snr_max` | Maximum SNR | Captures the strongest detection event (some transients “light up” briefly) |

### Cadence / gaps

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `median_dt` | Median time gap between consecutive observations in `t_rel` | Describes typical cadence (important since sparse sampling hides shape) |
| `max_gap` | Maximum time gap between consecutive observations in `t_rel` | Detects missing windows (large gaps can explain unreliable peak/width estimates) |

### Time-series variability / shape

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `eta_von_neumann` | Von Neumann eta statistic on `flux_corr` (smoothness vs jumpiness) | Separates smooth evolving curves from noisy jitter or sudden jumps |
| `chi2_const_global` | Chi-square vs constant-flux model using `err_corr` | Quantifies variability relative to measurement noise (true variability vs noise) |
| `stetsonJ_global_obs` | Stetson J (consecutive-pairs) using observed-frame times | More cadence-aware correlation metric; reduces sensitivity to irregular sampling |
| `stetsonJ_global_rest` | Stetson J (consecutive-pairs) using rest-frame times | Same correlation idea, but corrected for time dilation |

### Slopes / rate of change (global)

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `max_slope_global_obs` | Maximum absolute slope in observed time (`t_rel`) | Captures fastest brightness change (sharp rise/fall events) |
| `max_slope_global_rest` | Maximum absolute slope in rest-frame time (`t_rest`) | Intrinsic fastest change rate (removes redshift stretching) |
| `med_abs_slope_global_obs` | Median absolute slope in observed time | Typical observed change rate (slow drifters vs active transients) |
| `med_abs_slope_global_rest` | Median absolute slope in rest-frame time | Typical intrinsic change rate |
| `slope_global_obs` | Best-fit linear slope over observed time | Captures long-term trend direction (rising vs fading overall) |
| `slope_global_rest` | Best-fit linear slope over rest-frame time | Same trend, but comparable across redshifts |

### Fractional variability

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `fvar_global` | Fractional variability accounting for measurement errors | Estimates intrinsic variability strength after subtracting noise contribution |

### Context metadata

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `Z` | Redshift `z` | Encodes distance/epoch effects and shifts events into different observed regimes |
| `log1pZ` | `log(1+z)` | Stabilizes redshift scaling for models (less extreme leverage at high `z`) |
| `Z_err` | Redshift uncertainty (clipped to `>= 0`) | Captures confidence in rest-frame correction; noisy redshifts degrade timing features |
| `log1pZerr` | `log(1+Z_err)` | Stabilizes uncertainty scaling and helps tree models split more smoothly |
| `EBV` | Dust reddening used for extinction correction | Helps the model learn residual dust systematics and measurement conditions |

### Filter coverage

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `n_filters_present` | Number of filters with ≥ 1 observation | Multi-band coverage gives richer color/shape info; missing bands can correlate with class |
| `total_obs` | Total observations summed across all filters (same as `n_obs`) | Redundant but convenient sanity/coverage signal that some models exploit |

## Per-filter (band-wise) features

For each band `b ∈ {u,g,r,i,z,y}`, the following features are computed independently per filter.  
This version adds **pre-peak baseline features** and richer **post-peak decay morphology**.

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `n_{b}` | Number of observations in band `b` | Band missingness and sampling density vary by object/class and affect reliability |
| `baseline_pre_{b}` | Estimated baseline flux before the main peak (from earliest fraction of points) | Gives a cleaner “true baseline” than median when post-peak tail biases the median |
| `amp_{b}` | Peak above median baseline: `peak_flux - median(fb)` | Simple band strength; works even if pre-peak baseline is unreliable |
| `amp_pre_{b}` | Peak above pre-peak baseline: `peak_flux - baseline_pre` | Physically better peak amplitude when baseline is stable; improves peak-related shape features |
| `robust_amp_{b}` | Robust amplitude: `p95_b - p5_b` | More stable amplitude estimate when peaks/outliers are noisy |
| `tpeak_{b}_obs` | Observed-frame time of peak flux in band `b` | Captures when the band reaches maximum brightness (timing is class-dependent) |
| `tpeak_{b}_rest` | Rest-frame time of peak flux: `tpeak_obs / (1+z)` | Removes time dilation so peak timing is comparable across redshifts |
| `snrmax_{b}` | Maximum SNR within band `b` | Strongest detection in that band (some classes peak strongly only in certain filters) |
| `eta_{b}` | Von Neumann eta within band `b` | Detects smooth evolution vs noise inside a single wavelength band |
| `chi2_const_{b}` | Chi-square vs constant-flux model within band | Measures variability significance relative to band-specific noise |
| `slope_{b}_obs` | Best-fit linear slope in band over observed time | Captures overall rise/fade trend per band |
| `slope_{b}_rest` | Best-fit linear slope in band over rest-frame time | Intrinsic trend per band (comparable across redshifts) |
| `maxslope_{b}_obs` | Maximum absolute slope in band (observed time) | Captures sharpest observed change per band |
| `maxslope_{b}_rest` | Maximum absolute slope in band (rest time) | Captures sharpest intrinsic change rate per band |
| `stetsonJ_{b}_obs` | Stetson J (consecutive-pairs) in band using observed time | Cadence-aware correlation metric per band |
| `stetsonJ_{b}_rest` | Stetson J (consecutive-pairs) in band using rest time | Same, but corrected for time dilation |
| `fvar_{b}` | Fractional variability within band (noise-corrected) | Intrinsic variability strength per band |
| `p5_{b}` | 5th percentile of band flux `fb` | Robust low-end level per band |
| `p25_{b}` | 25th percentile of `fb` | Lower-quartile level per band |
| `p75_{b}` | 75th percentile of `fb` | Upper-quartile level per band |
| `p95_{b}` | 95th percentile of `fb` | Robust high-end level per band |
| `mad_{b}` | Median absolute deviation of `fb` | Robust band variability (outlier-resistant) |
| `iqr_{b}` | Interquartile range of `fb` | Robust spread of the middle 50% per band |
| `mad_over_std_{b}` | `mad_b / (std_b + EPS)` | Flags spike-dominated vs Gaussian-like variability (robustness/shape cue) |

### Post-peak fall times, widths, and sharpness (only if `amp_pre_{b} > 0`)

These use `baseline_pre_{b}` and `amp_pre_{b}` to define levels as fractions of the peak amplitude.

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `t_fall50_{b}_obs` | Observed-frame time after peak to reach `baseline_pre + 0.50 * amp_pre` | Encodes decay speed in observed time (fast vs slow fall) |
| `t_fall20_{b}_obs` | Observed-frame time after peak to reach `baseline_pre + 0.20 * amp_pre` | Longer-tail decay behavior; distinguishes slow fade vs quick drop |
| `t_fall50_{b}_rest` | Rest-frame fall time to 50% level | Intrinsic decay speed comparable across redshifts |
| `t_fall20_{b}_rest` | Rest-frame fall time to 20% level | Intrinsic late-time fading timescale |
| `width50_{b}_obs` | Observed-frame width above 50% level (time span where `fb >= base + 0.5*amp`) | Measures how long the event stays bright in observed time |
| `width80_{b}_obs` | Observed-frame width above 80% level | Captures core peak width (sharp vs broad peak) |
| `width50_{b}_rest` | Rest-frame width above 50% level | Intrinsic duration at mid-brightness |
| `width80_{b}_rest` | Rest-frame width above 80% level | Intrinsic core-peak duration |
| `sharp50_{b}_obs` | Sharpness proxy: `amp_pre / (width50_obs + EPS)` | High = tall + narrow peaks (very class-discriminative) |
| `sharp50_{b}_rest` | Sharpness proxy in rest-frame | Same idea, but intrinsic (less redshift-biased) |
| `auc_pos_{b}_obs` | Observed-frame AUC above `baseline_pre`: `∫ max(fb - baseline_pre, 0) dt` | Energy-like summary tied to true baseline, not median-biased |
| `auc_pos_{b}_rest` | Rest-frame AUC above `baseline_pre` | Comparable across redshifts; strong spectral-energy cue |

### Peak structure and post-peak behavior (only if `amp_pre_{b} > 0`)

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `peak_dominance_{b}` | `amp_pre / (mad_pre + EPS)` where `mad_pre` is pre-peak baseline MAD | Measures how dominant the peak is relative to baseline noise (real transients pop out) |
| `std_ratio_prepost_{b}` | `std(pre_seg) / (std(post_seg) + EPS)` | Captures how variability changes after peak (e.g., noisy baseline vs smooth decay) |
| `n_peaks_{b}` | Count of significant peaks above baseline (sigma-thresholded) | Distinguishes single-peaked transients from multi-peaked/variable sources |
| `postpeak_monotone_frac_{b}` | Fraction of post-peak steps that are monotonic decreasing | Smooth decays (high) vs rebrightening/AGN-like variability (low) |
| `n_rebrighten_{b}` | Count of rebrightening events after peak (relative to `amp_pre`) | Strong discriminator: rebrightening often means non-simple transient behavior |

### Decay power-law fit (post-peak, only if `amp_pre_{b} > 0`)

A power-law fit is attempted on the post-peak decay segment (up to a max time window).

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `decay_pl_slope_{b}_obs` | Fitted power-law decay slope in observed time | Encodes decay physics/shape; different classes have different decay slopes |
| `decay_pl_r2_{b}_obs` | R² of the observed-frame power-law fit | Measures how well a clean power-law explains the decay (clean transient vs messy variability) |
| `decay_pl_npts_{b}_obs` | Number of points used in the observed-frame decay fit | Reliability indicator: more points = more trustworthy slope |
| `decay_pl_slope_{b}_rest` | Fitted power-law decay slope in rest-frame time | Intrinsic decay slope, comparable across redshifts |
| `decay_pl_r2_{b}_rest` | R² of the rest-frame power-law fit | Fit quality after time dilation correction |
| `decay_pl_npts_{b}_rest` | Number of points used in the rest-frame decay fit | Reliability indicator in rest-frame |

## Multi-band peak timing dispersion

These summarize how synchronized (or not) the band peaks are across filters.

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `tpeak_std_obs` | Standard deviation of `tpeak_b_obs` across bands with peaks | Measures chromatic timing spread in observed time (class-dependent) |
| `tpeak_std_rest` | Standard deviation of `tpeak_b_rest` across bands with peaks | Intrinsic chromatic peak spread (less redshift-biased) |

## Cross-band pair features (adjacent pairs: `ug, gr, ri, iz, zy`)

For each adjacent filter pair `(a,b)`, these compare peak timing and peak flux ratios across wavelengths.

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `tpeakdiff_{a}{b}_obs` | Observed-frame peak time difference: `tpeak_a_obs - tpeak_b_obs` | Chromatic peak lag/lead in observed time (includes cadence + dilation effects) |
| `tpeakdiff_{a}{b}_rest` | Rest-frame peak time difference: `tpeak_a_rest - tpeak_b_rest` | Intrinsic chromatic lag/lead; strong class signature (blue earlier than red, etc.) |
| `peakratio_{a}{b}` | Peak flux ratio: `peak_flux_a / (peak_flux_b + EPS)` | Strong color/SED proxy without needing explicit magnitudes |

## Color features at r-peak (observed-frame) + 20/40-day evolution

These interpolate `g`, `r`, `i` flux at the observed time when the r-band peaks (`tpeak_r_obs`), then compute log-flux colors.  
They also sample the same colors at `+20` and `+40` days to capture cooling/heating trends.

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `color_gr_at_rpeak_obs` | `log1p(f_g) - log1p(f_r)` evaluated at `tpeak_r_obs` | Measures g-r color at peak, highly class-dependent |
| `color_ri_at_rpeak_obs` | `log1p(f_r) - log1p(f_i)` evaluated at `tpeak_r_obs` | Measures r-i color at peak (temperature / SED proxy) |
| `color_gr_rpeak_p20_obs` | g-r color at `tpeak_r_obs + 20` days | Captures medium-term color evolution after peak |
| `color_ri_rpeak_p20_obs` | r-i color at `tpeak_r_obs + 20` days | Same, for redder color index |
| `color_gr_rpeak_p40_obs` | g-r color at `tpeak_r_obs + 40` days | Captures longer-term cooling/heating behavior |
| `color_ri_rpeak_p40_obs` | r-i color at `tpeak_r_obs + 40` days | Longer-term evolution in redder bands |
| `color_gr_slope20_obs` | `(color_gr(+20) - color_gr(0)) / 20` | Rate of color change over 20 days (cooling slope) |
| `color_ri_slope20_obs` | `(color_ri(+20) - color_ri(0)) / 20` | Rate of red color change over 20 days |
| `color_gr_slope40_obs` | `(color_gr(+40) - color_gr(0)) / 40` | Rate of color change over 40 days (more stable, less noisy) |
| `color_ri_slope40_obs` | `(color_ri(+40) - color_ri(0)) / 40` | Rate of red color change over 40 days |

## SpecType teacher stacking features (high-level)

`add_spectype_teacher_features()` adds *legal stacking* features by training a multiclass model on **train only** to predict a grouped version of `SpecType`, then appending the predicted class probabilities as new features.

Key steps:
- Map `SpecType` → `SpecTypeGroup` (TDE, AGN, SNIa, SNother, Other)
- Train a LightGBM multiclass model using CV splits by `split`
- Create:
  - OOF probabilities for train
  - full-fit probabilities for test
- Append probabilities + entropy as new features

### Per-class probability features

For every group label `c` in `classes`:

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `p_spec_{c}` | Predicted probability the object belongs to SpecTypeGroup `c` | Injects a strong “soft label” summary of transient type, which improves the final binary classifier |

### Probability-uncertainty feature

| Feature | Meaning | Why it helps |
|--------|---------|--------------|
| `spec_entropy` | Entropy of the teacher probability vector | High entropy = teacher unsure (ambiguous object); low entropy = confident type signal (more reliable stacking) |

In [8]:
def extract_features_for_object(lc_raw, z, z_err, ebv):
    feats = {}

    lc = lc_raw.sort_values("Time (MJD)").reset_index(drop=True)

    t = lc["Time (MJD)"].to_numpy().astype(float)
    filt = lc["Filter"].to_numpy()

    if len(t) == 0:
        feats["n_obs"] = 0
        return feats

    z = safe_float(z, default=0.0)
    z_err = safe_float(z_err, default=0.0)
    ebv = safe_float(ebv, default=np.nan)

    t_rel = t - t.min()
    t_rest = t_rel / (1.0 + z)

    flux_corr, err_corr = deextinct_lightcurve(lc, ebv)

    feats["n_obs"] = int(len(t))
    feats["total_time_obs"] = float(t_rel.max() - t_rel.min())
    feats["total_time_rest"] = float(t_rest.max() - t_rest.min())

    feats["flux_mean"] = float(np.mean(flux_corr))
    feats["flux_median"] = float(np.median(flux_corr))
    feats["flux_std"] = float(np.std(flux_corr))
    feats["flux_min"] = float(np.min(flux_corr))
    feats["flux_max"] = float(np.max(flux_corr))

    feats["flux_mad"] = median_abs_dev(flux_corr)
    feats["flux_iqr"] = iqr(flux_corr)
    feats["flux_skew"] = skewness(flux_corr)
    feats["flux_kurt_excess"] = kurtosis_excess(flux_corr)

    p5, p25, p75, p95 = np.percentile(flux_corr, [5, 25, 75, 95])
    feats["flux_p5"] = float(p5)
    feats["flux_p25"] = float(p25)
    feats["flux_p75"] = float(p75)
    feats["flux_p95"] = float(p95)
    feats["robust_amp_global"] = float(p95 - p5)

    feats["neg_flux_frac"] = float(np.mean(flux_corr < 0))

    snr = np.abs(flux_corr) / (err_corr + EPS)
    feats["snr_median"] = float(np.median(snr))
    feats["snr_max"] = float(np.max(snr))

    if len(t_rel) >= 2:
        dt = np.diff(t_rel)
        feats["median_dt"] = float(np.median(dt))
        feats["max_gap"] = float(np.max(dt))
    else:
        feats["median_dt"] = np.nan
        feats["max_gap"] = np.nan

    feats["eta_von_neumann"] = von_neumann_eta(flux_corr)
    feats["chi2_const_global"] = chi2_to_constant(flux_corr, err_corr)

    feats["stetsonJ_global_obs"] = stetson_J_consecutive(t_rel, flux_corr, err_corr)
    feats["stetsonJ_global_rest"] = stetson_J_consecutive(t_rest, flux_corr, err_corr)

    feats["max_slope_global_obs"] = max_slope(t_rel, flux_corr)
    feats["max_slope_global_rest"] = max_slope(t_rest, flux_corr)

    feats["med_abs_slope_global_obs"] = median_abs_slope(t_rel, flux_corr)
    feats["med_abs_slope_global_rest"] = median_abs_slope(t_rest, flux_corr)

    feats["slope_global_obs"] = linear_slope(t_rel, flux_corr)
    feats["slope_global_rest"] = linear_slope(t_rest, flux_corr)

    feats["fvar_global"] = fractional_variability(flux_corr, err_corr)

    feats["Z"] = float(z)
    feats["log1pZ"] = float(np.log1p(max(0.0, z)))
    feats["Z_err"] = float(max(0.0, z_err))
    feats["log1pZerr"] = float(np.log1p(max(0.0, feats["Z_err"])))
    feats["EBV"] = ebv

    feats["n_filters_present"] = 0
    feats["total_obs"] = 0

    band_tpeak_obs = {}
    band_tpeak_rest = {}
    band_peak_flux = {}
    band_tb_obs = {}
    band_tb_rest = {}
    band_fb = {}

    for b in FILTERS:
        m = (filt == b)
        nb = int(np.sum(m))
        feats[f"n_{b}"] = nb
        feats["total_obs"] += nb

        for k in [
            f"amp_{b}",
            f"amp_pre_{b}",
            f"baseline_pre_{b}",
            f"robust_amp_{b}",
            f"tpeak_{b}_obs",
            f"tpeak_{b}_rest",
            f"width50_{b}_obs",
            f"width50_{b}_rest",
            f"width80_{b}_obs",
            f"width80_{b}_rest",
            f"auc_pos_{b}_obs",
            f"auc_pos_{b}_rest",
            f"snrmax_{b}",
            f"eta_{b}",
            f"chi2_const_{b}",
            f"slope_{b}_obs",
            f"slope_{b}_rest",
            f"maxslope_{b}_obs",
            f"maxslope_{b}_rest",
            f"stetsonJ_{b}_obs",
            f"stetsonJ_{b}_rest",
            f"p5_{b}",
            f"p25_{b}",
            f"p75_{b}",
            f"p95_{b}",
            f"mad_{b}",
            f"iqr_{b}",
            f"mad_over_std_{b}",
            f"fvar_{b}",
            f"t_fall50_{b}_obs",
            f"t_fall20_{b}_obs",
            f"t_fall50_{b}_rest",
            f"t_fall20_{b}_rest",
            f"sharp50_{b}_obs",
            f"sharp50_{b}_rest",
            f"peak_dominance_{b}",
            f"std_ratio_prepost_{b}",
            f"n_peaks_{b}",
            f"postpeak_monotone_frac_{b}",
            f"n_rebrighten_{b}",
            f"decay_pl_slope_{b}_obs",
            f"decay_pl_r2_{b}_obs",
            f"decay_pl_npts_{b}_obs",
            f"decay_pl_slope_{b}_rest",
            f"decay_pl_r2_{b}_rest",
            f"decay_pl_npts_{b}_rest",
        ]:
            feats[k] = np.nan

        if nb == 0:
            continue

        feats["n_filters_present"] += 1

        tb_obs = t_rel[m]
        fb = flux_corr[m]
        eb = err_corr[m]

        order = np.argsort(tb_obs)
        tb_obs = tb_obs[order]
        fb = fb[order]
        eb = eb[order]

        tb_rest = tb_obs / (1.0 + z)

        p5b, p25b, p75b, p95b = np.percentile(fb, [5, 25, 75, 95])
        feats[f"p5_{b}"] = float(p5b)
        feats[f"p25_{b}"] = float(p25b)
        feats[f"p75_{b}"] = float(p75b)
        feats[f"p95_{b}"] = float(p95b)
        feats[f"robust_amp_{b}"] = float(p95b - p5b)

        feats[f"mad_{b}"] = median_abs_dev(fb)
        feats[f"iqr_{b}"] = iqr(fb)
        stdb = float(np.std(fb))
        feats[f"mad_over_std_{b}"] = float(feats[f"mad_{b}"] / (stdb + EPS))

        base_pre, mad_pre, mederr_pre = pre_peak_baseline(tb_obs, fb, eb, frac=PRE_BASE_FRAC)
        feats[f"baseline_pre_{b}"] = float(base_pre) if np.isfinite(base_pre) else np.nan

        pidx = int(np.argmax(fb))
        peak_flux = float(fb[pidx])
        tpeak_obs = float(tb_obs[pidx])
        tpeak_rest = float(tb_rest[pidx])

        amp_median = peak_flux - float(np.median(fb))
        amp_pre = peak_flux - base_pre if np.isfinite(base_pre) else np.nan

        feats[f"amp_{b}"] = float(amp_median)
        feats[f"amp_pre_{b}"] = float(amp_pre) if np.isfinite(amp_pre) else np.nan

        feats[f"tpeak_{b}_obs"] = tpeak_obs
        feats[f"tpeak_{b}_rest"] = tpeak_rest
        feats[f"snrmax_{b}"] = float(np.max(np.abs(fb) / (eb + EPS)))

        feats[f"eta_{b}"] = von_neumann_eta(fb)
        feats[f"chi2_const_{b}"] = chi2_to_constant(fb, eb)

        feats[f"slope_{b}_obs"] = linear_slope(tb_obs, fb)
        feats[f"slope_{b}_rest"] = linear_slope(tb_rest, fb)

        feats[f"maxslope_{b}_obs"] = max_slope(tb_obs, fb)
        feats[f"maxslope_{b}_rest"] = max_slope(tb_rest, fb)

        feats[f"stetsonJ_{b}_obs"] = stetson_J_consecutive(tb_obs, fb, eb)
        feats[f"stetsonJ_{b}_rest"] = stetson_J_consecutive(tb_rest, fb, eb)

        feats[f"fvar_{b}"] = fractional_variability(fb, eb)

        if np.isfinite(amp_pre) and amp_pre > 0:
            feats[f"peak_dominance_{b}"] = float(amp_pre / (mad_pre + EPS))

            pre_seg = fb[:max(2, pidx)]
            post_seg = fb[pidx:]
            std_pre = float(np.std(pre_seg)) if len(pre_seg) >= 2 else np.nan
            std_post = float(np.std(post_seg)) if len(post_seg) >= 2 else np.nan
            if np.isfinite(std_pre) and np.isfinite(std_post):
                feats[f"std_ratio_prepost_{b}"] = float(std_pre / (std_post + EPS))

            feats[f"postpeak_monotone_frac_{b}"] = float(postpeak_monotonicity(tb_obs, fb, pidx))
            feats[f"n_peaks_{b}"] = float(count_significant_peaks(tb_obs, fb, eb, base_pre, k_sigma=PEAK_SIGMA_K))
            feats[f"n_rebrighten_{b}"] = float(count_rebrighten(tb_obs, fb, base_pre, amp_pre, pidx, frac=REBRIGHT_FRAC))
            feats[f"t_fall50_{b}_obs"] = float(fall_time_to_level(tb_obs, fb, base_pre, amp_pre, pidx, frac=0.50))
            feats[f"t_fall20_{b}_obs"] = float(fall_time_to_level(tb_obs, fb, base_pre, amp_pre, pidx, frac=0.20))
            feats[f"t_fall50_{b}_rest"] = float(fall_time_to_level(tb_rest, fb, base_pre, amp_pre, pidx, frac=0.50))
            feats[f"t_fall20_{b}_rest"] = float(fall_time_to_level(tb_rest, fb, base_pre, amp_pre, pidx, frac=0.20))
            feats[f"auc_pos_{b}_obs"] = float(trapz_safe(np.maximum(fb - base_pre, 0.0), tb_obs))
            feats[f"auc_pos_{b}_rest"] = float(trapz_safe(np.maximum(fb - base_pre, 0.0), tb_rest))

            def width_at_level(tt, ff, base, amp, frac):
                if amp <= 0 or len(ff) < 3:
                    return np.nan
                level = base + frac * amp
                above = ff >= level
                if not np.any(above):
                    return np.nan
                idx = np.where(above)[0]
                return float(tt[idx[-1]] - tt[idx[0]])

            w50_obs = width_at_level(tb_obs, fb, base_pre, amp_pre, 0.50)
            w80_obs = width_at_level(tb_obs, fb, base_pre, amp_pre, 0.80)
            w50_rest = width_at_level(tb_rest, fb, base_pre, amp_pre, 0.50)
            w80_rest = width_at_level(tb_rest, fb, base_pre, amp_pre, 0.80)

            feats[f"width50_{b}_obs"] = w50_obs
            feats[f"width80_{b}_obs"] = w80_obs
            feats[f"width50_{b}_rest"] = w50_rest
            feats[f"width80_{b}_rest"] = w80_rest

            feats[f"sharp50_{b}_obs"] = float(amp_pre / (w50_obs + EPS)) if np.isfinite(w50_obs) else np.nan
            feats[f"sharp50_{b}_rest"] = float(amp_pre / (w50_rest + EPS)) if np.isfinite(w50_rest) else np.nan

            b_obs, r2_obs, npts_obs = decay_powerlaw_fit(tb_obs, fb, base_pre, pidx, tmax=300.0)
            b_rest, r2_rest, npts_rest = decay_powerlaw_fit(tb_rest, fb, base_pre, pidx, tmax=300.0)

            feats[f"decay_pl_slope_{b}_obs"] = b_obs
            feats[f"decay_pl_r2_{b}_obs"] = r2_obs
            feats[f"decay_pl_npts_{b}_obs"] = float(npts_obs)

            feats[f"decay_pl_slope_{b}_rest"] = b_rest
            feats[f"decay_pl_r2_{b}_rest"] = r2_rest
            feats[f"decay_pl_npts_{b}_rest"] = float(npts_rest)

        band_tpeak_obs[b] = tpeak_obs
        band_tpeak_rest[b] = tpeak_rest
        band_peak_flux[b] = peak_flux
        band_tb_obs[b] = tb_obs
        band_tb_rest[b] = tb_rest
        band_fb[b] = fb

    tpeaks_obs = [band_tpeak_obs.get(b, np.nan) for b in FILTERS]
    tpeaks_rest = [band_tpeak_rest.get(b, np.nan) for b in FILTERS]

    tpeaks_obs = np.array([x for x in tpeaks_obs if np.isfinite(x)], float)
    tpeaks_rest = np.array([x for x in tpeaks_rest if np.isfinite(x)], float)

    feats["tpeak_std_obs"] = float(np.std(tpeaks_obs)) if len(tpeaks_obs) >= 2 else np.nan
    feats["tpeak_std_rest"] = float(np.std(tpeaks_rest)) if len(tpeaks_rest) >= 2 else np.nan

    pairs = [("u", "g"), ("g", "r"), ("r", "i"), ("i", "z"), ("z", "y")]
    for a, b in pairs:
        ta_obs, tb_obs = band_tpeak_obs.get(a, np.nan), band_tpeak_obs.get(b, np.nan)
        ta_rest, tb_rest = band_tpeak_rest.get(a, np.nan), band_tpeak_rest.get(b, np.nan)
        pa, pb = band_peak_flux.get(a, np.nan), band_peak_flux.get(b, np.nan)

        feats[f"tpeakdiff_{a}{b}_obs"] = (ta_obs - tb_obs) if (np.isfinite(ta_obs) and np.isfinite(tb_obs)) else np.nan
        feats[f"tpeakdiff_{a}{b}_rest"] = (ta_rest - tb_rest) if (np.isfinite(ta_rest) and np.isfinite(tb_rest)) else np.nan
        feats[f"peakratio_{a}{b}"] = (pa / (pb + EPS)) if (np.isfinite(pa) and np.isfinite(pb)) else np.nan

    def logp(x):
        if np.isnan(x):
            return np.nan
        return float(np.log1p(max(0.0, x)))

    tpr_obs = feats.get("tpeak_r_obs", np.nan)
    if np.isfinite(tpr_obs):
        def colors_at_time(t0):
            fr = interp_flux_at_time(band_tb_obs.get("r", np.array([])), band_fb.get("r", np.array([])), t0)
            fg = interp_flux_at_time(band_tb_obs.get("g", np.array([])), band_fb.get("g", np.array([])), t0)
            fi = interp_flux_at_time(band_tb_obs.get("i", np.array([])), band_fb.get("i", np.array([])), t0)

            cgr = (logp(fg) - logp(fr)) if (np.isfinite(fg) and np.isfinite(fr)) else np.nan
            cri = (logp(fr) - logp(fi)) if (np.isfinite(fr) and np.isfinite(fi)) else np.nan
            return cgr, cri

        cgr0, cri0 = colors_at_time(tpr_obs)
        feats["color_gr_at_rpeak_obs"] = cgr0
        feats["color_ri_at_rpeak_obs"] = cri0

        cgr20, cri20 = colors_at_time(tpr_obs + 20.0)
        cgr40, cri40 = colors_at_time(tpr_obs + 40.0)

        feats["color_gr_rpeak_p20_obs"] = cgr20
        feats["color_ri_rpeak_p20_obs"] = cri20
        feats["color_gr_rpeak_p40_obs"] = cgr40
        feats["color_ri_rpeak_p40_obs"] = cri40

        def slope(c1, c2, dt):
            if np.isfinite(c1) and np.isfinite(c2):
                return float((c2 - c1) / dt)
            return np.nan

        feats["color_gr_slope20_obs"] = slope(cgr0, cgr20, 20.0)
        feats["color_ri_slope20_obs"] = slope(cri0, cri20, 20.0)
        feats["color_gr_slope40_obs"] = slope(cgr0, cgr40, 40.0)
        feats["color_ri_slope40_obs"] = slope(cri0, cri40, 40.0)

    else:
        feats["color_gr_at_rpeak_obs"] = np.nan
        feats["color_ri_at_rpeak_obs"] = np.nan
        feats["color_gr_rpeak_p20_obs"] = np.nan
        feats["color_ri_rpeak_p20_obs"] = np.nan
        feats["color_gr_rpeak_p40_obs"] = np.nan
        feats["color_ri_rpeak_p40_obs"] = np.nan
        feats["color_gr_slope20_obs"] = np.nan
        feats["color_ri_slope20_obs"] = np.nan
        feats["color_gr_slope40_obs"] = np.nan
        feats["color_ri_slope40_obs"] = np.nan

    return feats

In [9]:
def build_lightcurve_cache(splits, base_dir="data", kind="train"):
    lc_cache = {}
    idx_cache = {}

    for s in splits:
        path = f"{base_dir}/{s}/{kind}_full_lightcurves.csv"
        lc = pd.read_csv(path)
        groups = lc.groupby("object_id").indices
        lc_cache[s] = lc
        idx_cache[s] = groups

    return lc_cache, idx_cache


def get_lightcurve(lc_cache, idx_cache, split, object_id):
    idx = idx_cache[split].get(object_id, None)
    if idx is None:
        return None
    return lc_cache[split].iloc[idx]


def build_feature_table(
    log_df,
    lc_cache,
    idx_cache,
    augment_photoz=False,
    test_zerr_pool=None,
    n_aug=1,
    seed=6
):
    rng = np.random.default_rng(seed)
    rows = []

    if test_zerr_pool is not None:
        test_zerr_pool = np.asarray(test_zerr_pool, float)
        test_zerr_pool = test_zerr_pool[np.isfinite(test_zerr_pool)]
        test_zerr_pool = test_zerr_pool[test_zerr_pool > 0]

    for i in range(len(log_df)):
        r = log_df.iloc[i]
        obj = r["object_id"]
        split = r["split"]

        lc = get_lightcurve(lc_cache, idx_cache, split, obj)
        if lc is None:
            feats = {"n_obs": 0}
            feats["object_id"] = obj
            feats["split"] = split
            feats["photoz_aug"] = 0
            if "target" in log_df.columns:
                feats["target"] = int(r["target"])
            rows.append(feats)
            continue

        feats = extract_features_for_object(
            lc_raw=lc,
            z=r["Z"],
            z_err=r.get("Z_err", 0.0),
            ebv=r["EBV"],
        )
        feats["object_id"] = obj
        feats["split"] = split
        feats["photoz_aug"] = 0
        if "target" in log_df.columns:
            feats["target"] = int(r["target"])
        rows.append(feats)

        if augment_photoz and ("target" in log_df.columns) and (test_zerr_pool is not None) and (len(test_zerr_pool) > 0):
            z0 = safe_float(r["Z"], default=0.0)
            for _ in range(n_aug):
                sigma = float(rng.choice(test_zerr_pool))
                z_sim = max(0.0, z0 + float(rng.normal(0.0, sigma)))

                feats2 = extract_features_for_object(
                    lc_raw=lc,
                    z=z_sim,
                    z_err=sigma,
                    ebv=r["EBV"],
                )
                feats2["object_id"] = obj
                feats2["split"] = split
                feats2["target"] = int(r["target"])
                feats2["photoz_aug"] = 1
                rows.append(feats2)

    return pd.DataFrame(rows)


In [10]:
def clean_features(df, drop_cols):
    X = df.drop(columns=drop_cols).copy()
    X = X.replace([np.inf, -np.inf], np.nan)

    med = X.median(numeric_only=True)
    X = X.fillna(med)
    X = X.fillna(0.0)
    return X


def best_threshold_f1(y_true, probs):
    ths = np.linspace(0.01, 0.99, 200)
    f1s = [f1_score(y_true, probs > t, zero_division=0) for t in ths]
    j = int(np.argmax(f1s))
    return float(ths[j]), float(f1s[j])


def best_alpha_and_threshold(y_true, p_xgb, p_lgb):
    alphas = np.linspace(0.0, 1.0, 101)
    best = (0.5, 0.5, -1.0)  # alpha, th, f1

    for a in alphas:
        p = a * p_xgb + (1.0 - a) * p_lgb
        th, f1 = best_threshold_f1(y_true, p)
        if f1 > best[2]:
            best = (float(a), float(th), float(f1))

    return best


def make_splitter(n_splits, random_state=6):
    return StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

In [11]:
def add_spectype_teacher_features(train_feat, train_log, test_feat, n_splits=10):
    """
    Legal stacking:
    - Train multiclass model to predict "SpecTypeGroup" on train
    - Generate OOF probs for train, full-fit probs for test
    - Append probs + entropy as additional features

    This uses only features that exist in both train/test.
    """

    df = train_feat.merge(train_log[["object_id", "SpecType"]], on="object_id", how="left")
    spec = df["SpecType"].fillna("Unknown").astype(str)

    def map_group(s):
        if s == "TDE":
            return "TDE"
        if s == "AGN":
            return "AGN"
        if s == "SN Ia" or s.startswith("SN Ia"):
            return "SNIa"
        if s.startswith("SN"):
            return "SNother"
        return "Other"

    spec_group = spec.map(map_group).astype(str)

    classes = sorted(spec_group.unique())
    class_to_idx = {c: i for i, c in enumerate(classes)}
    y_mc = spec_group.map(class_to_idx).to_numpy()

    X_tr = clean_features(df, drop_cols=["object_id", "split", "target", "SpecType"])
    X_te = clean_features(test_feat, drop_cols=["object_id", "split"])

    groups = df["split"].to_numpy()

    splitter = make_splitter(n_splits, random_state=6)
    split_iter = splitter.split(X_tr, y_mc, groups)

    oof = np.zeros((len(X_tr), len(classes)), dtype=float)

    base = dict(
        objective="multiclass",
        num_class=len(classes),
        metric="multi_logloss",

        n_estimators=20000,
        learning_rate=0.03,

        num_leaves=63,
        min_child_samples=5,

        subsample=0.8,
        subsample_freq=1,
        colsample_bytree=0.8,

        n_jobs=-1,
        random_state=42,
        verbosity=-1,
        force_col_wise=True
    )

    for fold, (tr_idx, va_idx) in enumerate(split_iter, 1):
        model = LGBMClassifier(**base)
        model.fit(
            X_tr.iloc[tr_idx], y_mc[tr_idx],
            eval_set=[(X_tr.iloc[va_idx], y_mc[va_idx])],
            eval_metric="multi_logloss",
            callbacks=[lgb.early_stopping(200, verbose=False)]
        )
        oof[va_idx] = model.predict_proba(
            X_tr.iloc[va_idx],
            num_iteration=model.best_iteration_
        )
        print(f"[SpecType teacher] fold {fold:02d} done")

    model_full = LGBMClassifier(**base)
    model_full.fit(X_tr, y_mc)
    p_test = model_full.predict_proba(X_te)

    def entropy(p):
        p = np.clip(p, 1e-12, 1.0)
        return -np.sum(p * np.log(p), axis=1)

    for i, c in enumerate(classes):
        train_feat[f"p_spec_{c}"] = oof[:, i]
        test_feat[f"p_spec_{c}"] = p_test[:, i]

    train_feat["spec_entropy"] = entropy(oof)
    test_feat["spec_entropy"] = entropy(p_test)

    return train_feat, test_feat

In [12]:
def run_optuna_xgb(train_feat, n_folds_tune=10, timeout_sec=28800):
    y = train_feat["target"].astype(int).to_numpy()
    groups = train_feat["split"].to_numpy()

    X = clean_features(train_feat, drop_cols=["object_id", "split", "target"])

    def objective(trial):
        params = {
            "objective": "binary:logistic",
            "eval_metric": "logloss",
            "random_state": 6,
            "n_jobs": -1,

            "tree_method": "hist",
            "device": "cuda",

            "n_estimators": trial.suggest_int("n_estimators", 800, 8000),
            "learning_rate": trial.suggest_float("learning_rate", 0.003, 0.12, log=True),

            "max_depth": trial.suggest_int("max_depth", 2, 10),
            "min_child_weight": trial.suggest_int("min_child_weight", 1, 40),

            "subsample": trial.suggest_float("subsample", 0.5, 1.0),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),

            "gamma": trial.suggest_float("gamma", 0.0, 10.0),
            "reg_alpha": trial.suggest_float("reg_alpha", 0.0, 20.0),
            "reg_lambda": trial.suggest_float("reg_lambda", 0.05, 30.0),

            "max_delta_step": trial.suggest_int("max_delta_step", 0, 10),

            "grow_policy": trial.suggest_categorical("grow_policy", ["depthwise", "lossguide"]),
        }

        if params["grow_policy"] == "lossguide":
            params["max_leaves"] = trial.suggest_int("max_leaves", 16, 256)

        scores = []

        splitter = make_splitter(n_folds_tune, random_state=6)
        split_iter = splitter.split(X, y, groups)

        for fold, (tr_idx, va_idx) in enumerate(split_iter, 1):
            X_tr, y_tr = X.iloc[tr_idx], y[tr_idx]
            X_va, y_va = X.iloc[va_idx], y[va_idx]

            neg = np.sum(y_tr == 0)
            pos = np.sum(y_tr == 1)
            params["scale_pos_weight"] = float(neg / max(1, pos))

            model = XGBClassifier(**params)
            model.fit(X_tr, y_tr, verbose=False)

            probs = model.predict_proba(X_va)[:, 1]
            ap = average_precision_score(y_va, probs)
            scores.append(ap)

            trial.report(float(np.mean(scores)), step=fold)
            if trial.should_prune():
                raise optuna.TrialPruned()

        return float(np.mean(scores))

    sampler = optuna.samplers.TPESampler(seed=6, multivariate=True, group=True)
    pruner = optuna.pruners.MedianPruner(n_startup_trials=30, n_warmup_steps=3)

    study = optuna.create_study(
        direction="maximize",
        sampler=sampler,
        pruner=pruner,
        study_name="xgb_ap_split_cv_gpu",
        storage="sqlite:///optuna_xgb_ap_gpu.db",
        load_if_exists=True
    )

    study.optimize(objective, n_trials=999999, timeout=timeout_sec)

    print("\nOptuna best AP:", study.best_value)
    print("Best params:")
    for k, v in study.best_params.items():
        print(k, "=", v)

    return study.best_params

In [13]:
def train_full_ensemble(train_feat, xgb_params, n_splits_full=20):
    y = train_feat["target"].astype(int).to_numpy()
    groups = train_feat["split"].to_numpy()

    X = clean_features(train_feat, drop_cols=["object_id", "split", "target"])

    splitter = make_splitter(n_splits_full, random_state=6)
    split_iter = splitter.split(X, y, groups)

    xgb_base = {
        "objective": "binary:logistic",
        "eval_metric": "logloss",
        "random_state": 6,
        "n_jobs": -1,
        "tree_method": "hist",
        "device": "cuda",
        **xgb_params
    }

    lgb_base = dict(
        objective="binary",
        boosting_type="gbdt",
        n_estimators=20000,
        learning_rate=0.02,
        num_leaves=63,
        max_depth=-1,
        min_child_samples=5,
        subsample=0.8,
        subsample_freq=1,
        colsample_bytree=0.8,
        reg_alpha=0.0,
        reg_lambda=0.0,
        n_jobs=-1,
        random_state=6,
        verbosity=-1
    )

    xgb_models = []
    lgb_models = []

    oof_xgb = np.zeros(len(X), dtype=float)
    oof_lgb = np.zeros(len(X), dtype=float)

    for fold, (tr_idx, va_idx) in enumerate(split_iter, 1):
        X_tr, y_tr = X.iloc[tr_idx], y[tr_idx]
        X_va, y_va = X.iloc[va_idx], y[va_idx]

        neg = np.sum(y_tr == 0)
        pos = np.sum(y_tr == 1)
        spw = float(neg / max(1, pos))

        xgb_base["scale_pos_weight"] = spw
        xgb_model = XGBClassifier(**xgb_base)
        xgb_model.fit(X_tr, y_tr, verbose=False)
        p_xgb = xgb_model.predict_proba(X_va)[:, 1]
        oof_xgb[va_idx] = p_xgb
        xgb_models.append(xgb_model)

        lgb_model = LGBMClassifier(**{**lgb_base, "scale_pos_weight": spw})
        lgb_model.fit(
            X_tr, y_tr,
            eval_set=[(X_va, y_va)],
            eval_metric="binary_logloss",
            callbacks=[lgb.early_stopping(200, verbose=False)]
        )
        p_lgb = lgb_model.predict_proba(X_va, num_iteration=lgb_model.best_iteration_)[:, 1]
        oof_lgb[va_idx] = p_lgb
        lgb_models.append(lgb_model)

        p_tmp = 0.5 * p_xgb + 0.5 * p_lgb
        th, f1 = best_threshold_f1(y_va, p_tmp)
        print(f"Fold {fold:02d} | temp blend(0.5) best F1={f1:.4f} @ th={th:.3f}")

    alpha_best, th_best, f1_best = best_alpha_and_threshold(y, oof_xgb, oof_lgb)
    print("\nOOF best alpha:", alpha_best)
    print("OOF best threshold:", th_best)
    print("OOF blended best F1:", f1_best)

    return xgb_models, lgb_models, alpha_best, th_best


def predict_ensemble(test_feat, xgb_models, lgb_models, alpha):
    X_test = clean_features(test_feat, drop_cols=["object_id", "split"])

    p_xgb = np.mean([m.predict_proba(X_test)[:, 1] for m in xgb_models], axis=0)
    p_lgb = np.mean([m.predict_proba(X_test)[:, 1] for m in lgb_models], axis=0)

    p_blend = alpha * p_xgb + (1.0 - alpha) * p_lgb
    return p_blend

In [None]:
USE_SPECTYPE_TEACHER = True

from pathlib import Path
ROOT = Path.cwd().parents[1]
DATA_DIR = ROOT / "data"

train_log = pd.read_csv(DATA_DIR / "train_log.csv")
test_log  = pd.read_csv(DATA_DIR / "test_log.csv")

if "Z_err" not in train_log.columns:
    train_log["Z_err"] = 0.0
train_log["Z_err"] = train_log["Z_err"].fillna(0.0)

if "Z_err" not in test_log.columns:
    test_log["Z_err"] = 0.0
test_log["Z_err"] = test_log["Z_err"].fillna(0.0)

train_splits = sorted(train_log["split"].unique())
test_splits = sorted(test_log["split"].unique())

train_lc_cache, train_idx_cache = build_lightcurve_cache(train_splits, base_dir="data", kind="train")
test_lc_cache, test_idx_cache = build_lightcurve_cache(test_splits, base_dir="data", kind="test")

In [15]:
test_zerr_pool = test_log["Z_err"].dropna().values

train_feat = build_feature_table(
    train_log, train_lc_cache, train_idx_cache,
    augment_photoz=True,
    test_zerr_pool=test_zerr_pool,
    n_aug=1,
    seed=6
)
test_feat = build_feature_table(test_log, test_lc_cache, test_idx_cache)
print("train_feat:", train_feat.shape)
print("test_feat :", test_feat.shape)

train_feat: (6086, 353)
test_feat : (7135, 352)


In [13]:
train_feat, test_feat = add_spectype_teacher_features(train_feat, train_log, test_feat, n_splits=10)
best_xgb_params = run_optuna_xgb(train_feat, n_folds_tune=10, timeout_sec=28800)

[SpecType teacher] fold 01 done
[SpecType teacher] fold 02 done
[SpecType teacher] fold 03 done
[SpecType teacher] fold 04 done
[SpecType teacher] fold 05 done
[SpecType teacher] fold 06 done
[SpecType teacher] fold 07 done
[SpecType teacher] fold 08 done
[SpecType teacher] fold 09 done
[SpecType teacher] fold 10 done


  optuna_warn(
  optuna_warn(
[32m[I 2026-01-23 02:30:22,444][0m Using an existing study with name 'xgb_ap_split_cv_gpu' instead of creating a new one.[0m
Potential solutions:
- Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict.


  return func(**kwargs)
[32m[I 2026-01-23 02:36:52,974][0m Trial 54 finished with value: 0.5720202657858958 and parameters: {'n_estimators': 5517, 'learning_rate': 0.007468075888488339, 'max_depth': 7, 'min_child_weight': 1, 'subsample': 0.535318379384974, 'colsample_bytree': 0.8508864478794678, 'gamma': 1.1792960319994241, 'reg_alpha': 4.043209140453466, 'reg_lambda': 8.834042521353396, 'max_delta_step': 2, 'grow_policy': 'lossguide', 'max_leaves': 256}. Best is trial 54 with value: 0.5720202657858958.[0m
[32m[I 2026-01-23 02:43:39,548][0m Trial 55 finished with value: 0.57956268187457 and parameters: {'n_estimators': 5839, 'learning_rate': 0.005438201455415597, 'max_depth


Optuna best AP: 0.6134734232399863
Best params:
n_estimators = 4770
learning_rate = 0.009408348066891026
max_depth = 5
min_child_weight = 38
subsample = 0.9580408244820326
colsample_bytree = 0.5859885271647445
gamma = 8.679259249940205
reg_alpha = 17.53744145401043
reg_lambda = 24.224933334472816
max_delta_step = 2
grow_policy = depthwise


LGBM Classifier is trained to predict SpecTypeGroup. A kaggle user pointed out this method. To focus on light curve features that distinguish TDE from SN and AGN. Becaus SpecType is only in the train set and not in the test set, a seperate model is trained to predict SpecType to create additional features.
Trial 204 finished with value: 0.6134734232399863 and parameters: {'n_estimators': 4770, 'learning_rate': 0.009408348066891026, 'max_depth': 5, 'min_child_weight': 38, 'subsample': 0.9580408244820326, 'colsample_bytree': 0.5859885271647445, 'gamma': 8.679259249940205, 'reg_alpha': 17.53744145401043, 'reg_lambda': 24.224933334472816, 'max_delta_step': 2, 'grow_policy': 'depthwise'}

In [14]:
xgb_models, lgb_models, alpha_best, best_th = train_full_ensemble(
    train_feat, best_xgb_params, n_splits_full=len(train_splits)
)

Fold 01 | temp blend(0.5) best F1=0.7500 @ th=0.502
Fold 02 | temp blend(0.5) best F1=0.5714 @ th=0.143
Fold 03 | temp blend(0.5) best F1=0.6061 @ th=0.202
Fold 04 | temp blend(0.5) best F1=0.0000 @ th=0.010
Fold 05 | temp blend(0.5) best F1=0.5455 @ th=0.305
Fold 06 | temp blend(0.5) best F1=0.6250 @ th=0.463
Fold 07 | temp blend(0.5) best F1=0.7368 @ th=0.301
Fold 08 | temp blend(0.5) best F1=0.6667 @ th=0.143
Fold 09 | temp blend(0.5) best F1=0.8333 @ th=0.562
Fold 10 | temp blend(0.5) best F1=0.9231 @ th=0.438
Fold 11 | temp blend(0.5) best F1=0.3922 @ th=0.246
Fold 12 | temp blend(0.5) best F1=0.7500 @ th=0.315
Fold 13 | temp blend(0.5) best F1=0.6957 @ th=0.291
Fold 14 | temp blend(0.5) best F1=0.5714 @ th=0.148
Fold 15 | temp blend(0.5) best F1=0.6875 @ th=0.355
Fold 16 | temp blend(0.5) best F1=0.0000 @ th=0.010
Fold 17 | temp blend(0.5) best F1=0.6154 @ th=0.399
Fold 18 | temp blend(0.5) best F1=0.5000 @ th=0.566
Fold 19 | temp blend(0.5) best F1=0.8980 @ th=0.443
Fold 20 | te

In [15]:
test_probs = predict_ensemble(test_feat, xgb_models, lgb_models, alpha=alpha_best)
test_pred = (test_probs > best_th).astype(int)

sub = pd.DataFrame({
    "object_id": test_feat["object_id"].values,
    "target": test_pred
})
sub.to_csv("XGB-LGBM-2.csv", index=False)
print("Saved XGB-LGBM-2.csv | alpha:", alpha_best, "| threshold:", best_th)

Saved XGB-LGBM-2.csv | alpha: 0.2 | threshold: 0.18728643216080404


In [17]:
from xgboost import XGBClassifier
import numpy as np
import pandas as pd

y = train_feat["target"].astype(int).to_numpy()
groups = train_feat["split"].to_numpy()
X = clean_features(train_feat, drop_cols=["object_id", "split", "target"])

X_test = clean_features(test_feat, drop_cols=["object_id", "split"])

xgb_base = {
    "objective": "binary:logistic",
    "eval_metric": "logloss",
    "random_state": 67,
    "n_jobs": -1,
    "tree_method": "hist",
    "device": "cuda",
    "n_estimators" : 4770,
    "learning_rate" : 0.009408348066891026,
    "max_depth" : 5,
    "min_child_weight" : 38,
    "subsample" : 0.9580408244820326,
    "colsample_bytree" : 0.5859885271647445,
    "gamma" : 8.679259249940205,
    "reg_alpha" : 17.53744145401043,
    "reg_lambda" : 24.224933334472816,
    "max_delta_step" : 2,
    "grow_policy" : "depthwise"
}

splitter = make_splitter(n_splits=len(train_splits), random_state=6)
split_iter = splitter.split(X, y, groups)

oof = np.zeros(len(X), dtype=float)

for fold, (tr_idx, va_idx) in enumerate(split_iter, 1):
    X_tr, y_tr = X.iloc[tr_idx], y[tr_idx]
    X_va, y_va = X.iloc[va_idx], y[va_idx]

    neg = np.sum(y_tr == 0)
    pos = np.sum(y_tr == 1)
    spw = float(neg / max(1, pos))

    model = XGBClassifier(**{**xgb_base, "scale_pos_weight": spw})
    model.fit(X_tr, y_tr, verbose=False)

    oof[va_idx] = model.predict_proba(X_va)[:, 1]

    th, f1 = best_threshold_f1(y_va, oof[va_idx])
    print(f"Fold {fold:02d} | XGB best F1={f1:.4f} @ th={th:.3f}")

best_th, best_f1 = best_threshold_f1(y, oof)
print("\nOOF XGB best threshold:", best_th)
print("OOF XGB best F1:", best_f1)

neg = np.sum(y == 0)
pos = np.sum(y == 1)
spw_full = float(neg / max(1, pos))

final_model = XGBClassifier(**{**xgb_base, "scale_pos_weight": spw_full})
final_model.fit(X, y, verbose=False)

test_probs = final_model.predict_proba(X_test)[:, 1]
test_pred = (test_probs > best_th).astype(int)

sub = pd.DataFrame({
    "object_id": test_feat["object_id"].values,
    "target": test_pred
})
sub.to_csv("XGB-only3.csv", index=False)
print("Saved XGB-only2.csv | threshold:", best_th)


Fold 01 | XGB best F1=0.8571 @ th=0.571
Fold 02 | XGB best F1=0.4000 @ th=0.537
Fold 03 | XGB best F1=0.4444 @ th=0.335
Fold 04 | XGB best F1=0.0000 @ th=0.010
Fold 05 | XGB best F1=0.5000 @ th=0.488
Fold 06 | XGB best F1=0.6829 @ th=0.463
Fold 07 | XGB best F1=0.6667 @ th=0.468
Fold 08 | XGB best F1=0.5333 @ th=0.389
Fold 09 | XGB best F1=0.6667 @ th=0.517
Fold 10 | XGB best F1=0.6667 @ th=0.586
Fold 11 | XGB best F1=0.3846 @ th=0.207
Fold 12 | XGB best F1=0.6000 @ th=0.246
Fold 13 | XGB best F1=0.7660 @ th=0.502
Fold 14 | XGB best F1=0.6753 @ th=0.374
Fold 15 | XGB best F1=0.6875 @ th=0.547
Fold 16 | XGB best F1=0.0000 @ th=0.010
Fold 17 | XGB best F1=0.5000 @ th=0.251
Fold 18 | XGB best F1=0.5714 @ th=0.670
Fold 19 | XGB best F1=0.8000 @ th=0.424
Fold 20 | XGB best F1=0.5405 @ th=0.286

OOF XGB best threshold: 0.46798994974874375
OOF XGB best F1: 0.5531914893617021
Saved XGB-only2.csv | threshold: 0.46798994974874375
