In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
from tqdm.auto import tqdm
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from pathlib import Path
from torch_geometric.nn import GATv2Conv, global_mean_pool
import torch.nn as nn
from torch_geometric.nn import GINEConv
import logging
import os
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from datetime import datetime
from sklearn.model_selection import ParameterGrid, StratifiedKFold
from sklearn.metrics import (
    roc_auc_score,
    precision_recall_curve,
    auc,
    confusion_matrix,
    classification_report,
    f1_score,
    precision_score,
    recall_score,
    roc_curve,
)
import seaborn as sns
import time
from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
import warnings

warnings.filterwarnings(
    "ignore"
)  # "error", "ignore", "always", "default", "module" or "once"

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
from tqdm.auto import tqdm
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATv2Conv, GINEConv, global_mean_pool
import torch.nn as nn
import logging
import os
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from datetime import datetime
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    roc_auc_score,
    precision_recall_curve,
    auc,
    confusion_matrix,
    classification_report,
    f1_score,
    precision_score,
    recall_score,
)
import seaborn as sns
import time
from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
import warnings
import optuna
from optuna.trial import Trial
import hydra
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional, Union

warnings.filterwarnings("ignore")


def objective(
    trial: Trial, cfg: DictConfig, full_data: List[Data], log_dir: Path
) -> float:
    """Optuna hyperparameter optimization objective"""
    # Suggest hyperparameters
    params = {
        "model": {
            "hidden_channels": trial.suggest_categorical(
                "hidden_channels", cfg.hparams.hidden_channels.options
            ),
            "num_layers": trial.suggest_int(
                "num_layers", cfg.hparams.num_layers.min, cfg.hparams.num_layers.max
            ),
            "dropout": trial.suggest_float(
                "dropout", cfg.hparams.dropout.min, cfg.hparams.dropout.max
            ),
            "heads": trial.suggest_int(
                "heads", cfg.hparams.heads.min, cfg.hparams.heads.max
            )
            if cfg.model.type == "GATv2"
            else 1,
            "use_edge_encoders": trial.suggest_categorical(
                "use_edge_encoders", [True, False]
            ),
            "residual": trial.suggest_categorical("residual", [True, False]),
            "use_classifier_mlp": trial.suggest_categorical(
                "use_classifier_mlp", [True, False]
            ),
            "classifier_mlp_dims": trial.suggest_categorical(
                "classifier_mlp_dims", cfg.hparams.classifier_mlp_dims.options
            ),
        },
        "training": {
            "learning_rate": trial.suggest_float(
                "learning_rate",
                cfg.hparams.learning_rate.min,
                cfg.hparams.learning_rate.max,
                log=True,
            ),
        },
    }

    # Create trial-specific config
    trial_cfg = OmegaConf.merge(cfg, OmegaConf.create(params))

    # Set up logging
    trial_log_dir = log_dir / f"trial_{trial.number}"
    trial_log_dir.mkdir(exist_ok=True)
    logger, tb_writer = setup_logging(trial_log_dir)

    # Cross-validation
    cv_scores = []
    skf = StratifiedKFold(
        n_splits=trial_cfg.training.cv_folds, shuffle=True, random_state=trial_cfg.seed
    )
    labels = [data.y.item() for data in full_data]

    for fold, (train_idx, val_idx) in enumerate(skf.split(full_data, labels)):
        fold_log_dir = trial_log_dir / f"fold_{fold}"
        fold_log_dir.mkdir(exist_ok=True)

        # Create data loaders
        train_loader = DataLoader(
            [full_data[i] for i in train_idx],
            batch_size=trial_cfg.training.batch_size,
            shuffle=True,
            num_workers=os.cpu_count(),
            pin_memory=True,
        )
        val_loader = DataLoader(
            [full_data[i] for i in val_idx],
            batch_size=trial_cfg.training.batch_size,
            num_workers=os.cpu_count(),
            pin_memory=True,
        )

        # Initialize model
        model = GNNModel(trial_cfg, in_channels=1, out_channels=2, edge_dim=1).to(
            torch.device(trial_cfg.device)
        )

        # Optimizer and scheduler
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=trial_cfg.training.learning_rate,
            weight_decay=trial_cfg.training.weight_decay,
        )
        scheduler = ReduceLROnPlateau(
            optimizer,
            mode="max",
            patience=trial_cfg.training.lr_patience,
            factor=trial_cfg.training.lr_factor,
        )
        criterion = nn.CrossEntropyLoss()

        # Trainer setup
        trainer = GNNTrainer(trial_cfg, device=trial_cfg.device)

        try:
            # Training
            trainer.train(
                model,
                train_loader,
                val_loader,
                optimizer,
                criterion,
                scheduler,
                logger,
                tb_writer,
                fold_log_dir,
            )

            # Validation metrics
            val_metrics = trainer.evaluate(model, val_loader, criterion)
            cv_scores.append(val_metrics["roc_auc"])

            # Report intermediate result
            trial.report(val_metrics["roc_auc"], fold)

            # Handle pruning
            if trial.should_prune():
                raise optuna.TrialPruned()

        except Exception as e:
            logger.error(f"Training failed: {str(e)}")
            cv_scores.append(0.0)

    # Clean up
    tb_writer.close()
    return np.mean(cv_scores)


@hydra.main(config_path="conf", config_name="config", version_base="1.3")
def main(cfg: DictConfig) -> None:
    """Main experiment runner with Hydra configuration"""
    # Initialize output directory
    base_dir = Path.cwd()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Load datasets
    all_data = cfg.datasets

    for dataset_name, data in all_data.items():
        dataset_dir = base_dir / dataset_name
        dataset_dir.mkdir(exist_ok=True)

        # Prepare data
        train_data = data.train
        test_data = data.test
        full_data = train_data  # For cross-validation

        for model_type in cfg.model.types:
            model_dir = dataset_dir / model_type
            model_dir.mkdir(exist_ok=True)

            for activation in cfg.model.activations:
                exp_dir = model_dir / activation
                exp_dir.mkdir(exist_ok=True)

                # Set up logging
                logger, tb_writer = setup_logging(exp_dir)
                logger.info(
                    f"Starting experiment: {dataset_name}/{model_type}/{activation}"
                )

                # Update config for current experiment
                cfg.model.type = model_type
                cfg.model.activation = activation

                # Optuna hyperparameter optimization
                study = optuna.create_study(
                    direction="maximize",
                    sampler=optuna.samplers.TPESampler(seed=cfg.seed),
                    pruner=optuna.pruners.MedianPruner(
                        n_startup_trials=cfg.optuna.n_startup_trials,
                        n_warmup_steps=cfg.optuna.n_warmup_steps,
                    ),
                )

                study.optimize(
                    lambda trial: objective(trial, cfg, full_data, exp_dir),
                    n_trials=cfg.optuna.n_trials,
                    timeout=cfg.optuna.timeout,
                    show_progress_bar=True,
                )

                # Save best parameters
                best_params = study.best_params
                logger.info(f"Best parameters: {best_params}")
                logger.info(f"Best ROC-AUC: {study.best_value:.4f}")

                # Final training with best parameters
                cfg.model.update(best_params.get("model", {}))
                cfg.training.update(best_params.get("training", {}))

                # Data loaders
                train_loader = DataLoader(
                    train_data,
                    batch_size=cfg.training.batch_size,
                    shuffle=True,
                    num_workers=os.cpu_count(),
                    pin_memory=True,
                )
                test_loader = DataLoader(
                    test_data,
                    batch_size=cfg.training.batch_size,
                    num_workers=os.cpu_count(),
                    pin_memory=True,
                )

                # Initialize model
                model = GNNModel(cfg, in_channels=1, out_channels=2, edge_dim=1).to(
                    torch.device(cfg.device)
                )

                # Optimizer and scheduler
                optimizer = torch.optim.AdamW(
                    model.parameters(),
                    lr=cfg.training.learning_rate,
                    weight_decay=cfg.training.weight_decay,
                )
                scheduler = ReduceLROnPlateau(
                    optimizer,
                    mode="max",
                    patience=cfg.training.lr_patience,
                    factor=cfg.training.lr_factor,
                )
                criterion = nn.CrossEntropyLoss()

                # Train final model
                trainer = GNNTrainer(cfg, device=cfg.device)
                history, model = trainer.train(
                    model,
                    train_loader,
                    test_loader,  # Using test as validation for final training
                    optimizer,
                    criterion,
                    scheduler,
                    logger,
                    tb_writer,
                    exp_dir,
                )

                # Final evaluation
                test_metrics = trainer.evaluate(model, test_loader, criterion)

                # Log results
                logger.info(f"\n{'=' * 50}")
                logger.info(f"FINAL RESULTS: {dataset_name}/{model_type}/{activation}")
                logger.info(f"Test ROC-AUC: {test_metrics['roc_auc']:.4f}")
                logger.info(f"Test F1: {test_metrics['f1']:.4f}")
                logger.info(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
                logger.info(f"Test PR-AUC: {test_metrics['pr_auc']:.4f}")
                logger.info("\nClassification Report:")
                logger.info(
                    classification_report(
                        test_metrics["classification_report"]["0"],
                        test_metrics["classification_report"]["1"],
                        target_names=["Class 0", "Class 1"],
                    )
                )
                logger.info("=" * 50)

                # Visualizations
                plot_metrics(history, exp_dir)

                # Save confusion matrix
                plt.figure(figsize=(8, 6))
                sns.heatmap(
                    test_metrics["confusion_matrix"], annot=True, fmt="d", cmap="Blues"
                )
                plt.title("Confusion Matrix")
                plt.savefig(exp_dir / "confusion_matrix.png")
                plt.close()

                # Close resources
                tb_writer.close()


if __name__ == "__main__":
    main()

In [5]:
run_experiments(datasets, all_data)

  0%|          | 0/15 [00:00<?, ?it/s]

experiment: GATv2 with leaky_relu activation:   0%|          | 0/128 [00:00<?, ?it/s]

2025-07-23 10:52:47,066 - Ionoshp_GATv2_leaky_relu_final - INFO - 
2025-07-23 10:52:47,067 - Ionoshp_GATv2_leaky_relu_final - INFO - FINAL RESULTS: GATv2 with leaky_relu
2025-07-23 10:52:47,067 - Ionoshp_GATv2_leaky_relu_final - INFO - Test ROC-AUC: 0.9939
2025-07-23 10:52:47,067 - Ionoshp_GATv2_leaky_relu_final - INFO - Test F1: 0.7931
2025-07-23 10:52:47,068 - Ionoshp_GATv2_leaky_relu_final - INFO - Test Accuracy: 0.6620
2025-07-23 10:52:47,068 - Ionoshp_GATv2_leaky_relu_final - INFO - Test PR-AUC: 0.9965
2025-07-23 10:52:47,068 - Ionoshp_GATv2_leaky_relu_final - INFO - 
Classification Report:
2025-07-23 10:52:47,068 - Ionoshp_GATv2_leaky_relu_final - INFO - {'False': {'precision': 1.0, 'recall': 0.04, 'f1-score': 0.07692307692307693, 'support': 25.0}, 'True': {'precision': 0.6571428571428571, 'recall': 1.0, 'f1-score': 0.7931034482758621, 'support': 46.0}, 'accuracy': 0.6619718309859155, 'macro avg': {'precision': 0.8285714285714285, 'recall': 0.52, 'f1-score': 0.4350132625994695, '

experiment: GINE with leaky_relu activation:   0%|          | 0/128 [00:00<?, ?it/s]

2025-07-23 11:03:13,754 - Ionoshp_GINE_leaky_relu_final - INFO - 
2025-07-23 11:03:13,754 - Ionoshp_GINE_leaky_relu_final - INFO - FINAL RESULTS: GINE with leaky_relu
2025-07-23 11:03:13,754 - Ionoshp_GINE_leaky_relu_final - INFO - Test ROC-AUC: 0.9852
2025-07-23 11:03:13,754 - Ionoshp_GINE_leaky_relu_final - INFO - Test F1: 0.9783
2025-07-23 11:03:13,755 - Ionoshp_GINE_leaky_relu_final - INFO - Test Accuracy: 0.9718
2025-07-23 11:03:13,755 - Ionoshp_GINE_leaky_relu_final - INFO - Test PR-AUC: 0.9902
2025-07-23 11:03:13,755 - Ionoshp_GINE_leaky_relu_final - INFO - 
Classification Report:
2025-07-23 11:03:13,755 - Ionoshp_GINE_leaky_relu_final - INFO - {'False': {'precision': 0.96, 'recall': 0.96, 'f1-score': 0.96, 'support': 25.0}, 'True': {'precision': 0.9782608695652174, 'recall': 0.9782608695652174, 'f1-score': 0.9782608695652174, 'support': 46.0}, 'accuracy': 0.971830985915493, 'macro avg': {'precision': 0.9691304347826086, 'recall': 0.9691304347826086, 'f1-score': 0.96913043478260

experiment: GATv2 with tanh activation:   0%|          | 0/128 [00:00<?, ?it/s]

2025-07-23 11:35:01,505 - Ionoshp_GATv2_tanh_final - INFO - 
2025-07-23 11:35:01,505 - Ionoshp_GATv2_tanh_final - INFO - FINAL RESULTS: GATv2 with tanh
2025-07-23 11:35:01,505 - Ionoshp_GATv2_tanh_final - INFO - Test ROC-AUC: 0.9939
2025-07-23 11:35:01,506 - Ionoshp_GATv2_tanh_final - INFO - Test F1: 0.9020
2025-07-23 11:35:01,506 - Ionoshp_GATv2_tanh_final - INFO - Test Accuracy: 0.8592
2025-07-23 11:35:01,506 - Ionoshp_GATv2_tanh_final - INFO - Test PR-AUC: 0.9965
2025-07-23 11:35:01,506 - Ionoshp_GATv2_tanh_final - INFO - 
Classification Report:
2025-07-23 11:35:01,507 - Ionoshp_GATv2_tanh_final - INFO - {'False': {'precision': 1.0, 'recall': 0.6, 'f1-score': 0.75, 'support': 25.0}, 'True': {'precision': 0.8214285714285714, 'recall': 1.0, 'f1-score': 0.9019607843137255, 'support': 46.0}, 'accuracy': 0.8591549295774648, 'macro avg': {'precision': 0.9107142857142857, 'recall': 0.8, 'f1-score': 0.8259803921568627, 'support': 71.0}, 'weighted avg': {'precision': 0.8843058350100603, 'rec

experiment: GINE with tanh activation:   0%|          | 0/128 [00:00<?, ?it/s]

2025-07-23 11:44:16,524 - Ionoshp_GINE_tanh_final - INFO - 
2025-07-23 11:44:16,525 - Ionoshp_GINE_tanh_final - INFO - FINAL RESULTS: GINE with tanh
2025-07-23 11:44:16,525 - Ionoshp_GINE_tanh_final - INFO - Test ROC-AUC: 0.9870
2025-07-23 11:44:16,525 - Ionoshp_GINE_tanh_final - INFO - Test F1: 0.7863
2025-07-23 11:44:16,525 - Ionoshp_GINE_tanh_final - INFO - Test Accuracy: 0.6479
2025-07-23 11:44:16,526 - Ionoshp_GINE_tanh_final - INFO - Test PR-AUC: 0.9916
2025-07-23 11:44:16,526 - Ionoshp_GINE_tanh_final - INFO - 
Classification Report:
2025-07-23 11:44:16,526 - Ionoshp_GINE_tanh_final - INFO - {'False': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 25.0}, 'True': {'precision': 0.647887323943662, 'recall': 1.0, 'f1-score': 0.7863247863247863, 'support': 46.0}, 'accuracy': 0.647887323943662, 'macro avg': {'precision': 0.323943661971831, 'recall': 0.5, 'f1-score': 0.39316239316239315, 'support': 71.0}, 'weighted avg': {'precision': 0.4197579845268796, 'recall': 0.6478

experiment: GATv2 with leaky_relu activation:   0%|          | 0/128 [00:00<?, ?it/s]

2025-07-23 11:51:34,586 - Cryotherapy_GATv2_leaky_relu_final - INFO - 
2025-07-23 11:51:34,587 - Cryotherapy_GATv2_leaky_relu_final - INFO - FINAL RESULTS: GATv2 with leaky_relu
2025-07-23 11:51:34,587 - Cryotherapy_GATv2_leaky_relu_final - INFO - Test ROC-AUC: 0.8750
2025-07-23 11:51:34,587 - Cryotherapy_GATv2_leaky_relu_final - INFO - Test F1: 0.4615
2025-07-23 11:51:34,588 - Cryotherapy_GATv2_leaky_relu_final - INFO - Test Accuracy: 0.6111
2025-07-23 11:51:34,588 - Cryotherapy_GATv2_leaky_relu_final - INFO - Test PR-AUC: 0.9292
2025-07-23 11:51:34,588 - Cryotherapy_GATv2_leaky_relu_final - INFO - 
Classification Report:
2025-07-23 11:51:34,589 - Cryotherapy_GATv2_leaky_relu_final - INFO - {'False': {'precision': 0.5333333333333333, 'recall': 1.0, 'f1-score': 0.6956521739130435, 'support': 8.0}, 'True': {'precision': 1.0, 'recall': 0.3, 'f1-score': 0.46153846153846156, 'support': 10.0}, 'accuracy': 0.6111111111111112, 'macro avg': {'precision': 0.7666666666666666, 'recall': 0.65, 'f1

experiment: GINE with leaky_relu activation:   0%|          | 0/128 [00:00<?, ?it/s]

2025-07-23 11:54:03,715 - Cryotherapy_GINE_leaky_relu_final - INFO - 
2025-07-23 11:54:03,715 - Cryotherapy_GINE_leaky_relu_final - INFO - FINAL RESULTS: GINE with leaky_relu
2025-07-23 11:54:03,715 - Cryotherapy_GINE_leaky_relu_final - INFO - Test ROC-AUC: 0.9000
2025-07-23 11:54:03,716 - Cryotherapy_GINE_leaky_relu_final - INFO - Test F1: 0.0000
2025-07-23 11:54:03,716 - Cryotherapy_GINE_leaky_relu_final - INFO - Test Accuracy: 0.4444
2025-07-23 11:54:03,716 - Cryotherapy_GINE_leaky_relu_final - INFO - Test PR-AUC: 0.9064
2025-07-23 11:54:03,716 - Cryotherapy_GINE_leaky_relu_final - INFO - 
Classification Report:
2025-07-23 11:54:03,716 - Cryotherapy_GINE_leaky_relu_final - INFO - {'False': {'precision': 0.4444444444444444, 'recall': 1.0, 'f1-score': 0.6153846153846154, 'support': 8.0}, 'True': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 10.0}, 'accuracy': 0.4444444444444444, 'macro avg': {'precision': 0.2222222222222222, 'recall': 0.5, 'f1-score': 0.307692307692307

experiment: GATv2 with tanh activation:   0%|          | 0/128 [00:00<?, ?it/s]

2025-07-23 12:02:32,734 - Cryotherapy_GATv2_tanh_final - INFO - 
2025-07-23 12:02:32,734 - Cryotherapy_GATv2_tanh_final - INFO - FINAL RESULTS: GATv2 with tanh
2025-07-23 12:02:32,734 - Cryotherapy_GATv2_tanh_final - INFO - Test ROC-AUC: 0.9125
2025-07-23 12:02:32,735 - Cryotherapy_GATv2_tanh_final - INFO - Test F1: 0.7143
2025-07-23 12:02:32,735 - Cryotherapy_GATv2_tanh_final - INFO - Test Accuracy: 0.5556
2025-07-23 12:02:32,735 - Cryotherapy_GATv2_tanh_final - INFO - Test PR-AUC: 0.9247
2025-07-23 12:02:32,735 - Cryotherapy_GATv2_tanh_final - INFO - 
Classification Report:
2025-07-23 12:02:32,736 - Cryotherapy_GATv2_tanh_final - INFO - {'False': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 8.0}, 'True': {'precision': 0.5555555555555556, 'recall': 1.0, 'f1-score': 0.7142857142857143, 'support': 10.0}, 'accuracy': 0.5555555555555556, 'macro avg': {'precision': 0.2777777777777778, 'recall': 0.5, 'f1-score': 0.35714285714285715, 'support': 18.0}, 'weighted avg': {'preci

experiment: GINE with tanh activation:   0%|          | 0/128 [00:00<?, ?it/s]

2025-07-23 12:05:23,626 - Cryotherapy_GINE_tanh_final - INFO - 
2025-07-23 12:05:23,626 - Cryotherapy_GINE_tanh_final - INFO - FINAL RESULTS: GINE with tanh
2025-07-23 12:05:23,626 - Cryotherapy_GINE_tanh_final - INFO - Test ROC-AUC: 0.9750
2025-07-23 12:05:23,627 - Cryotherapy_GINE_tanh_final - INFO - Test F1: 0.0000
2025-07-23 12:05:23,627 - Cryotherapy_GINE_tanh_final - INFO - Test Accuracy: 0.4444
2025-07-23 12:05:23,627 - Cryotherapy_GINE_tanh_final - INFO - Test PR-AUC: 0.9826
2025-07-23 12:05:23,627 - Cryotherapy_GINE_tanh_final - INFO - 
Classification Report:
2025-07-23 12:05:23,628 - Cryotherapy_GINE_tanh_final - INFO - {'False': {'precision': 0.4444444444444444, 'recall': 1.0, 'f1-score': 0.6153846153846154, 'support': 8.0}, 'True': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 10.0}, 'accuracy': 0.4444444444444444, 'macro avg': {'precision': 0.2222222222222222, 'recall': 0.5, 'f1-score': 0.3076923076923077, 'support': 18.0}, 'weighted avg': {'precision': 0.1

experiment: GATv2 with leaky_relu activation:   0%|          | 0/128 [00:00<?, ?it/s]

KeyboardInterrupt: 