In [None]:
!pip -q install lightkurve==2.* numpy pandas scipy astropy tqdm pyarrow

In [None]:
import numpy as np
import pandas as pd
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
from scipy.signal import savgol_filter
import lightkurve as lk

def _to_clean_arrays(time_in, flux_in):
    t = np.array(time_in, dtype=float)
    f = np.array(flux_in, dtype=float)
    m = np.isfinite(t) & np.isfinite(f)
    t, f = t[m], f[m]
    if len(t) == 0:
        return t, f
    order = np.argsort(t)
    return t[order], f[order]

def _median_cadence_min(time):
    if len(time) < 2: return 30.0
    dt_days = np.nanmedian(np.diff(time))
    return float(dt_days * 24.0 * 60.0)

def _savgol_detrend(time, flux, window_length=401, polyorder=2):
    if len(flux) < 7:
        return time, flux
    wl = max(5, int(window_length))
    wl = wl if wl % 2 == 1 else wl + 1
    wl = min(wl, len(flux) - (1 - len(flux) % 2))
    if wl < 5:
        return time, flux
    trend = savgol_filter(flux, window_length=wl, polyorder=polyorder, mode="interp")
    med = np.nanmedian(trend) if np.isfinite(trend).any() else 1.0
    trend = np.where(trend == 0, med, trend)
    return time, flux / trend - 1.0

def _lk_flatten(time, flux, window_length=401, polyorder=2, break_tolerance=5):
    try:
        lc = lk.LightCurve(time=time, flux=flux)
        wl = max(5, int(window_length))
        wl = wl if wl % 2 == 1 else wl + 1
        lc_flat = lc.flatten(window_length=wl, polyorder=polyorder, break_tolerance=break_tolerance)
        return np.array(lc_flat.time.value), np.array(lc_flat.flux.value)
    except Exception:
        return _savgol_detrend(time, flux, window_length=window_length, polyorder=polyorder)

def detrend_if_needed(time, flux, already_detrended: bool, cadence_min: Optional[float]=None,
                      method="lightkurve", max_transit_hours=10.0):
    t, f = _to_clean_arrays(time, flux)
    if already_detrended:
        return t, f - np.nanmedian(f)
    if cadence_min is None: cadence_min = _median_cadence_min(t)
    target_minutes = max_transit_hours * 60.0 * 2.5
    wl = max(51, int(round(target_minutes / max(cadence_min, 1e-6))))
    wl = wl if wl % 2 == 1 else wl + 1
    if method == "lightkurve":
        t, fd = _lk_flatten(t, f, window_length=wl, polyorder=2)
    else:
        t, fd = _savgol_detrend(t, f, window_length=wl, polyorder=2)
    return t, fd - np.nanmedian(fd)

def sample_transit_params(
    period_range=(0.5, 30.0), duration_hours_range=(1.0, 10.0), depth_ppm_range=(200, 5000),
    t0: Optional[float]=None, rng: Optional[np.random.Generator]=None
) -> Dict[str, float]:
    rng = rng or np.random.default_rng()
    P = 10**rng.uniform(np.log10(period_range[0]), np.log10(period_range[1]))
    dur_h = 10**rng.uniform(np.log10(duration_hours_range[0]), np.log10(duration_hours_range[1]))
    depth_ppm = 10**rng.uniform(np.log10(depth_ppm_range[0]), np.log10(depth_ppm_range[1]))
    return {"period": P, "duration_days": dur_h/24.0, "depth": depth_ppm*1e-6, "t0": t0}

def box_transit_model(time, period, duration_days, depth, t0=None):
    t = np.asarray(time)
    if t0 is None:
        t0 = t.min() + 0.25*period
    phase = ((t - t0) % period) / period
    half = 0.5 * duration_days / period  # fraction of phase
    in_transit = (phase < half) | (phase > 1 - half)
    model = np.zeros_like(t, dtype=float)
    model[in_transit] = -depth
    return model

def inject_transit(flux_detrended, model):
    return flux_detrended + model

def add_white_noise(flux, sigma_ppm=300, rng=None):
    rng = rng or np.random.default_rng()
    sigma = sigma_ppm * 1e-6
    return flux + rng.normal(0.0, sigma, size=len(flux))

def add_red_noise_ar1(flux, rho=0.5, sigma_ppm=300, rng=None):
    rng = rng or np.random.default_rng()
    sigma = sigma_ppm * 1e-6
    eps_sd = sigma * np.sqrt(1 - rho**2)
    red = np.zeros_like(flux, dtype=float)
    for i in range(1, len(flux)):
        red[i] = rho * red[i-1] + rng.normal(0.0, eps_sd)
    return flux + red

def insert_cadence_gaps(time, flux, gap_prob=0.01, block_drop_prob=0.02, block_len=50, rng=None):
    rng = rng or np.random.default_rng()
    keep = np.ones(len(time), dtype=bool)
    single = rng.uniform(size=len(time)) > gap_prob
    keep &= single
    if rng.uniform() < block_drop_prob and len(time) > block_len:
        start = rng.integers(0, len(time)-block_len)
        keep[start:start+block_len] = False
    return time[keep], flux[keep]

def inject_flare(flux, n_flares=1, amp_ppm_range=(200, 5000), tau_hours_range=(0.2, 3.0), cadence_min=30.0, rng=None):
    rng = rng or np.random.default_rng()
    f = flux.copy()
    n = len(f)
    for _ in range(n_flares):
        idx0 = rng.integers(low=0, high=max(1, n-3))
        amp = 10**rng.uniform(np.log10(amp_ppm_range[0]), np.log10(amp_ppm_range[1])) * 1e-6
        tau_h = 10**rng.uniform(np.log10(tau_hours_range[0]), np.log10(tau_hours_range[1]))
        tau_cad = max(1.0, tau_h * 60.0 / max(cadence_min, 1e-6))
        k = np.arange(0, min(int(8*tau_cad), n-idx0))
        f[idx0:idx0+len(k)] += amp * np.exp(-k / tau_cad)
    return f


In [None]:
def augment_block(row: pd.Series,
                  already_detrended: bool,
                  n_pos: int = 2,
                  n_neg: int = 2,
                  rng: Optional[np.random.Generator] = None,
                  # transit sampling
                  period_range=(0.5, 30.0),
                  duration_hours_range=(1.0, 10.0),
                  depth_ppm_range=(200, 5000),
                  # noise knobs
                  white_sigma_ppm=300,
                  red_rho=0.5,
                  red_sigma_ppm=200,
                  flare_prob=0.2,
                  max_flares=2,
                  gap_prob=0.01,
                  block_drop_prob=0.02,
                  block_len=50,
                  detrend_method="lightkurve",
                  max_transit_hours=10.0) -> List[Dict[str, Any]]:

    rng = rng or np.random.default_rng()
    time = row.get("time"); flux = row.get("flux")
    if time is None or flux is None:
        return []
    cadence_min = row.get("cadence_min") or _median_cadence_min(np.array(time))
    t, f0 = detrend_if_needed(time, flux, already_detrended, cadence_min, method=detrend_method, max_transit_hours=max_transit_hours)

    out = []

    for _ in range(max(0, int(n_pos))):
        pars = sample_transit_params(period_range, duration_hours_range, depth_ppm_range, rng=rng)
        model = box_transit_model(t, pars["period"], pars["duration_days"], pars["depth"], pars["t0"])
        f = inject_transit(f0, model)
        f = add_white_noise(f, sigma_ppm=white_sigma_ppm, rng=rng)
        f = add_red_noise_ar1(f, rho=red_rho, sigma_ppm=red_sigma_ppm, rng=rng)
        if rng.uniform() < flare_prob:
            f = inject_flare(f, n_flares=rng.integers(1, max_flares+1), cadence_min=cadence_min, rng=rng)

        t_aug, f_aug = insert_cadence_gaps(t, f, gap_prob=gap_prob, block_drop_prob=block_drop_prob, block_len=block_len, rng=rng)

        out.append({
            "obs_block_id": row.get("obs_block_id"),
            "target_id": row.get("target_id"),
            "mission": row.get("mission"),
            "sector": row.get("sector"),
            "quarter": row.get("quarter"),
            "campaign": row.get("campaign"),
            "label_aug": 1,
            "aug_type": "transit+noise",
            "time": t_aug.tolist(),
            "flux": f_aug.tolist(),
            "cadence_min": float(cadence_min),
            # provenance
            "period": float(pars["period"]),
            "duration_days": float(pars["duration_days"]),
            "depth": float(pars["depth"]),
            "white_sigma_ppm": float(white_sigma_ppm),
            "red_rho": float(red_rho),
            "red_sigma_ppm": float(red_sigma_ppm),
            "flare_used": float(1.0 if "n_flares" else 0.0),
            "gap_prob": float(gap_prob),
            "block_drop_prob": float(block_drop_prob),
        })

    for _ in range(max(0, int(n_neg))):
        f = f0.copy()
        f = add_white_noise(f, sigma_ppm=white_sigma_ppm, rng=rng)
        f = add_red_noise_ar1(f, rho=red_rho, sigma_ppm=red_sigma_ppm, rng=rng)
        if rng.uniform() < flare_prob:
            f = inject_flare(f, n_flares=rng.integers(1, max_flares+1), cadence_min=cadence_min, rng=rng)
        t_aug, f_aug = insert_cadence_gaps(t, f, gap_prob=gap_prob, block_drop_prob=block_drop_prob, block_len=block_len, rng=rng)

        out.append({
            "obs_block_id": row.get("obs_block_id"),
            "target_id": row.get("target_id"),
            "mission": row.get("mission"),
            "sector": row.get("sector"),
            "quarter": row.get("quarter"),
            "campaign": row.get("campaign"),
            "label_aug": 0,
            "aug_type": "noise-only",
            "time": t_aug.tolist(),
            "flux": f_aug.tolist(),
            "cadence_min": float(cadence_min),
            "white_sigma_ppm": float(white_sigma_ppm),
            "red_rho": float(red_rho),
            "red_sigma_ppm": float(red_sigma_ppm),
            "flare_used": float(1.0 if "n_flares" else 0.0),
            "gap_prob": float(gap_prob),
            "block_drop_prob": float(block_drop_prob),
        })

    return out

def build_augmented_dataset(
    vet_df: pd.DataFrame,
    already_detrended: bool = False,
    n_pos_per_block: int = 2,
    n_neg_per_block: int = 2,
    rng_seed: int = 42,
    **kwargs
) -> pd.DataFrame:
    rng = np.random.default_rng(rng_seed)
    recs: List[Dict[str, Any]] = []
    for _, row in tqdm(vet_df.iterrows(), total=len(vet_df), desc="Augmenting"):
        try:
            recs.extend(augment_block(row, already_detrended, n_pos_per_block, n_neg_per_block, rng=rng, **kwargs))
        except Exception as e:
            print(f"[WARN] augment failed for {row.get('obs_block_id','?')}: {repr(e)}")
    return pd.DataFrame.from_records(recs)


In [None]:
from google.colab import drive
drive.mount('/content/drive')

VET_PARQUET = "/content//vetting_kepler.parquet"  # update path
AUG_OUT_PARQUET = "/content/drive/augmented_training.parquet"

vet_df = pd.read_parquet(VET_PARQUET)
print("Loaded blocks:", len(vet_df))
display(vet_df.head(2))
aug_df = build_augmented_dataset(
    vet_df,
    already_detrended=False,
    n_pos_per_block=2,
    n_neg_per_block=2,
    rng_seed=123,
    period_range=(0.5, 30.0),
    duration_hours_range=(1.0, 10.0),
    depth_ppm_range=(200, 5000),
    white_sigma_ppm=300,
    red_rho=0.5,
    red_sigma_ppm=200,
    flare_prob=0.2,
    max_flares=2,
    gap_prob=0.01,
    block_drop_prob=0.02,
    block_len=50,
    detrend_method="lightkurve",
    max_transit_hours=10.0,
)

print("Augmented samples:", len(aug_df))
display(aug_df.head(3))
aug_df.to_parquet(AUG_OUT_PARQUET, index=False)
print("Saved augmented set →", AUG_OUT_PARQUET)

In [None]:
import matplotlib.pyplot as plt

def quick_plot_sample(sample_row):
    t = np.array(sample_row["time"])
    f = np.array(sample_row["flux"])
    plt.figure(figsize=(10,3))
    plt.plot(t, f, lw=0.5)
    plt.xlabel("Time (BKJD/BTJD)")
    plt.ylabel("Flux (detrended)")
    title = f"{sample_row.get('aug_type','?')}  label={sample_row.get('label_aug','?')}"
    if "period" in sample_row and not np.isnan(sample_row.get("period", np.nan)):
        title += f"  P={sample_row['period']:.2f} d"
    plt.title(title)
    plt.show()

if not aug_df.empty:
    quick_plot_sample(aug_df.iloc[0])