In [None]:
import itertools
import os, sys

ROOT = os.path.abspath(os.path.join(os.getcwd(), "..")) 
sys.path.append(os.path.join(ROOT, "src"))
import matplotlib.pyplot as plt 
from dagc.sparsifiers.gnn.train import train_gnn_sparsifier_on_dataset

train_dir = "../data/synthetic_graphs/ba/train"
val_dir = "../data/synthetic_graphs/ba/val"

keep_ratios = [0.8]
lambda_sparsities = [0.1, 1.0, 5.0]
lrs = [1e-3, 3e-4]

results = []

for keep_ratio, lambda_sparsity, lr in itertools.product(
    keep_ratios, lambda_sparsities, lrs
):
    tag = f"kr{keep_ratio}_lam{lambda_sparsity}_lr{lr}"
    ckpt = f"results/checkpoints/gnn_rw_k1_{tag}.pt"

    print("\n============================")
    print(f"Config: keep_ratio={keep_ratio}, lambda={lambda_sparsity}, lr={lr}")
    print("============================")

    model, history = train_gnn_sparsifier_on_dataset(
        train_dir=train_dir,
        val_dir=val_dir,
        keep_ratio=keep_ratio,
        num_steps=1,
        num_epochs=5,        # keep small for sweeps
        learning_rate=lr,
        lambda_sparsity=lambda_sparsity,
        device="cpu",
        print_every=1,
        checkpoint_path=ckpt,
    )

    best_val = min(history["val_loss"])
    history
    results.append(
        {
            "keep_ratio": keep_ratio,
            "lambda_sparsity": lambda_sparsity,
            "lr": lr,
            "best_val_loss": best_val,
            "checkpoint": ckpt,
        }
    )

# Sort and print summary
results_sorted = sorted(results, key=lambda d: d["best_val_loss"])
print("\n=== Sweep summary (best to worst by val loss) ===")
for r in results_sorted:
    print(
        f"keep_ratio={r['keep_ratio']}, lambda={r['lambda_sparsity']}, "
        f"lr={r['lr']}, best_val_loss={r['best_val_loss']:.4f}"
    )



Config: keep_ratio=0.8, lambda=0.1, lr=0.001
[train] Using device: cpu
[train] Loaded 120 train graphs and 40 val graphs
[Epoch 1/5] train_loss = 0.000926 (rw=0.0000, sparse=0.0089) val_loss = 0.000050 (rw=0.0000, sparse=0.0001)
[Epoch 2/5] train_loss = 0.000077 (rw=0.0000, sparse=0.0004) val_loss = 0.000037 (rw=0.0000, sparse=0.0001)
[Epoch 3/5] train_loss = 0.000067 (rw=0.0000, sparse=0.0004) val_loss = 0.000027 (rw=0.0000, sparse=0.0001)
[Epoch 4/5] train_loss = 0.000030 (rw=0.0000, sparse=0.0002) val_loss = 0.000010 (rw=0.0000, sparse=0.0000)
[Epoch 5/5] train_loss = 0.000020 (rw=0.0000, sparse=0.0001) val_loss = 0.000013 (rw=0.0000, sparse=0.0001)
[train] Restoring best model with val_loss = 0.000010
[train] Saved checkpoint to: results/checkpoints/gnn_rw_k1_kr0.8_lam0.1_lr0.001.pt

Config: keep_ratio=0.8, lambda=0.1, lr=0.0003
[train] Using device: cpu
[train] Loaded 120 train graphs and 40 val graphs
[Epoch 1/5] train_loss = 0.002294 (rw=0.0000, sparse=0.0229) val_loss = 0.0000