#### Imports

In [1]:
import os, json, math, time
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, random_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
torch.manual_seed(42); np.random.seed(42)
print("Device:", device)

Device: cuda


In [2]:
from cnn import CNN1DEncoder, AstronetMVP
from dataset import ExoplanetDataset, collate

#### Data

In [3]:
# load npz for train / test
import numpy as np

npz_train = np.load("npzs/train.npz")
npz_test = np.load("npzs/test.npz")

In [4]:
npz_train

NpzFile 'npzs/train.npz' with keys: global_view, local_view, kepid, label

In [19]:
# create train / test set from numpy

global_train = torch.from_numpy(npz_train['global_view']).float()
local_train = torch.from_numpy(npz_train['local_view']).float()
labels_train = torch.from_numpy(npz_train['label']).int()
ids_train = torch.from_numpy(npz_train['kepid']).float()
tab_train = None #torch.from_numpy(npz_train['tabular'])

global_test = torch.from_numpy(npz_test['global_view']).float()
local_test = torch.from_numpy(npz_test['local_view']).float()
labels_test = torch.from_numpy(npz_test['label']).int()
ids_test = torch.from_numpy(npz_test['kepid']).float()
tab_test = None #torch.from_numpy(npz_test['tabular'])


In [20]:
labels_train

tensor([0, 1, 0, 0, 1, 0, 0, 0, 0, 2, 2, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1,
        0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 2, 0, 0, 0, 0,
        0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0,
        0, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 2, 0, 1, 1, 2, 2, 1, 0, 2, 0, 2, 1, 1,
        0, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 1, 0, 0, 0, 0, 0, 2,
        0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1, 2, 0, 2, 0, 0, 2, 0, 0,
        0, 0, 0, 2, 0, 0, 0, 2, 0, 0, 1, 0, 0, 2, 2, 0], dtype=torch.int32)

In [6]:
print(global_train.shape[1])
print(global_test.ndim, "\n")

print(local_train.ndim)
print(local_test.shape, "\n")

print(labels_train.ndim)
print(labels_test.ndim)

2001
2 

2
torch.Size([72, 201]) 

1
1


In [7]:
ds_train = ExoplanetDataset(global_train, local_train, labels_train, ids_train, tabular=tab_train)
ds_test  = ExoplanetDataset(global_test,  local_test,  labels_test,  ids_test,  tabular=tab_test)

In [8]:
# create data loaders from train / test data sets
dl_train = DataLoader(ds_train, batch_size = 64, shuffle=True, collate_fn=collate, drop_last=False)
dl_test = DataLoader(ds_test, batch_size = 64, shuffle=True, collate_fn=collate, drop_last=False)

In [9]:
# loss, optimizer, pos weight from train set
def compute_pos_weight_from_loader(dloader):
    pos = neg = 0
    with torch.no_grad():
        for _xg, _xl, _xt, _y, _ids in dloader:
            pos += (_y == 1).sum().item()
            neg += (_y == 0).sum().item()
    if pos == 0 or neg == 0: return None
    return torch.tensor([neg / max(pos, 1.0)], dtype=torch.float32, device=device)

# Infer tabular_dim from one train batch
xg0, xl0, xt0, y0, ids0 = next(iter(dl_train))
tabular_dim = 0 if xt0 is None else xt0.shape[1]

model = AstronetMVP(hidden=64, k_global=7, k_local=5, tabular_dim=tabular_dim).to(device)

pos_weight = compute_pos_weight_from_loader(dl_train)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) if pos_weight is not None else nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [10]:
# make & save config / checkpoints dir
save_dir = "checkpoints"
os.makedirs(save_dir, exist_ok=True)

# make & save config
config = {
    "hidden":64,
    "k_global": 7,
    "k_local": 5,
    "tabular_dim": tabular_dim,
    "lr": 1e-3,
    "wd": 1e-5,
    "pos_weight": None if pos_weight is None else float(pos_weight.item())
}
json.dump(config, open(os.path.join(save_dir, "config.json"), "w"), indent = 2)

In [11]:
def train_epoch(model, dloader, optim, criterion):
    model.train()
    total, n = 0.0, 0

    # feed train data
    for xg, xl, xt, y, _ids in dloader:

        # move to gpu
        xg = xg.to(device)
        xl = xl.to(device)
        y = y.to(device)# expected

        # feed
        xt = None if xt is None else xt.to(device)
        logits = model(xg, xl, xt).squeeze(1) # actual
        loss = criterion(logits, y)

        #compute optim
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), max_norm=5.0)
        optimizer.step()

        bs = y.size(0)
        total += loss.item() * bs
        n+= bs

    return total / max(n,1)



In [17]:
# train loop: train + save checkpoints
num_epochs = 1000
for epoch in range(1, num_epochs+1):
    t0 = time.time()
    tr_loss = train_epoch(model, dl_train, optimizer, criterion)
    dt = time.time() - t0

    print(f"Epoch {epoch:02d}/{num_epochs} | train_loss={tr_loss:.4f} | {dt:.1f}s")

# Save final checkpoint
ckpt_path = os.path.join(save_dir, "model_final.pt")
torch.save({"model_state": model.state_dict(), "config": config}, ckpt_path)
print("Saved:", ckpt_path)

  nn.utils.clip_grad_norm(model.parameters(), max_norm=5.0)


Epoch 01/1000 | train_loss=1.0256 | 0.0s
Epoch 02/1000 | train_loss=1.0230 | 0.0s
Epoch 03/1000 | train_loss=1.0200 | 0.0s
Epoch 04/1000 | train_loss=1.0183 | 0.0s
Epoch 05/1000 | train_loss=1.0159 | 0.0s
Epoch 06/1000 | train_loss=1.0117 | 0.0s
Epoch 07/1000 | train_loss=1.0096 | 0.0s
Epoch 08/1000 | train_loss=1.0068 | 0.0s
Epoch 09/1000 | train_loss=1.0057 | 0.0s
Epoch 10/1000 | train_loss=1.0051 | 0.0s
Epoch 11/1000 | train_loss=1.0011 | 0.0s
Epoch 12/1000 | train_loss=0.9953 | 0.0s
Epoch 13/1000 | train_loss=0.9977 | 0.0s
Epoch 14/1000 | train_loss=0.9910 | 0.0s
Epoch 15/1000 | train_loss=0.9896 | 0.0s
Epoch 16/1000 | train_loss=0.9867 | 0.0s
Epoch 17/1000 | train_loss=0.9837 | 0.0s
Epoch 18/1000 | train_loss=0.9823 | 0.0s
Epoch 19/1000 | train_loss=0.9841 | 0.0s
Epoch 20/1000 | train_loss=0.9810 | 0.0s
Epoch 21/1000 | train_loss=0.9771 | 0.0s
Epoch 22/1000 | train_loss=0.9758 | 0.0s
Epoch 23/1000 | train_loss=0.9749 | 0.0s
Epoch 24/1000 | train_loss=0.9728 | 0.0s
Epoch 25/1000 | 

In [13]:
# create test set predictions for eval
@torch.no_grad()
def predict_on_loader(model, dloader):
    model.eval()
    out = []  # list of dicts: {id, logit, prob}
    for xg, xl, xt, y, ids in dloader:
        xg = xg.to(device); xl = xl.to(device)
        xt = None if xt is None else xt.to(device).float()
        logits = model(xg, xl, xt).squeeze(1)          # [B]
        probs  = torch.sigmoid(logits)
        for _id, lo, pr in zip(ids, logits.cpu().tolist(), probs.cpu().tolist()):
            out.append({"id": _id, "logit": float(lo), "prob": float(pr)})
    return out

test_preds = predict_on_loader(model, dl_test)
print("Test preds:", len(test_preds))

# save it as csv
# import csv
# pred_csv = os.path.join(save_dir, "test_predictions.csv")
# with open(pred_csv, "w", newline="") as f:
#     w = csv.DictWriter(f, fieldnames=["id","logit","prob"])
#     w.writeheader()
#     w.writerows(test_preds)
# print("Saved:", pred_csv)

Test preds: 72


In [14]:
## EVAL

# get ground truth labels for test ids
labels_by_id = {}
for _xg, _xl, _xt, _y, _ids in dl_test:
    for i, y in zip(_ids, _y.tolist()):
        labels_by_id[i] = int(y)
len(labels_by_id), list(list(labels_by_id.items())[:3])

# get predictions for labels
ids  = [d["id"] for d in test_preds if d["id"] in labels_by_id]
probs = [d["prob"] for d in test_preds if d["id"] in labels_by_id]
labels = [labels_by_id[i] for i in ids]

print(f"N test: {len(labels)} | Positives: {sum(labels)} | Negatives: {len(labels)-sum(labels)}")

N test: 0 | Positives: 0 | Negatives: 0


In [15]:
import numpy as np

def metrics_at_threshold(y, p, thr=0.5): #eval metrics at some thresh
    y = np.asarray(y, dtype=int)
    p = np.asarray(p, dtype=float)
    pred = (p >= thr).astype(int)
    tp = int(((pred==1)&(y==1)).sum())
    fp = int(((pred==1)&(y==0)).sum())
    tn = int(((pred==0)&(y==0)).sum())
    fn = int(((pred==0)&(y==1)).sum())
    acc  = (tp+tn)/max(len(y),1)
    prec = tp/max(tp+fp,1) if (tp+fp)>0 else 0.0
    rec  = tp/max(tp+fn,1) if (tp+fn)>0 else 0.0
    f1   = (2*prec*rec)/(prec+rec) if (prec+rec)>0 else 0.0
    return dict(threshold=thr, tp=tp, fp=fp, tn=tn, fn=fn, acc=acc, precision=prec, recall=rec, f1=f1)

def roc_auc(y, p):# rank based auc
    y = np.asarray(y, dtype=int)
    p = np.asarray(p, dtype=float)
    n_pos = (y==1).sum(); n_neg = (y==0).sum()
    if n_pos==0 or n_neg==0: return float("nan")
    order = np.argsort(p)
    ranks = np.empty_like(order); ranks[order] = np.arange(1, len(p)+1)
    sum_ranks_pos = ranks[y==1].sum()
    return float((sum_ranks_pos - n_pos*(n_pos+1)/2) / (n_pos*n_neg))


# get metrics / rank + print
m = metrics_at_threshold(labels, probs, thr=0.5)
auc = roc_auc(labels, probs)

print(f"AUC: {auc:.4f}")
print("Confusion Matrix @ 0.5")
print(f"TP: {m['tp']}  FP: {m['fp']}")
print(f"FN: {m['fn']}  TN: {m['tn']}")
print(f"Acc: {m['acc']:.3f}  Precision: {m['precision']:.3f}  Recall: {m['recall']:.3f}  F1: {m['f1']:.3f}")

AUC: nan
Confusion Matrix @ 0.5
TP: 0  FP: 0
FN: 0  TN: 0
Acc: 0.000  Precision: 0.000  Recall: 0.000  F1: 0.000


In [16]:
# Best-F1 threshold (quick scan over unique probs)
uniq = sorted(set(probs))
best = {"f1": -1.0, "threshold": 0.5}
for thr in uniq:
    mm = metrics_at_threshold(labels, probs, thr)
    if mm["f1"] > best["f1"]:
        best = mm
print("\nBest F1 threshold search:")
print(f"thr={best['threshold']:.4f}  F1={best['f1']:.3f}  Prec={best['precision']:.3f}  Rec={best['recall']:.3f}")

# Show a few confident mistakes (helps sanity-check training)
arr = list(zip(ids, labels, probs))
false_pos = sorted([(i,y,p) for i,y,p in arr if y==0], key=lambda t: -t[2])[:5]
false_neg = sorted([(i,y,p) for i,y,p in arr if y==1], key=lambda t:  t[2])[:5]

print("\nTop false positives by prob (should look like near-miss/non-planet):")
for i,y,p in false_pos: print(f"id={i}  label={y}  prob={p:.3f}")

print("\nTop false negatives by prob (missed likely planets):")
for i,y,p in false_neg: print(f"id={i}  label={y}  prob={p:.3f}")


Best F1 threshold search:


KeyError: 'precision'

In [None]:
# TORCH SCRIPT EXPORT FOR PRODUCTION #
import copy
model.eval()
model_cpu = copy.deepcopy(model).cpu()
scripted = torch.jit.script(model_cpu)
scripted.save("model_scripted.pt")