In [None]:
import os
import torch
from torch import nn
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
from functools import partial
import optuna
import gc
from typing import Literal
import torch.nn.functional as F
from torch.utils.data import Subset, Dataset, WeightedRandomSampler


# Load utility functions from cloned repository
from src.loadData import GraphDataset
from src.utils import set_seed
from src.models import GNN


# Set the random seed
set_seed()


In [2]:
def add_zeros(data):
    data.x = torch.zeros(data.num_nodes, dtype=torch.long)
    return data

In [3]:
def save_predictions(predictions, test_path):
    script_dir = os.getcwd()
    submission_folder = os.path.join(script_dir, "submission")
    test_dir_name = os.path.basename(os.path.dirname(test_path))

    os.makedirs(submission_folder, exist_ok=True)

    output_csv_path = os.path.join(submission_folder, f"testset_{test_dir_name}.csv")

    test_graph_ids = list(range(len(predictions)))
    output_df = pd.DataFrame({"id": test_graph_ids, "pred": predictions})

    output_df.to_csv(output_csv_path, index=False)
    print(f"Predictions saved to {output_csv_path}")

In [4]:
def plot_training_progress(train_losses, train_accuracies, output_dir):
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(12, 6))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Training Loss", color="blue")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss per Epoch")

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Training Accuracy", color="green")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training Accuracy per Epoch")

    # Save plots in the current directory
    os.makedirs(output_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_progress.png"))
    plt.close()

In [5]:
class IndexedSubset(Dataset):
    def __init__(self, subset: Subset):
        self.subset = subset

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, i):
        data = self.subset[i]
        data.idx = torch.tensor(i, dtype=torch.long)  # type: ignore
        return data

In [6]:
class NCODLoss(nn.Module):
    past_embeddings: torch.Tensor
    centroids: torch.Tensor

    def __init__(
        self,
        dataset: Dataset,
        embedding_dimensions: int = 300,
        total_epochs: int = 150,
        lambda_consistency: float = 1.0,
        device: torch.device | None = None,
    ):
        """
        Args
        ----
        dataset : iterable whose elements expose an integer label in `elem.y`
        embedding_dimensions : size of the feature vectors
        total_epochs : number of training epochs (used for centroid update schedule)
        lambda_consistency : weight for the MSE consistency term
        device : cuda / cpu device.  If None, picks CUDA if available.
        """
        super().__init__()

        self.device = device or torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self.embedding_dimensions = embedding_dimensions
        self.total_epochs = total_epochs
        self.lambda_consistency = lambda_consistency

        labels = [int(elem.y) for elem in dataset]
        self.num_elements = len(labels)
        self.num_classes = max(labels) + 1  # robust to gaps (e.g. labels {1,3})
        # self.register_buffer("bins", torch.empty(self.num_classes, 0, dtype=torch.long))

        # Convert bins to a list-of-lists for easy appends, then to tensors
        tmp_bins: list[list[int]] = [[] for _ in range(self.num_classes)]
        for idx, lab in enumerate(labels):
            tmp_bins[lab].append(idx)
        self.bins = [
            torch.as_tensor(b, dtype=torch.long, device=self.device) for b in tmp_bins
        ]

        # Confidence parameter per sample (trainable!)
        self.u = nn.Parameter(torch.empty(self.num_elements, 1, device=self.device))
        nn.init.normal_(self.u, mean=1e-8, std=1e-9)

        # Running memory of embeddings
        self.register_buffer(
            "past_embeddings",
            torch.rand(
                self.num_elements, self.embedding_dimensions, device=self.device
            ),
        )
        # Class centroids
        self.register_buffer(
            "centroids",
            torch.rand(self.num_classes, self.embedding_dimensions, device=self.device),
        )

    def forward(
        self,
        *,
        logits: torch.Tensor,  # (B, C)
        indexes: torch.Tensor,  # (B,) – dataset indices of the current batch
        embeddings: torch.Tensor,  # (B, D)
        targets: torch.Tensor,  # (B,)
        epoch: int,
    ) -> torch.Tensor:
        eps = 1e-6

        # Keep an L2-normalised copy of current embeddings
        embeddings = F.normalize(embeddings, dim=1)
        self.past_embeddings[indexes] = embeddings.detach()

        # ---------------- Centroid update ------------------------------------
        if epoch == 0:
            with torch.no_grad():
                for c, idxs in enumerate(self.bins):
                    if idxs.numel():
                        self.centroids[c] = self.past_embeddings[idxs].mean(0)
        else:
            # Shrink the subset of samples that contribute to the centroid
            percent = int(max(1, min(100, 50 + 50 * (1 - epoch / self.total_epochs))))
            for c, idxs in enumerate(self.bins):
                if idxs.numel() == 0:
                    continue
                # bottom-k u’s  (small u  ⇒ low confidence ⇒ smaller weight)
                k = max(1, idxs.numel() * percent // 100)
                u_batch = self.u[idxs].squeeze(1)
                keep = torch.topk(u_batch, k, largest=False).indices  # (k,)
                selected = idxs[keep]  # (k,)
                self.centroids[c] = self.past_embeddings[selected].mean(0)

        centroids = F.normalize(self.centroids, dim=1)  # (C, D)

        # ---------------- Probability shaping --------------------------------
        soft_labels = F.softmax(embeddings @ centroids.T, dim=1)  # (B, C)
        probs = F.softmax(logits, dim=1)  # (B, C)
        u_vals = torch.sigmoid(self.u[indexes]).squeeze(1)  # (B,)

        adjusted = (probs + u_vals[:, None] * soft_labels).clamp(min=eps)
        adjusted = adjusted / adjusted.sum(1, keepdim=True)

        # ---------------- Loss terms -----------------------------------------
        hard_ce = (
            (1.0 - u_vals) * F.cross_entropy(logits, targets, reduction="none")
        ).mean()
        soft_ce = -(soft_labels * torch.log(adjusted)).sum(1).mean()
        consistency = F.mse_loss(adjusted, soft_labels)

        return hard_ce + soft_ce + self.lambda_consistency * consistency


In [7]:
class NoisyCrossEntropyLoss(torch.nn.Module):
    def __init__(self, p_noisy):
        super().__init__()
        self.p = p_noisy
        self.ce = torch.nn.CrossEntropyLoss(reduction="none")

    def forward(self, logits, targets):
        losses = self.ce(logits, targets)
        weights = (1 - self.p) + self.p * (
            1
            - torch.nn.functional.one_hot(targets, num_classes=logits.size(1))
            .float()
            .sum(dim=1)
        )
        return (losses * weights).mean()

In [8]:
class SCELoss(torch.nn.Module):
    def __init__(self, num_classes: int = 6, alpha: float = 0.1, beta: float = 1.0):
        super().__init__()
        self.alpha, self.beta = alpha, beta
        self.num_classes = num_classes

    def forward(self, logits, targets):
        # CCE
        ce = F.cross_entropy(logits, targets, reduction="none")

        # RCE
        pred = F.softmax(logits, dim=1).clamp(min=1e-6, max=1 - 1e-6)
        one_hot = F.one_hot(targets, self.num_classes).float()
        rce = -(1 - one_hot) * torch.log(1 - pred)
        rce = rce.sum(dim=1)
        return (self.alpha * ce + self.beta * rce).mean()

In [9]:
def train(
    data_loader: DataLoader,
    model: GNN,
    optimizer_theta: torch.optim.Optimizer,
    optimizer_u: torch.optim.Optimizer | None,
    criterion: nn.Module,
    device: torch.device,
    checkpoint_path: str,
    current_epoch: int,
    save_checkpoints: bool = True,
):
    model.train()

    total_loss = total_conf = total_entropy = 0.0
    correct = num_samples = 0

    for batch in data_loader:
        batch = batch.to(device)

        node_emb = model.gnn_node(batch)
        graph_emb = model.pool(node_emb, batch.batch)
        logits = model.graph_pred_linear(graph_emb)

        if isinstance(criterion, NCODLoss):
            loss = criterion(
                logits=logits,
                indexes=batch.idx.to(device),
                embeddings=graph_emb,
                targets=batch.y.to(device),
                epoch=current_epoch,
            )
        else:
            loss = criterion(logits, batch.y)

        optimizer_theta.zero_grad(set_to_none=True)
        if optimizer_u is not None:
            optimizer_u.zero_grad(set_to_none=True)
        loss.backward()
        optimizer_theta.step()
        if optimizer_u is not None:
            optimizer_u.step()

        with torch.no_grad():
            probs = F.softmax(logits, dim=1)
            batch_size = batch.y.size(0)

            total_loss += loss.item() * batch_size
            pred = probs.argmax(dim=1)
            correct += (pred == batch.y).sum().item()
            num_samples += batch_size

            total_conf += probs.max(dim=1).values.sum().item()
            total_entropy += (
                (-torch.sum(probs * torch.log(probs + 1e-10), 1)).sum().item()
            )

    if save_checkpoints:
        ckpt = f"{checkpoint_path}_epoch_{current_epoch + 1}.pth"
        torch.save(model.state_dict(), ckpt)
        print(f"[checkpoint] saved: {ckpt}")

    loss = total_loss / num_samples
    confidence = total_conf / num_samples
    entropy = total_entropy / num_samples
    accuracy = correct / num_samples
    return loss, confidence, accuracy, entropy


In [10]:
def evaluate(
    data_loader: DataLoader,
    model: GNN,
    criterion: nn.CrossEntropyLoss,
    device: torch.device,
) -> tuple[float, float, float, float]:
    """
    Returns
    -------
    avg_loss, avg_confidence, accuracy, avg_entropy
    """
    model.eval()

    total_loss = total_conf = total_entropy = 0.0
    correct = num_samples = 0

    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)

            node_emb = model.gnn_node(batch)
            graph_emb = model.pool(node_emb, batch.batch)
            logits = model.graph_pred_linear(graph_emb)
            probs = F.softmax(logits, dim=1)

            loss = criterion(logits, batch.y)

            batch_size = batch.y.size(0)
            total_loss += loss.item() * batch_size

            total_conf += probs.max(dim=1).values.sum().item()
            total_entropy += (
                (-torch.sum(probs * torch.log(probs + 1e-10), dim=1)).sum().item()
            )

            pred = probs.argmax(dim=1)
            correct += (pred == batch.y).sum().item()
            num_samples += batch_size

    loss = total_loss / num_samples
    confidence = total_conf / num_samples
    entropy = total_entropy / num_samples
    accuracy = correct / num_samples

    return loss, confidence, accuracy, entropy


In [11]:
def objective(
    trial,
    train_loader: DataLoader,
    val_loader: DataLoader,
    num_checkpoints: int,
    checkpoints_dir: str,
    run_name: str,
    best_model_path: str,
    logs_dir: str,
    resume_training: bool,
    *,
    score_type: Literal["loss", "accuracy", "entropy", "confidence"] = "confidence",
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):  # -> float | Any:
    try:
        logging.info("#" * 80)
        # Hyperparameter search space
        logging.info("Start case study with parameters:")
        gnn_type = trial.suggest_categorical(
            "gnn_type", ["gin", "gin-virtual", "gcn", "gcn-virtual"]
        )
        loss_type = trial.suggest_categorical(
            "loss_type", ["ce", "ncod", "noisy_ce", "sce"]
        )
        graph_pooling = trial.suggest_categorical("graph_pooling", ["sum","mean","max","attention","set2set"])
        drop_ratio = trial.suggest_float("dropout", 0.0, 0.7)
        num_layers = 6 #trial.suggest_int("num_layers", 6, 12)
        embedding_dim = 300  # trial.suggest_categorical("embedding_dim", [64, 128, 300])
        num_epochs = 30  # trial.suggest_int("num_epochs", 10, 11, step=1)

        print(f"{gnn_type=}\n{loss_type=}\n{graph_pooling=}\n{drop_ratio=}\n{num_layers=}\n{embedding_dim=}\n{num_epochs=}")
        logging.info(f"{gnn_type=} {loss_type=} {graph_pooling=} {drop_ratio=} {num_layers=} {embedding_dim=} {num_epochs=}")

        # Initialize model
        model = GNN(
            gnn_type="gin" if "gin" in gnn_type else "gcn",
            num_class=6,
            num_layer=num_layers,
            emb_dim=embedding_dim,
            drop_ratio=drop_ratio,
            virtual_node="virtual" in gnn_type,
            graph_pooling=graph_pooling
        ).to(device)
        if resume_training:
            model.load_state_dict(torch.load(best_model_path))
        optimizer_theta = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-5)
        optimizer_u = None
        if loss_type == "ce":
            train_criterion = nn.CrossEntropyLoss()
        elif loss_type == "ncod":
            train_criterion = NCODLoss(
                train_loader.dataset,
                embedding_dimensions=embedding_dim,
                total_epochs=num_epochs,
                device=device,
            )
            optimizer_u = torch.optim.SGD(train_criterion.parameters(), lr=1e-3)
        elif loss_type == "sce":
            train_criterion = SCELoss()
        else:
            train_criterion = NoisyCrossEntropyLoss(0.2)

        val_criterion = nn.CrossEntropyLoss()

        # Checkpoint logic
        checkpoint_epochs = [
            int((i + 1) * num_epochs / num_checkpoints) for i in range(num_checkpoints)
        ]

        best_val_score = -float("inf")
        train_losses, train_confs, train_accs, train_entropies, train_scores = (
            [],
            [],
            [],
            [],
            [],
        )
        val_losses, val_confs, val_accs, val_entropies, val_scores = [], [], [], [], []

        progress_bar = tqdm(range(num_epochs), desc="Training...", leave=False)
        for epoch in progress_bar:
            train_loss, train_conf, train_acc, train_entropy = train(
                train_loader,
                model,
                optimizer_theta,
                optimizer_u,
                train_criterion,
                device,
                save_checkpoints=(epoch + 1 in checkpoint_epochs),
                checkpoint_path=os.path.join(checkpoints_dir, f"model_{run_name}"),
                current_epoch=epoch,
            )

            val_loss, val_conf, val_acc, val_entropy = evaluate(
                val_loader,
                model,
                val_criterion,
                device,
            )

            tqdm.write(
                f"Epoch {epoch}/{num_epochs}\n"
                f"Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}, Confidence: {train_conf:.4f}, Entropy: {train_entropy:.4f}\n"
                f"Val Loss  : {val_loss:.4f}, Accuracy: {val_acc:.4f}, Confidence: {val_conf:.4f}, Entropy: {val_entropy:.4f}\n"
            )
            logging.info(
                f"Epoch {epoch}/{num_epochs}| "
                f"Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}, Confidence: {train_conf:.4f}, Entropy: {train_entropy:.4f}| "
                f"Val Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}, Confidence: {val_conf:.4f}, Entropy: {val_entropy:.4f}"
            )

            train_losses.append(train_loss)
            train_accs.append(train_acc)
            train_confs.append(train_conf)
            train_entropies.append(train_entropy)

            val_losses.append(val_loss)
            val_accs.append(val_acc)
            val_confs.append(val_conf)
            val_entropies.append(val_entropy)

            if score_type == "loss":
                train_score = train_loss
                val_score = val_loss
            elif score_type == "entropy":
                train_score = train_entropy
                val_score = val_entropy
            elif score_type == "confidence":
                train_score = train_conf
                val_score = val_conf
            else:
                train_score = train_acc
                val_score = val_acc

            train_scores.append(train_score)
            val_scores.append(val_score)
            if val_score > best_val_score:
                best_val_score = val_score
                torch.save(model.state_dict(), best_model_path)
                logging.info(f"[{run_name}] Best model updated at {best_model_path}")

        # Plot training curves
        plot_training_progress(
            train_losses, train_scores, os.path.join(logs_dir, "train_plots")
        )
        plot_training_progress(val_losses, val_scores, os.path.join(logs_dir, "val_plots"))

        logging.info(f"Case study end, {best_val_score}")
        logging.info("#" * 80)
        logging.info("\n")

        return best_val_score
    except Exception as e:
        print("Unhandled exception ",e)
        return - float("inf")

In [None]:
def case_study(
    version_number,
    resume_training: bool,
    n_trials: int = 30,
    num_checkpoints: int = 10,
    default_batch_size: int = 32,
    score_type: Literal["loss", "accuracy", "entropy", "confidence"] = "confidence",
    full_dataset=None,
):
    script_root = os.getcwd()
    train_path = (
        f"./datasets/filtered_aggregated/filtered_aggregated_{version_number}.json.gz"
    )
    run_name = f"filtered_aggregated_{version_number}"

    logs_dir = os.path.join(script_root, "logs", run_name)
    os.makedirs(logs_dir, exist_ok=True)
    logging.basicConfig(
        filename=os.path.join(logs_dir, "training.log"),
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        filemode="w",
    )

    checkpoints_dir = os.path.join(script_root, "checkpoints", run_name)
    best_model_path = os.path.join(
        checkpoints_dir, f"baseline_{version_number}_best.pth"
    )
    summary_csv_path = os.path.join(logs_dir, f"optuna_summary_{version_number}.csv")
    os.makedirs(checkpoints_dir, exist_ok=True)
    if full_dataset is None:
        full_dataset = GraphDataset(train_path, transform=add_zeros)
    
    val_size = int(0.2 * len(full_dataset))
    train_size = len(full_dataset) - val_size
    generator = torch.Generator().manual_seed(12)
    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, val_size], generator=generator
    )
    train_dataset = IndexedSubset(train_dataset)
    val_dataset = IndexedSubset(val_dataset)

    train_loader = DataLoader(
        train_dataset,  # type:ignore
        batch_size=default_batch_size,
        shuffle=True,
    )

    val_loader = DataLoader(
        val_dataset,  # type:ignore
        batch_size=default_batch_size,
        shuffle=False,
    )

    print(f"Starting Optuna optimization for dataset {version_number}")

    study = optuna.create_study(
        study_name=run_name,
        direction="minimize"
        if score_type == "loss" or score_type == "entropy"
        else "maximize",
    )

    obj = partial(
        objective,
        train_loader=train_loader,
        val_loader=val_loader,
        num_checkpoints=num_checkpoints,
        checkpoints_dir=checkpoints_dir,
        run_name=run_name,
        best_model_path=best_model_path,
        logs_dir=logs_dir,
        resume_training=resume_training,
        score_type=score_type,
    )
    study.optimize(obj, n_trials=n_trials)

    all_trials = []
    for trial in study.trials:
        if trial.state == optuna.trial.TrialState.COMPLETE:
            row = {score_type: trial.value}
            row.update(trial.params)
            all_trials.append(row)

    results_df = pd.DataFrame(all_trials)
    results_df.to_csv(summary_csv_path, index=False)

    print(f"\nAll trials saved to: {summary_csv_path}")
    print(f"\nBest result for dataset {version_number}:")
    display(
        results_df.sort_values(
            score_type, ascending=(score_type == "loss" or score_type == "entropy")
        )
    )
    print(f"\nBest Params for baseline {version_number}:")
    for k, v in study.best_params.items():
        print(f"  {k}: {v}")

    del train_loader, val_loader, full_dataset, train_dataset, val_dataset
    gc.collect()


In [13]:
filtered_aggregated = GraphDataset(
    "./datasets/merged/merged_dataset_v2.json.gz", transform=add_zeros,load_in_memory=False
)

KeyboardInterrupt: 

In [None]:
dataloader = DataLoader(filtered_aggregated,batch_size=64)
for data in dataloader:
    print(data.to(torch.device("cuda")))
    

In [None]:
# case_study(1, False, n_trials=50,score_type="accuracy", full_dataset=filtered_aggregated)