In [1]:
import os
from pathlib import Path
import pandas as pd
from pymongo import MongoClient, ASCENDING
from dotenv import load_dotenv
from itertools import product
import torch
from contextlib import redirect_stdout
from gnn import TSPDataset, TSPGNN, train_with_hyperparams, FocalLoss
import matplotlib.pyplot as plt


base = Path().cwd()
mongo_configs = base / "container" / "mongo.env"
load_dotenv(str(mongo_configs), override=True)
mongo_username = os.getenv("MONGO_INITDB_ROOT_USERNAME", "root")
mongo_password = os.getenv("MONGO_INITDB_ROOT_PASSWORD", "secret")

In [2]:
mongo_uri = f"mongodb://{mongo_username}:{mongo_password}@localhost:27017"
client = MongoClient(mongo_uri)
database = client["tsp_database"]
collection = database["tsp_solutions"]
all_docs = list(collection.find().sort("_id", ASCENDING))  # Deterministic order
tsp_dataset = TSPDataset(all_docs, normalize_edges=True, num_samples=500)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Hyperparameters
optimizers = ["AdamW", "SGD", "RMSprop"]
batch_size = [8, 16, 32]
pos_weight = [3.0, 5.0]
criterion = [(None, "BCE"), (FocalLoss(), "FocalLoss")]

# Model definition

model = TSPGNN(node_dim=2, edge_dim=16, hidden_dim=64, num_heads=4)
other_dir = base / "other"
other_dir.mkdir(parents=True, exist_ok=True)

results = []

with redirect_stdout(open("other/training_with_hyperparams_info.log", "w")):
    for optimizer, batch_size, pos_weigh, criterion in product(optimizers, batch_size, pos_weight, criterion):
        print(
            "Training Parameters:\n ",
            f"\t- Optimizer: {optimizer}\n",
            f"\t- Criterion: {criterion[1]}\n",
            f"\t- Batch size: {batch_size}\n",
            f"\t- Positive class weigh: {pos_weigh}\n",
        )
        if batch_size == 8:
            print_every = 100
        elif batch_size == 16:
            print_every = 50
        else:
            print_every = 10

        train_loader, val_loader, test_loader = tsp_dataset.get_dataloaders(
            test_size=0.2, batch_size=batch_size, shuffle=True
        )
        metrics = train_with_hyperparams(
            model=model,
            device=device,
            train_loader=train_loader,
            val_loader=val_loader,
            optimizer_name=optimizer,
            criterion=criterion[0],
            lr=1e-4,
            pos_weight=pos_weigh,
            num_epochs=15,
            early_stopping_patience=5,
            gradient_clip=1,
            print_every=print_every,
        )

         # Store last-epoch metrics + hyperparams
        results.append({
            "optimizer": optimizer,
            "batch_size": batch_size,
            "pos_weight": pos_weigh,
            "criterion": criterion[1],
            "final_train_loss": metrics["train_loss"][-1],
            "final_val_loss": metrics["val_loss"][-1],
            "final_val_precision": metrics["val_precision"][-1],
            "final_val_recall": metrics["val_recall"][-1],
            "final_val_roc_auc": metrics["val_roc_auc"][-1],
            "final_val_pr_auc": metrics["val_pr_auc"][-1],
        })

# Convert to DataFrame for easy analysis
df_results = pd.DataFrame(results)
df_results.to_csv(other_dir / "hyperparam_results.csv", index=False)
print("Saved results to hyperparam_results.csv")

Saved results to hyperparam_results.csv


In [8]:
pivot = pd.pivot_table(
    df_results,
    values=["final_val_recall", "final_val_pr_auc", "final_val_precision"],
    index=["optimizer", "batch_size"],
    columns=["criterion"],
    aggfunc="max",
)

print(pivot)
pivot.to_excel("hyperparam_results_pivot.xlsx")

                     final_val_pr_auc           final_val_precision            \
criterion                         BCE FocalLoss                 BCE FocalLoss   
optimizer batch_size                                                            
AdamW     8                  0.547883  0.557336            0.537775  0.710720   
          16                 0.554810  0.568201            0.519792  0.766478   
          32                 0.559682  0.567927            0.563233  0.736486   
RMSprop   8                  0.560917  0.561214            0.564387  0.777070   
          16                 0.562430  0.566673            0.485075  0.701473   
          32                 0.564695  0.565543            0.507715  0.716718   
SGD       8                  0.555747  0.541488            0.555486  0.614770   
          16                 0.560478  0.545515            0.576471  0.616901   
          32                 0.561850  0.552572            0.554245  0.625698   

                     final_