In [33]:
import os
import torch
from torch import nn
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
from functools import partial
import optuna
import gc
from typing import Literal, Optional
import torch.nn.functional as F
from torch.utils.data import Subset, Dataset
from torch.utils.data import WeightedRandomSampler
import numpy as np
from optuna.pruners import MedianPruner

# Load utility functions from cloned repository
from src.loadData import GraphDataset
from src.utils import set_seed
from src.models import GNN


# Set the random seed
set_seed()


In [34]:
def add_zeros(data):
    data.x = torch.zeros(data.num_nodes, dtype=torch.long)
    return data

In [35]:
def save_predictions(predictions, test_path):
    script_dir = os.getcwd()
    submission_folder = os.path.join(script_dir, "submission")
    test_dir_name = os.path.basename(os.path.dirname(test_path))

    os.makedirs(submission_folder, exist_ok=True)

    output_csv_path = os.path.join(submission_folder, f"testset_{test_dir_name}.csv")

    test_graph_ids = list(range(len(predictions)))
    output_df = pd.DataFrame({"id": test_graph_ids, "pred": predictions})

    output_df.to_csv(output_csv_path, index=False)
    print(f"Predictions saved to {output_csv_path}")

In [36]:
def plot_training_progress(train_losses, train_accuracies, output_dir):
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(12, 6))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Training Loss", color="blue")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss per Epoch")

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Training Accuracy", color="green")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training Accuracy per Epoch")

    # Save plots in the current directory
    os.makedirs(output_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_progress.png"))
    plt.close()

In [37]:
from torch_geometric.data import Data


class IndexedSubset(Dataset):
    def __init__(self, subset: Subset[Data]):
        self.subset = subset

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, i):
        data = self.subset[i]
        data.idx = torch.tensor(i, dtype=torch.long)  # type: ignore
        return data

In [38]:
class NCODLoss(nn.Module):
    past_embeddings: torch.Tensor
    centroids: torch.Tensor

    def __init__(
        self,
        dataset: Dataset,
        embedding_dimensions: int = 300,
        total_epochs: int = 150,
        lambda_consistency: float = 1.0,
        device: torch.device | None = None,
    ):
        """
        Args
        ----
        dataset : iterable whose elements expose an integer label in `elem.y`
        embedding_dimensions : size of the feature vectors
        total_epochs : number of training epochs (used for centroid update schedule)
        lambda_consistency : weight for the MSE consistency term
        device : cuda / cpu device.  If None, picks CUDA if available.
        """
        super().__init__()

        self.device = device or torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self.embedding_dimensions = embedding_dimensions
        self.total_epochs = total_epochs
        self.lambda_consistency = lambda_consistency

        labels = [int(elem.y) for elem in dataset]
        self.num_elements = len(labels)
        self.num_classes = max(labels) + 1  # robust to gaps (e.g. labels {1,3})
        # self.register_buffer("bins", torch.empty(self.num_classes, 0, dtype=torch.long))

        # Convert bins to a list-of-lists for easy appends, then to tensors
        tmp_bins: list[list[int]] = [[] for _ in range(self.num_classes)]
        for idx, lab in enumerate(labels):
            tmp_bins[lab].append(idx)
        self.bins = [
            torch.as_tensor(b, dtype=torch.long, device=self.device) for b in tmp_bins
        ]

        # Confidence parameter per sample (trainable!)
        self.u = nn.Parameter(torch.empty(self.num_elements, 1, device=self.device))
        nn.init.normal_(self.u, mean=1e-8, std=1e-9)

        # Running memory of embeddings
        self.register_buffer(
            "past_embeddings",
            torch.rand(
                self.num_elements, self.embedding_dimensions, device=self.device
            ),
        )
        # Class centroids
        self.register_buffer(
            "centroids",
            torch.rand(self.num_classes, self.embedding_dimensions, device=self.device),
        )

    def forward(
        self,
        *,
        logits: torch.Tensor,  # (B, C)
        indexes: torch.Tensor,  # (B,) – dataset indices of the current batch
        embeddings: torch.Tensor,  # (B, D)
        targets: torch.Tensor,  # (B,)
        epoch: int,
    ) -> torch.Tensor:
        eps = 1e-6

        # Keep an L2-normalised copy of current embeddings
        embeddings = F.normalize(embeddings, dim=1)
        self.past_embeddings[indexes] = embeddings.detach()

        # ---------------- Centroid update ------------------------------------
        if epoch == 0:
            with torch.no_grad():
                for c, idxs in enumerate(self.bins):
                    if idxs.numel():
                        self.centroids[c] = self.past_embeddings[idxs].mean(0)
        else:
            # Shrink the subset of samples that contribute to the centroid
            percent = int(max(1, min(100, 50 + 50 * (1 - epoch / self.total_epochs))))
            for c, idxs in enumerate(self.bins):
                if idxs.numel() == 0:
                    continue
                # bottom-k u’s  (small u  ⇒ low confidence ⇒ smaller weight)
                k = max(1, idxs.numel() * percent // 100)
                u_batch = self.u[idxs].squeeze(1)
                keep = torch.topk(u_batch, k, largest=False).indices  # (k,)
                selected = idxs[keep]  # (k,)
                self.centroids[c] = self.past_embeddings[selected].mean(0)

        centroids = F.normalize(self.centroids, dim=1)  # (C, D)

        # ---------------- Probability shaping --------------------------------
        soft_labels = F.softmax(embeddings @ centroids.T, dim=1)  # (B, C)
        probs = F.softmax(logits, dim=1)  # (B, C)
        u_vals = torch.sigmoid(self.u[indexes]).squeeze(1)  # (B,)

        adjusted = (probs + u_vals[:, None] * soft_labels).clamp(min=eps)
        adjusted = adjusted / adjusted.sum(1, keepdim=True)

        # ---------------- Loss terms -----------------------------------------
        hard_ce = (
            (1.0 - u_vals) * F.cross_entropy(logits, targets, reduction="none")
        ).mean()
        soft_ce = -(soft_labels * torch.log(adjusted)).sum(1).mean()
        consistency = F.mse_loss(adjusted, soft_labels)

        return hard_ce + soft_ce + self.lambda_consistency * consistency


In [39]:
class NoisyCrossEntropyLoss(torch.nn.Module):
    def __init__(self, p_noisy):
        super().__init__()
        self.p = p_noisy
        self.ce = torch.nn.CrossEntropyLoss(reduction="none")

    def forward(self, logits, targets):
        losses = self.ce(logits, targets)
        weights = (1 - self.p) + self.p * (
            1
            - torch.nn.functional.one_hot(targets, num_classes=logits.size(1))
            .float()
            .sum(dim=1)
        )
        return (losses * weights).mean()

In [40]:
class SCELoss(torch.nn.Module):
    def __init__(self, num_classes: int = 6, alpha: float = 0.1, beta: float = 1.0):
        super().__init__()
        self.alpha, self.beta = alpha, beta
        self.num_classes = num_classes

    def forward(self, logits, targets):
        # CCE
        ce = F.cross_entropy(logits, targets, reduction="none")

        # RCE
        pred = F.softmax(logits, dim=1).clamp(min=1e-6, max=1 - 1e-6)
        one_hot = F.one_hot(targets, self.num_classes).float()
        rce = -(1 - one_hot) * torch.log(1 - pred)
        rce = rce.sum(dim=1)
        return (self.alpha * ce + self.beta * rce).mean()

In [41]:
def train(
    data_loader: DataLoader,
    model: GNN,
    optimizer_theta: torch.optim.Optimizer,
    optimizer_u: torch.optim.Optimizer | None,
    criterion: nn.Module,
    device: torch.device,
    checkpoint_path: str,
    current_epoch: int,
    save_checkpoints: bool = True,
):
    model.train()

    total_loss = total_conf = total_entropy = 0.0
    correct = num_samples = 0

    for batch in data_loader:
        batch = batch.to(device)

        node_emb = model.gnn_node(batch)
        graph_emb = model.pool(node_emb, batch.batch)
        logits = model.graph_pred_linear(graph_emb)

        if isinstance(criterion, NCODLoss):
            loss = criterion(
                logits=logits,
                indexes=batch.idx.to(device),
                embeddings=graph_emb,
                targets=batch.y.to(device),
                epoch=current_epoch,
            )
        else:
            loss = criterion(logits, batch.y)

        optimizer_theta.zero_grad(set_to_none=True)
        if optimizer_u is not None:
            optimizer_u.zero_grad(set_to_none=True)
        loss.backward()
        optimizer_theta.step()
        if optimizer_u is not None:
            optimizer_u.step()

        with torch.no_grad():
            probs = F.softmax(logits, dim=1)
            batch_size = batch.y.size(0)

            total_loss += loss.item() * batch_size
            pred = probs.argmax(dim=1)
            correct += (pred == batch.y).sum().item()
            num_samples += batch_size

            total_conf += probs.max(dim=1).values.sum().item()
            total_entropy += (
                (-torch.sum(probs * torch.log(probs + 1e-10), 1)).sum().item()
            )

    if save_checkpoints:
        ckpt = f"{checkpoint_path}_epoch_{current_epoch + 1}.pth"
        torch.save(model.state_dict(), ckpt)
        print(f"[checkpoint] saved: {ckpt}")

    loss = total_loss / num_samples
    confidence = total_conf / num_samples
    entropy = total_entropy / num_samples
    accuracy = correct / num_samples
    return loss, confidence, accuracy, entropy


In [42]:
def evaluate(
    data_loader: DataLoader,
    model: GNN,
    criterion: nn.CrossEntropyLoss,
    device: torch.device,
) -> tuple[float, float, float, float]:
    """
    Returns
    -------
    avg_loss, avg_confidence, accuracy, avg_entropy
    """
    model.eval()

    total_loss = total_conf = total_entropy = 0.0
    correct = num_samples = 0

    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)

            node_emb = model.gnn_node(batch)
            graph_emb = model.pool(node_emb, batch.batch)
            logits = model.graph_pred_linear(graph_emb)
            probs = F.softmax(logits, dim=1)

            loss = criterion(logits, batch.y)

            batch_size = batch.y.size(0)
            total_loss += loss.item() * batch_size

            total_conf += probs.max(dim=1).values.sum().item()
            total_entropy += (
                (-torch.sum(probs * torch.log(probs + 1e-10), dim=1)).sum().item()
            )

            pred = probs.argmax(dim=1)
            correct += (pred == batch.y).sum().item()
            num_samples += batch_size

    loss = total_loss / num_samples
    confidence = total_conf / num_samples
    entropy = total_entropy / num_samples
    accuracy = correct / num_samples

    return loss, confidence, accuracy, entropy


In [43]:
def directory_setup(dataset_name):
    script_root = os.getcwd()
    logs_dir = os.path.join(script_root, "logs", dataset_name)
    os.makedirs(logs_dir, exist_ok=True)
    logging.basicConfig(
        filename=os.path.join(logs_dir, "training.log"),
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        filemode="w",
    )
    checkpoints_dir = os.path.join(script_root, "checkpoints", dataset_name)
    best_model_path = os.path.join(checkpoints_dir, f"model_{dataset_name}_best.pth")
    return logs_dir, checkpoints_dir, best_model_path


def dataloader_setup(dataset_name, batch_size):
    train_path = f"./datasets/{dataset_name}/train.json.gz"
    full_dataset = GraphDataset(train_path, transform=add_zeros)

    val_size = int(0.2 * len(full_dataset))
    train_size = len(full_dataset) - val_size
    generator = torch.Generator().manual_seed(12)
    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, val_size], generator=generator
    )
    train_dataset = IndexedSubset(train_dataset)
    val_dataset = IndexedSubset(val_dataset)

    # ---- WeightedRandomSampler for class balancing ----
    labels = [int(data.y[0]) for data in train_dataset]
    class_counts = np.bincount(labels)
    class_weights = 1.0 / class_counts
    sample_weights = [class_weights[label] for label in labels]
    sampler = WeightedRandomSampler(
        sample_weights, num_samples=len(sample_weights), replacement=True
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,  # replaces shuffle=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
    )
    return train_loader, val_loader, full_dataset, train_dataset, val_dataset


In [44]:
def objective(
    trial: Optional[optuna.trial.Trial] = None,
    *,
    dataset_name: str,
    score_type: Literal["loss", "accuracy", "entropy", "confidence"] = "accuracy",
    checkpoints: int = 5,
    epochs: int = 150,
    train_loader: Optional[DataLoader] = None,
    val_loader: Optional[DataLoader] = None,
    full_dataset = None,
    train_dataset = None,
    val_dataset = None,
    batch_size: Optional[int] = None,
    logs_dir: Optional[str] = None,
    checkpoints_dir: Optional[str] = None,
    best_model_path: Optional[str],
    gnn_type: Optional[str] = None,
    loss_type: Optional[str] = None,
    graph_pooling: Optional[str] = None,
    drop_ratio: Optional[float] = None,
    num_layers: Optional[int] = None,
    embedding_dim: Optional[int] = None,
    model: Optional[GNN] = None,
    optimizer_theta: Optional[torch.optim.Optimizer] = None,
    optimizer_u: Optional[torch.optim.Optimizer] = None,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
    try:
        logging.info("#" * 80)
        # Hyperparameter search space
        logging.info("Start case study with parameters:")

        batch_size = batch_size or 64
        if logs_dir is None or checkpoints_dir is None or best_model_path is None:
            logs_dir, checkpoints_dir, best_model_path = directory_setup(dataset_name)

        if train_loader is None or val_loader is None:
            train_loader, val_loader, full_dataset, train_dataset, val_dataset = (
                dataloader_setup(dataset_name, batch_size)
            )
        if trial is None:
            gnn_type = gnn_type or "gcn"
            graph_pooling = graph_pooling or "mean"
            loss_type = loss_type or "noisy_ce"
            drop_ratio = drop_ratio or 0.2
            num_layers = num_layers or 11
            embedding_dim = embedding_dim or 300

        else:
            gnn_type = trial.suggest_categorical(
                "gnn_type", ["gin", "gin-virtual", "gcn", "gcn-virtual"]
            )
            loss_type = trial.suggest_categorical(
                "loss_type", ["ce"] + ["ncod"] * 5 + ["noisy_ce"] * 5 + ["sce"]
            )
            graph_pooling = trial.suggest_categorical(
                "graph_pooling",
                ["sum"] * 2
                + ["mean"] * 4
                + ["max"] * 2
                + ["attention"] * 2
                + ["set2set"],
            )
            drop_ratio = trial.suggest_float("dropout", 0.0, 0.7)
            num_layers = trial.suggest_int("num_layers", 6, 12)
            embedding_dim = trial.suggest_categorical(
                "embedding_dim", [64, 128, 300, 600]
            )
        embedding_dim = (
            2 * embedding_dim
            if graph_pooling == "set2set" and loss_type == "ncod"
            else embedding_dim
        )
        
        if model is None:
            model = GNN(
                gnn_type="gin" if "gin" in gnn_type else "gcn",
                num_class=6,
                num_layer=num_layers,
                emb_dim=embedding_dim,
                drop_ratio=drop_ratio,
                virtual_node="virtual" in gnn_type,
                graph_pooling=graph_pooling,
            )
        # Initialize model
        model = model.to(device)
        print(f"{model.__dict__}\n{epochs=}")
        logging.info(f"{model.__dict__} {epochs=}")
        if optimizer_theta is None:
            optimizer_theta = torch.optim.Adam(
                model.parameters(), lr=3e-4, weight_decay=1e-5
            )

        optimizer_u = None
        if loss_type == "ce":
            train_criterion = nn.CrossEntropyLoss()
        elif loss_type == "ncod":
            train_criterion = NCODLoss(
                train_loader.dataset,
                embedding_dimensions=embedding_dim,
                total_epochs=epochs,
                device=device,
            )
            if optimizer_u is None:
                optimizer_u = torch.optim.SGD(train_criterion.parameters(), lr=1e-3)
        elif loss_type == "sce":
            train_criterion = SCELoss()
        else:
            train_criterion = NoisyCrossEntropyLoss(0.2)

        val_criterion = nn.CrossEntropyLoss()

        # Checkpoint logic
        checkpoint_epochs = [
            int((i + 1) * epochs / checkpoints) for i in range(checkpoints)
        ]

        train_losses, train_confs, train_accs, train_entropies, train_scores = (
            [],
            [],
            [],
            [],
            [],
        )
        best_val_score = -float("inf")
        val_losses, val_confs, val_accs, val_entropies, val_scores = [], [], [], [], []

        progress_bar = tqdm(range(epochs), desc="Training...", leave=False)
        for epoch in progress_bar:
            train_loss, train_conf, train_acc, train_entropy = train(
                train_loader,
                model,
                optimizer_theta,
                optimizer_u,
                train_criterion,
                device,
                save_checkpoints=(epoch + 1 in checkpoint_epochs),
                checkpoint_path=os.path.join(checkpoints_dir, f"model_{dataset_name}"),
                current_epoch=epoch,
            )

            val_loss, val_conf, val_acc, val_entropy = evaluate(
                val_loader,
                model,
                val_criterion,
                device,
            )

            train_losses.append(train_loss)
            train_accs.append(train_acc)
            train_confs.append(train_conf)
            train_entropies.append(train_entropy)

            val_losses.append(val_loss)
            val_accs.append(val_acc)
            val_confs.append(val_conf)
            val_entropies.append(val_entropy)

            if score_type == "loss":
                train_score = train_loss
                val_score = -val_loss
            elif score_type == "entropy":
                train_score = train_entropy
                val_score = -val_entropy
            elif score_type == "confidence":
                train_score = train_conf
                val_score = val_conf
            else:
                train_score = train_acc
                val_score = val_acc

            train_scores.append(train_score)
            val_scores.append(val_score)
            if trial is not None:
                trial.report(val_score, step=epoch)

            if val_score > best_val_score:
                best_val_score = val_score
                torch.save(model.state_dict(), best_model_path)
                logging.info(
                    f"[{dataset_name}] Best model updated at {best_model_path}"
                )
            progress_bar.set_postfix_str(
                f"Train Score: {train_score:.4f}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}, Confidence: {train_conf:.4f}, Entropy: {train_entropy:.4f}| "
                f"Val Score: {val_score:.4f}, Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}, Confidence: {val_conf:.4f}, Entropy: {val_entropy:.4f}"
            )
            logging.info(
                f"Epoch {epoch}/{epochs}| "
                f"Train Score: {train_score:.4f}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}, Confidence: {train_conf:.4f}, Entropy: {train_entropy:.4f}| "
                f"Val Score: {val_score:.4f}, Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}, Confidence: {val_conf:.4f}, Entropy: {val_entropy:.4f}"
            )
            if trial is not None and trial.should_prune():
                logging.warning(f"Trial was pruned at epoch {epoch}")
                raise optuna.exceptions.TrialPruned()
        logging.info(f"Case study end, {best_val_score}")
        # Plot training curves
        plot_training_progress(
            train_losses, train_scores, os.path.join(logs_dir, "train_plots")
        )
        plot_training_progress(
            val_losses, val_scores, os.path.join(logs_dir, "val_plots")
        )

        return best_val_score
    except optuna.exceptions.TrialPruned as e:
        raise e

    except Exception as e:
        print("Unhandled exception ", e)
        raise e
        return -float("inf")

    finally:
        
        logging.info("#" * 80)
        logging.info("\n")
        if trial is None:
            del train_loader, val_loader, full_dataset, train_dataset, val_dataset
            gc.collect()
        progress_bar.close()

In [45]:
def case_study(
    dataset_name: Literal["A", "B", "C", "D"],
    n_trials: int = 30,
    checkpoints: int = 5,
    batch_size: int = 64,
    epochs=150,
    score_type: Literal["loss", "accuracy", "entropy", "confidence"] = "accuracy",
):
    logs_dir, checkpoints_dir, best_model_path = directory_setup(dataset_name)
    summary_csv_path = os.path.join(logs_dir, f"optuna_summary_{dataset_name}.csv")
    os.makedirs(checkpoints_dir, exist_ok=True)
    train_loader, val_loader, full_dataset, train_dataset, val_dataset = (
        dataloader_setup(dataset_name, batch_size)
    )
    print(f"Starting Optuna optimization for dataset {dataset_name}")

    study = optuna.create_study(
        study_name=dataset_name,
        direction="maximize",
        pruner=MedianPruner(n_warmup_steps=10),
        sampler=optuna.samplers.TPESampler(n_startup_trials=0),
    )

    obj = partial(
        objective,
        dataset_name=dataset_name,
        score_type=score_type,
        epochs=epochs,
        checkpoints=checkpoints,
        train_loader=train_loader,
        val_loader=val_loader,
        batch_size=batch_size,
        logs_dir=logs_dir,
        checkpoints_dir=checkpoints_dir,
        best_model_path=best_model_path,
    )
    study.optimize(obj, n_trials=n_trials, show_progress_bar=True)

    all_trials = []
    for trial in study.trials:
        if trial.state == optuna.trial.TrialState.COMPLETE:
            row = {score_type: trial.value}
            row.update(trial.params)
            all_trials.append(row)

    results_df = pd.DataFrame(all_trials).sort_values(score_type, ascending=False)
    results_df.to_csv(summary_csv_path, index=False)

    print(f"\nAll trials saved to: {summary_csv_path}")
    print(f"\nBest result for dataset {dataset_name}:")
    display(results_df)
    print(f"\nBest Params for {dataset_name}:")
    for k, v in study.best_params.items():
        print(f"  {k}: {v}")

    del train_loader, val_loader, full_dataset, train_dataset, val_dataset
    gc.collect()
    return study.best_trial.params, study.best_value


In [46]:
# def train_on_dataset(
#     model,
#     optimizer_theta,
#     dataset_name,
#     logs_dir,
#     checkpoints_dir,
#     best_model_path,
#     device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
# ):
#     train_loader, val_loader, full_dataset, train_dataset, val_dataset = (
#         dataloader_setup(dataset_name, 64)
#     )
#     objective(
#         dataset_name="aggregated",
#         score_type="accuracy",
#         checkpoints=5,
#         epochs=50,
#         train_loader=train_loader,
#         val_loader=val_loader,
#         batch_size=64,
#         logs_dir=logs_dir,
#         checkpoints_dir=checkpoints_dir,
#         best_model_path=best_model_path,
#         model=model,
#         optimizer_theta = optimizer_theta,
#         device=device,
#     )
#     # del train_loader, val_loader, full_dataset, train_dataset, val_dataset
#     # gc.collect()


In [47]:
# gnn_type = "gcn"
# graph_pooling = "mean"
# loss_type = "noisy_ce"
# drop_ratio = 0.2
# num_layers = 11
# embedding_dim = 300

# model = GNN(
#     gnn_type=gnn_type,
#     num_class=6,
#     num_layer=num_layers,
#     emb_dim=embedding_dim,
#     drop_ratio=drop_ratio,
#     virtual_node="virtual" in gnn_type,
#     graph_pooling=graph_pooling,
# )
# optimizer_theta = torch.optim.Adam(
#             model.parameters(), lr=3e-4, weight_decay=1e-5
#         )
# logs_dir, checkpoints_dir, best_model_path = directory_setup("aggregated")
# train_loader, val_loader, full_dataset, train_dataset, val_dataset = dataloader_setup(
#     "C", 64
# )
# for name in ["D","C",  "B", "A"]:
#     train_on_dataset(model, optimizer_theta,name,logs_dir,checkpoints_dir,best_model_path)

In [48]:
case_study(
    "A",
    60,
    batch_size=64,
    epochs = 20,
    score_type="accuracy"
)

[I 2025-05-29 20:33:55,755] A new study created in memory with name: A


Starting Optuna optimization for dataset A


  0%|          | 0/60 [00:00<?, ?it/s]

{'training': True, '_parameters': {}, '_buffers': {}, '_non_persistent_buffers_set': set(), '_backward_pre_hooks': OrderedDict(), '_backward_hooks': OrderedDict(), '_is_full_backward_hook': None, '_forward_hooks': OrderedDict(), '_forward_hooks_with_kwargs': OrderedDict(), '_forward_hooks_always_called': OrderedDict(), '_forward_pre_hooks': OrderedDict(), '_forward_pre_hooks_with_kwargs': OrderedDict(), '_state_dict_hooks': OrderedDict(), '_state_dict_pre_hooks': OrderedDict(), '_load_state_dict_pre_hooks': OrderedDict(), '_load_state_dict_post_hooks': OrderedDict(), '_modules': {'gnn_node': GNN_node_Virtualnode(
  (node_encoder): Embedding(1, 300)
  (virtualnode_embedding): Embedding(1, 300)
  (convs): ModuleList(
    (0-9): 10 x GINConv()
  )
  (batch_norms): ModuleList(
    (0-9): 10 x BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (mlp_virtualnode_list): ModuleList(
    (0-8): 9 x Sequential(
      (0): Linear(in_features=300, out_features=60

Training...:   0%|          | 0/20 [00:00<?, ?it/s]



Unhandled exception  CUDA out of memory. Tried to allocate 368.00 MiB. GPU 0 has a total capacity of 11.60 GiB of which 391.06 MiB is free. Process 2160 has 7.59 MiB memory in use. Process 10052 has 23.84 MiB memory in use. Process 11053 has 3.98 GiB memory in use. Including non-PyTorch memory, this process has 6.46 GiB memory in use. Of the allocated memory 5.85 GiB is allocated by PyTorch, and 397.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[W 2025-05-29 20:33:57,378] Trial 0 failed with parameters: {'gnn_type': 'gin-virtual', 'loss_type': 'sce', 'graph_pooling': 'max', 'dropout': 0.28425527459347893, 'num_layers': 10, 'embedding_dim': 300} because of the following error: OutOfMemoryError('CUDA out of memory. Tried to allocate 368.00 MiB. GPU

OutOfMemoryError: CUDA out of memory. Tried to allocate 368.00 MiB. GPU 0 has a total capacity of 11.60 GiB of which 391.06 MiB is free. Process 2160 has 7.59 MiB memory in use. Process 10052 has 23.84 MiB memory in use. Process 11053 has 3.98 GiB memory in use. Including non-PyTorch memory, this process has 6.46 GiB memory in use. Of the allocated memory 5.85 GiB is allocated by PyTorch, and 397.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# case_study(
#     "C", n_trials=100, num_checkpoints=5, default_batch_size=64, score_type="composite"
# )

In [None]:
# case_study(
#     "D", n_trials=100, num_checkpoints=5, default_batch_size=64, score_type="composite"
# )