In [1]:
import os
import torch
from torch import nn
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm.notebook import tqdm
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
from functools import partial
import optuna
import gc
from typing import Literal
import torch.nn.functional as F
import math
from torch_geometric.data import Batch
from torch.utils.data import Subset, Dataset
from torch_geometric.loader import DataLoader

# Load utility functions from cloned repository
from src.loadData import GraphDataset
from src.utils import set_seed
from src.models import GNN


# Set the random seed
set_seed()


In [6]:
def add_zeros(data):
    data.x = torch.zeros(data.num_nodes, dtype=torch.long)
    return data

In [5]:
def save_predictions(predictions, test_path):
    script_dir = os.getcwd()
    submission_folder = os.path.join(script_dir, "submission")
    test_dir_name = os.path.basename(os.path.dirname(test_path))

    os.makedirs(submission_folder, exist_ok=True)

    output_csv_path = os.path.join(submission_folder, f"testset_{test_dir_name}.csv")

    test_graph_ids = list(range(len(predictions)))
    output_df = pd.DataFrame({"id": test_graph_ids, "pred": predictions})

    output_df.to_csv(output_csv_path, index=False)
    print(f"Predictions saved to {output_csv_path}")

In [4]:
def plot_training_progress(train_losses, train_accuracies, output_dir):
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(12, 6))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Training Loss", color="blue")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss per Epoch")

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Training Accuracy", color="green")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training Accuracy per Epoch")

    # Save plots in the current directory
    os.makedirs(output_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_progress.png"))
    plt.close()

In [7]:
class NoisyCrossEntropyLoss(torch.nn.Module):
    def __init__(self, p_noisy):
        super().__init__()
        self.p = p_noisy
        self.ce = torch.nn.CrossEntropyLoss(reduction="none")

    def forward(self, logits, targets):
        losses = self.ce(logits, targets)
        weights = (1 - self.p) + self.p * (
            1
            - torch.nn.functional.one_hot(targets, num_classes=logits.size(1))
            .float()
            .sum(dim=1)
        )
        return (losses * weights).mean()

In [8]:
class SCELoss(torch.nn.Module):
    def __init__(self, num_classes: int = 6, alpha: float = 0.1, beta: float = 1.0):
        super().__init__()
        self.alpha, self.beta = alpha, beta
        self.num_classes = num_classes

    def forward(self, logits, targets):
        # CCE
        ce = F.cross_entropy(logits, targets, reduction="none")

        # RCE
        pred = F.softmax(logits, dim=1).clamp(min=1e-6, max=1 - 1e-6)
        one_hot = F.one_hot(targets, self.num_classes).float()
        rce = -(1 - one_hot) * torch.log(1 - pred)
        rce = rce.sum(dim=1)
        return (self.alpha * ce + self.beta * rce).mean()

In [9]:
class NCODLoss(torch.nn.Module):
    def __init__(
        self,
        dataset,
        embedding_dimensions: int = 300,
        total_epochs: int = 150,
        lambda_consistency: float = 1.0,
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.device = device
        self.embedding_dimensions = embedding_dimensions
        self.total_epochs = total_epochs
        self.lambda_consistency = lambda_consistency

        label_counts = {}
        for elem in dataset:
            y = int(elem.y)
            label_counts[y] = label_counts.get(y, 0) + 1

        self.num_elements = len(dataset)
        self.num_classes = len(label_counts)

        self.u = nn.Parameter(torch.empty(self.num_elements, 1, device=device))
        torch.nn.init.normal_(self.u, mean=1e-8, std=1e-9)

        self.past_embeddings = torch.rand(
            (self.num_elements, embedding_dimensions), device=device
        )
        self.centroids = torch.rand(
            (self.num_classes, embedding_dimensions), device=device
        )

        self.bins = [[] for _ in range(self.num_classes)]
        for idx, d in enumerate(dataset):
            self.bins[int(d.y)].append(idx)

    def forward(
        self,
        logits: torch.Tensor,     
        indexes: torch.Tensor,      
        embeddings: torch.Tensor,   
        targets: torch.Tensor,      
        epoch: int,
    ):
        eps = 1e-6
        device = logits.device

        embeddings = F.normalize(embeddings, dim=1)
        self.past_embeddings[indexes] = embeddings.detach()

        if epoch == 0:
            with torch.no_grad():
                for c in range(self.num_classes):
                    idxs = self.bins[c]
                    if idxs:
                        self.centroids[c] = self.past_embeddings[idxs].mean(0)
        else:
            percent = math.ceil((50 - (50 / self.total_epochs) * epoch) + 50)
            for c in range(self.num_classes):
                idxs = self.bins[c]
                if not idxs:
                    continue
                u_vals = self.u[idxs]
                k = int(len(u_vals) * percent / 100)
                keep = torch.topk(u_vals.squeeze(), k, largest=False).indices
                self.centroids[c] = self.past_embeddings[[idxs[i] for i in keep]].mean(0)

        cent = F.normalize(self.centroids, dim=1)
        cosine_sim = embeddings @ cent.T                      
        soft_labels = F.softmax(cosine_sim, dim=1)           

        # ---- prediction adjustment ---------------------------------
        probs = F.softmax(logits, dim=1)                      # [B, C]
        u_vals = torch.sigmoid(self.u[indexes]).squeeze(1)    # (0,1)
        adjusted_probs = probs + u_vals.unsqueeze(1) * soft_labels
        adjusted_probs = adjusted_probs.clamp(min=eps)
        adjusted_probs = adjusted_probs / adjusted_probs.sum(1, keepdim=True)

        # ---- losses -------------------------------------------------
        # 1. discounted hard CE
        hard_ce = F.cross_entropy(logits, targets, reduction='none')
        hard_ce = ((1.0 - u_vals) * hard_ce).mean()

        # 2. soft-label CE (only on the target column, as in the paper)
        tgt_onehot = F.one_hot(targets, num_classes=self.num_classes).float()
        soft_loss = -torch.sum(
            (tgt_onehot * soft_labels) * torch.log(adjusted_probs),
            dim=1
        ).mean()

        # 3. consistency
        loss_cons = F.mse_loss(adjusted_probs, soft_labels)

        return hard_ce + soft_loss + self.lambda_consistency * loss_cons


In [None]:
def train(
    data_loader,
    model,
    optimizer_theta,
    optimizer_u,
    criterion,          # may be NCODLoss
    device,
    save_checkpoints,
    checkpoint_path,
    current_epoch=0,
):
    model.train()
    total_loss = total_conf = total_entropy = 0.0
    correct = total = 0

    for batch_idx, data in enumerate(data_loader):
        data = data.to(device)
        
        node_emb = model.gnn_node(data)
        graph_emb = model.pool(node_emb, data.batch)
        logits = model.graph_pred_linear(graph_emb)
        probs = F.softmax(logits, dim=1)

        optimizer_theta.zero_grad()
        optimizer_u.zero_grad()
        if isinstance(criterion, NCODLoss):
            
            loss = criterion(
                logits=logits,
                indexes=data.idx.to(device),
                embeddings=graph_emb,
                targets=data.y.to(device),
                epoch=current_epoch,
            )
        else:
            loss = criterion(logits, data.y)

        
        loss.backward()
        optimizer_theta.step()
        optimizer_u.step()
        

        
        total_loss += loss.item()
        pred = probs.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        total += data.y.size(0)
        total_conf += probs.max(1).values.sum().item()
        total_entropy += (-torch.sum(probs * torch.log(probs + 1e-10), 1)).sum().item()

    
    if save_checkpoints:
        ckpt_file = f"{checkpoint_path}_epoch_{current_epoch + 1}.pth"
        torch.save(model.state_dict(), ckpt_file)
        print(f"[checkpoint] saved: {ckpt_file}")

    avg_loss = total_loss / len(data_loader)
    avg_conf = total_conf / total
    avg_entropy = total_entropy / total
    accuracy = correct / total
    return avg_loss, avg_conf, accuracy, avg_entropy


In [11]:
def evaluate(
    data_loader,
    model,
    criterion,
    device,
):
    model.eval()
    total_loss = 0.0
    total_confidence = 0.0
    total_entropy = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data in data_loader:
            data = data.to(device)

            # Forward pass
            node_emb = model.gnn_node(data)
            graph_emb = model.pool(node_emb, data.batch)
            logits = model.graph_pred_linear(graph_emb)
            probs = F.softmax(logits, dim=1)

            # Compute loss (match train logic)
            if isinstance(criterion, NCODLoss):
                indexes = (
                    data.idx
                    if hasattr(data, "idx")
                    else torch.arange(total, total + logits.size(0))
                )
                loss = criterion(
                    logits=logits,
                    indexes=indexes.to(device),
                    embeddings=graph_emb,
                    epoch=-1,
                )
            else:
                loss = criterion(logits, data.y)

            total_loss += loss.item()

            # Metrics
            pred = probs.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total += data.y.size(0)

            total_confidence += probs.max(dim=1).values.sum().item()
            total_entropy += (
                (-torch.sum(probs * torch.log(probs + 1e-10), dim=1)).sum().item()
            )

    avg_loss = total_loss / len(data_loader)
    avg_conf = total_confidence / total
    avg_entropy = total_entropy / total
    accuracy = correct / total

    return (
        avg_loss,
        avg_conf,
        accuracy,
        avg_entropy,
    )


In [10]:
# def predict(data_loader, model, device, criterion=None):
#     model.eval()
#     correct = 0
#     total = 0
#     total_loss = 0.0
#     total_confidence = 0.0
#     total_entropy = 0.0
#     predictions = []

#     use_ce_loss = isinstance(criterion, torch.nn.CrossEntropyLoss)

#     with torch.no_grad():
#         for data in data_loader:
#             data = data.to(device)

#             # Forward
#             node_emb = model.gnn_node(data)
#             graph_emb = model.pool(node_emb, data.batch)
#             logits = model.graph_pred_linear(graph_emb)
#             probs = F.softmax(logits, dim=1)

#             pred = probs.argmax(dim=1)
#             predictions.extend(pred.cpu().numpy())

#             if calculate_accuracy:
#                 correct += (pred == data.y).sum().item()
#                 total += data.y.size(0)

#                 if use_ce_loss:
#                     total_loss += criterion(logits, data.y).item()

#             # Confidence = max prob
#             total_confidence += probs.max(dim=1).values.sum().item()

#             # Entropy = -∑p log p
#             entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=1)  # [B]
#             total_entropy += entropy.sum().item()

#     if calculate_accuracy:
#         accuracy = correct / total if total > 0 else 0.0
#         avg_loss = total_loss / len(data_loader) if use_ce_loss else None
#         avg_conf = total_confidence / total
#         avg_entropy = total_entropy / total

#         return {
#             "accuracy": accuracy,
#             "cross_entropy_loss": avg_loss,
#             "avg_confidence": avg_conf,
#             "avg_entropy": avg_entropy,
#         }

#     return predictions


In [12]:
def objective(
    trial,
    train_loader,
    val_loader,
    num_checkpoints,
    checkpoints_dir,
    run_name,
    best_model_path,
    logs_dir,
    *,
    score_type: Literal["loss", "accuracy", "entropy", "confidence"] = "confidence",
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):  # -> float | Any:
    logging.info("#" * 80)
    # Hyperparameter search space
    logging.info("Start case study with parameters:")
    gnn_type = "gin-virtual"
    # trial.suggest_categorical(
    #     "gnn_type", ["gin", ]
    # )
    drop_ratio = 0.0  # trial.suggest_float("dropout", 0.0, 0.7)
    num_layers = 5  # trial.suggest_int("num_layers", 6, 12)
    embedding_dim = 300  # trial.suggest_categorical("embedding_dim", [64, 128, 300])
    num_epochs = 60  # trial.suggest_int("num_epochs", 10, 11, step=1)

    logging.info(f"{gnn_type=}")
    logging.info(f"{drop_ratio=}")
    logging.info(f"{num_layers=}")
    logging.info(f"{embedding_dim=}")
    logging.info(f"{num_epochs=}")

    # Initialize model
    model = GNN(
        gnn_type="gin" if "gin" in gnn_type else "gcn",
        num_class=6,
        num_layer=num_layers,
        emb_dim=embedding_dim,
        drop_ratio=drop_ratio,
        virtual_node="virtual" in gnn_type,
    ).to(device)
    # NCOD by default — switch to CE for testing if needed
    criterion = NCODLoss(
        train_loader.dataset,
        embedding_dimensions=embedding_dim,
        total_epochs=num_epochs,
    ).to(device)
    optimizer_theta = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-5)
    optimizer_u     = torch.optim.SGD(criterion.parameters(), lr=1e-3)

    # Checkpoint logic
    checkpoint_epochs = [
        int((i + 1) * num_epochs / num_checkpoints) for i in range(num_checkpoints)
    ]

    best_val_score = -float("inf")
    train_losses, train_confs, train_accs, train_entropies, train_scores = (
        [],
        [],
        [],
        [],
        [],
    )
    val_losses, val_confs, val_accs, val_entropies, val_scores = [], [], [], [], []

    progress_bar = tqdm(range(num_epochs), desc="Training...", leave=False)
    for epoch in progress_bar:
        train_loss, train_conf, train_acc, train_entropy = train(
            train_loader,
            model,
            optimizer_theta,
            optimizer_u,
            criterion,
            device,
            save_checkpoints=(epoch + 1 in checkpoint_epochs),
            checkpoint_path=os.path.join(checkpoints_dir, f"model_{run_name}"),
            current_epoch=epoch,
        )

        val_loss, val_conf, val_acc, val_entropy = evaluate(
            val_loader,
            model,
            criterion,
            device,
        )
        
        tqdm.write(
            f"Epoch {epoch}/{num_epochs}\n"
            f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Conf: {train_conf:.4f}, Entropy: {train_entropy:.4f}\n"
            f"Val Loss  : {val_loss:.4f}, Acc: {val_acc:.4f}, Conf: {val_conf:.4f}, Entropy: {val_entropy:.4f}\n"
        
        )
        logging.info(
            f"Epoch {epoch}/{num_epochs}\n"
            f"\tTrain Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Conf: {train_conf:.4f}, Entropy: {train_entropy:.4f}\n"
            f"\tVal Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, Conf: {val_conf:.4f}, Entropy: {val_entropy:.4f}\n"
        )

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        train_confs.append(train_conf)
        train_entropies.append(train_entropy)

        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_confs.append(val_conf)
        val_entropies.append(val_entropy)

        if score_type == "loss":
            train_score = train_loss
            val_score = val_loss
        elif score_type == "entropy":
            train_score = train_entropy
            val_score = val_entropy
        elif score_type == "confidence":
            train_score = train_conf
            val_score = val_conf
        else:
            train_score = train_acc
            val_score = val_acc

        train_scores.append(train_score)
        val_scores.append(val_score)
        if val_score > best_val_score:
            best_val_score = val_score
            torch.save(model.state_dict(), best_model_path)
            logging.info(f"[{run_name}] Best model updated at {best_model_path}")

    # Plot training curves
    plot_training_progress(
        train_losses, train_scores, os.path.join(logs_dir, "train_plots")
    )
    plot_training_progress(val_losses, val_scores, os.path.join(logs_dir, "val_plots"))

    logging.info(f"Case study end, {best_val_score}")
    logging.info("#" * 80)
    logging.info("\n")

    return best_val_score

In [13]:
class IndexedSubset(Dataset):
    def __init__(self, subset: Subset):
        self.subset = subset
    def __len__(self):  return len(self.subset)
    def __getitem__(self, i):
        data = self.subset[i]
        data.idx = torch.tensor(i, dtype=torch.long)   # permanent id in 0…len-1
        return data

In [14]:
def case_study(
    dataset_name: Literal["A", "B", "C", "D"],
    n_trials: int = 30,
    num_checkpoints: int = 10,
    default_batch_size: int = 32,
    score_type: Literal["loss", "accuracy", "entropy", "confidence"] = "confidence",
    full_dataset=None,
):
    script_root = os.getcwd()
    train_path = f"./datasets/{dataset_name}/train.json.gz"
    run_name = dataset_name

    logs_dir = os.path.join(script_root, "logs", run_name)
    os.makedirs(logs_dir, exist_ok=True)
    logging.basicConfig(
        filename=os.path.join(logs_dir, "training.log"),
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        filemode="w",
    )
    # logging.getLogger().addHandler(logging.StreamHandler())

    checkpoints_dir = os.path.join(script_root, "checkpoints", run_name)
    best_model_path = os.path.join(checkpoints_dir, f"model_{run_name}_best.pth")
    summary_csv_path = os.path.join(logs_dir, f"optuna_summary_{dataset_name}.csv")
    os.makedirs(checkpoints_dir, exist_ok=True)
    if full_dataset is None:
        full_dataset = GraphDataset(train_path, transform=add_zeros)
    val_size = int(0.2 * len(full_dataset))
    train_size = len(full_dataset) - val_size
    generator = torch.Generator().manual_seed(12)
    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, val_size], generator=generator
    )
    train_dataset = IndexedSubset(train_dataset)
    val_dataset = IndexedSubset(val_dataset)

    train_loader = DataLoader(
        train_dataset,  # type:ignore
        batch_size=default_batch_size,
        shuffle=True,
    )
    
    val_loader = DataLoader(
        val_dataset,  # type:ignore
        batch_size=default_batch_size,
        shuffle=False,
    )

    print(f"Starting Optuna optimization for dataset {dataset_name}")

    study = optuna.create_study(study_name=run_name, direction="minimize" if score_type =="loss" else "maximize")

    obj = partial(
        objective,
        train_loader=train_loader,
        val_loader=val_loader,
        num_checkpoints=num_checkpoints,
        checkpoints_dir=checkpoints_dir,
        run_name=run_name,
        best_model_path=best_model_path,
        logs_dir=logs_dir,
        score_type="confidence",
    )
    study.optimize(obj, n_trials=n_trials)

    all_trials = []
    for trial in study.trials:
        if trial.state == optuna.trial.TrialState.COMPLETE:
            row = {score_type: trial.value}
            row.update(trial.params)
            all_trials.append(row)

    results_df = pd.DataFrame(all_trials)
    results_df.to_csv(summary_csv_path, index=False)

    print(f"\nAll trials saved to: {summary_csv_path}")
    print(f"\nBest result for dataset {dataset_name}:")
    display(results_df.sort_values(score_type, ascending=False))
    print(f"\nBest Params for {dataset_name}:")
    for k, v in study.best_params.items():
        print(f"  {k}: {v}")

    del train_loader, val_loader, full_dataset, train_dataset, val_dataset
    gc.collect()


In [15]:
dataset_a = GraphDataset("./datasets/A/train.json.gz",transform=add_zeros)

In [17]:
case_study("A", 1,full_dataset = dataset_a) 
# case_study("B", 1)
# case_study("C", 1)
# case_study("D", 1)

[I 2025-05-26 16:45:48,557] A new study created in memory with name: A
[W 2025-05-26 16:45:48,578] Trial 0 failed with parameters: {} because of the following error: RuntimeError('CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n').
Traceback (most recent call last):
  File "/home/haislich/Documents/noisy_labels/.venv/lib/python3.12/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipykernel_3360/2992990716.py", line 40, in objective
    ).to(device)
      ^^^^^^^^^^
  File "/home/haislich/Documents/noisy_labels/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1355, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  

Starting Optuna optimization for dataset A


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
