#### Imports

In [1]:
import os, json, math, time
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, random_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
torch.manual_seed(42); np.random.seed(42)
print("Device:", device)

Device: cuda


In [2]:
from cnn import CNN1DEncoder, AstronetMVP
from dataset import ExoplanetDataset, collate

In [3]:
##### FAKE DATA
import numpy as np
import torch

# ---------- helpers: signal models ----------

def trapezoid_transit(phase, depth_ppm, dur_frac, ingress_frac=0.2, center=0.0):
    """
    Simple trapezoid transit model centered at 'center' (in phase units).
    phase: array in [-0.5, 0.5]
    depth_ppm: positive number (e.g., 500..10000)
    dur_frac: total transit duration as fraction of phase (e.g., ~ 0.01)
    ingress_frac: fraction of duration spent in ingress+egress (0..1)
    """
    depth = depth_ppm * 1e-6  # convert ppm to relative flux
    half = dur_frac / 2.0
    p = phase - center

    # define trapezoid edges
    flat_half = max(half * (1 - ingress_frac), 0.0)
    left_ing_start  = -half
    left_ing_end    = -flat_half
    right_ing_start =  flat_half
    right_ing_end   =  half

    # shifted piecewise depth profile
    y = np.zeros_like(p)
    # flat bottom
    mask_flat = (p >= left_ing_end) & (p <= right_ing_start)
    y[mask_flat] = -depth
    # ingress
    mask_ing = (p >= left_ing_start) & (p < left_ing_end)
    y[mask_ing] = -depth * ( (p[mask_ing] - left_ing_start) / (left_ing_end - left_ing_start) )
    # egress
    mask_egr = (p > right_ing_start) & (p <= right_ing_end)
    y[mask_egr] = -depth * ( 1 - (p[mask_egr] - right_ing_start) / (right_ing_end - right_ing_start) )
    return y

def eb_like_signal(phase, depth_ppm_primary, depth_ppm_secondary, dur_frac, v_shape=True):
    """
    Eclipsing-binary-like signal: primary dip at 0, secondary at 0.5 phase.
    v_shape=True makes narrower ingress/egress (sharper V).
    """
    ingress_frac = 0.9 if v_shape else 0.3
    y1 = trapezoid_transit(phase, depth_ppm_primary, dur_frac, ingress_frac=ingress_frac, center=0.0)
    y2 = trapezoid_transit(phase, depth_ppm_secondary, dur_frac, ingress_frac=ingress_frac, center=0.5 if np.max(phase) > 0.25 else -0.5)
    return y1 + y2

def colored_noise(n, alpha=1.0, scale=1.0, rng=None):
    """
    1/f^alpha-ish noise via FFT shaping. alpha=0 -> white, 1 -> pinkish.
    """
    rng = np.random.default_rng(None if rng is None else rng)
    # random phases
    rand = rng.normal(size=n) + 1j * rng.normal(size=n)
    # frequency bins
    freqs = np.fft.rfftfreq(n)
    mag = np.ones_like(freqs)
    mag[1:] = 1.0 / np.power(freqs[1:], alpha)
    spec = np.zeros(n//2 + 1, dtype=np.complex128)
    spec.real = rand[:spec.size].real * mag
    spec.imag = rand[:spec.size].imag * mag
    series = np.fft.irfft(spec, n=n)
    series = series / (np.std(series) + 1e-8)
    return scale * series

# ---------- generator ----------

def generate_synthetic_exoplanet_dataset(
    N=1024,
    Tg=2001, Tl=201,
    pos_frac=0.5,
    tabular_dim=4,
    rng_seed=42
):
    """
    Returns:
      global_curves: torch.float32 [N,1,Tg]
      local_curves : torch.float32 [N,1,Tl]
      labels       : torch.float32 [N] (0/1)
      ids          : list[str] length N
      tabular      : torch.float32 [N, F]  (or None if tabular_dim=0)

    All curves are normalized around ~1.0 and *folded* into phase space.
    The local view is a zoom around phase=0 whose width scales with transit duration.
    """
    rng = np.random.default_rng(rng_seed)

    # Phase axes
    phase_g = np.linspace(-0.5, 0.5, Tg, dtype=np.float64)

    global_curves = np.empty((N, 1, Tg), dtype=np.float32)
    local_curves  = np.empty((N, 1, Tl), dtype=np.float32)
    labels        = np.zeros((N,), dtype=np.float32)
    ids           = [f"synth_{i:06d}" for i in range(N)]
    tabular       = None if tabular_dim == 0 else np.empty((N, tabular_dim), dtype=np.float32)

    for i in range(N):
        is_pos = (rng.random() < pos_frac)
        labels[i] = 1.0 if is_pos else 0.0

        # Random astrophysical-ish parameters
        period_days   = float(rng.uniform(0.5, 20.0))
        depth_ppm     = float(rng.uniform(300, 10000)) if is_pos else float(rng.uniform(200, 15000))
        dur_hours     = float(rng.uniform(1.0, 6.0))
        dur_frac      = dur_hours / (24.0 * period_days)  # duration as fraction of phase
        dur_frac      = float(np.clip(dur_frac, 0.002, 0.05))
        ingress_frac  = float(rng.uniform(0.1, 0.3))      # smoother U-shape for planets

        # Noise level relative to depth
        snr_target    = float(rng.uniform(8.0, 50.0)) if is_pos else float(rng.uniform(3.0, 30.0))
        noise_std     = (depth_ppm * 1e-6) / max(snr_target, 1.0)

        # Base flux ~1 with colored noise
        flux = 1.0 + colored_noise(Tg, alpha=rng.uniform(0.5, 1.2), scale=noise_std, rng=rng)

        if is_pos:
            # Planet-like: single trapezoid at phase=0
            transit = trapezoid_transit(phase_g, depth_ppm, dur_frac, ingress_frac=ingress_frac, center=0.0)
            flux = flux + transit
        else:
            # Negatives: half pure noise, half EB-like
            if rng.random() < 0.5:
                # EB-like: primary + secondary
                depth2 = depth_ppm * rng.uniform(0.3, 0.9)
                flux = flux + eb_like_signal(phase_g, depth_ppm, depth2, dur_frac, v_shape=True)
            else:
                # Just noise / variability (already in 'flux')
                pass

        # Normalize roughly to median ~1
        med = np.median(flux)
        if med != 0:
            flux = flux / med

        # Build local view around phase 0 with width scaled by duration
        k = 4.0  # half-width in units of duration (wider => more context)
        half_width = max(k * dur_frac, 2.0 / Tg)  # ensure at least a couple points
        phase_l = np.linspace(-half_width, +half_width, Tl, dtype=np.float64)
        # interpolate global flux onto local phase window
        local_flux = np.interp(phase_l, phase_g, flux)

        # Store
        global_curves[i, 0, :] = flux.astype(np.float32)
        local_curves[i, 0, :]  = local_flux.astype(np.float32)

        # Tabular features (simple, consistent with your 4-col example)
        if tabular is not None:
            # [period_days, depth_ppm, dur_hours, T_eff (dummy)]
            teff = float(rng.uniform(3000, 7000))
            tabular[i, :] = np.array([period_days, depth_ppm, dur_hours, teff], dtype=np.float32)

    # Convert to tensors
    global_t = torch.from_numpy(global_curves)
    local_t  = torch.from_numpy(local_curves)
    labels_t = torch.from_numpy(labels)
    tab_t    = None if tabular is None else torch.from_numpy(tabular)

    return global_t, local_t, labels_t, ids, tab_t








###### -- - - - -  

#### Data

to-do: 
get:
- global curves
- local curves
- labels
- ids
- tabular data

In [4]:
# create data set

# TODO: replace with real data

# ds = ExoplanetDataset(
#     global_curves = ...,
#     local_curves = ...,
#     labels = ...,
#     ids = ...,
#     tabular = ...
# )

global_t, local_t, labels_t, ids_seq, tab_t = generate_synthetic_exoplanet_dataset(
    N=2000, Tg=2001, Tl=201, pos_frac=0.5, tabular_dim=4, rng_seed=123
)

# Split
idx = np.arange(len(labels_t))
np.random.default_rng(123).shuffle(idx)
split = int(0.8 * len(idx))
train_idx, test_idx = idx[:split], idx[split:]

def take(arr, ix):
    return arr[ix] if isinstance(arr, np.ndarray) else arr[torch.as_tensor(ix, dtype=torch.long)] if torch.is_tensor(arr) else [arr[i] for i in ix]

global_train, global_test = global_t[train_idx], global_t[test_idx]
local_train,  local_test  = local_t[train_idx],  local_t[test_idx]
labels_train, labels_test = labels_t[train_idx], labels_t[test_idx]
ids_train   = [ids_seq[i] for i in train_idx]
ids_test    = [ids_seq[i] for i in test_idx]
tab_train   = None if tab_t is None else tab_t[train_idx]
tab_test    = None if tab_t is None else tab_t[test_idx]

ds_train = ExoplanetDataset(global_train, local_train, labels_train, ids_train, tabular=tab_train)
ds_test  = ExoplanetDataset(global_test,  local_test,  labels_test,  ids_test,  tabular=tab_test)

In [5]:
# create train test split from data set
#TODO: use real data
# train_size = int(0.8 * len(ds))
# test_size = len(ds) - train_size

# ds_train, ds_test = random_split(ds, [train_size, test_size], generator = torch.Generator().manual_seed(42))

dl_train = DataLoader(ds_train, batch_size = 64, shuffle=True, collate_fn=collate, drop_last=False)
dl_test = DataLoader(ds_test, batch_size = 64, shuffle=True, collate_fn=collate, drop_last=False)


In [6]:
# loss, optimizer, pos weight from train set
def compute_pos_weight_from_loader(dloader):
    pos = neg = 0
    with torch.no_grad():
        for _xg, _xl, _xt, _y, _ids in dloader:
            pos += (_y == 1).sum().item()
            neg += (_y == 0).sum().item()
    if pos == 0 or neg == 0: return None
    return torch.tensor([neg / max(pos, 1.0)], dtype=torch.float32, device=device)

# Infer tabular_dim from one train batch
xg0, xl0, xt0, y0, ids0 = next(iter(dl_train))
tabular_dim = 0 if xt0 is None else xt0.shape[1]

model = AstronetMVP(hidden=64, k_global=7, k_local=5, tabular_dim=tabular_dim).to(device)

pos_weight = compute_pos_weight_from_loader(dl_train)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) if pos_weight is not None else nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [7]:
# make & save config / checkpoints dir
save_dir = "checkpoints"
os.makedirs(save_dir, exist_ok=True)

# make & save config
config = {
    "hidden":64,
    "k_global": 7,
    "k_local": 5,
    "tabular_dim": tabular_dim,
    "lr": 1e-3,
    "wd": 1e-5,
    "pos_weight": None if pos_weight is None else float(pos_weight.item())
}
json.dump(config, open(os.path.join(save_dir, "config.json"), "w"), indent = 2)

In [8]:
# train loop: train + save checkpoints

def train_epoch(model, dloader, optim, criterion):
    model.train()
    total, n = 0.0, 0

    # feed train data
    for xg, xl, xt, y, _ids in dloader:

        # move to gpu
        xg = xg.to(device)
        xl = xl.to(device)
        y = y.to(device).float().view(-1) # expected

        # feed
        xt = None if xt is None else xt.to(device).float()
        logits = model(xg, xl, xt).squeeze(1) # actual
        loss = criterion(logits, y)

        #compute optim
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), max_norm=5.0)
        optimizer.step()

        bs = y.size(0)
        total += loss.item() * bs
        n+= bs

    return total / max(n,1)


num_epochs = 10
for epoch in range(1, num_epochs+1):
    t0 = time.time()
    tr_loss = train_epoch(model, dl_train, optimizer, criterion)
    dt = time.time() - t0

    print(f"Epoch {epoch:02d}/{num_epochs} | train_loss={tr_loss:.4f} | {dt:.1f}s")

# Save final checkpoint
ckpt_path = os.path.join(save_dir, "model_final.pt")
torch.save({"model_state": model.state_dict(), "config": config}, ckpt_path)
print("Saved:", ckpt_path)

TypeError: adaptive_avg_pool1d(): argument 'input' (position 1) must be Tensor, not Sequential

In [None]:
# create test set predictions for eval
@torch.no_grad()
def predict_on_loader(model, dloader):
    model.eval()
    out = []  # list of dicts: {id, logit, prob}
    for xg, xl, xt, y, ids in dloader:
        xg = xg.to(device); xl = xl.to(device)
        xt = None if xt is None else xt.to(device).float()
        logits = model(xg, xl, xt).squeeze(1)          # [B]
        probs  = torch.sigmoid(logits)
        for _id, lo, pr in zip(ids, logits.cpu().tolist(), probs.cpu().tolist()):
            out.append({"id": _id, "logit": float(lo), "prob": float(pr)})
    return out

test_preds = predict_on_loader(model, dl_test)
print("Test preds:", len(test_preds))

# save it as csv
# import csv
# pred_csv = os.path.join(save_dir, "test_predictions.csv")
# with open(pred_csv, "w", newline="") as f:
#     w = csv.DictWriter(f, fieldnames=["id","logit","prob"])
#     w.writeheader()
#     w.writerows(test_preds)
# print("Saved:", pred_csv)

In [None]:
## EVAL

# get ground truth labels for test ids
labels_by_id = {}
for _xg, _xl, _xt, _y, _ids in dl_test:
    for i, y in zip(_ids, _y.tolist()):
        labels_by_id[i] = int(y)
len(labels_by_id), list(list(labels_by_id.items())[:3])

# get predictions for labels
ids  = [d["id"] for d in test_preds if d["id"] in labels_by_id]
probs = [d["prob"] for d in test_preds if d["id"] in labels_by_id]
labels = [labels_by_id[i] for i in ids]

print(f"N test: {len(labels)} | Positives: {sum(labels)} | Negatives: {len(labels)-sum(labels)}")

In [None]:
import numpy as np

def metrics_at_threshold(y, p, thr=0.5): #eval metrics at some thresh
    y = np.asarray(y, dtype=int)
    p = np.asarray(p, dtype=float)
    pred = (p >= thr).astype(int)
    tp = int(((pred==1)&(y==1)).sum())
    fp = int(((pred==1)&(y==0)).sum())
    tn = int(((pred==0)&(y==0)).sum())
    fn = int(((pred==0)&(y==1)).sum())
    acc  = (tp+tn)/max(len(y),1)
    prec = tp/max(tp+fp,1) if (tp+fp)>0 else 0.0
    rec  = tp/max(tp+fn,1) if (tp+fn)>0 else 0.0
    f1   = (2*prec*rec)/(prec+rec) if (prec+rec)>0 else 0.0
    return dict(threshold=thr, tp=tp, fp=fp, tn=tn, fn=fn, acc=acc, precision=prec, recall=rec, f1=f1)

def roc_auc(y, p):# rank based auc
    y = np.asarray(y, dtype=int)
    p = np.asarray(p, dtype=float)
    n_pos = (y==1).sum(); n_neg = (y==0).sum()
    if n_pos==0 or n_neg==0: return float("nan")
    order = np.argsort(p)
    ranks = np.empty_like(order); ranks[order] = np.arange(1, len(p)+1)
    sum_ranks_pos = ranks[y==1].sum()
    return float((sum_ranks_pos - n_pos*(n_pos+1)/2) / (n_pos*n_neg))


# get metrics / rank + print
m = metrics_at_threshold(labels, probs, thr=0.5)
auc = roc_auc(labels, probs)

print(f"AUC: {auc:.4f}")
print("Confusion Matrix @ 0.5")
print(f"TP: {m['tp']}  FP: {m['fp']}")
print(f"FN: {m['fn']}  TN: {m['tn']}")
print(f"Acc: {m['acc']:.3f}  Precision: {m['precision']:.3f}  Recall: {m['recall']:.3f}  F1: {m['f1']:.3f}")