In [None]:
import argparse
import json
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--flows_parquet", type=str, required=True)
    ap.add_argument("--train_caps", type=str, required=True)
    ap.add_argument("--val_caps", type=str, required=True)

    ap.add_argument("--X", type=str, required=True)
    ap.add_argument("--Y", type=str, required=True)
    ap.add_argument("--K", type=str, required=True)

    ap.add_argument("--out_dir", type=str, default="artifacts/cnn")
    ap.add_argument("--min_k", type=int, default=50)

    ap.add_argument("--batch_size", type=int, default=256)
    ap.add_argument("--epochs", type=int, default=50)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--patience", type=int, default=5)

    ap.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda"])
    return ap.parse_args()


class FlowSeqDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return int(self.X.shape[0])

    def __getitem__(self, idx):
        # torch expects (C, L) for Conv1d
        x = torch.from_numpy(self.X[idx]).float().transpose(0, 1)  # (2, 100)
        y = torch.tensor(self.y[idx]).float()
        return x, y


class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(2, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),  # global avg pool
        )
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        x = self.net(x)
        x = self.head(x)
        return x.squeeze(1)  # logits


def pick_device(choice: str) -> torch.device:
    if choice == "cpu":
        return torch.device("cpu")
    if choice == "cuda":
        return torch.device("cuda")
    # auto
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main():
    args = parse_args()
    device = pick_device(args.device)
    print("Device:", device)

    flows = pd.read_parquet(Path(args.flows_parquet)).set_index("flow_id").sort_index()

    train_caps = set(Path(args.train_caps).read_text(encoding="utf-8").splitlines())
    val_caps = set(Path(args.val_caps).read_text(encoding="utf-8").splitlines())

    X = np.load(args.X, mmap_mode="r")  # mmap helps memory
    y = np.load(args.Y)
    K = np.load(args.K)

    # Build split by capture_id
    is_train = flows["capture_id"].isin(train_caps).to_numpy()
    is_val = flows["capture_id"].isin(val_caps).to_numpy()

    # Filter train by MIN_K (optional but recommended)
    is_train = is_train & (K >= args.min_k)

    X_train = np.asarray(X[is_train], dtype=np.float32)
    y_train = y[is_train].astype(np.float32)

    X_val = np.asarray(X[is_val], dtype=np.float32)
    y_val = y[is_val].astype(np.float32)

    print("Train:", X_train.shape, "Val:", X_val.shape)
    print("Train labels:", {int(k): int(v) for k, v in zip(*np.unique(y_train, return_counts=True))})
    print("Val labels:", {int(k): int(v) for k, v in zip(*np.unique(y_val, return_counts=True))})

    train_ds = FlowSeqDataset(X_train, y_train)
    val_ds = FlowSeqDataset(X_val, y_val)

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True)

    model = SimpleCNN().to(device)

    # class imbalance: use pos_weight
    n_pos = float((y_train == 1).sum())
    n_neg = float((y_train == 0).sum())
    pos_weight = torch.tensor([n_neg / max(n_pos, 1.0)], device=device)
    print("pos_weight:", float(pos_weight.item()))

    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr)

    best_val = float("inf")
    bad = 0
    history = []

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    model_path = out_dir / "model.pt"
    hist_path = out_dir / "history.json"

    for epoch in range(1, args.epochs + 1):
        model.train()
        tr_loss = 0.0

        for xb, yb in train_loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            opt.step()
            tr_loss += float(loss.item()) * xb.size(0)

        tr_loss /= len(train_ds)

        model.eval()
        va_loss = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device, non_blocking=True)
                yb = yb.to(device, non_blocking=True)
                logits = model(xb)
                loss = criterion(logits, yb)
                va_loss += float(loss.item()) * xb.size(0)

        va_loss /= len(val_ds)

        rec = {"epoch": epoch, "train_loss": tr_loss, "val_loss": va_loss}
        history.append(rec)
        print(rec)

        if va_loss < best_val - 1e-6:
            best_val = va_loss
            bad = 0
            torch.save({"model_state": model.state_dict(), "best_val_loss": best_val}, model_path)
        else:
            bad += 1
            if bad >= args.patience:
                print("Early stopping.")
                break

    hist_path.write_text(json.dumps(history, indent=2), encoding="utf-8")
    print("Saved:", model_path, hist_path)


if __name__ == "__main__":
    main()
