In [1]:
import os
from pathlib import Path
from pymongo import MongoClient, ASCENDING
from dotenv import load_dotenv
from itertools import product
import torch
from contextlib import redirect_stdout
from gnn import TSPDataset, TSPGNN, train_with_hyperparams, FocalLoss


base = Path().cwd()
mongo_configs = base / "container" / "mongo.env"
load_dotenv(str(mongo_configs), override=True)
mongo_username = os.getenv("MONGO_INITDB_ROOT_USERNAME", "root")
mongo_password = os.getenv("MONGO_INITDB_ROOT_PASSWORD", "secret")

In [2]:
mongo_uri = f"mongodb://{mongo_username}:{mongo_password}@localhost:27017"
client = MongoClient(mongo_uri)
database = client["tsp_database"]
collection = database["tsp_solutions"]
all_docs = list(collection.find().sort("_id", ASCENDING))  # Deterministic order
tsp_dataset = TSPDataset(all_docs, normalize_edges=True, num_samples=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Hyperparameters
optimizers = ["AdamW", "SGD", "RMSprop"]
batch_size = [8, 16, 32]
pos_weight = [3.0, 5.0]
criterion = [(None, "BCE"), (FocalLoss(), "FocalLoss")]

# Model definition

model = TSPGNN(node_dim=2, edge_dim=16, hidden_dim=64, num_heads=4)
other_dir = base / "other"
other_dir.mkdir(parents=True, exist_ok=True)

with redirect_stdout(open("other/training_with_hyperparams_info.log", "w")):
    for optimizer, batch_size, pos_weigh, criterion in product(optimizers, batch_size, pos_weight, criterion):
        print(
            "Training Parameters:\n ",
            f"\t- Optimizer: {optimizer}\n",
            f"\t- Criterion: {criterion[1]}\n",
            f"\t- Batch size: {batch_size}\n",
            f"\t- Positive class weigh: {pos_weigh}\n",
        )
        if batch_size == 8:
            print_every = 100
        elif batch_size == 16:
            print_every = 50
        else:
            print_every = 10

        train_loader, val_loader, test_loader = tsp_dataset.get_dataloaders(
            test_size=0.2, batch_size=batch_size, shuffle=True
        )
        train_with_hyperparams(
            model=model,
            device=device,
            train_loader=train_loader,
            val_loader=val_loader,
            optimizer_name=optimizer,
            criterion=criterion[0],
            lr=1e-4,
            pos_weight=pos_weigh,
            num_epochs=15,
            early_stopping_patience=5,
            gradient_clip=1,
            print_every=print_every,
        )