In [None]:
import os
import yaml
import optuna
import subprocess

In [None]:
class FlexibleGraphDataset(Dataset):
    def __init__(self, data_list, transform=None):
        """
        Args:
            data_list: List of torch_geometric.data.Data objects.
            transform: Optional transform to be applied on a sample.
        """
        super().__init__()
        self.data_list = data_list
        self.transform = transform

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        data = self.data_list[idx]
        if self.transform:
            data = self.transform(data)
        return data

def prepare_datasets(data_list, test_size=0.2, val_size=0.1, random_seed=42):
    train_val_data, test_data = train_test_split(data_list, test_size=test_size, random_state=random_seed)
    train_data, val_data = train_test_split(train_val_data, test_size=val_size / (1 - test_size), random_state=random_seed)

    train_dataset = FlexibleGraphDataset(train_data)
    val_dataset = FlexibleGraphDataset(val_data)
    test_dataset = FlexibleGraphDataset(test_data)

    return train_dataset, val_dataset, test_dataset

def create_dataloader(dataset, batch_size=32, shuffle=True):
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

In [None]:
def create_config(trial, folder="configs"):
    os.makedirs(folder, exist_ok=True)
    params = {
        "seed": 42,
        "max_epochs": 5,
        "gpus": 1,
        "batch_size": 32,
        "num_workers": 4,
        "in_channels": 16,
        "hidden_dim": trial.suggest_int("hidden_dim", 16, 64),
        "out_channels": 4,
        "lr": trial.suggest_loguniform("lr", 1e-4, 1e-2),
        "weight_decay": trial.suggest_loguniform("weight_decay", 1e-6, 1e-4),
        "logging": {
            "project_name": "gnn_dist_reconfig_hpo",
            "offline": True
        },
        "checkpoint": {
            "dirpath": "checkpoints",
            "filename": f"best-checkpoint-trial{trial.number}",
            "monitor": "val_total_loss"
        },
        "loss_weights": {
            "main": 1.0,
            "radiality": trial.suggest_float("radiality_w", 0.1, 0.5),
            "voltage_flow": trial.suggest_float("voltage_flow_w", 0.05, 0.3)
        }
    }
    config_path = os.path.join(folder, f"config_{trial.number}.yaml")
    with open(config_path, "w") as f:
        yaml.dump(params, f)
    return config_path

def objective(trial):
    config_path = create_config(trial)
    cmd = f"python pl_training.py {config_path}"
    subprocess.run(cmd, shell=True)
    # For real experiments, parse logs or metrics to return the best objective value.
    # Below is a placeholder zero return value.
    return 0.0

study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=3)


In [None]:
config_file = "configs/config_0.yaml"
slurm_script = "run_training.slurm"
subprocess.run(["sbatch", slurm_script, config_file])

In [None]:
#sbatch run_training.slurm configs/config.yaml