# Model 5: Teacher Features + Feature Selection + Optuna (OOF F1) + Multiseed XGBoost

This notebook contains Model 5 for the MALLORN challenge. Highest performing model out of all of them.

Model 5 builds on the teacher-stacking concept from Model 4, but pushes further in three directions:
1) Richer time-series and cross-band physics-inspired features
2) Feature selection using XGBoost gain importance
3) Direct optimization for OOF F1 using split-aware Optuna, followed by multiseed training/inference

## Results

Best parameters:
- n_estimators: 9476
- learning_rate: 0.0024306289953670325
- max_depth: 7
- min_child_weight: 6
- subsample: 0.5344962939912224
- colsample_bytree: 0.464696420753079
- colsample_bylevel: 0.8146569410634974
- colsample_bynode: 0.7285475291884695
- max_bin: 181
- gamma: 8.476938947246458
- reg_alpha: 0.44957196104419117
- reg_lambda: 5.23806334613521
- max_delta_step: 0
- grow_policy: depthwise

OOF multiseed best threshold: 0.419  
OOF multiseed best F1: 0.6243705941591138  
OOF AP (aucpr-ish): 0.5164263302782434  

| Submission | Public LB F1 | Private LB F1 |
|-------------|--------------|----------------|
| 1 | 0.6222 | 0.6231 |
| 2 | 0.6358 | 0.6413 |
| 3 | 0.5830 | 0.5962 |
| 4 | 0.6304 | 0.6491 |
| 5 | 0.5840 | 0.6100 |


Key upgrades vs Model 4:
- Expanded feature set (seasonality, structure function, SED slope fits, Bazin fits + rise/fall asymmetry)
- SpecType teacher grouping expanded (includes SLSN, SNII)
- Missing-value indicators added (`*_isnan`)
- Gain-based Top-K feature selection before Optuna
- Optuna objective returns OOF F1 (not AP/logloss)
- Final prediction is multiseed averaged XGB + global OOF-optimized threshold

In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score
from sklearn.metrics import average_precision_score
from sklearn.model_selection import StratifiedGroupKFold
from xgboost import XGBClassifier
from extinction import fitzpatrick99
import lightgbm as lgb
from lightgbm import LGBMClassifier
import optuna
from scipy.optimize import curve_fit
import xgboost as xgb

FILTERS = ["u", "g", "r", "i", "z", "y"]

EFF_WL_AA = {
    "u": 3641.0,
    "g": 4704.0,
    "r": 6155.0,
    "i": 7504.0,
    "z": 8695.0,
    "y": 10056.0,
}

R_V = 3.1
PRE_BASE_FRAC = 0.20
MIN_BAND_POINTS = 5
PEAK_SIGMA_K = 3.0
REBRIGHT_FRAC = 0.30
EPS = 1e-8

SEASON_GAP_DAYS = 90.0

SF_LAGS = [5.0, 10.0, 20.0, 50.0, 100.0]


## Helper Functions

This notebook reuses the same helper functions from prior models (robust stats, slopes, chi2, Stetson J, de-extinction, peak logic, etc).

New helper functions introduced in Model 5:
- `seasonality_features`
- `structure_function_lags`
- `peak_vs_wavelength_slope`
- `sed_logflux_vs_loglambda_at_time`
- Bazin fit functions (`fit_bazin`, etc.)

Docstring generated by AI after completion.

In [None]:
def safe_float(x, default=np.nan):
    """Convert value to float safely.

    Returns default if x is None, NaN, or cannot be converted.
    Prevents crashes from bad metadata inputs.
    """
    try:
        if x is None:
            return default
        x = float(x)
        if np.isnan(x):
            return default
        return x
    except Exception:
        return default


def trapz_safe(y, x):
    """Compute trapezoidal integral safely.

    Uses numpy.trapezoid if available, otherwise manual formula.
    Returns NaN if fewer than two x points.
    """
    if hasattr(np, "trapezoid"):
        return float(np.trapezoid(y, x))
    y = np.asarray(y)
    x = np.asarray(x)
    if len(x) < 2:
        return np.nan
    return float(np.sum((x[1:] - x[:-1]) * (y[1:] + y[:-1]) * 0.5))


def median_abs_dev(x):
    """Median absolute deviation (robust spread measure)."""
    x = np.asarray(x)
    if len(x) == 0:
        return np.nan
    med = np.median(x)
    return float(np.median(np.abs(x - med)))


def iqr(x):
    """Interquartile range (Q75 − Q25)."""
    x = np.asarray(x)
    if len(x) < 2:
        return np.nan
    q75, q25 = np.percentile(x, [75, 25])
    return float(q75 - q25)


def skewness(x):
    """Compute skewness (distribution asymmetry)."""
    x = np.asarray(x)
    n = len(x)
    if n < 3:
        return np.nan
    mu = np.mean(x)
    s = np.std(x)
    if s < 1e-12:
        return 0.0
    m3 = np.mean((x - mu) ** 3)
    return float(m3 / (s ** 3))


def kurtosis_excess(x):
    """Compute excess kurtosis (tail heaviness vs normal)."""
    x = np.asarray(x)
    n = len(x)
    if n < 4:
        return np.nan
    mu = np.mean(x)
    s = np.std(x)
    if s < 1e-12:
        return 0.0
    m4 = np.mean((x - mu) ** 4)
    return float(m4 / (s ** 4) - 3.0)


def von_neumann_eta(x):
    """Von Neumann variability ratio using successive differences."""
    x = np.asarray(x)
    n = len(x)
    if n < 3:
        return np.nan
    v = np.var(x)
    if v < 1e-12:
        return 0.0
    dif = np.diff(x)
    return float(np.mean(dif ** 2) / v)


def max_slope(t, f):
    """Maximum absolute slope between consecutive points."""
    t = np.asarray(t)
    f = np.asarray(f)
    if len(t) < 3:
        return np.nan
    dt = np.diff(t)
    df = np.diff(f)
    good = dt > 0
    if not np.any(good):
        return np.nan
    slopes = df[good] / dt[good]
    return float(np.max(np.abs(slopes)))


def median_abs_slope(t, f):
    """Median absolute slope between consecutive points."""
    t = np.asarray(t)
    f = np.asarray(f)
    if len(t) < 3:
        return np.nan
    dt = np.diff(t)
    df = np.diff(f)
    good = dt > 0
    if not np.any(good):
        return np.nan
    slopes = df[good] / dt[good]
    return float(np.median(np.abs(slopes)))


def linear_slope(t, f):
    """Slope of best-fit line f(t) using linear regression."""
    t = np.asarray(t)
    f = np.asarray(f)
    if len(t) < 3:
        return np.nan
    try:
        a, b = np.polyfit(t, f, 1)
        return float(a)
    except Exception:
        return np.nan


def chi2_to_constant(f, ferr):
    """Reduced chi-square vs constant (median) flux model."""
    f = np.asarray(f)
    ferr = np.asarray(ferr)
    n = len(f)
    if n < 3:
        return np.nan
    mu = np.median(f)
    denom = (ferr + EPS) ** 2
    chi2 = np.sum((f - mu) ** 2 / denom)
    dof = max(1, n - 1)
    return float(chi2 / dof)


def interp_flux_at_time(tb, fb, t0):
    """Linearly interpolate flux at time t0 within range."""
    tb = np.asarray(tb)
    fb = np.asarray(fb)
    if len(tb) < 2:
        return np.nan
    if (t0 < tb.min()) or (t0 > tb.max()):
        return np.nan
    return float(np.interp(t0, tb, fb))


def interp_err_at_time(tb, eb, t0):
    """Linearly interpolate flux error at time t0."""
    tb = np.asarray(tb)
    eb = np.asarray(eb)
    if len(tb) < 2:
        return np.nan
    if (t0 < tb.min()) or (t0 > tb.max()):
        return np.nan
    return float(np.interp(t0, tb, eb))


def fractional_variability(f, ferr):
    """Fractional intrinsic variability after error correction."""
    f = np.asarray(f, float)
    ferr = np.asarray(ferr, float)
    n = len(f)
    if n < 3:
        return np.nan
    mu = np.mean(f)
    if np.abs(mu) < 1e-8:
        return np.nan
    s2 = np.var(f, ddof=1)
    mean_err2 = np.mean(ferr ** 2)
    excess = max(0.0, s2 - mean_err2)
    return float(np.sqrt(excess) / np.abs(mu))


def stetson_J_consecutive(t, f, ferr):
    """Stetson J variability index using consecutive pairs."""
    t = np.asarray(t)
    f = np.asarray(f)
    ferr = np.asarray(ferr)
    n = len(t)
    if n < 4:
        return np.nan
    mu = np.mean(f)
    scale = np.sqrt(n / max(1, n - 1))
    delta = scale * (f - mu) / (ferr + EPS)
    vals = []
    for i in range(n - 1):
        P = delta[i] * delta[i + 1]
        vals.append(np.sign(P) * np.sqrt(np.abs(P)))
    return float(np.mean(vals))


def pre_peak_baseline(tb, fb, eb, frac=PRE_BASE_FRAC):
    """Estimate baseline stats from early light-curve segment."""
    tb = np.asarray(tb)
    fb = np.asarray(fb)
    eb = np.asarray(eb)
    n = len(tb)
    if n < 3:
        return np.nan, np.nan, np.nan
    k = max(2, int(np.ceil(frac * n)))
    k = min(k, n)
    base = float(np.median(fb[:k]))
    mad_pre = median_abs_dev(fb[:k])
    mederr_pre = float(np.median(eb[:k])) if k > 0 else np.nan
    return base, mad_pre, mederr_pre


def count_significant_peaks(tb, fb, eb, baseline_pre, k_sigma=PEAK_SIGMA_K):
    """Count local peaks above noise-scaled baseline threshold."""
    tb = np.asarray(tb)
    fb = np.asarray(fb)
    eb = np.asarray(eb)
    n = len(fb)
    if n < 5:
        return 0
    mederr = float(np.median(eb)) if np.isfinite(np.median(eb)) else 0.0
    thresh = baseline_pre + k_sigma * mederr
    peaks = 0
    for i in range(1, n - 1):
        if (fb[i] > fb[i - 1]) and (fb[i] > fb[i + 1]) and (fb[i] > thresh):
            peaks += 1
    return int(peaks)


def postpeak_monotonicity(tb, fb, pidx):
    """Fraction of negative slopes after peak index."""
    tb = np.asarray(tb)
    fb = np.asarray(fb)
    if pidx is None or pidx >= len(fb) - 2:
        return np.nan
    t2 = tb[pidx:]
    f2 = fb[pidx:]
    if len(f2) < 3:
        return np.nan
    dt = np.diff(t2)
    df = np.diff(f2)
    good = dt > 0
    if not np.any(good):
        return np.nan
    return float(np.mean((df[good] / dt[good]) < 0))


def count_rebrighten(tb, fb, baseline_pre, amp, pidx, frac=REBRIGHT_FRAC):
    """Count post-peak upward crossings above fractional amplitude."""
    if pidx is None or pidx >= len(fb) - 2:
        return 0
    level = baseline_pre + frac * amp
    post = fb[pidx:]
    if len(post) < 3:
        return 0
    above = post > level
    crossings = np.sum((~above[:-1]) & (above[1:]))
    return int(crossings)


def fall_time_to_level(tb, fb, baseline_pre, amp, pidx, frac):
    """Time from peak to decay below fractional amplitude."""
    if amp <= 0 or pidx is None:
        return np.nan
    level = baseline_pre + frac * amp
    t_dec = tb[pidx:]
    f_dec = fb[pidx:]
    if len(f_dec) < 2:
        return np.nan
    idx = np.where(f_dec <= level)[0]
    if len(idx) == 0:
        return np.nan
    return float(t_dec[idx[0]] - t_dec[0])


def rise_time_to_level(tb, fb, baseline_pre, amp, pidx, frac):
    """Time to rise from level to peak."""
    if amp <= 0 or pidx is None or pidx < 2:
        return np.nan
    level = baseline_pre + frac * amp
    t_pre = tb[:pidx + 1]
    f_pre = fb[:pidx + 1]
    idx = np.where(f_pre >= level)[0]
    if len(idx) == 0:
        return np.nan
    return float(t_pre[-1] - t_pre[idx[0]])


def decay_powerlaw_fit(tb, fb, baseline_pre, pidx, tmax=300.0):
    """Fit post-peak decay to power-law in log-log space.

    Returns slope, R², and number of points used.
    """
    if pidx is None or pidx >= len(fb) - 3:
        return np.nan, np.nan, 0
    t0 = tb[pidx]
    t_dec = tb[pidx:]
    f_dec = fb[pidx:]
    dt = t_dec - t0
    m = (dt > 0.0) & (dt <= tmax)
    dt = dt[m]
    fd = f_dec[m] - baseline_pre
    m2 = fd > 0.0
    dt = dt[m2]
    fd = fd[m2]
    if len(dt) < 4:
        return np.nan, np.nan, int(len(dt))
    x = np.log(dt + EPS)
    y = np.log(fd + EPS)
    try:
        b, a = np.polyfit(x, y, 1)
    except Exception:
        return np.nan, np.nan, int(len(dt))
    yhat = a + b * x
    ss_res = float(np.sum((y - yhat) ** 2))
    ss_tot = float(np.sum((y - np.mean(y)) ** 2)) + EPS
    r2 = 1.0 - ss_res / ss_tot
    return float(b), float(r2), int(len(dt))


def signed_log1p(x):
    """Signed log(1+|x|) transform for scale compression."""
    x = float(x)
    return float(np.sign(x) * np.log1p(np.abs(x)))


def deextinct_band(flux, flux_err, ebv, band, r_v=R_V):
    """Apply dust extinction correction to one band."""
    if ebv is None or (isinstance(ebv, float) and np.isnan(ebv)):
        return flux, flux_err
    A_V = float(ebv) * float(r_v)
    wave = np.array([EFF_WL_AA[band]], dtype=float)
    A_lambda = float(fitzpatrick99(wave, A_V, r_v=r_v, unit="aa")[0])
    fac = 10.0 ** (0.4 * A_lambda)
    return flux * fac, flux_err * fac


def deextinct_lightcurve(lc, ebv):
    """Apply extinction correction across all filters."""
    flux = lc["Flux"].to_numpy().astype(float)
    ferr = lc["Flux_err"].to_numpy().astype(float)
    filt = lc["Filter"].to_numpy()
    flux_corr = flux.copy()
    ferr_corr = ferr.copy()
    for b in FILTERS:
        m = (filt == b)
        if not np.any(m):
            continue
        flux_corr[m], ferr_corr[m] = deextinct_band(flux_corr[m], ferr_corr[m], ebv, b)
    return flux_corr, ferr_corr


def band_corr(tt_a, ff_a, tt_b, ff_b, n_grid=30):
    """Correlation between two bands via interpolated grid."""
    tt_a = np.asarray(tt_a, float)
    ff_a = np.asarray(ff_a, float)
    tt_b = np.asarray(tt_b, float)
    ff_b = np.asarray(ff_b, float)

    if len(tt_a) < 3 or len(tt_b) < 3:
        return np.nan

    tmin = max(tt_a.min(), tt_b.min())
    tmax = min(tt_a.max(), tt_b.max())
    if (tmax - tmin) < 5.0:
        return np.nan

    grid = np.linspace(tmin, tmax, n_grid)
    fa = np.interp(grid, tt_a, ff_a)
    fb = np.interp(grid, tt_b, ff_b)

    sa = np.std(fa)
    sb = np.std(fb)
    if sa < 1e-12 or sb < 1e-12:
        return 0.0
    return float(np.corrcoef(fa, fb)[0, 1])

EPS = 1e-8

def sigmoid(x):
    x = np.clip(x, -60.0, 60.0)
    return 1.0 / (1.0 + np.exp(-x))


def bazin_stable(t, A, t0, trise, tfall, B, eps=EPS):
    """
    Numerically stable Bazin-like function:
        f(t) = A * exp(-(t-t0)/tfall) * sigmoid((t-t0)/trise) + B
    """
    trise = np.maximum(trise, eps)
    tfall = np.maximum(tfall, eps)

    x = (t - t0) / trise
    exp_term = np.exp(np.clip(-(t - t0) / tfall, -60.0, 60.0))
    return A * exp_term * sigmoid(x) + B


def should_fit_bazin(tb, fb, eb, min_points=8, amp_sigma=3.0):
    """
    Gate: only fit when there is enough data and a detectable transient-like signal.
    """
    tb = np.asarray(tb, float)
    fb = np.asarray(fb, float)
    eb = np.asarray(eb, float)

    if len(tb) < min_points:
        return False

    mederr = float(np.median(eb)) if np.isfinite(np.median(eb)) else np.inf
    if not np.isfinite(mederr) or mederr <= 0:
        return False

    amp = float(np.percentile(fb, 95) - np.percentile(fb, 5))
    if not np.isfinite(amp) or amp < amp_sigma * mederr:
        return False

    if float(np.std(fb)) < 1e-10:
        return False

    return True


def fit_bazin(tb, fb, eb):
    """
    Returns (A, t0, trise, tfall, B, chi2_red) on success.
    Returns (nan...nan) on failure or gate fails.
    """
    nan_out = (np.nan, np.nan, np.nan, np.nan, np.nan, np.nan)


    tb = np.asarray(tb, float)
    fb = np.asarray(fb, float)
    eb = np.asarray(eb, float)

    order = np.argsort(tb)
    t = tb[order]
    f = fb[order]
    e = eb[order]

    m = np.isfinite(t) & np.isfinite(f) & np.isfinite(e)
    t, f, e = t[m], f[m], e[m]
    if len(t) < 3:
        return nan_out

    e = np.maximum(e, 1e-6)

    if not should_fit_bazin(t, f, e, min_points=8, amp_sigma=3.0):
        return nan_out

    B0 = float(np.median(f))
    A0 = float(max(1e-6, np.percentile(f, 95) - B0))
    t0_0 = float(t[int(np.argmax(f))])
    tr0 = 20.0
    tf0 = 60.0
    p0 = [A0, t0_0, tr0, tf0, B0]

    tmin, tmax = float(t.min()), float(t.max())
    iqr = float(np.percentile(f, 75) - np.percentile(f, 25))
    amp = float(max(1e-6, np.percentile(f, 95) - np.percentile(f, 5)))

    lo = [0.0, tmin - 50.0, 0.5, 1.0, B0 - 5.0 * (iqr + 1e-6)]
    hi = [10.0 * amp, tmax + 50.0, 200.0, 600.0, B0 + 5.0 * (iqr + 1e-6)]

    try:
        popt, _ = curve_fit(
            bazin_stable, t, f,
            p0=p0,
            sigma=e,
            absolute_sigma=True,
            bounds=(lo, hi),
            maxfev=5000
        )

        fhat = bazin_stable(t, *popt)
        resid = (f - fhat) / e
        chi2 = float(np.sum(resid * resid))
        dof = max(1, len(t) - len(popt))
        chi2_red = chi2 / dof

        A, t0, trise, tfall, B = [float(x) for x in popt]
        return (A, t0, trise, tfall, B, float(chi2_red))

    except Exception:
        return nan_out

## Seasonality Features

Astronomical time-series often have seasonal observing gaps.

This computes:
- `n_seasons`: number of observing segments separated by gaps > `SEASON_GAP_DAYS`
- `season_maxspan`: max time span of any segment
- `season_meanspan`: mean time span of segments

These are useful because:
- Some classes are observed continuously, others only show up in narrow windows
- Seasonal gaps can distort measured peak width/shape

In [None]:
def seasonality_features(tb):
    tb = np.asarray(tb, float)
    if len(tb) < 2:
        return np.nan, np.nan, np.nan
    dt = np.diff(tb)
    breaks = np.where(dt > SEASON_GAP_DAYS)[0]
    seg_starts = [0] + (breaks + 1).tolist()
    seg_ends = breaks.tolist() + [len(tb) - 1]
    spans = []
    for s, e in zip(seg_starts, seg_ends):
        spans.append(tb[e] - tb[s])
    spans = np.asarray(spans, float)
    n_seasons = float(len(spans))
    return n_seasons, float(np.max(spans)), float(np.mean(spans))

## Structure Function Features (Variability vs Timescale)

A structure function measures typical variability amplitude at different time lags.

For each lag in `SF_LAGS`, compute:
- `sf_medabs_<lag>`: median absolute flux difference for pairs separated by that lag (± tolerance)
- `sf_n_<lag>`: number of qualifying pairs

This helps differentiate:
- smooth transients vs noisy stochastic variability (AGN-like)
- fast vs slow-changing behavior

In [None]:
def structure_function_lags(tb, fb, lags=SF_LAGS):
    tb = np.asarray(tb, float)
    fb = np.asarray(fb, float)
    n = len(tb)
    out = {}
    if n < 6:
        for lag in lags:
            out[f"sf_medabs_{int(lag)}"] = np.nan
            out[f"sf_n_{int(lag)}"] = 0.0
        return out

    for lag in lags:
        tol = max(2.0, 0.2 * lag)
        vals = []
        for i in range(n - 1):
            dt = tb[i + 1:] - tb[i]
            m = (dt >= (lag - tol)) & (dt <= (lag + tol))
            if np.any(m):
                dif = np.abs(fb[i + 1:][m] - fb[i])
                vals.extend(dif.tolist())
        if len(vals) == 0:
            out[f"sf_medabs_{int(lag)}"] = np.nan
            out[f"sf_n_{int(lag)}"] = 0.0
        else:
            out[f"sf_medabs_{int(lag)}"] = float(np.median(vals))
            out[f"sf_n_{int(lag)}"] = float(len(vals))
    return out

## Peak Timing / Peak Flux vs Wavelength (Rest-frame)

This fits a simple linear trend across filters in rest-frame wavelength:

- For a per-band value `v_b` (ex: peak flux), use x = λ_rest and fit y = slope*x + intercept
- Returns slope/intercept and fit quality (R²)

This captures chromatic behavior:
- some classes peak earlier in blue, later in red
- peak flux distribution across wavelength is a crude "temperature/SED" proxy

In [None]:
def peak_vs_wavelength_slope(tpeak_by_band, val_by_band, z=0.0):
    xs = []
    ys = []
    for b in FILTERS:
        v = val_by_band.get(b, np.nan)
        t = tpeak_by_band.get(b, np.nan)
        if np.isfinite(v):
            lam = float(EFF_WL_AA[b] / (1.0 + float(z)))
            xs.append(lam)
            ys.append(float(v))
    xs = np.asarray(xs, float)
    ys = np.asarray(ys, float)
    if len(xs) < 2:
        return np.nan, np.nan, np.nan
    try:
        slope, intercept = np.polyfit(xs, ys, 1)
        yhat = slope * xs + intercept
        ss_res = float(np.sum((ys - yhat) ** 2))
        ss_tot = float(np.sum((ys - np.mean(ys)) ** 2)) + EPS
        r2 = 1.0 - ss_res / ss_tot
        return float(slope), float(intercept), float(r2)
    except Exception:
        return np.nan, np.nan, np.nan

## SED Slope: log(flux) vs log(wavelength) at Specific Times

At a chosen time `t0` (ex: r-band peak), interpolate flux and error per band.

Then fit:
- x = log(λ_rest)
- y = log(flux)
with weights based on relative flux uncertainty.

Outputs:
- `sed_logflux_loglambda_slope_*`: spectral slope proxy
- `sed_*_r2`: fit quality
- `sed_*_nbands`: how many bands contributed

This is basically a compact color/SED summary without needing full SED modeling.

In [None]:
def sed_logflux_vs_loglambda_at_time(band_tb, band_fb, band_eb, t0, z=0.0):
    xs = []
    ys = []
    ws = []
    for b in FILTERS:
        tb = band_tb.get(b, None)
        fb = band_fb.get(b, None)
        eb = band_eb.get(b, None)
        if tb is None or fb is None or eb is None:
            continue
        f = interp_flux_at_time(tb, fb, t0)
        e = interp_err_at_time(tb, eb, t0)
        if not np.isfinite(f) or not np.isfinite(e):
            continue
        if f <= 0:
            continue
        lam_rest = float(EFF_WL_AA[b] / (1.0 + float(z)))
        xs.append(np.log(lam_rest + EPS))
        ys.append(np.log(f + EPS))
        ws.append(1.0 / ((e / (f + EPS)) ** 2 + EPS))  # weight by relative error
    if len(xs) < 2:
        return np.nan, np.nan, np.nan, float(len(xs))
    xs = np.asarray(xs, float)
    ys = np.asarray(ys, float)
    ws = np.asarray(ws, float)

    try:
        W = np.sum(ws)
        xbar = np.sum(ws * xs) / (W + EPS)
        ybar = np.sum(ws * ys) / (W + EPS)
        cov = np.sum(ws * (xs - xbar) * (ys - ybar))
        var = np.sum(ws * (xs - xbar) ** 2) + EPS
        slope = cov / var
        intercept = ybar - slope * xbar
        yhat = slope * xs + intercept
        ss_res = float(np.sum(ws * (ys - yhat) ** 2))
        ss_tot = float(np.sum(ws * (ys - ybar) ** 2)) + EPS
        r2 = 1.0 - ss_res / ss_tot
        return float(slope), float(intercept), float(r2), float(len(xs))
    except Exception:
        return np.nan, np.nan, np.nan, float(len(xs))

# Model 5: Raw-vs-deextinct deltas + seasonality/structure-function + Bazin fits + richer cross-band + upgraded SpecType teacher

Differences vs Model 4:
 - Adds raw (un-corrected) flux stats alongside de-extincted stats, plus "delta" features (deextinct minus raw).
 - Adds observation seasonality features globally and per band (counts of seasons, gap fractions, season spans).
 - Adds structure function features per band at multiple time lags (captures variability vs timescale).
 - Always fits Bazin per band (A, t0, trise/tfall, B, chi2red) and adds rest-frame versions of rise/fall.
 - Adds rise-time metrics (to 20%/50%) and asymmetry ratios (fall/rise) per band.
 - Adds cross-band ratios (amp_pre, AUC, width50, asym50) and band-to-band correlations (g-r, r-i, i-z).
 - Adds wavelength-trend features (peak time vs lambda, peak flux vs lambda) and SED slope fits at r-peak and +20d.
 - SpecType teacher: richer label mapping (splits out SLSN and SNII), uses missing-value flags, exposes spec_topprob.

There are ~562 train features, so tables will be kept very compact.

## New Features

| Feature | Meaning | Why it helps |
|---------|----------|--------------|
| `flux_mean_raw` | Mean flux before dust de-extinction correction | Lets the model compare raw vs corrected brightness scale and learn dust-impact patterns |
| `flux_std_raw` | Standard deviation of raw flux | Captures variability before correction to detect dust-driven distortions |
| `snr_max_raw` | Maximum SNR using raw flux and raw error | Measures best raw detection strength as a correction sanity check |
| `fvar_raw` | Fractional variability using raw flux and error | Provides intrinsic-variability proxy before dust adjustment |
| `flux_mean_deext_minus_raw` | Difference between corrected and raw mean flux | Direct signal of how strongly extinction correction shifts brightness |
| `snrmax_deext_minus_raw` | Difference between corrected and raw max SNR | Measures how much detectability improves after correction |
| `n_seasons_global` | Number of observing seasons inferred from large time gaps | Separates single-season vs multi-season coverage patterns |
| `gap_frac_gt90` | Fraction of time gaps greater than 90 days | Flags strongly seasonal sampling |
| `gap_frac_gt30` | Fraction of time gaps greater than 30 days | Captures moderate sampling fragmentation |
| `n_seasons_{b}` | Number of observing seasons in band b | Band-specific sampling structure can differ by class |
| `season_maxspan_{b}` | Longest continuous season span in band b | Measures longest uninterrupted coverage window |
| `season_meanspan_{b}` | Mean season span in band b | Captures typical continuous coverage length |
| `sf_medabs_5_{b}` | Median absolute flux difference at ~5-day lag | Measures short-timescale variability strength |
| `sf_n_5_{b}` | Number of pairs used for 5-day lag SF | Reliability indicator for short-lag estimate |
| `sf_medabs_10_{b}` | Median absolute flux difference at ~10-day lag | Captures slightly longer-timescale changes |
| `sf_n_10_{b}` | Pair count for 10-day lag SF | Reliability indicator |
| `sf_medabs_20_{b}` | Median absolute flux difference at ~20-day lag | Mid-scale variability measure |
| `sf_n_20_{b}` | Pair count for 20-day lag SF | Reliability indicator |
| `sf_medabs_50_{b}` | Median absolute flux difference at ~50-day lag | Long-timescale variability proxy |
| `sf_n_50_{b}` | Pair count for 50-day lag SF | Reliability indicator |
| `sf_medabs_100_{b}` | Median absolute flux difference at ~100-day lag | Very long-timescale variability proxy |
| `sf_n_100_{b}` | Pair count for 100-day lag SF | Reliability indicator |
| `bazin_A_{b}` | Bazin model amplitude parameter | Smooth transient strength estimate |
| `bazin_t0_{b}_obs` | Bazin peak-time parameter (observed frame) | Parametric peak timing estimate |
| `bazin_trise_{b}_obs` | Bazin rise timescale (observed frame) | Encodes rise speed |
| `bazin_tfall_{b}_obs` | Bazin decay timescale (observed frame) | Encodes decay speed |
| `bazin_B_{b}` | Bazin baseline parameter | Estimates underlying baseline level |
| `bazin_chi2red_{b}_obs` | Reduced chi-square of Bazin fit | Fit quality indicator |
| `bazin_trise_{b}_rest` | Bazin rise timescale (rest frame) | Intrinsic rise speed |
| `bazin_tfall_{b}_rest` | Bazin fall timescale (rest frame) | Intrinsic decay speed |
| `t_rise50_{b}_obs` | Time from baseline to 50% amplitude (observed) | Measures rise speed |
| `t_rise20_{b}_obs` | Time from baseline to 20% amplitude (observed) | Early-rise behavior |
| `t_rise50_{b}_rest` | Rise time to 50% amplitude (rest frame) | Intrinsic rise speed |
| `t_rise20_{b}_rest` | Rise time to 20% amplitude (rest frame) | Intrinsic early-rise behavior |
| `asym50_{b}_obs` | Fall50 / Rise50 ratio (observed) | Captures peak asymmetry |
| `asym50_{b}_rest` | Fall50 / Rise50 ratio (rest frame) | Intrinsic asymmetry measure |
| `amppreratio_{a}{b}` | Ratio of pre-baseline amplitudes between bands | Color-dependent peak strength comparison |
| `aucratio_{a}{b}_obs` | Ratio of positive AUC between bands | Relative emitted-energy proxy |
| `width50ratio_{a}{b}_obs` | Ratio of 50% widths between bands | Cross-band duration contrast |
| `asym50ratio_{a}{b}_obs` | Ratio of asymmetry metrics between bands | Cross-band shape contrast |
| `corr_gr_obs` | Correlation between g and r band lightcurves | Measures multi-band coherence |
| `corr_ri_obs` | Correlation between r and i bands | Same, redder wavelengths |
| `corr_iz_obs` | Correlation between i and z bands | Same, further red |
| `tpeak_vs_lambda_slope_obs` | Slope of peak-time vs wavelength fit | Detects chromatic timing trends |
| `tpeak_vs_lambda_intercept_obs` | Intercept of that regression | Baseline timing offset |
| `tpeak_vs_lambda_r2_obs` | R² of peak-time vs wavelength fit | Reliability of chromatic timing trend |
| `peakflux_vs_lambda_slope` | Slope of peak-flux vs wavelength fit | Spectral energy trend |
| `peakflux_vs_lambda_intercept` | Intercept of flux–wavelength fit | Baseline spectral level |
| `peakflux_vs_lambda_r2` | R² of flux–wavelength fit | Reliability of spectral slope |
| `sed_logflux_loglambda_slope_rpeak` | Slope of log(flux) vs log(wavelength) at r-peak | Spectral slope at peak |
| `sed_logflux_loglambda_r2_rpeak` | R² of SED fit at r-peak | Fit reliability |
| `sed_logflux_loglambda_nbands_rpeak` | Number of bands used in SED fit | Coverage reliability |
| `sed_slope_rpeak_p20` | SED slope at r-peak + 20 days | Spectral evolution rate |
| `sed_r2_rpeak_p20` | R² of SED fit at +20 days | Reliability indicator |
| `sed_nbands_rpeak_p20` | Bands used at +20 days | Coverage indicator |
| `spec_topprob` | Maximum teacher-model class probability | Teacher confidence summary for meta-learning |

## Features from other models

| Feature | Meaning | Why it helps |
|---------|----------|--------------|
| `n_obs` | Total number of observations across all filters | Coverage proxy; some classes are observed more densely |
| `total_time_obs` | Total observed duration (max time − min time) | Separates long-timescale variability from short transients |
| `total_time_rest` | Duration corrected by (1+z) time dilation | Makes durations comparable across redshift |
| `flux_mean` | Mean dust-corrected flux | Overall brightness level |
| `flux_median` | Median dust-corrected flux | Robust brightness estimate |
| `flux_std` | Standard deviation of corrected flux | Overall variability strength |
| `flux_min` | Minimum corrected flux | Captures deep dips / noise floor |
| `flux_max` | Maximum corrected flux | Captures peak brightness |
| `flux_mad` | Median absolute deviation | Robust variability measure |
| `flux_iqr` | Interquartile range | Robust spread measure |
| `flux_skew` | Skewness of flux distribution | Detects asymmetric burst-like shapes |
| `flux_kurt_excess` | Excess kurtosis | Detects heavy-tailed spike behavior |
| `flux_p5` | 5th percentile flux | Robust low level |
| `flux_p25` | 25th percentile flux | Lower quartile |
| `flux_p75` | 75th percentile flux | Upper quartile |
| `flux_p95` | 95th percentile flux | Robust high level |
| `robust_amp_global` | p95 − p5 | Stable global amplitude proxy |
| `neg_flux_frac` | Fraction of flux values below zero | Noise-dominated vs real detection signal |
| `snr_median` | Median signal-to-noise ratio | Typical detection quality |
| `snr_max` | Maximum signal-to-noise ratio | Strongest detection strength |
| `median_dt` | Median time gap between observations | Sampling cadence proxy |
| `max_gap` | Largest time gap | Detects large seasonal breaks |
| `eta_von_neumann` | Von Neumann eta statistic | Smoothness vs randomness indicator |
| `chi2_const_global` | Chi-square vs constant model | Detects variability vs flat signal |
| `stetsonJ_global_obs` | Stetson J index (observed frame) | Robust correlated variability measure |
| `stetsonJ_global_rest` | Stetson J index (rest frame) | Intrinsic variability measure |
| `max_slope_global_obs` | Maximum absolute slope (observed) | Fastest brightness change |
| `max_slope_global_rest` | Maximum slope (rest frame) | Intrinsic fastest change |
| `med_abs_slope_global_obs` | Median absolute slope (observed) | Typical change rate |
| `med_abs_slope_global_rest` | Median absolute slope (rest) | Intrinsic change rate |
| `slope_global_obs` | Linear trend slope (observed) | Long-term drift indicator |
| `slope_global_rest` | Linear trend slope (rest) | Intrinsic drift |
| `fvar_global` | Fractional variability | Noise-corrected variability strength |
| `Z` | Redshift | Distance and time-dilation proxy |
| `log1pZ` | log(1+Z) | Stabilized redshift scale |
| `Z_err` | Redshift uncertainty | Reliability of distance estimate |
| `log1pZerr` | log(1+Z_err) | Stabilized uncertainty scale |
| `EBV` | Dust extinction value | Measures dust impact |
| `n_filters_present` | Number of filters with data | Multi-band coverage indicator |
| `total_obs` | Total observations across bands | Coverage strength |
| `n_{b}` | Number of observations in band b | Band completeness differs by class |
| `p5_{b}` | 5th percentile flux in band b | Robust low level |
| `p25_{b}` | 25th percentile | Lower quartile |
| `p75_{b}` | 75th percentile | Upper quartile |
| `p95_{b}` | 95th percentile | Robust high level |
| `robust_amp_{b}` | p95 − p5 in band b | Stable band amplitude |
| `mad_{b}` | Median absolute deviation | Robust variability |
| `iqr_{b}` | Interquartile range | Robust spread |
| `mad_over_std_{b}` | MAD / std ratio | Outlier sensitivity indicator |
| `eta_{b}` | Von Neumann eta | Smoothness vs noise |
| `chi2_const_{b}` | Chi-square vs constant | Variability detector |
| `stetsonJ_{b}_obs` | Stetson J (observed) | Correlated variability |
| `stetsonJ_{b}_rest` | Stetson J (rest) | Intrinsic correlated variability |
| `fvar_{b}` | Fractional variability | Normalized variability strength |
| `snrmax_{b}` | Maximum SNR | Best detection strength |
| `baseline_pre_{b}` | Estimated pre-peak baseline | Reference level for amplitude |
| `amp_{b}` | Peak − median flux | Simple amplitude |
| `amp_pre_{b}` | Peak − pre-peak baseline | Cleaner transient amplitude |
| `tpeak_{b}_obs` | Peak time (observed frame) | Band timing behavior |
| `tpeak_{b}_rest` | Peak time (rest frame) | Intrinsic timing |
| `peak_dominance_{b}` | Peak / baseline noise scale | Peak significance |
| `std_ratio_prepost_{b}` | Pre/post peak std ratio | Stability vs post-peak chaos |
| `width50_{b}_obs` | Width above 50% amplitude (obs) | Duration at mid level |
| `width80_{b}_obs` | Width above 80% amplitude (obs) | Peak sharpness |
| `width50_{b}_rest` | Width50 (rest) | Intrinsic duration |
| `width80_{b}_rest` | Width80 (rest) | Intrinsic peak shape |
| `t_fall50_{b}_obs` | Fall time to 50% (obs) | Decay speed |
| `t_fall20_{b}_obs` | Fall time to 20% (obs) | Late decay |
| `t_fall50_{b}_rest` | Fall50 (rest) | Intrinsic decay |
| `t_fall20_{b}_rest` | Fall20 (rest) | Intrinsic late decay |
| `sharp50_{b}_obs` | Amplitude / width50 (obs) | Spike sharpness |
| `sharp50_{b}_rest` | Amplitude / width50 (rest) | Intrinsic sharpness |
| `postpeak_monotone_frac_{b}` | Fraction monotonic after peak | Smooth decay vs noisy |
| `n_peaks_{b}` | Significant peak count | Multi-peak vs single transient |
| `n_rebrighten_{b}` | Rebrightening count | Secondary bump behavior |
| `decay_pl_slope_{b}_obs` | Power-law decay slope (obs) | Decay steepness |
| `decay_pl_r2_{b}_obs` | Fit R² (obs) | Fit reliability |
| `decay_pl_npts_{b}_obs` | Points used (obs) | Support size |
| `decay_pl_slope_{b}_rest` | Power-law slope (rest) | Intrinsic decay |
| `decay_pl_r2_{b}_rest` | Fit R² (rest) | Reliability |
| `decay_pl_npts_{b}_rest` | Points used (rest) | Support size |
| `tpeak_std_obs` | Std of peak times across bands (obs) | Peak alignment indicator |
| `tpeak_std_rest` | Std of peak times (rest) | Intrinsic alignment |
| `tpeakdiff_{a}{b}_obs` | Peak time difference (obs) | Chromatic lag signal |
| `tpeakdiff_{a}{b}_rest` | Peak time difference (rest) | Intrinsic lag |
| `peakratio_{a}{b}` | Peak flux ratio | Peak color proxy |
| `color_gr_at_rpeak_obs` | g−r color at r-peak | Spectral color at peak |
| `color_ri_at_rpeak_obs` | r−i color at r-peak | Red color proxy |
| `color_gr_rpeak_p20_obs` | g−r at +20d | Color evolution |
| `color_ri_rpeak_p20_obs` | r−i at +20d | Color evolution |
| `color_gr_rpeak_p40_obs` | g−r at +40d | Slower evolution |
| `color_ri_rpeak_p40_obs` | r−i at +40d | Slower evolution |
| `color_gr_slope20_obs` | g−r slope over 20d | Early color change rate |
| `color_ri_slope20_obs` | r−i slope over 20d | Early color change |
| `color_gr_slope40_obs` | g−r slope over 40d | Longer color trend |
| `color_ri_slope40_obs` | r−i slope over 40d | Longer trend |
| `p_spec_{c}` | Teacher probability for class c | Soft-label prior signal |
| `spec_entropy` | Entropy of teacher probs | Teacher uncertainty |
| `spec_topprob` | Max teacher probability | Teacher confidence summary |

In [None]:
def extract_features_for_object(lc_raw, z, z_err, ebv):
    feats = {}

    # Sort observations by time so time-based calculations make sense
    lc = lc_raw.sort_values("Time (MJD)").reset_index(drop=True)

    # Extract time values and filter (band) labels
    t = lc["Time (MJD)"].to_numpy().astype(float)
    filt = lc["Filter"].to_numpy()

    # If there are no observations, return minimal info
    if len(t) == 0:
        feats["n_obs"] = 0
        return feats

    # Make sure metadata fields are valid numbers (avoid NaNs / strings / missing values)
    z = safe_float(z, default=0.0)                     # redshift (distance proxy)
    z_err = safe_float(z_err, default=0.0)             # redshift uncertainty
    ebv = safe_float(ebv, default=np.nan)              # dust amount (can be missing)

    # Convert time to start at 0 (relative time axis)
    t_rel = t - t.min()

    # Convert observed time to intrinsic time of the object
    # Distant objects appear to evolve slower, so divide by (1 + z)
    t_rest = t_rel / (1.0 + z)

    # Extract raw flux/error (before dust correction)
    flux_raw = lc["Flux"].to_numpy().astype(float)
    err_raw = lc["Flux_err"].to_numpy().astype(float)

    # Correct brightness values for dust in the Milky Way
    # (dust makes objects look dimmer than they really are)
    flux_corr, err_corr = deextinct_lightcurve(lc, ebv)

    # Basic observation statistics
    feats["n_obs"] = int(len(t))                                  # total number of measurements
    feats["total_time_obs"] = float(t_rel.max() - t_rel.min())    # total observed duration
    feats["total_time_rest"] = float(t_rest.max() - t_rest.min()) # duration corrected for distance effects

    # Global brightness statistics (after dust correction)
    feats["flux_mean"] = float(np.mean(flux_corr))                # average brightness level
    feats["flux_median"] = float(np.median(flux_corr))            # robust typical brightness
    feats["flux_std"] = float(np.std(flux_corr))                  # overall variability
    feats["flux_min"] = float(np.min(flux_corr))                  # dimmest point
    feats["flux_max"] = float(np.max(flux_corr))                  # brightest point

    # Robust statistics that are less sensitive to outliers
    feats["flux_mad"] = median_abs_dev(flux_corr)                 # median absolute deviation
    feats["flux_iqr"] = iqr(flux_corr)                            # interquartile range (Q3 - Q1)

    # Distribution shape features
    feats["flux_skew"] = skewness(flux_corr)                      # asymmetry of values
    feats["flux_kurt_excess"] = kurtosis_excess(flux_corr)        # tail heaviness / spikiness

    # Robust amplitude using percentiles (stable against a few extreme points)
    p5, p25, p75, p95 = np.percentile(flux_corr, [5, 25, 75, 95])
    feats["flux_p5"] = float(p5)
    feats["flux_p25"] = float(p25)
    feats["flux_p75"] = float(p75)
    feats["flux_p95"] = float(p95)
    feats["robust_amp_global"] = float(p95 - p5)                  # robust amplitude proxy

    # Fraction of measurements that are below zero
    # (often indicates noise-dominated detections)
    feats["neg_flux_frac"] = float(np.mean(flux_corr < 0))

    # Signal-to-noise ratio summaries (dust-corrected)
    snr = np.abs(flux_corr) / (err_corr + EPS)
    feats["snr_median"] = float(np.median(snr))                   # typical signal quality
    feats["snr_max"] = float(np.max(snr))                         # strongest detection

    # Raw-flux sanity features (captures how dust correction shifts statistics)
    feats["flux_mean_raw"] = float(np.mean(flux_raw))             # mean before de-extinction
    feats["flux_std_raw"] = float(np.std(flux_raw))               # variability before de-extinction
    feats["snr_max_raw"] = float(np.max(np.abs(flux_raw) / (err_raw + EPS)))  # raw best SNR
    feats["fvar_raw"] = fractional_variability(flux_raw, err_raw) # raw fractional variability

    # Dust-correction “delta” features (how much correction changes the signal)
    feats["flux_mean_deext_minus_raw"] = float(feats["flux_mean"] - feats["flux_mean_raw"])
    feats["snrmax_deext_minus_raw"] = float(feats["snr_max"] - feats["snr_max_raw"])

    # Observation timing + seasonality (global)
    if len(t_rel) >= 2:
        dt = np.diff(t_rel)
        feats["median_dt"] = float(np.median(dt))                 # typical time between observations
        feats["max_gap"] = float(np.max(dt))                      # largest observation gap

        # Seasonality proxies: large gaps often indicate separate observing seasons
        feats["n_seasons_global"] = float(np.sum(dt > SEASON_GAP_DAYS) + 1)  # number of seasons
        feats["gap_frac_gt90"] = float(np.mean(dt > SEASON_GAP_DAYS))        # fraction of gaps > 90d
        feats["gap_frac_gt30"] = float(np.mean(dt > 30.0))                  # fraction of gaps > 30d
    else:
        feats["median_dt"] = np.nan
        feats["max_gap"] = np.nan
        feats["n_seasons_global"] = np.nan
        feats["gap_frac_gt90"] = np.nan
        feats["gap_frac_gt30"] = np.nan

    # Global time-series variability diagnostics
    feats["eta_von_neumann"] = von_neumann_eta(flux_corr)              # smoothness vs noise proxy
    feats["chi2_const_global"] = chi2_to_constant(flux_corr, err_corr) # variability vs constant model

    feats["stetsonJ_global_obs"] = stetson_J_consecutive(t_rel, flux_corr, err_corr)
    feats["stetsonJ_global_rest"] = stetson_J_consecutive(t_rest, flux_corr, err_corr)

    # Global slope features (obs + rest frame)
    feats["max_slope_global_obs"] = max_slope(t_rel, flux_corr)        # fastest brightness change (obs)
    feats["max_slope_global_rest"] = max_slope(t_rest, flux_corr)      # fastest brightness change (rest)

    feats["med_abs_slope_global_obs"] = median_abs_slope(t_rel, flux_corr)   # typical change rate (obs)
    feats["med_abs_slope_global_rest"] = median_abs_slope(t_rest, flux_corr) # typical change rate (rest)

    feats["slope_global_obs"] = linear_slope(t_rel, flux_corr)         # best-fit linear trend (obs)
    feats["slope_global_rest"] = linear_slope(t_rest, flux_corr)       # best-fit linear trend (rest)

    feats["fvar_global"] = fractional_variability(flux_corr, err_corr) # noise-corrected variability

    # Metadata features
    feats["Z"] = float(z)                           # distance proxy (redshift)
    feats["log1pZ"] = float(np.log1p(max(0.0, z)))  # compressed redshift scale
    feats["Z_err"] = float(max(0.0, z_err))         # clamp negative uncertainty to 0
    feats["log1pZerr"] = float(np.log1p(max(0.0, feats["Z_err"])))     # compressed uncertainty scale
    feats["EBV"] = ebv                              # dust amount

    # Counters for band coverage
    feats["n_filters_present"] = 0                  # how many bands have >= 1 observation
    feats["total_obs"] = 0                          # total observations across all bands

    # Storage for cross-band timing/color/SED features later
    band_tpeak_obs = {}                             # per-band peak time (obs frame)
    band_tpeak_rest = {}                            # per-band peak time (rest frame)
    band_peak_flux = {}                             # per-band peak flux

    band_tb_obs = {}                                # per-band time arrays (obs frame)
    band_tb_rest = {}                               # per-band time arrays (rest frame)
    band_fb = {}                                    # per-band flux arrays (dust-corrected)
    band_eb = {}                                    # per-band error arrays (dust-corrected)

    # Loop over each wavelength band (u, g, r, i, z, y)
    for b in FILTERS:
        m = (filt == b)
        nb = int(np.sum(m))

        # Number of observations in this band
        feats[f"n_{b}"] = nb
        feats["total_obs"] += nb

        # Initialize all band features as missing by default
        keys = [
            f"amp_{b}", f"amp_pre_{b}", f"baseline_pre_{b}", f"robust_amp_{b}",
            f"tpeak_{b}_obs", f"tpeak_{b}_rest",
            f"width50_{b}_obs", f"width50_{b}_rest",
            f"width80_{b}_obs", f"width80_{b}_rest",
            f"auc_pos_{b}_obs", f"auc_pos_{b}_rest",
            f"snrmax_{b}", f"eta_{b}", f"chi2_const_{b}",
            f"slope_{b}_obs", f"slope_{b}_rest",
            f"maxslope_{b}_obs", f"maxslope_{b}_rest",
            f"stetsonJ_{b}_obs", f"stetsonJ_{b}_rest",
            f"p5_{b}", f"p25_{b}", f"p75_{b}", f"p95_{b}",
            f"mad_{b}", f"iqr_{b}", f"mad_over_std_{b}", f"fvar_{b}",
            f"t_fall50_{b}_obs", f"t_fall20_{b}_obs", f"t_fall50_{b}_rest", f"t_fall20_{b}_rest",
            f"t_rise50_{b}_obs", f"t_rise20_{b}_obs", f"t_rise50_{b}_rest", f"t_rise20_{b}_rest",
            f"asym50_{b}_obs", f"asym50_{b}_rest",
            f"sharp50_{b}_obs", f"sharp50_{b}_rest",
            f"peak_dominance_{b}", f"std_ratio_prepost_{b}",
            f"n_peaks_{b}", f"postpeak_monotone_frac_{b}", f"n_rebrighten_{b}",
            f"decay_pl_slope_{b}_obs", f"decay_pl_r2_{b}_obs", f"decay_pl_npts_{b}_obs",
            f"decay_pl_slope_{b}_rest", f"decay_pl_r2_{b}_rest", f"decay_pl_npts_{b}_rest",

            # Seasonality and structure function per band
            f"n_seasons_{b}", f"season_maxspan_{b}", f"season_meanspan_{b}",
            f"sf_medabs_5_{b}", f"sf_n_5_{b}",
            f"sf_medabs_10_{b}", f"sf_n_10_{b}",
            f"sf_medabs_20_{b}", f"sf_n_20_{b}",
            f"sf_medabs_50_{b}", f"sf_n_50_{b}",
            f"sf_medabs_100_{b}", f"sf_n_100_{b}",

            # Bazin shape fit parameters (obs + rest)
            f"bazin_A_{b}", f"bazin_t0_{b}_obs",
            f"bazin_trise_{b}_obs", f"bazin_tfall_{b}_obs",
            f"bazin_B_{b}", f"bazin_chi2red_{b}_obs",
            f"bazin_trise_{b}_rest", f"bazin_tfall_{b}_rest",
        ]
        for k in keys:
            feats[k] = np.nan

        # Skip bands with no data
        if nb == 0:
            continue

        feats["n_filters_present"] += 1

        # Extract time, brightness, and error for this band
        tb_obs = t_rel[m]
        fb = flux_corr[m]
        eb = err_corr[m]

        # Sort observations within the band by time
        order = np.argsort(tb_obs)
        tb_obs = tb_obs[order]
        fb = fb[order]
        eb = eb[order]

        # Convert to intrinsic time scale
        tb_rest = tb_obs / (1.0 + z)

        # Cache arrays for later cross-band operations (correlations, SED fits, colors)
        band_tb_obs[b] = tb_obs
        band_tb_rest[b] = tb_rest
        band_fb[b] = fb
        band_eb[b] = eb

        # Seasonality features for this band
        ns, maxsp, meansp = seasonality_features(tb_obs)
        feats[f"n_seasons_{b}"] = ns
        feats[f"season_maxspan_{b}"] = maxsp
        feats[f"season_meanspan_{b}"] = meansp

        # Structure function features
        sf = structure_function_lags(tb_obs, fb, lags=SF_LAGS)
        for lag in SF_LAGS:
            feats[f"sf_medabs_{int(lag)}_{b}"] = sf.get(f"sf_medabs_{int(lag)}", np.nan)
            feats[f"sf_n_{int(lag)}_{b}"] = sf.get(f"sf_n_{int(lag)}", 0.0)

        # Robust per-band amplitude using percentiles (stable against outliers)
        p5b, p25b, p75b, p95b = np.percentile(fb, [5, 25, 75, 95])
        feats[f"p5_{b}"] = float(p5b)
        feats[f"p25_{b}"] = float(p25b)
        feats[f"p75_{b}"] = float(p75b)
        feats[f"p95_{b}"] = float(p95b)
        feats[f"robust_amp_{b}"] = float(p95b - p5b)

        # Robust variability summaries
        feats[f"mad_{b}"] = median_abs_dev(fb)
        feats[f"iqr_{b}"] = iqr(fb)
        stdb = float(np.std(fb))
        feats[f"mad_over_std_{b}"] = float(feats[f"mad_{b}"] / (stdb + EPS))

        # Estimate a pre-peak baseline using early observations
        base_pre, mad_pre, mederr_pre = pre_peak_baseline(tb_obs, fb, eb, frac=PRE_BASE_FRAC)
        feats[f"baseline_pre_{b}"] = float(base_pre) if np.isfinite(base_pre) else np.nan

        # Identify peak flux and peak time
        pidx = int(np.argmax(fb))
        peak_flux = float(fb[pidx])
        tpeak_obs = float(tb_obs[pidx])
        tpeak_rest = float(tb_rest[pidx])

        # Amplitude relative to two different baselines
        amp_median = peak_flux - float(np.median(fb))                             # peak relative to median
        amp_pre = peak_flux - base_pre if np.isfinite(base_pre) else np.nan       # peak relative to pre-peak baseline

        feats[f"amp_{b}"] = float(amp_median)
        feats[f"amp_pre_{b}"] = float(amp_pre) if np.isfinite(amp_pre) else np.nan

        feats[f"tpeak_{b}_obs"] = tpeak_obs
        feats[f"tpeak_{b}_rest"] = tpeak_rest
        feats[f"snrmax_{b}"] = float(np.max(np.abs(fb) / (eb + EPS)))              # best detection quality

        # Band-level variability diagnostics
        feats[f"eta_{b}"] = von_neumann_eta(fb)
        feats[f"chi2_const_{b}"] = chi2_to_constant(fb, eb)

        feats[f"slope_{b}_obs"] = linear_slope(tb_obs, fb)
        feats[f"slope_{b}_rest"] = linear_slope(tb_rest, fb)

        feats[f"maxslope_{b}_obs"] = max_slope(tb_obs, fb)
        feats[f"maxslope_{b}_rest"] = max_slope(tb_rest, fb)

        feats[f"stetsonJ_{b}_obs"] = stetson_J_consecutive(tb_obs, fb, eb)
        feats[f"stetsonJ_{b}_rest"] = stetson_J_consecutive(tb_rest, fb, eb)

        feats[f"fvar_{b}"] = fractional_variability(fb, eb)

        # Bazin fit parameters (smooth transient model)
        A, t0, trise, tfall, B, chi2 = fit_bazin(tb_obs, fb, eb)
        feats[f"bazin_A_{b}"] = A
        feats[f"bazin_t0_{b}_obs"] = t0
        feats[f"bazin_trise_{b}_obs"] = trise
        feats[f"bazin_tfall_{b}_obs"] = tfall
        feats[f"bazin_B_{b}"] = B
        feats[f"bazin_chi2red_{b}_obs"] = chi2

        # Convert Bazin times to rest frame (distance/time-dilation corrected)
        feats[f"bazin_trise_{b}_rest"] = trise / (1.0 + z) if np.isfinite(trise) else np.nan
        feats[f"bazin_tfall_{b}_rest"] = tfall / (1.0 + z) if np.isfinite(tfall) else np.nan

        # Peak morphology features only make sense if we have a positive transient above baseline
        if np.isfinite(amp_pre) and amp_pre > 0:
            feats[f"peak_dominance_{b}"] = float(amp_pre / (mad_pre + EPS))        # peak relative to baseline noise

            # Pre vs post variability ratio
            pre_seg = fb[:max(2, pidx)]
            post_seg = fb[pidx:]
            std_pre = float(np.std(pre_seg)) if len(pre_seg) >= 2 else np.nan
            std_post = float(np.std(post_seg)) if len(post_seg) >= 2 else np.nan
            if np.isfinite(std_pre) and np.isfinite(std_post):
                feats[f"std_ratio_prepost_{b}"] = float(std_pre / (std_post + EPS))

            feats[f"postpeak_monotone_frac_{b}"] = float(postpeak_monotonicity(tb_obs, fb, pidx))
            feats[f"n_peaks_{b}"] = float(count_significant_peaks(tb_obs, fb, eb, base_pre, k_sigma=PEAK_SIGMA_K))
            feats[f"n_rebrighten_{b}"] = float(count_rebrighten(tb_obs, fb, base_pre, amp_pre, pidx, frac=REBRIGHT_FRAC))

            # Fall times from peak to given fractional levels (observed + rest frame)
            feats[f"t_fall50_{b}_obs"] = float(fall_time_to_level(tb_obs, fb, base_pre, amp_pre, pidx, frac=0.50))
            feats[f"t_fall20_{b}_obs"] = float(fall_time_to_level(tb_obs, fb, base_pre, amp_pre, pidx, frac=0.20))
            feats[f"t_fall50_{b}_rest"] = float(fall_time_to_level(tb_rest, fb, base_pre, amp_pre, pidx, frac=0.50))
            feats[f"t_fall20_{b}_rest"] = float(fall_time_to_level(tb_rest, fb, base_pre, amp_pre, pidx, frac=0.20))

            # Rise times from baseline to given fractional levels
            feats[f"t_rise50_{b}_obs"] = float(rise_time_to_level(tb_obs, fb, base_pre, amp_pre, pidx, frac=0.50))
            feats[f"t_rise20_{b}_obs"] = float(rise_time_to_level(tb_obs, fb, base_pre, amp_pre, pidx, frac=0.20))
            feats[f"t_rise50_{b}_rest"] = float(rise_time_to_level(tb_rest, fb, base_pre, amp_pre, pidx, frac=0.50))
            feats[f"t_rise20_{b}_rest"] = float(rise_time_to_level(tb_rest, fb, base_pre, amp_pre, pidx, frac=0.20))

            # Asymmetry: fall / rise (large means slow decay compared to rise)
            tr50o = feats[f"t_rise50_{b}_obs"]
            tf50o = feats[f"t_fall50_{b}_obs"]
            tr50r = feats[f"t_rise50_{b}_rest"]
            tf50r = feats[f"t_fall50_{b}_rest"]
            feats[f"asym50_{b}_obs"] = float(tf50o / (tr50o + EPS)) if np.isfinite(tf50o) and np.isfinite(tr50o) else np.nan
            feats[f"asym50_{b}_rest"] = float(tf50r / (tr50r + EPS)) if np.isfinite(tf50r) and np.isfinite(tr50r) else np.nan

            # Area under the curve above the pre-peak baseline (positive only)
            feats[f"auc_pos_{b}_obs"] = float(trapz_safe(np.maximum(fb - base_pre, 0.0), tb_obs))
            feats[f"auc_pos_{b}_rest"] = float(trapz_safe(np.maximum(fb - base_pre, 0.0), tb_rest))

            # Width at a given fractional level (simple “time above threshold” proxy)
            def width_at_level(tt, ff, base, amp, frac):
                if amp <= 0 or len(ff) < 3:
                    return np.nan
                level = base + frac * amp
                above = ff >= level
                if not np.any(above):
                    return np.nan
                idx = np.where(above)[0]
                return float(tt[idx[-1]] - tt[idx[0]])

            w50_obs = width_at_level(tb_obs, fb, base_pre, amp_pre, 0.50)
            w80_obs = width_at_level(tb_obs, fb, base_pre, amp_pre, 0.80)
            w50_rest = width_at_level(tb_rest, fb, base_pre, amp_pre, 0.50)
            w80_rest = width_at_level(tb_rest, fb, base_pre, amp_pre, 0.80)

            feats[f"width50_{b}_obs"] = w50_obs
            feats[f"width80_{b}_obs"] = w80_obs
            feats[f"width50_{b}_rest"] = w50_rest
            feats[f"width80_{b}_rest"] = w80_rest

            # Sharpness: high amplitude + short width means a “spiky” transient
            feats[f"sharp50_{b}_obs"] = float(amp_pre / (w50_obs + EPS)) if np.isfinite(w50_obs) else np.nan
            feats[f"sharp50_{b}_rest"] = float(amp_pre / (w50_rest + EPS)) if np.isfinite(w50_rest) else np.nan

            # Fit a simple power-law decay model post-peak (captures decay steepness)
            b_obs, r2_obs, npts_obs = decay_powerlaw_fit(tb_obs, fb, base_pre, pidx, tmax=300.0)
            b_rest, r2_rest, npts_rest = decay_powerlaw_fit(tb_rest, fb, base_pre, pidx, tmax=300.0)

            feats[f"decay_pl_slope_{b}_obs"] = b_obs
            feats[f"decay_pl_r2_{b}_obs"] = r2_obs
            feats[f"decay_pl_npts_{b}_obs"] = float(npts_obs)

            feats[f"decay_pl_slope_{b}_rest"] = b_rest
            feats[f"decay_pl_r2_{b}_rest"] = r2_rest
            feats[f"decay_pl_npts_{b}_rest"] = float(npts_rest)

        # Store values for cross-band comparisons and wavelength-trend features
        band_tpeak_obs[b] = tpeak_obs
        band_tpeak_rest[b] = tpeak_rest
        band_peak_flux[b] = peak_flux

    # Peak-time dispersion across filters (how synchronized the bands are)
    tpeaks_obs = np.array([band_tpeak_obs.get(b, np.nan) for b in FILTERS], float)
    tpeaks_rest = np.array([band_tpeak_rest.get(b, np.nan) for b in FILTERS], float)
    tpeaks_obs = np.array([x for x in tpeaks_obs if np.isfinite(x)], float)
    tpeaks_rest = np.array([x for x in tpeaks_rest if np.isfinite(x)], float)
    feats["tpeak_std_obs"] = float(np.std(tpeaks_obs)) if len(tpeaks_obs) >= 2 else np.nan
    feats["tpeak_std_rest"] = float(np.std(tpeaks_rest)) if len(tpeaks_rest) >= 2 else np.nan

    # Cross-band peak-time lags and peak ratios for adjacent filters
    pairs = [("u", "g"), ("g", "r"), ("r", "i"), ("i", "z"), ("z", "y")]
    for a, b in pairs:
        ta_obs = band_tpeak_obs.get(a, np.nan)
        tb_obs2 = band_tpeak_obs.get(b, np.nan)
        ta_rest = band_tpeak_rest.get(a, np.nan)
        tb_rest2 = band_tpeak_rest.get(b, np.nan)
        pa = band_peak_flux.get(a, np.nan)
        pb = band_peak_flux.get(b, np.nan)

        feats[f"tpeakdiff_{a}{b}_obs"] = (ta_obs - tb_obs2) if (np.isfinite(ta_obs) and np.isfinite(tb_obs2)) else np.nan
        feats[f"tpeakdiff_{a}{b}_rest"] = (ta_rest - tb_rest2) if (np.isfinite(ta_rest) and np.isfinite(tb_rest2)) else np.nan
        feats[f"peakratio_{a}{b}"] = (pa / (pb + EPS)) if (np.isfinite(pa) and np.isfinite(pb)) else np.nan

    # Helper for safe ratio features
    def ratio_feature(name, num, den):
        if np.isfinite(num) and np.isfinite(den):
            feats[name] = float(num / (den + EPS))
        else:
            feats[name] = np.nan

    # Cross-band ratios (adjacent filters)
    # (helps capture “relative shape” instead of absolute scale)
    for a, b in pairs:
        ratio_feature(f"amppreratio_{a}{b}", feats.get(f"amp_pre_{a}", np.nan), feats.get(f"amp_pre_{b}", np.nan))
        ratio_feature(f"aucratio_{a}{b}_obs", feats.get(f"auc_pos_{a}_obs", np.nan), feats.get(f"auc_pos_{b}_obs", np.nan))
        ratio_feature(f"width50ratio_{a}{b}_obs", feats.get(f"width50_{a}_obs", np.nan), feats.get(f"width50_{b}_obs", np.nan))
        ratio_feature(f"asym50ratio_{a}{b}_obs", feats.get(f"asym50_{a}_obs", np.nan), feats.get(f"asym50_{b}_obs", np.nan))

    # Band-to-band correlations (captures whether bands rise/fall together)
    for a, b in [("g", "r"), ("r", "i"), ("i", "z")]:
        if (a in band_tb_obs) and (b in band_tb_obs):
            feats[f"corr_{a}{b}_obs"] = band_corr(
                band_tb_obs[a], band_fb[a],
                band_tb_obs[b], band_fb[b]
            )
        else:
            feats[f"corr_{a}{b}_obs"] = np.nan

    # Wavelength-trend features (peak time vs wavelength, peak flux vs wavelength)
    # NOTE: these helpers likely internally use FILTER effective wavelengths, so they can regress vs lambda
    slope_t, intercept_t, r2_t = peak_vs_wavelength_slope(band_tpeak_obs, band_tpeak_obs, z=z)
    feats["tpeak_vs_lambda_slope_obs"] = slope_t
    feats["tpeak_vs_lambda_intercept_obs"] = intercept_t
    feats["tpeak_vs_lambda_r2_obs"] = r2_t

    slope_pf, intercept_pf, r2_pf = peak_vs_wavelength_slope(band_tpeak_obs, band_peak_flux, z=z)
    feats["peakflux_vs_lambda_slope"] = slope_pf
    feats["peakflux_vs_lambda_intercept"] = intercept_pf
    feats["peakflux_vs_lambda_r2"] = r2_pf

    # Color features anchored at r-band peak time
    tpr_obs = feats.get("tpeak_r_obs", np.nan)
    if np.isfinite(tpr_obs):
        def colors_at_time(t0):
            fr = interp_flux_at_time(band_tb_obs.get("r", np.array([])), band_fb.get("r", np.array([])), t0)
            fg = interp_flux_at_time(band_tb_obs.get("g", np.array([])), band_fb.get("g", np.array([])), t0)
            fi = interp_flux_at_time(band_tb_obs.get("i", np.array([])), band_fb.get("i", np.array([])), t0)

            # signed_log1p allows negative flux without crashing log transforms
            cgr = (signed_log1p(fg) - signed_log1p(fr)) if (np.isfinite(fg) and np.isfinite(fr)) else np.nan
            cri = (signed_log1p(fr) - signed_log1p(fi)) if (np.isfinite(fr) and np.isfinite(fi)) else np.nan
            return cgr, cri

        # Colors at peak
        cgr0, cri0 = colors_at_time(tpr_obs)
        feats["color_gr_at_rpeak_obs"] = cgr0
        feats["color_ri_at_rpeak_obs"] = cri0

        # Colors at +20d and +40d to capture slower spectral evolution
        cgr20, cri20 = colors_at_time(tpr_obs + 20.0)
        cgr40, cri40 = colors_at_time(tpr_obs + 40.0)

        feats["color_gr_rpeak_p20_obs"] = cgr20
        feats["color_ri_rpeak_p20_obs"] = cri20
        feats["color_gr_rpeak_p40_obs"] = cgr40
        feats["color_ri_rpeak_p40_obs"] = cri40

        # Simple finite-difference slopes (color change per day)
        def slope(c1, c2, dt):
            if np.isfinite(c1) and np.isfinite(c2):
                return float((c2 - c1) / dt)
            return np.nan

        feats["color_gr_slope20_obs"] = slope(cgr0, cgr20, 20.0)
        feats["color_ri_slope20_obs"] = slope(cri0, cri20, 20.0)
        feats["color_gr_slope40_obs"] = slope(cgr0, cgr40, 40.0)
        feats["color_ri_slope40_obs"] = slope(cri0, cri40, 40.0)
    else:
        feats["color_gr_at_rpeak_obs"] = np.nan
        feats["color_ri_at_rpeak_obs"] = np.nan
        feats["color_gr_rpeak_p20_obs"] = np.nan
        feats["color_ri_rpeak_p20_obs"] = np.nan
        feats["color_gr_rpeak_p40_obs"] = np.nan
        feats["color_ri_rpeak_p40_obs"] = np.nan
        feats["color_gr_slope20_obs"] = np.nan
        feats["color_ri_slope20_obs"] = np.nan
        feats["color_gr_slope40_obs"] = np.nan
        feats["color_ri_slope40_obs"] = np.nan

    # SED slope features:
    if np.isfinite(tpr_obs):
        sed_slope, sed_int, sed_r2, sed_n = sed_logflux_vs_loglambda_at_time(
            band_tb_obs, band_fb, band_eb, tpr_obs, z=z
        )
        feats["sed_logflux_loglambda_slope_rpeak"] = sed_slope
        feats["sed_logflux_loglambda_r2_rpeak"] = sed_r2
        feats["sed_logflux_loglambda_nbands_rpeak"] = sed_n

        sed_slope20, sed_int20, sed_r2_20, sed_n20 = sed_logflux_vs_loglambda_at_time(
            band_tb_obs, band_fb, band_eb, tpr_obs + 20.0, z=z
        )
        feats["sed_slope_rpeak_p20"] = sed_slope20
        feats["sed_r2_rpeak_p20"] = sed_r2_20
        feats["sed_nbands_rpeak_p20"] = sed_n20
    else:
        feats["sed_logflux_loglambda_slope_rpeak"] = np.nan
        feats["sed_logflux_loglambda_r2_rpeak"] = np.nan
        feats["sed_logflux_loglambda_nbands_rpeak"] = np.nan
        feats["sed_slope_rpeak_p20"] = np.nan
        feats["sed_r2_rpeak_p20"] = np.nan
        feats["sed_nbands_rpeak_p20"] = np.nan

    return feats

In [None]:
from pathlib import Path

def build_lightcurve_cache(splits, base_dir, kind="train"):
    base_dir = Path(base_dir)
    lc_cache = {}
    idx_cache = {}
    for s in splits:
        path = base_dir / str(s) / f"{kind}_full_lightcurves.csv"
        lc = pd.read_csv(path)
        lc["object_id"] = lc["object_id"].astype(str)
        groups = lc.groupby("object_id").indices
        lc_cache[s] = lc
        idx_cache[s] = groups
    return lc_cache, idx_cache


def get_lightcurve(lc_cache, idx_cache, split, object_id):
    object_id = str(object_id)
    idx = idx_cache[split].get(object_id, None)
    if idx is None:
        return None
    return lc_cache[split].iloc[idx]

## Feature Table Construction + Photo-z Augmentation

Build a row per object by pulling its lightcurve from the cached split CSV.

Augmentation:
- For each training object, create `N_AUG` additional rows
- Sample `sigma` from test `Z_err` pool
- Set `z_sim = z0 + Normal(0, sigma)`
- Mark augmented rows with `photoz_aug = 1`

In [None]:
def build_feature_table(
    log_df,
    lc_cache,
    idx_cache,
    augment_photoz=False,
    test_zerr_pool=None,
    n_aug=2,
    seed=6
):
    rng = np.random.default_rng(seed)
    rows = []

    if test_zerr_pool is not None:
        test_zerr_pool = np.asarray(test_zerr_pool, float)
        test_zerr_pool = test_zerr_pool[np.isfinite(test_zerr_pool)]
        test_zerr_pool = test_zerr_pool[test_zerr_pool > 0]

    for i in range(len(log_df)):
        r = log_df.iloc[i]
        obj = str(r["object_id"])
        split = r["split"]

        lc = get_lightcurve(lc_cache, idx_cache, split, obj)
        if lc is None:
            feats = {"n_obs": 0}
            feats["object_id"] = obj
            feats["split"] = split
            feats["photoz_aug"] = 0
            if "target" in log_df.columns:
                feats["target"] = int(r["target"])
            rows.append(feats)
            continue

        feats = extract_features_for_object(
            lc_raw=lc,
            z=r["Z"],
            z_err=r.get("Z_err", 0.0),
            ebv=r["EBV"],
        )
        feats["object_id"] = obj
        feats["split"] = split
        feats["photoz_aug"] = 0
        if "target" in log_df.columns:
            feats["target"] = int(r["target"])
        rows.append(feats)

        if augment_photoz and ("target" in log_df.columns) and (test_zerr_pool is not None) and (len(test_zerr_pool) > 0):
            z0 = safe_float(r["Z"], default=0.0)
            for _ in range(n_aug):
                sigma = float(rng.choice(test_zerr_pool))
                z_sim = max(0.0, z0 + float(rng.normal(0.0, sigma)))
                feats2 = extract_features_for_object(
                    lc_raw=lc,
                    z=z_sim,
                    z_err=sigma,
                    ebv=r["EBV"],
                )
                feats2["object_id"] = obj
                feats2["split"] = split
                feats2["target"] = int(r["target"])
                feats2["photoz_aug"] = 1
                rows.append(feats2)

    return pd.DataFrame(rows)

## Feature Cleaning (with Missing Flags)

Unlike earlier models that median-filled immediately, Model 5 keeps NaNs and adds explicit missing indicators.

`clean_features(..., add_missing_flags=True)`:
- replaces ±inf with NaN
- appends `feature_isnan` columns

This helps tree models learn patterns like:
- "missing band features"
- "failed Bazin fit"
- "no valid SED bands at peak"

In [None]:
def clean_features(df, drop_cols, add_missing_flags=True):
    X = df.drop(columns=drop_cols).copy()
    X = X.replace([np.inf, -np.inf], np.nan)

    if add_missing_flags:
        miss = X.isna().astype(np.uint8)
        miss.columns = [c + "_isnan" for c in miss.columns]
        X = pd.concat([X, miss], axis=1)

    return X

In [None]:
def best_threshold_f1(y_true, probs):
    ths = np.linspace(0.01, 0.99, 401)
    f1s = [f1_score(y_true, probs > t, zero_division=0) for t in ths]
    j = int(np.argmax(f1s))
    return float(ths[j]), float(f1s[j])

## SpecType Teacher Features (Expanded Classes)

Model 5 repeats the legal teacher stacking approach, but expands:
- Adds SLSN and SNII
- Adds `spec_topprob` (max class probability) as a confidence feature

In [None]:
def add_spectype_teacher_features(train_feat, train_log, test_feat, n_splits=10, seed=6):

    # Join spectroscopic labels onto the feature table
    df = train_feat.merge(train_log[["object_id", "SpecType"]], on="object_id", how="left")
    spec = df["SpecType"].fillna("Unknown").astype(str)

    # Map detailed SpecType strings into a richer set of coarse groups
    # (separates SLSN and SNII explicitly vs earlier models)
    def map_group(s):
        s2 = s.strip()
        if s2 == "TDE":
            return "TDE"
        if s2 == "AGN":
            return "AGN"
        if "SLSN" in s2:
            return "SLSN"
        if s2 == "SN Ia" or s2.startswith("SN Ia"):
            return "SNIa"
        if s2.startswith("SN II") or ("SN II" in s2):
            return "SNII"
        if s2.startswith("SN"):
            return "SNother"
        return "Other"

    spec_group = spec.map(map_group).astype(str)

    # Encode group labels into integers for multiclass training
    classes = sorted(spec_group.unique())
    class_to_idx = {c: i for i, c in enumerate(classes)}
    y_mc = spec_group.map(class_to_idx).to_numpy()

    # Build train/test matrices using only shared columns
    # add_missing_flags=True exposes “isnan” indicators as extra features
    X_tr = clean_features(df, drop_cols=["object_id", "split", "target", "SpecType"], add_missing_flags=True)
    X_te = clean_features(test_feat, drop_cols=["object_id", "split"], add_missing_flags=True)

    # Use split folders as groups to prevent leakage across simulated splits
    groups = df["split"].to_numpy()

    # Splitter
    splitter = StratifiedGroupKFold(n_splits, shuffle=True, random_state=seed)
    split_iter = splitter.split(X_tr, y_mc, groups)

    # Out-of-fold predicted probabilities for the teacher
    oof = np.zeros((len(X_tr), len(classes)), dtype=float)

    # LightGBM multiclass teacher configuration
    base = dict(
        objective="multiclass",
        num_class=len(classes),
        metric="multi_logloss",
        n_estimators=20000,
        learning_rate=0.03,
        num_leaves=63,
        min_child_samples=5,
        subsample=0.8,
        subsample_freq=1,
        colsample_bytree=0.8,
        reg_alpha=0.0,
        reg_lambda=0.0,
        n_jobs=-1,
        random_state=seed,
        verbosity=-1,
        force_col_wise=True
    )

    # Train teacher in CV and collect OOF probabilities (legal stacking)
    for fold, (tr_idx, va_idx) in enumerate(split_iter, 1):
        model = LGBMClassifier(**base)
        model.fit(
            X_tr.iloc[tr_idx], y_mc[tr_idx],
            eval_set=[(X_tr.iloc[va_idx], y_mc[va_idx])],
            eval_metric="multi_logloss",
            callbacks=[lgb.early_stopping(200, verbose=False)]
        )

        # Predict probabilities on the validation fold using the best iteration
        oof[va_idx] = model.predict_proba(
            X_tr.iloc[va_idx],
            num_iteration=model.best_iteration_
        )

    # Fit on full training data and predict teacher probabilities for test
    model_full = LGBMClassifier(**base)
    model_full.fit(X_tr, y_mc)
    p_test = model_full.predict_proba(X_te)

    # Entropy summary: high entropy means the teacher is uncertain
    def entropy(p):
        p = np.clip(p, 1e-12, 1.0)
        return -np.sum(p * np.log(p), axis=1)

    # Append per-class probabilities as new features
    for i, c in enumerate(classes):
        train_feat[f"p_spec_{c}"] = oof[:, i]
        test_feat[f"p_spec_{c}"] = p_test[:, i]

    # Append teacher uncertainty + top-probability confidence
    train_feat["spec_entropy"] = entropy(oof)
    test_feat["spec_entropy"] = entropy(p_test)

    train_feat["spec_topprob"] = np.max(oof, axis=1)   # teacher confidence (train OOF)
    test_feat["spec_topprob"] = np.max(p_test, axis=1) # teacher confidence (test)

    return train_feat, test_feat

## Feature Selection (Top-K)

Before Optuna, Model 5 performs feature selection:

1) Train a reasonably strong XGB baseline per fold
2) Extract feature importance via `importance_type="gain"`
3) Sum gains across folds
4) Keep Top-K features (`FS_TOPK`)

This reduces search space and helps Optuna focus on the best subset. I don't have the computational power in order to test performance without feature selection.

In [None]:
def feature_select_gain_topk(train_feat, k=350, n_splits=10, seed=6):
    y = train_feat["target"].astype(int).to_numpy()
    groups = train_feat["split"].to_numpy()
    X = clean_features(train_feat, drop_cols=["object_id", "split", "target"], add_missing_flags=True)

    splitter = StratifiedGroupKFold(n_splits, shuffle=True, random_state=seed)
    split_iter = splitter.split(X, y, groups)

    gains = {c: 0.0 for c in X.columns}

    base_params = dict(
        objective="binary:logistic",
        eval_metric="aucpr",
        random_state=seed,
        n_jobs=-1,
        tree_method="hist",
        device="cuda",
        n_estimators=6000,
        learning_rate=0.02,
        max_depth=6,
        min_child_weight=10,
        subsample=0.9,
        colsample_bytree=0.9,
        reg_alpha=2.0,
        reg_lambda=2.0,
        gamma=0.0,
        max_bin=256,
    )

    for tr_idx, va_idx in split_iter:
        X_tr, y_tr = X.iloc[tr_idx], y[tr_idx]
        neg = np.sum(y_tr == 0)
        pos = np.sum(y_tr == 1)
        spw = float(neg / max(1, pos))

        model = XGBClassifier(**{**base_params, "scale_pos_weight": spw})
        model.fit(X_tr, y_tr, verbose=False)

        score = model.get_booster().get_score(importance_type="gain")
        for feat, g in score.items():
            if feat in gains:
                gains[feat] += float(g)

    ranked = sorted(gains.items(), key=lambda x: x[1], reverse=True)
    top = [f for f, _ in ranked[:k]]
    return top

## Optuna XGBoost Hyperparameter Tuning

This function runs Optuna using:
- Out-of-fold (OOF) F1 score as the optimization target
- StratifiedGroupKFold to prevent split/group leakage
- Selected feature subset only
- Early stopping + Optuna pruning

Key behavior:
- Cleans features and adds missing-value indicator flags
- Uses split labels as groups to keep objects from the same split together
- Computes `scale_pos_weight` per fold for class imbalance
- Tunes tree structure, sampling, regularization, and binning parameters
- Stores results in a persistent Optuna SQLite study
- Returns the best parameter set found within the time limit

In [None]:
def run_optuna_xgb_f1(train_feat, feature_cols, n_folds_tune=10, timeout_sec=28800, seed=6):
    y = train_feat["target"].astype(int).to_numpy()
    groups = train_feat["split"].to_numpy()

    X_all = clean_features(train_feat, drop_cols=["object_id", "split", "target"], add_missing_flags=True)
    X = X_all[feature_cols].copy()

    splitter = StratifiedGroupKFold(n_folds_tune, shuffle=True, random_state=seed)
    split_iter_all = list(splitter.split(X, y, groups))

    def objective(trial):
        params = {
            "objective": "binary:logistic",
            "eval_metric": "aucpr",
            "random_state": seed,
            "n_jobs": -1,
            "tree_method": "hist",
            "device": "cuda",
            "n_estimators": trial.suggest_int("n_estimators", 1500, 14000),
            "learning_rate": trial.suggest_float("learning_rate", 0.002, 0.08, log=True),
            "max_depth": trial.suggest_int("max_depth", 2, 10),
            "min_child_weight": trial.suggest_int("min_child_weight", 1, 80),
            "subsample": trial.suggest_float("subsample", 0.5, 1.0),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.4, 1.0),
            "colsample_bylevel": trial.suggest_float("colsample_bylevel", 0.5, 1.0),
            "colsample_bynode": trial.suggest_float("colsample_bynode", 0.5, 1.0),
            "max_bin": trial.suggest_int("max_bin", 128, 512),
            "gamma": trial.suggest_float("gamma", 0.0, 12.0),
            "reg_alpha": trial.suggest_float("reg_alpha", 0.0, 35.0),
            "reg_lambda": trial.suggest_float("reg_lambda", 0.05, 50.0),
            "max_delta_step": trial.suggest_int("max_delta_step", 0, 10),
            "grow_policy": trial.suggest_categorical("grow_policy", ["depthwise", "lossguide"]),
        }

        if params["grow_policy"] == "lossguide":
            params["max_leaves"] = trial.suggest_int("max_leaves", 16, 512)

        oof = np.zeros(len(X), dtype=float)
        f1_progress = []

        for fold, (tr_idx, va_idx) in enumerate(split_iter_all, 1):
            X_tr, y_tr = X.iloc[tr_idx], y[tr_idx]
            X_va, y_va = X.iloc[va_idx], y[va_idx]

            neg = np.sum(y_tr == 0)
            pos = np.sum(y_tr == 1)
            spw = float(neg / max(1, pos))

            model = XGBClassifier(**{**params, "scale_pos_weight": spw})
            model.fit(
                X_tr, y_tr,
                eval_set=[(X_va, y_va)],
                verbose=False,
                callbacks=[xgb.callback.EarlyStopping(rounds=200, save_best=True)]
            )

            oof[va_idx] = model.predict_proba(X_va)[:, 1]
            th_fold, f1_fold = best_threshold_f1(y_va, oof[va_idx])
            f1_progress.append(f1_fold)

            trial.report(float(np.mean(f1_progress)), step=fold)
            if trial.should_prune():
                raise optuna.TrialPruned()

        th, f1 = best_threshold_f1(y, oof)
        return float(f1)

    sampler = optuna.samplers.TPESampler(seed=seed, multivariate=True, group=True)
    pruner = optuna.pruners.MedianPruner(n_startup_trials=40, n_warmup_steps=3)

    study = optuna.create_study(
        direction="maximize",
        sampler=sampler,
        pruner=pruner,
        study_name="xgb_oof_f1_splitcv_gpu_selected",
        storage="sqlite:///optuna_xgb_oof_f1_gpu_selected.db",
        load_if_exists=True
    )

    study.optimize(objective, n_trials=999999, timeout=timeout_sec)

    print("\nOptuna best OOF F1:", study.best_value)
    print("Best params:")
    for k, v in study.best_params.items():
        print(k, "=", v)

    return study.best_params

## Multiseed XGBoost OOF Ensemble

This function trains a multiseed XGBoost ensemble using split-aware cross validation

- Cleans features and adds missing flags
- Uses selected feature subset only
- StratifiedGroupKFold prevents group leakage
- Multiple seeds per fold for prediction stability
- Averages seed predictions for OOF and test (Should have tested without averaging seeds as some said it lowered performance)
- Computes best OOF F1 threshold and AP
- Trains final multiseed models on full data for test predictions

In [None]:
def predict_xgb_multiseed(train_feat, test_feat, best_params, feature_cols, n_splits_oof=20, seeds=(6, 67, 6767)):
    y = train_feat["target"].astype(int).to_numpy()
    groups = train_feat["split"].to_numpy()

    X_all = clean_features(train_feat, drop_cols=["object_id", "split", "target"], add_missing_flags=True)
    X_test_all = clean_features(test_feat, drop_cols=["object_id", "split"], add_missing_flags=True)

    X = X_all[feature_cols].copy()
    X_test = X_test_all[feature_cols].copy()

    splitter = StratifiedGroupKFold(n_splits=n_splits_oof, shuffle=True, random_state=6)
    split_iter = list(splitter.split(X, y, groups))

    oof = np.zeros(len(X), dtype=float)

    for fold, (tr_idx, va_idx) in enumerate(split_iter, 1):
        X_tr, y_tr = X.iloc[tr_idx], y[tr_idx]
        X_va, y_va = X.iloc[va_idx], y[va_idx]
        
        neg = np.sum(y_tr == 0)
        pos = np.sum(y_tr == 1)
        spw = float(neg / max(1, pos))

        probs_va = []
        for sd in seeds:
            model = XGBClassifier(
                objective="binary:logistic",
                eval_metric="logloss",
                random_state=sd,
                n_jobs=-1,
                tree_method="hist",
                device="cuda",
                scale_pos_weight=spw,
                **best_params,
            )
            model.fit(X_tr, y_tr, eval_set=[(X_va, y_va)], verbose=False,
            )
            probs_va.append(model.predict_proba(X_va)[:, 1])

        oof[va_idx] = np.mean(probs_va, axis=0)

    best_th, best_f1 = best_threshold_f1(y, oof)
    ap = average_precision_score(y, oof)
    print("\nOOF multiseed best threshold:", best_th)
    print("OOF multiseed best F1:", best_f1)
    print("OOF AP (aucpr-ish):", ap)

    probs_test = []
    neg_full = np.sum(y == 0)
    pos_full = np.sum(y == 1)
    spw_full = float(neg_full / max(1, pos_full))

    for sd in seeds:
        model = XGBClassifier(
            objective="binary:logistic",
            eval_metric="aucpr",
            random_state=sd,
            n_jobs=-1,
            tree_method="hist",
            device="cuda",
            scale_pos_weight=spw_full,
            **best_params,
        )
        model.fit(X, y, verbose=False)
        probs_test.append(model.predict_proba(X_test)[:, 1])

    p_test = np.mean(probs_test, axis=0)
    return p_test, best_th

## Data Init and Feature Table Build

Outputs train/test feature tables and prints their shapes (Not including SpecType features and before feature selection).

In [None]:
N_AUG = 2
FS_TOPK = 380
FS_FOLDS = 10
OPTUNA_FOLDS = 10
OPTUNA_TIMEOUT_SEC = 28800
FINAL_OOF_FOLDS = 20
SEEDS = (6, 67, 6767)

ROOT = Path.cwd().parents[0]
DATA_DIR = ROOT / "data"

train_log = pd.read_csv(DATA_DIR / "train_log.csv")
test_log  = pd.read_csv(DATA_DIR / "test_log.csv")

train_log["object_id"] = train_log["object_id"].astype(str)
test_log["object_id"] = test_log["object_id"].astype(str)

train_log["Z_err"] = train_log["Z_err"].fillna(0.0)
test_log["Z_err"] = test_log["Z_err"].fillna(0.0)

train_splits = sorted(train_log["split"].unique())
test_splits = sorted(test_log["split"].unique())

train_lc_cache, train_idx_cache = build_lightcurve_cache(train_splits, base_dir=DATA_DIR, kind="train")
test_lc_cache, test_idx_cache = build_lightcurve_cache(test_splits, base_dir=DATA_DIR, kind="test")

test_zerr_pool = test_log["Z_err"].dropna().values

train_feat = build_feature_table(
    train_log, train_lc_cache, train_idx_cache,
    augment_photoz=True,
    test_zerr_pool=test_zerr_pool,
    n_aug=N_AUG,
    seed=6
)

test_feat = build_feature_table(
    test_log, test_lc_cache, test_idx_cache,
    augment_photoz=False
)
print(f"train shape: {train_feat.shape}")
print(f"test shape: {test_feat.shape}")


train shape: (9129, 559)
test shape: (7135, 558)


## SpecType Features and Feature Selection

In [14]:
train_feat, test_feat = add_spectype_teacher_features(train_feat, train_log, test_feat, n_splits=10, seed=6)
selected_cols = feature_select_gain_topk(train_feat, k=FS_TOPK, n_splits=FS_FOLDS, seed=6)

In [None]:
best_params = run_optuna_xgb_f1(
    train_feat,
    feature_cols=selected_cols,
    n_folds_tune=OPTUNA_FOLDS,
    timeout_sec=OPTUNA_TIMEOUT_SEC,
    seed=6
)

Best OOF F1: **0.6093264248704663**

Best Parameters:
```json
{
  "n_estimators": 9476,
  "learning_rate": 0.0024306289953670325,
  "max_depth": 7,
  "min_child_weight": 6,
  "subsample": 0.5344962939912224,
  "colsample_bytree": 0.464696420753079,
  "colsample_bylevel": 0.8146569410634974,
  "colsample_bynode": 0.7285475291884695,
  "max_bin": 181,
  "gamma": 8.476938947246458,
  "reg_alpha": 0.44957196104419117,
  "reg_lambda": 5.23806334613521,
  "max_delta_step": 0,
  "grow_policy": "depthwise"
}

Reinitializing best_params because kernel reset

In [20]:
best_params = {'n_estimators': 9476,
               'learning_rate': 0.0024306289953670325,
               'max_depth': 7, 'min_child_weight': 6,
               'subsample': 0.5344962939912224,
               'colsample_bytree': 0.464696420753079,
               'colsample_bylevel': 0.8146569410634974,
               'colsample_bynode': 0.7285475291884695,
               'max_bin': 181, 'gamma': 8.476938947246458,
               'reg_alpha': 0.44957196104419117,
               'reg_lambda': 5.23806334613521,
               'max_delta_step': 0,
               'grow_policy': 'depthwise'}

In [21]:
p_test, best_th = predict_xgb_multiseed(
    train_feat,
    test_feat,
    best_params,
    feature_cols=selected_cols,
    n_splits_oof=min(FINAL_OOF_FOLDS, len(train_splits)),
    seeds=(99, 999, 909)
)


OOF multiseed best threshold: 0.419
OOF multiseed best F1: 0.6243705941591138
OOF AP (aucpr-ish): 0.5164263302782434


In [22]:
test_pred = (p_test > best_th).astype(int)

sub = pd.DataFrame({
    "object_id": test_feat["object_id"].values,
    "target": test_pred
})
out_name = "XGB_multiseed_teacher5.csv"
sub.to_csv(out_name, index=False)
print("Saved", out_name, " threshold:", best_th)

Saved XGB_multiseed_teacher5.csv  threshold: 0.419
