In [1]:
import pandas as pd
import numpy as np
from pathlib import Path

FLOWS = Path("../data/processed/vnat/flows.parquet")
OUT_DIR = Path("../data/splits")
OUT_DIR.mkdir(parents=True, exist_ok=True)

train_path = OUT_DIR / "vnat_train_captures.txt"
val_path   = OUT_DIR / "vnat_val_captures.txt"
test_path  = OUT_DIR / "vnat_test_captures.txt"

flows = pd.read_parquet(FLOWS)

# flow counts per capture (one label per capture in VNAT, but we compute safely anyway)
cap_stats = flows.groupby(["capture_id", "label"]).size().reset_index(name="n_flows")

vpn_caps = cap_stats[cap_stats["label"] == 1][["capture_id", "n_flows"]].copy()
non_caps = cap_stats[cap_stats["label"] == 0][["capture_id", "n_flows"]].copy()

# Shuffle deterministically
vpn_caps = vpn_caps.sample(frac=1.0, random_state=42).reset_index(drop=True)
non_caps = non_caps.sample(frac=1.0, random_state=42).reset_index(drop=True)

def greedy_split_by_flows(df_caps: pd.DataFrame, train=0.70, val=0.15):
    """
    Greedy assignment of captures to splits aiming for flow-count targets.
    Keeps captures disjoint and balances by total number of flows, not capture count.
    """
    total = float(df_caps["n_flows"].sum())
    target_train = total * train
    target_val   = total * val

    tr, va, te = [], [], []
    s_tr, s_va = 0.0, 0.0

    for cap, n in zip(df_caps["capture_id"].tolist(), df_caps["n_flows"].tolist()):
        n = float(n)
        if s_tr < target_train:
            tr.append(cap); s_tr += n
        elif s_va < target_val:
            va.append(cap); s_va += n
        else:
            te.append(cap)

    return tr, va, te

vpn_tr, vpn_va, vpn_te = greedy_split_by_flows(vpn_caps)
non_tr, non_va, non_te = greedy_split_by_flows(non_caps)

train_caps = sorted(vpn_tr + non_tr)
val_caps   = sorted(vpn_va + non_va)
test_caps  = sorted(vpn_te + non_te)

# --- SAFETY CHECKS ---
train_set = set(train_caps)
val_set   = set(val_caps)
test_set  = set(test_caps)

assert train_set.isdisjoint(val_set)
assert train_set.isdisjoint(test_set)
assert val_set.isdisjoint(test_set)

all_split = train_set | val_set | test_set
all_caps  = set(flows["capture_id"].unique().tolist())

assert all_split == all_caps, "Some captures are missing from splits!"

def flow_counts(caps):
    sub = flows[flows["capture_id"].isin(caps)]
    return sub["label"].value_counts().to_dict(), len(sub)

tr_counts, tr_flows = flow_counts(train_caps)
va_counts, va_flows = flow_counts(val_caps)
te_counts, te_flows = flow_counts(test_caps)

# Save
train_path.write_text("\n".join(train_caps), encoding="utf-8")
val_path.write_text("\n".join(val_caps), encoding="utf-8")
test_path.write_text("\n".join(test_caps), encoding="utf-8")

print("Captures total:", len(all_caps))
print("Captures Train/Val/Test:", len(train_caps), len(val_caps), len(test_caps))
print("Flows Train:", tr_flows, "class_counts:", tr_counts)
print("Flows Val:  ", va_flows, "class_counts:", va_counts)
print("Flows Test: ", te_flows, "class_counts:", te_counts)
print("Saved:", train_path, val_path, test_path)

Captures total: 165
Captures Train/Val/Test: 154 11 0
Flows Train: 33249 class_counts: {0: 32876, 1: 373}
Flows Val:   462 class_counts: {0: 456, 1: 6}
Flows Test:  0 class_counts: {}
Saved: ..\data\splits\vnat_train_captures.txt ..\data\splits\vnat_val_captures.txt ..\data\splits\vnat_test_captures.txt
