In [1]:
import torch
print(f"Current PyTorch version: {torch.__version__}")
print(f"Current CUDA version: {torch.version.cuda}")

!pip uninstall -y torch torchvision torchaudio
# PyTorch 2.5.0 with CUDA 12.4
!pip install torch==2.5.0+cu124 torchvision==0.20.0+cu124 torchaudio==2.5.0+cu124 --index-url https://download.pytorch.org/whl/cu124
# now geometric stuff
#need ro restart session

Current PyTorch version: 2.6.0+cu124
Current CUDA version: 12.4
Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Looking in indexes: https://download.pytorch.org/whl/cu124
Collecting torch==2.5.0+cu124
  Downloading https://download.pytorch.org/whl/cu124/torch-2.5.0%2Bcu124-cp311-cp311-linux_x86_64.whl (908.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m908.3/908.3 MB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.20.0+cu124
  Downloading https://download.pytorch.org/whl/cu124/torchvision-0.20.0%2Bcu124-cp311-cp311-linux_x86_64.whl (7.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [1]:
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.5.0+cu124.html
!pip install torch-geometric
!pip install matplotlib seaborn PyYAML tqdm

Looking in links: https://data.pyg.org/whl/torch-2.5.0+cu124.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_scatter-2.1.2%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m94.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_sparse-0.6.18%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (5.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m61.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-scatter, torch-sparse
Successfully installed torch-scatter-2.1.2+pt25cu124 torch-sparse-0.6.18+pt25cu124
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geo

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import math
import json
from typing import List, Tuple, Dict, Optional, Union
from sklearn.manifold import TSNE
import seaborn as sns

# PyTorch Geometric imports
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import to_networkx, degree
import networkx as nx

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Model

In [3]:
class PonderGAT(nn.Module):
    """
    PonderNet-style Graph Attention Network with Gumbel-Sigmoid halting.

    We implement an adaptive computation for graph neural networks, where each node
    independently determines how many computation steps it needs through a learned halting mechanism.
    We do a node-level Bernoulli sampling thinking the node - layer decision of halting as a condition probability
    such that the actual probability of halting becomes a geometric distribution. This is as suggested in pondernet paper
    with the difference that here we use a Gumbel approach for training instead of their "expected" loss - since we claim that
    it should be more representative of the actual behaviour it will experience by inference and preserving the computational graph
    for backprop.

    In addition, we use a attention network with regularization for the message passing since the idea is that the
    attention weights should capture on each layer
    the importance of each neighbour,adding some degree of directionality as it was suggested by CO-GNN -
    although they use a completely different approach.
    """
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        out_features: int,
        num_layers: int = 5,
        heads: int = 1,
        dropout: float = 0.5,
        halting_hidden: int = 32,
        beta_kl: float = 0.01,
        beta_attn: float = 0.01,
        prior_lambda: float = 0.2,
        use_mlp_input: bool = True,
        shared_gat: bool = False,
        gumbel_temperature: float = 1.0,
        use_residual: bool = True,
        residual_alpha: float = 0.1
    ):
        super(PonderGAT, self).__init__()

        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.num_layers = num_layers
        self.heads = heads
        self.beta_kl = beta_kl
        self.beta_attn = beta_attn
        self.prior_lambda = prior_lambda
        self.shared_gat = shared_gat
        self.gumbel_temperature = gumbel_temperature
        self.use_residual = use_residual
        self.residual_alpha = residual_alpha

        # hidden dimension considering multi-head concatenation
        self.hidden_dim = hidden_features * heads

        # first projection MLP: maps x_i to h_i^(0)
        if use_mlp_input:
            self.input_mlp = nn.Sequential(
                nn.Linear(in_features, 2*self.hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(2*self.hidden_dim, self.hidden_dim)
            )
        else:
            self.input_mlp = nn.Linear(in_features, self.hidden_dim)

        # GAT layers - perform message passing with attention
        if shared_gat:
            # Single GAT layer applied repeatedly -> in case we want so reduce the use of wieghts
            self.gat_layers = nn.ModuleList([
                GATConv(
                    in_channels=self.hidden_dim,
                    out_channels=hidden_features,
                    heads=heads,
                    concat=True,
                    dropout=dropout,
                    add_self_loops=True
                )
            ])
        else:
            # Separate GAT layer for each step
            self.gat_layers = nn.ModuleList([
                GATConv(
                    in_channels=self.hidden_dim,
                    out_channels=hidden_features,
                    heads=heads,
                    concat=True,
                    dropout=dropout,
                    add_self_loops=True
                ) for _ in range(num_layers)
            ])

        # Halting networks - one for each layer
        self.halting_networks = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.hidden_dim + 1, halting_hidden),  # +1 for distance
                nn.ReLU(),
                nn.Linear(halting_hidden, 1)
            ) for _ in range(num_layers)
        ])

        # Classification networks - one for each layer
        self.classification_networks = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.hidden_dim, self.hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(self.hidden_dim, out_features)
            ) for _ in range(num_layers)
        ])


        self.dropout = nn.Dropout(dropout)
        self.reset_parameters()

    def reset_parameters(self):
        """i use the same idea of the authors for initializing the halting network with the 1/n+1 expectation."""
        if isinstance(self.input_mlp, nn.Sequential):
            for module in self.input_mlp:
                if isinstance(module, nn.Linear):
                    nn.init.xavier_uniform_(module.weight)
                    if module.bias is not None:
                        nn.init.zeros_(module.bias)
        else:
            nn.init.xavier_uniform_(self.input_mlp.weight)
            nn.init.zeros_(self.input_mlp.bias)

        x = (self.num_layers + 1)
        b = math.log((1/x)/(1-(1/x)))
        for halt_net in self.halting_networks:
            halt_net[-1].bias.data.fill_(b)

    def gumbel_sigmoid(self, logits, temperature=1.0, hard=False, eps=1e-10):
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + eps) + eps)

        y_soft = torch.sigmoid((logits + gumbel_noise) / temperature)

        if hard:
            y_hard = (y_soft > 0.5).float()
            # straight-through trick: use hard values in forward, but soft gradient in backward
            return y_hard - y_soft.detach() + y_soft

        return y_soft

    def forward(self, x, edge_index, training=True):
        device = x.device
        batch_size = x.size(0)

        h_0 = self.input_mlp(x)

        # halting data for decision
        not_halted_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
        node_outputs = torch.zeros(batch_size, self.out_features, device=device)
        halting_layers = torch.zeros(batch_size, dtype=torch.long, device=device)

        all_logits = []
        all_lambdas = []
        all_attention = []

        # embedding for distance calculation
        h = h_0
        h_prev = h_0

        for l in range(self.num_layers):
            gat_layer = self.gat_layers[0] if self.shared_gat else self.gat_layers[l]

            # message passing that updates all node representations (even halted ones) -> we will ignore them
            h_new_all, attn_weights = gat_layer(h, edge_index, return_attention_weights=True)
            h_new_all = F.elu(h_new_all)

            # layer norm -> authors did it and seems to actually help a lot!
            mean = h_new_all.mean(dim=-1, keepdim=True)
            var = h_new_all.var(dim=-1, keepdim=True, unbiased=False)
            h_new_all = (h_new_all - mean) / torch.sqrt(var + 1e-5)

            # in case we want to use it but i am not doing it by default. just wanted to check if it will increase by a lot but i dont see that .
            if self.use_residual:
                h_new_all = h_new_all + self.residual_alpha * h

            # update non-halted nodes, keep previous embedding for halted nodes
            h_new = torch.where(not_halted_mask.unsqueeze(1), h_new_all, h)
            h_new = self.dropout(h_new)

            # normalized embedding distance for all nodes
            distances = torch.norm(h_new - h_prev, dim=1, keepdim=True) / math.sqrt(self.hidden_dim)

            #halting probabilities for all nodes
            halting_input = torch.cat([h_new, distances], dim=1)
            raw_lambda = self.halting_networks[l](halting_input).squeeze(-1)
            lambda_l = torch.sigmoid(raw_lambda)

            # classification logits for all nodes
            logits_l = self.classification_networks[l](h_new)

            # and we store the loss computation
            all_logits.append(logits_l)
            all_lambdas.append(lambda_l)
            all_attention.append(attn_weights)

            # for the nodes that haven't halted yet, we need to decide whether to halt at this layer
            if not_halted_mask.any():
                active_indices = torch.nonzero(not_halted_mask, as_tuple=True)[0]
                active_lambda = lambda_l[active_indices] #prob of active nodes

                if training:
                    # for training, we use Gumbel-Sigmoid for differentiable sampling
                    active_logits = raw_lambda[active_indices]
                    gumbel_samples = self.gumbel_sigmoid(
                        active_logits,
                        temperature=self.gumbel_temperature,
                        hard=True
                    )
                    halt_decisions = (gumbel_samples > 0.5)
                else:
                    # in evaluation, we use direct Bernoulli sampling
                    halt_decisions = torch.bernoulli(active_lambda).bool()

                # nodes that halt at this layer
                halting_indices = active_indices[halt_decisions]

                if len(halting_indices) > 0:
                    node_outputs_new = node_outputs.clone()
                    node_outputs_new[halting_indices] = logits_l[halting_indices]
                    node_outputs = node_outputs_new #stored logits

                    halting_layers_new = halting_layers.clone()
                    halting_layers_new[halting_indices] = l + 1
                    halting_layers = halting_layers_new  #storing halting layer.

                    not_halted_mask_new = not_halted_mask.clone()
                    not_halted_mask_new[halting_indices] = False
                    not_halted_mask = not_halted_mask_new #and we update the halted mask

            # and we reupdate h and h_prev for the next layer
            h_prev = h
            h = h_new

        # now last case of reaching the end but did not halt
        if not_halted_mask.any():
            still_active = torch.nonzero(not_halted_mask, as_tuple=True)[0]

            # final layer logits and halting layer
            node_outputs_new = node_outputs.clone()
            node_outputs_new[still_active] = logits_l[still_active]
            node_outputs = node_outputs_new

            halting_layers_new = halting_layers.clone()
            halting_layers_new[still_active] = self.num_layers
            halting_layers = halting_layers_new

        # then with all the lambdas we can build p_n for loss computation -> the geometric.
        lambda_matrix = torch.stack(all_lambdas, dim=1)  # [N, L]
        p_n = torch.zeros(batch_size, self.num_layers, device=device)

        # p_n calculation
        one_minus_lambda = 1 - lambda_matrix
        cumprod = torch.cumprod(one_minus_lambda, dim=1)

        p_n[:, 0] = lambda_matrix[:, 0] #first layer.
        # next: p_n[l] = lambda[l] * prod(1-lambda[j]) for j=0...l-1
        for l in range(1, self.num_layers):
            p_n[:, l] = lambda_matrix[:, l] * cumprod[:, l-1]

        # and the normalization to 1
        sum_p = p_n.sum(dim=1, keepdim=True)
        leftover = (1.0 - sum_p).clamp(min=0.0)
        p_n[:, -1] = p_n[:, -1] + leftover.squeeze(-1)

        return all_logits, p_n, all_attention, halting_layers, node_outputs

    def compute_loss(self, all_logits, p_n, labels, halting_layers, final_logits, all_attention=None, mask=None):
        device = labels.device
        # the datasets bring masks for training, val and test.
        if mask is not None:
            masked_p_n = p_n[mask]
            masked_labels = labels[mask]
            masked_final_logits = final_logits[mask]
        else:
            masked_p_n = p_n
            masked_labels = labels
            masked_final_logits = final_logits

        # 1. Task Loss - Cross Entropy on the sampled layer predictions
        task_loss = F.cross_entropy(masked_final_logits, masked_labels)

        # 2. KL Divergence with Geometric Prior
        # geometric prior distribution
        layer_indices = torch.arange(1, self.num_layers + 1, device=device).float()
        prior_probs = self.prior_lambda * ((1 - self.prior_lambda) ** (layer_indices - 1))
        prior_probs = prior_probs / prior_probs.sum()  # and norm to 1
        # actual KL divergence: p_n * log(p_n / prior)
        epsilon = 1e-10
        masked_p_n_safe = masked_p_n + epsilon
        prior_probs_safe = prior_probs + epsilon

        kl_div = masked_p_n_safe * (torch.log(masked_p_n_safe) - torch.log(prior_probs_safe.unsqueeze(0)))
        kl_loss = self.beta_kl * kl_div.sum(dim=1).mean()

        # 3. Attention Entropy Loss
        attn_loss = torch.tensor(0.0, device=device)

        if all_attention is not None and self.beta_attn > 0:
            total_entropy = 0.0
            count = 0

            for attn_tuple in all_attention:
                _, attn_weights = attn_tuple
                if isinstance(attn_weights, list):
                    attn_weights = attn_weights[0]

                # sadly we need non-zero values for log calculation
                attn_weights = torch.clamp(attn_weights, min=1e-10)
                entropy = -(attn_weights * torch.log(attn_weights)).sum()
                total_entropy += entropy
                count += attn_weights.numel()

            if count > 0:
                attn_loss = self.beta_attn * (total_entropy / count)

        # TOTAL LOSS
        total_loss = task_loss + kl_loss + attn_loss

        loss_dict = {
            'task_loss': task_loss.item(),
            'kl_loss': kl_loss.item(),
            'attn_loss': attn_loss.item(),
            'total_loss': total_loss.item()
        }

        return total_loss, loss_dict

    def inference(self, x, edge_index):
        """
        bernoulli sampling at each layer.
        """
        with torch.no_grad():
            _, _, _, halting_layers, final_logits = self.forward(
                x, edge_index, training=False
            )

            final_preds = final_logits.argmax(dim=1)

        return final_preds, final_logits, halting_layers

## Training

In [4]:
def train_epoch(model, data, optimizer, device):
    """
    one epoch training using gumbal
    """
    model.train()
    optimizer.zero_grad()

    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    labels = data.y.to(device)
    mask = data.train_mask.to(device)

    # forward pass with Gumbel sampling
    all_logits, p_n, all_attention, halting_layers, final_logits = model(
        x, edge_index, training=True
    )

    # loss
    loss, loss_dict = model.compute_loss(
        all_logits, p_n, labels, halting_layers,
        final_logits, all_attention, mask
    )
    loss.backward()
    optimizer.step()

    # accuracy on the halting layer logits
    pred = final_logits.argmax(dim=1)
    correct = pred[mask].eq(labels[mask]).sum().item()
    total = mask.sum().item()
    acc = correct / total if total > 0 else 0

    # average halting layer
    avg_layers = halting_layers[mask].float().mean().item()

    return loss.item(), loss_dict, acc, avg_layers


def evaluate(model, data, mask_name, device):
    model.eval()

    x = data.x.to(device)
    edge_index = data.edge_index.to(device)
    labels = data.y.to(device)

    if mask_name == 'train':
        mask = data.train_mask.to(device)
    elif mask_name == 'val':
        mask = data.val_mask.to(device)
    elif mask_name == 'test':
        mask = data.test_mask.to(device)
    else:
        raise ValueError(f"Unknown mask: {mask_name}")

    with torch.no_grad():
        pred, final_logits, halting_layers = model.inference(x, edge_index)

        correct = pred[mask].eq(labels[mask]).sum().item()
        total = mask.sum().item()
        acc = correct / total if total > 0 else 0

        avg_layers = halting_layers[mask].float().mean().item()

    return acc, avg_layers

## Visualization

In [5]:
def visualize_training_curves(train_stats, dataset_name, save_dir):
    """
    Plot training curves including loss, accuracy, and average layers.
    """
    os.makedirs(save_dir, exist_ok=True)
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Plot 1: Training loss components
    axes[0, 0].plot(train_stats['train_loss'], 'b-', label='Total Loss')
    if 'task_loss' in train_stats:
        axes[0, 0].plot(train_stats['task_loss'], 'g-', label='Task Loss')
    if 'kl_loss' in train_stats:
        axes[0, 0].plot(train_stats['kl_loss'], 'r-', label='KL Loss')
    if 'attn_loss' in train_stats:
        axes[0, 0].plot(train_stats['attn_loss'], 'y-', label='Attn Loss')
    axes[0, 0].set_title('Loss Components')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Plot 2: Training and validation accuracy
    axes[0, 1].plot(train_stats['train_acc'], 'b-', label='Train')
    axes[0, 1].plot(train_stats['val_acc'], 'r-', label='Val')
    axes[0, 1].set_title('Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Plot 3: Average layers
    axes[1, 0].plot(train_stats['train_layers'], 'b-', label='Train')
    axes[1, 0].plot(train_stats['val_layers'], 'r-', label='Val')
    axes[1, 0].set_title('Average Layers')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Layers')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # Plot 4: Loss curve
    axes[1, 1].semilogy(train_stats['train_loss'], 'b-', label='Train')
    axes[1, 1].set_title('Loss (log scale)')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(f'{save_dir}/training_curves_{dataset_name}.png', dpi=300)
    plt.close()


def visualize_layer_distribution(halting_layers, num_layers, dataset_name, save_dir):
    """
    distribution of halting layers across nodes.
    """
    os.makedirs(save_dir, exist_ok=True)

    plt.figure(figsize=(10, 6))
    halting_np = halting_layers.cpu().numpy()

    layer_counts = np.bincount(halting_np, minlength=num_layers+1)[1:] # count nodes halting at each layer

    plt.bar(range(1, num_layers+1), layer_counts, alpha=0.7, width=0.7, color='skyblue', edgecolor='navy')

    for i, count in enumerate(layer_counts):
        plt.text(i+1, count + max(layer_counts)*0.01, str(count), ha='center')

    plt.title(f'Distribution of Halting Layers - {dataset_name}')
    plt.xlabel('Layer')
    plt.ylabel('Number of Nodes')
    plt.xticks(range(1, num_layers+1))
    plt.grid(True, axis='y', alpha=0.3)
    plt.figtext(0.5, 0.01,
               f'Mean Layer: {halting_np.mean():.2f} | Std Dev: {halting_np.std():.2f} | Min: {halting_np.min()} | Max: {halting_np.max()}',
               ha='center', bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))

    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    plt.savefig(f'{save_dir}/layer_distribution_{dataset_name}.png', dpi=300)
    plt.close()


## Main

In [6]:
def train_pondergat(dataset, device, model_params=None, train_params=None, save_dir='./results'):
    """
    Main function to train PonderGAT with node-level Gumbel-Sigmoid halting.
    """
    os.makedirs(save_dir, exist_ok=True)
    data = dataset[0].to(device)

    if model_params is None:
        model_params = {
            'hidden_features': 64,
            'num_layers': 5,
            'heads': 8,
            'dropout': 0.4,
            'halting_hidden': 64,
            'beta_kl': 0.005,
            'beta_attn': 0.001,
            'prior_lambda': 0.05,
            'use_mlp_input': True,
            'shared_gat': False,
            'gumbel_temperature': 1.0,
            'use_residual': False,
            'residual_alpha': 0.1
        }

    if train_params is None:
        train_params = {
            'lr': 0.001,
            'weight_decay': 5e-5,
            'epochs': 1000,
            'patience': 500,
            'anneal_rate': 0.003
        }

    model = PonderGAT(
        in_features=dataset.num_node_features,
        hidden_features=model_params['hidden_features'],
        out_features=dataset.num_classes,
        num_layers=model_params['num_layers'],
        heads=model_params['heads'],
        dropout=model_params['dropout'],
        halting_hidden=model_params['halting_hidden'],
        beta_kl=model_params['beta_kl'],
        beta_attn=model_params['beta_attn'],
        prior_lambda=model_params['prior_lambda'],
        use_mlp_input=model_params['use_mlp_input'],
        shared_gat=model_params['shared_gat'],
        gumbel_temperature=model_params['gumbel_temperature'],
        use_residual=model_params['use_residual'],
        residual_alpha=model_params['residual_alpha']
    ).to(device)

    print(model)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params}")

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=train_params['lr'],
        weight_decay=train_params['weight_decay']
    )

    # scheduler with warmup and decay
    def lr_lambda(epoch):
        if epoch < 20:
            return epoch / 20.0
        else:
            return max(0.1, 1.0 * (0.98 ** (epoch - 20)))  # exp decay

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    train_stats = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_layers': [], 'val_layers': [],
        'task_loss': [], 'kl_loss': [],
        'attn_loss': []
    }


    best_val_acc = 0
    best_epoch = 0
    patience_counter = 0
    best_model_path = f'{save_dir}/best_model_{dataset.name}.pt'

    # Training loop:
    start_time = time.time()
    print(f"Starting training on {dataset.name} dataset")

    for epoch in range(1, train_params['epochs'] + 1):
        # anneal Gumbel-Sigmoid temperature
        model.gumbel_temperature = max(
            0.1,  # min temperature
            model_params['gumbel_temperature'] * np.exp(-train_params['anneal_rate'] * epoch)
        )
        #train
        train_loss, loss_dict, train_acc, train_layers = train_epoch(model, data, optimizer, device)

        #eval
        val_acc, val_layers = evaluate(model, data, 'val', device)

        #update learning rate
        scheduler.step()

        train_stats['train_loss'].append(train_loss)
        train_stats['val_loss'].append(loss_dict['total_loss'])
        train_stats['train_acc'].append(train_acc)
        train_stats['val_acc'].append(val_acc)
        train_stats['train_layers'].append(train_layers)
        train_stats['val_layers'].append(val_layers)
        train_stats['task_loss'].append(loss_dict['task_loss'])
        train_stats['kl_loss'].append(loss_dict['kl_loss'])
        train_stats['attn_loss'].append(loss_dict['attn_loss'])

        # for early stopping improvement
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch
            patience_counter = 0

            torch.save(model.state_dict(), best_model_path)
        else:
            patience_counter += 1

        # start counting patience after 50 epochs to allow exploration
        if epoch < 50:
            patience_counter = 0

        if patience_counter >= train_params['patience']:
            print(f"Early stopping after {epoch} epochs")
            break

        # progress
        if epoch % 10 == 0 or epoch == 1:
            print(f"Epoch: {epoch:03d}, Loss: {train_loss:.4f}, "
                  f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, "
                  f"Train Layers: {train_layers:.2f}, Val Layers: {val_layers:.2f}, "
                  f"Temp: {model.gumbel_temperature:.4f}")

    total_time = time.time() - start_time
    print(f"Total training time: {total_time:.2f} seconds")

    # final evaluation
    model.load_state_dict(torch.load(best_model_path))

    train_acc, train_layers = evaluate(model, data, 'train', device)
    val_acc, val_layers = evaluate(model, data, 'val', device)
    test_acc, test_layers = evaluate(model, data, 'test', device)

    print(f"\nFinal Results - {dataset.name}:")
    print(f"  Best epoch: {best_epoch}")
    print(f"  Train accuracy: {train_acc:.4f}, layers: {train_layers:.2f}")
    print(f"  Val accuracy: {val_acc:.4f}, layers: {val_layers:.2f}")
    print(f"  Test accuracy: {test_acc:.4f}, layers: {test_layers:.2f}")

    # halting layer distribution for visualization
    with torch.no_grad():
        _, _, halting_layers = model.inference(data.x.to(device), data.edge_index.to(device))

    visualize_training_curves(train_stats, dataset.name, save_dir)
    visualize_layer_distribution(halting_layers, model.num_layers, dataset.name, save_dir)

    # results
    results = {
        'model_params': model_params,
        'train_params': train_params,
        'final_results': {
            'train_acc': train_acc,
            'val_acc': val_acc,
            'test_acc': test_acc,
            'train_layers': train_layers,
            'val_layers': val_layers,
            'test_layers': test_layers,
            'best_epoch': best_epoch,
            'training_time': total_time
        },
        'train_stats': {k: v for k, v in train_stats.items()}
    }

    # Save results
    with open(f'{save_dir}/results_{dataset.name}.json', 'w') as f:
        # Convert numpy values to python types for JSON serialization
        results_json = {}
        for k, v in results.items():
            if isinstance(v, dict):
                results_json[k] = {}
                for kk, vv in v.items():
                    if isinstance(vv, dict):
                        results_json[k][kk] = {}
                        for kkk, vvv in vv.items():
                            if isinstance(vvv, (np.int64, np.int32, np.float64, np.float32)):
                                results_json[k][kk][kkk] = vvv.item() if hasattr(vvv, 'item') else vvv
                            else:
                                results_json[k][kk][kkk] = vvv
                    elif isinstance(vv, (np.int64, np.int32, np.float64, np.float32)):
                        results_json[k][kk] = vv.item() if hasattr(vv, 'item') else vv
                    elif isinstance(vv, list) and len(vv) > 0 and isinstance(vv[0], (np.int64, np.int32, np.float64, np.float32)):
                        results_json[k][kk] = [x.item() if hasattr(x, 'item') else x for x in vv]
                    else:
                        results_json[k][kk] = vv
            else:
                results_json[k] = v

        json.dump(results_json, f, indent=2)

    return model, results

def main():
    """
    Main function to run PonderGAT on citation datasets.
    """
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load dataset
    dataset_name = "Cora"  # Options: "Cora", "CiteSeer", "PubMed"
    dataset = Planetoid(root=f'./data/{dataset_name}', name=dataset_name, transform=NormalizeFeatures())

    print(f"Dataset: {dataset_name}")
    print(f"Number of nodes: {dataset[0].num_nodes}")
    print(f"Number of edges: {dataset[0].num_edges}")
    print(f"Number of features: {dataset.num_node_features}")
    print(f"Number of classes: {dataset.num_classes}")

    # Train model
    model, results = train_pondergat(dataset, device, save_dir=f'./results/{dataset_name}')

    return model, results

if __name__ == "__main__":
    main()

Using device: cuda


Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


Dataset: Cora
Number of nodes: 2708
Number of edges: 10556
Number of features: 1433
Number of classes: 7
PonderGAT(
  (input_mlp): Sequential(
    (0): Linear(in_features=1433, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.4, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
  )
  (gat_layers): ModuleList(
    (0-4): 5 x GATConv(512, 64, heads=8)
  )
  (halting_networks): ModuleList(
    (0-4): 5 x Sequential(
      (0): Linear(in_features=513, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
  )
  (classification_networks): ModuleList(
    (0-4): 5 x Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.4, inplace=False)
      (3): Linear(in_features=512, out_features=7, bias=True)
    )
  )
  (dropout): Dropout(p=0.4, inplace=False)
)
Total parameters: 4807656
Starting training on Cora dataset
Epoch: 001, Loss:

  model.load_state_dict(torch.load(best_model_path))
