# Adaptive Propagation Graph Convolutional Network (AP-GCN)


In [1]:
import torch
print(f"Current PyTorch version: {torch.__version__}")
print(f"Current CUDA version: {torch.version.cuda}")

!pip uninstall -y torch torchvision torchaudio
# PyTorch 2.5.0 with CUDA 12.4
!pip install torch==2.5.0+cu124 torchvision==0.20.0+cu124 torchaudio==2.5.0+cu124 --index-url https://download.pytorch.org/whl/cu124
# now geometric stuff
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.5.0+cu124.html
!pip install torch-geometric
!pip install matplotlib seaborn PyYAML tqdm
#need ro restart session

Current PyTorch version: 2.6.0+cu124
Current CUDA version: 12.4
Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Looking in indexes: https://download.pytorch.org/whl/cu124
Collecting torch==2.5.0+cu124
  Downloading https://download.pytorch.org/whl/cu124/torch-2.5.0%2Bcu124-cp311-cp311-linux_x86_64.whl (908.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m908.3/908.3 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.20.0+cu124
  Downloading https://download.pytorch.org/whl/cu124/torchvision-0.20.0%2Bcu124-cp311-cp311-linux_x86_64.whl (7.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Looking in links: https://data.pyg.org/whl/torch-2.5.0+cu124.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_scatter-2.1.2%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m83.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_sparse-0.6.18%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (5.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m102.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-scatter, torch-sparse
Successfully installed torch-scatter-2.1.2+pt25cu124 torch-sparse-0.6.18+pt25cu124
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_ge

In [1]:
import torch
print(f"Current PyTorch version: {torch.__version__}")
print(f"Current CUDA version: {torch.version.cuda}")

Current PyTorch version: 2.5.0+cu124
Current CUDA version: 12.4


In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch.nn import ModuleList, Dropout, ReLU, Linear
from torch_geometric.nn import GCNConv, MessagePassing
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import dropout_adj, to_networkx, degree, remove_self_loops, dropout_edge
from torch_geometric.utils import add_self_loops as add_self_loops_fn
from torch_geometric.data import Data, InMemoryDataset
from torch_sparse import SparseTensor, matmul, fill_diag, sum as sparse_sum, mul
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from sklearn.decomposition import PCA
from matplotlib.colors import LinearSegmentedColormap, Normalize
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from matplotlib.cm import ScalarMappable
import os
import time
from typing import List

# seed plus gpu setting.
torch.manual_seed(4143496719)
np.random.seed(4143496719)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


## Model

In [3]:
class AdaptivePropagation(MessagePassing):
    """
    Adaptive Propagation layer that allows each node to determine
    its own optimal number of propagation steps.

    This implements the core adaptive halting mechanism described in the paper.
    """
    def __init__(self, niter: int, h_size: int, bias = True, **kwargs):
        """
        Adaptive propagation layer.

        where:
            niter: max number of propagation steps (T in the paper)
            h_size: size of the node embeddings
            bias: if to add a bias in the halting unit
        """
        super(AdaptivePropagation, self).__init__(aggr='add', **kwargs)

        self.niter = niter
        self.halt = Linear(h_size, 1) # halting unit (Q and q in equation 6)

        self.reg_params = list(self.halt.parameters()) #halting params
        self.dropout = Dropout()

        # normalization params for the GCN layer norm they do in their code, needed to adapt for the new version.
        self.improved = False
        self.add_self_loops = True

        # init params
        self.reset_parameters()

    def reset_parameters(self):
        """
        bias around 1/n+1 -> check my paper comments. it is easy to show that after passing
        through the sigmoid, we get that the probability takes a value around 1/n+1
        """
        self.halt.reset_parameters()

        x = (self.niter+1) // 1
        b = math.log((1/x)/(1-(1/x)))
        self.halt.bias.data.fill_(b)

    def forward(self, local_preds: torch.FloatTensor, edge_index):
        """
        local_preds: node embeddings from local prediction network
        edge_index: graph connectivity in COO format

        returns:
            Updated node embeddings, number of steps, and remainders
        """
        sz = local_preds.size(0) #num of nodes.

        steps = torch.ones(sz).to(local_preds.device)  # steps for each node (K_i)
        sum_h = torch.zeros(sz).to(local_preds.device)  # accum halting probs
        continue_mask = torch.ones(sz, dtype=torch.bool).to(local_preds.device)  # active nodes
        x = torch.zeros_like(local_preds).to(local_preds.device)  # embeddings

        # dropout of embedding.
        prop = self.dropout(local_preds)

        # propagation loop
        for _ in range(self.niter):
            old_prop = prop #h^(t-1)

            continue_fmask = continue_mask.float().to(local_preds.device)
            drop_edge_index, _ = dropout_edge(edge_index, p=0.5, training=self.training) #default is 0.5 as they did.

            # GCN normalization using the util that is now available.  -> https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gcn_conv.html#GCNConv
            edge_index_norm, norm = gcn_norm(
                drop_edge_index, None,
                sz, self.improved,
                self.add_self_loops,
                self.flow, local_preds.dtype
            )

            prop = self.propagate(edge_index_norm, x=prop, norm=norm)
            h = torch.sigmoid(self.halt(prop)).t().squeeze() # h^k_i = σ(Qz^k_i + q)

            # Handle dimension issues for single-node graphs
            # if h.dim() == 0:
            #     h = h.unsqueeze(0)

            # here we do the soft update based on equation 7
            # K_i = min{k : ∑(j=1 to k) h^j_i >= 1 - ε}
            # 0.99 is equivalent to (1 - ε) where ε = 0.01
            prob_mask = (((sum_h+h) < 0.99) & continue_mask).squeeze()
            prob_fmask = prob_mask.float().to(local_preds.device)

            # we add another step for those nodes that continue and that the accum prob is lower than threshold.
            steps = steps + prob_fmask
            sum_h = sum_h + prob_fmask * h #and update the accumulation for those nodes that continue  (otherwise the prob mask takes 0 so no update. )

            final_iter = steps < self.niter

            # if mask prob is 1 means that need to continue + did not reach the end so continue.
            condition = prob_mask & final_iter
            p = torch.where(condition, sum_h, 1-sum_h) #p^k_i according to equation 8

            # this is something they did in the code too.
            to_update = self.dropout(continue_fmask).unsqueeze(1)

            # equation 9 -> softupdate.
            # z̃_i = (1/K_i) * ∑(k=1 to K_i) p^k_i * z^k_i + (1-p^k_i) * z^(k-1)_i
            x = x + (prop * p.unsqueeze(1) + old_prop * (1-p).unsqueeze(1)) * to_update
            continue_mask = continue_mask & prob_mask

            # if all nodes halted, then stop.
            if (~continue_mask).all():
                break

        # continueation of the equation 9 (1/K_i)
        x = x / steps.unsqueeze(1)

        # updated embeddings, steps, and  R_i
        return x, steps, (1-sum_h)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

class APGCN(torch.nn.Module):
    """
    The actual Adaptive Propagation Graph Convolutional Network.
    """
    def __init__(self,
                 dataset,
                 niter=10,
                 prop_penalty=0.005,
                 hidden=[64],
                 dropout=0.5):
        """
        dataset: graph dataset
        niter: max number of propagation steps
        prop_penalty: prop penalty α in equation 11
        hidden: list of hidden layer sizes
        dropout: dropout rate
        """
        super(APGCN, self).__init__()

        num_features = [dataset.data.x.shape[1]] + hidden + [dataset.num_classes] # layer sizes.

        # as authors did, we create the mlp before prop.
        layers = []
        for in_features, out_features in zip(num_features[:-1], num_features[1:]):
            layers.append(Linear(in_features, out_features))

        # we do the propagation with the previous format.
        self.prop = AdaptivePropagation(niter, dataset.num_classes)

        self.prop_penalty = prop_penalty # alpha

        self.layers = ModuleList(layers) #mlp

        # we separate parameters into regularized and non-regularized groups -> they did this in their code.
        self.reg_params = list(layers[0].parameters())
        self.non_reg_params = list([p for l in layers[1:] for p in l.parameters()])
        self.dropout = Dropout(p=dropout)
        self.act_fn = ReLU()

        self.reset_parameters()

    def reset_parameters(self):
        self.prop.reset_parameters()
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, data, return_propagation_cost=False):
        """
        data: PyG data object containing x and edge_index
        return_propagation_cost: Whether to return the propagation cost

        returns:
            Log probabilities, number of steps, and remainders
        """
        x, edge_index = data.x, data.edge_index

        # MLP
        for i, layer in enumerate(self.layers):
            x = layer(self.dropout(x))

            #no non linearity in the last layer.
            if i == len(self.layers) - 1:
                break

            x = self.act_fn(x)

        # the adaptive propagation.
        x, steps, reminders = self.prop(x, edge_index)

        # log probabilities, steps, and remainders
        if return_propagation_cost:
            return torch.nn.functional.log_softmax(x, dim=1), steps, reminders
        return torch.nn.functional.log_softmax(x, dim=1), steps, reminders

Seeds Authors:

In [4]:
import numpy as np

def gen_seeds(size: int = None) -> np.ndarray:
    max_uint32 = np.iinfo(np.uint32).max
    return np.random.randint(
            max_uint32+1, size=size, dtype=np.uint32)

quick_seeds = [2144199730, 794209841]

test_seeds = [2144199730, 794209841, 2985733717, 2282690970, 1901557222,
        2009332812, 2266730407, 635625077, 3538425002, 960893189,
        497096336, 3940842554, 3594628340, 948012117, 3305901371,
        3644534211, 2297033685, 4092258879, 2590091101, 1694925034]

development_seed = 4143496719


Data Management Authors:

In [5]:
import warnings
from typing import Dict, Union, Tuple, Any
import numpy as np
import scipy.sparse as sp

__all__ = ['SparseGraph']

sparse_graph_properties = [
        'adj_matrix', 'attr_matrix', 'labels',
        'node_names', 'attr_names', 'class_names',
        'metadata']


class SparseGraph:
    """Attributed labeled graph stored in sparse matrix form.

    Parameters
    ----------
    adj_matrix
        Adjacency matrix in CSR format. Shape [num_nodes, num_nodes]
    attr_matrix
        Attribute matrix in CSR or numpy format. Shape [num_nodes, num_attr]
    labels
        Array, where each entry represents respective node's label(s). Shape [num_nodes]
        Alternatively, CSR matrix with labels in one-hot format. Shape [num_nodes, num_classes]
    node_names
        Names of nodes (as strings). Shape [num_nodes]
    attr_names
        Names of the attributes (as strings). Shape [num_attr]
    class_names
        Names of the class labels (as strings). Shape [num_classes]
    metadata
        Additional metadata such as text.

    """
    def __init__(
            self, adj_matrix: sp.spmatrix,
            attr_matrix: Union[np.ndarray, sp.spmatrix] = None,
            labels: Union[np.ndarray, sp.spmatrix] = None,
            node_names: np.ndarray = None,
            attr_names: np.ndarray = None,
            class_names: np.ndarray = None,
            metadata: Any = None):
        # Make sure that the dimensions of matrices / arrays all agree
        if sp.isspmatrix(adj_matrix):
            adj_matrix = adj_matrix.tocsr().astype(np.float32)
        else:
            raise ValueError("Adjacency matrix must be in sparse format (got {0} instead)."
                             .format(type(adj_matrix)))

        if adj_matrix.shape[0] != adj_matrix.shape[1]:
            raise ValueError("Dimensions of the adjacency matrix don't agree.")

        if attr_matrix is not None:
            if sp.isspmatrix(attr_matrix):
                attr_matrix = attr_matrix.tocsr().astype(np.float32)
            elif isinstance(attr_matrix, np.ndarray):
                attr_matrix = attr_matrix.astype(np.float32)
            else:
                raise ValueError("Attribute matrix must be a sp.spmatrix or a np.ndarray (got {0} instead)."
                                 .format(type(attr_matrix)))

            if attr_matrix.shape[0] != adj_matrix.shape[0]:
                raise ValueError("Dimensions of the adjacency and attribute matrices don't agree.")

        if labels is not None:
            if labels.shape[0] != adj_matrix.shape[0]:
                raise ValueError("Dimensions of the adjacency matrix and the label vector don't agree.")

        if node_names is not None:
            if len(node_names) != adj_matrix.shape[0]:
                raise ValueError("Dimensions of the adjacency matrix and the node names don't agree.")

        if attr_names is not None:
            if len(attr_names) != attr_matrix.shape[1]:
                raise ValueError("Dimensions of the attribute matrix and the attribute names don't agree.")

        self.adj_matrix = adj_matrix
        self.attr_matrix = attr_matrix
        self.labels = labels
        self.node_names = node_names
        self.attr_names = attr_names
        self.class_names = class_names
        self.metadata = metadata

    def num_nodes(self) -> int:
        """Get the number of nodes in the graph.
        """
        return self.adj_matrix.shape[0]

    def num_edges(self) -> int:
        """Get the number of edges in the graph.

        For undirected graphs, (i, j) and (j, i) are counted as _two_ edges.

        """
        return self.adj_matrix.nnz

    def get_neighbors(self, idx: int) -> np.ndarray:
        """Get the indices of neighbors of a given node.

        Parameters
        ----------
        idx
            Index of the node whose neighbors are of interest.

        """
        return self.adj_matrix[idx].indices

    def get_edgeid_to_idx_array(self) -> np.ndarray:
        """Return a Numpy Array that maps edgeids to the indices in the adjacency matrix.

        Returns
        -------
        np.ndarray
            The i'th entry contains the x- and y-coordinates of edge i in the adjacency matrix.
            Shape [num_edges, 2]

        """
        return np.transpose(self.adj_matrix.nonzero())

    def is_directed(self) -> bool:
        """Check if the graph is directed (adjacency matrix is not symmetric).
        """
        return (self.adj_matrix != self.adj_matrix.T).sum() != 0

    def to_undirected(self) -> 'SparseGraph':
        """Convert to an undirected graph (make adjacency matrix symmetric).
        """
        idx = self.get_edgeid_to_idx_array().T
        ridx = np.ravel_multi_index(idx, self.adj_matrix.shape)
        ridx_rev = np.ravel_multi_index(idx[::-1], self.adj_matrix.shape)

        # Get duplicate edges (self-loops and opposing edges)
        dup_ridx = ridx[np.isin(ridx, ridx_rev)]
        dup_idx = np.unravel_index(dup_ridx, self.adj_matrix.shape)

        # Check if the adjacency matrix weights are symmetric (if nonzero)
        if len(dup_ridx) > 0 and not np.allclose(self.adj_matrix[dup_idx], self.adj_matrix[dup_idx[::-1]]):
            raise ValueError("Adjacency matrix weights of opposing edges differ.")

        # Create symmetric matrix
        new_adj_matrix = self.adj_matrix + self.adj_matrix.T
        if len(dup_ridx) > 0:
            new_adj_matrix[dup_idx] = (new_adj_matrix[dup_idx] - self.adj_matrix[dup_idx]).A1

        self.adj_matrix = new_adj_matrix
        return self

    def is_weighted(self) -> bool:
        """Check if the graph is weighted (edge weights other than 1).
        """
        return np.any(np.unique(self.adj_matrix[self.adj_matrix.nonzero()].A1) != 1)

    def to_unweighted(self) -> 'SparseGraph':
        """Convert to an unweighted graph (set all edge weights to 1).
        """
        self.adj_matrix.data = np.ones_like(self.adj_matrix.data)
        return self

    def is_connected(self) -> bool:
        """Check if the graph is connected.
        """
        return sp.csgraph.connected_components(self.adj_matrix, return_labels=False) == 1

    def has_self_loops(self) -> bool:
        """Check if the graph has self-loops.
        """
        return not np.allclose(self.adj_matrix.diagonal(), 0)

    def __repr__(self) -> str:
        props = []
        for prop_name in sparse_graph_properties:
            prop = getattr(self, prop_name)
            if prop is not None:
                if prop_name == 'metadata':
                    props.append(prop_name)
                else:
                    shape_string = 'x'.join([str(x) for x in prop.shape])
                    props.append("{} ({})".format(prop_name, shape_string))
        dir_string = 'Directed' if self.is_directed() else 'Undirected'
        weight_string = 'weighted' if self.is_weighted() else 'unweighted'
        conn_string = 'connected' if self.is_connected() else 'disconnected'
        loop_string = 'has self-loops' if self.has_self_loops() else 'no self-loops'
        return ("<{}, {} and {} SparseGraph with {} edges ({}). Data: {}>"
                .format(dir_string, weight_string, conn_string,
                        self.num_edges(), loop_string,
                        ', '.join(props)))

    # Quality of life (shortcuts)
    def standardize(
            self, make_unweighted: bool = True,
            make_undirected: bool = True,
            no_self_loops: bool = True,
            select_lcc: bool = True
            ) -> 'SparseGraph':
        """Perform common preprocessing steps: remove self-loops, make unweighted/undirected, select LCC.

        All changes are done inplace.

        Parameters
        ----------
        make_unweighted
            Whether to set all edge weights to 1.
        make_undirected
            Whether to make the adjacency matrix symmetric. Can only be used if make_unweighted is True.
        no_self_loops
            Whether to remove self loops.
        select_lcc
            Whether to select the largest connected component of the graph.

        """
        G = self
        if make_unweighted and G.is_weighted():
            G = G.to_unweighted()
        if make_undirected and G.is_directed():
            G = G.to_undirected()
        if no_self_loops and G.has_self_loops():
            G = remove_self_loops(G)
        if select_lcc and not G.is_connected():
            G = largest_connected_components(G, 1)
        return G

    def unpack(self) -> Tuple[sp.csr_matrix,
                              Union[np.ndarray, sp.csr_matrix],
                              Union[np.ndarray, sp.csr_matrix]]:
        """Return the (A, X, E, z) quadruplet.
        """
        return self.adj_matrix, self.attr_matrix, self.labels

    def to_flat_dict(self) -> Dict[str, Any]:
        """Return flat dictionary containing all SparseGraph properties.
        """
        data_dict = {}
        for key in sparse_graph_properties:
            val = getattr(self, key)
            if sp.isspmatrix(val):
                data_dict['{}.data'.format(key)] = val.data
                data_dict['{}.indices'.format(key)] = val.indices
                data_dict['{}.indptr'.format(key)] = val.indptr
                data_dict['{}.shape'.format(key)] = val.shape
            else:
                data_dict[key] = val
        return data_dict

    @staticmethod
    def from_flat_dict(data_dict: Dict[str, Any]) -> 'SparseGraph':
        """Initialize SparseGraph from a flat dictionary.
        """
        init_dict = {}
        del_entries = []

        # Construct sparse matrices
        for key in data_dict.keys():
            if key.endswith('_data') or key.endswith('.data'):
                if key.endswith('_data'):
                    sep = '_'
                    warnings.warn(
                            "The separator used for sparse matrices during export (for .npz files) "
                            "is now '.' instead of '_'. Please update (re-save) your stored graphs.",
                            DeprecationWarning, stacklevel=2)
                else:
                    sep = '.'
                matrix_name = key[:-5]
                mat_data = key
                mat_indices = '{}{}indices'.format(matrix_name, sep)
                mat_indptr = '{}{}indptr'.format(matrix_name, sep)
                mat_shape = '{}{}shape'.format(matrix_name, sep)
                if matrix_name == 'adj' or matrix_name == 'attr':
                    warnings.warn(
                            "Matrices are exported (for .npz files) with full names now. "
                            "Please update (re-save) your stored graphs.",
                            DeprecationWarning, stacklevel=2)
                    matrix_name += '_matrix'
                init_dict[matrix_name] = sp.csr_matrix(
                        (data_dict[mat_data],
                         data_dict[mat_indices],
                         data_dict[mat_indptr]),
                        shape=data_dict[mat_shape])
                del_entries.extend([mat_data, mat_indices, mat_indptr, mat_shape])

        # Delete sparse matrix entries
        for del_entry in del_entries:
            del data_dict[del_entry]

        # Load everything else
        for key, val in data_dict.items():
            if ((val is not None) and (None not in val)):
                init_dict[key] = val

        # Check if the dictionary contains only entries in sparse_graph_properties
        unknown_keys = [key for key in init_dict.keys() if key not in sparse_graph_properties]
        if len(unknown_keys) > 0:
            raise ValueError("Input dictionary contains keys that are not SparseGraph properties ({})."
                             .format(unknown_keys))

        return SparseGraph(**init_dict)


def create_subgraph(
        sparse_graph: SparseGraph,
        _sentinel: None = None,
        nodes_to_remove: np.ndarray = None,
        nodes_to_keep: np.ndarray = None
        ) -> SparseGraph:
    """Create a graph with the specified subset of nodes.

    Exactly one of (nodes_to_remove, nodes_to_keep) should be provided, while the other stays None.
    Note that to avoid confusion, it is required to pass node indices as named arguments to this function.

    The subgraph partially points to the old graph's data.

    Parameters
    ----------
    sparse_graph
        Input graph.
    _sentinel
        Internal, to prevent passing positional arguments. Do not use.
    nodes_to_remove
        Indices of nodes that have to removed.
    nodes_to_keep
        Indices of nodes that have to be kept.

    Returns
    -------
    SparseGraph
        Graph with specified nodes removed.

    """
    # Check that arguments are passed correctly
    if _sentinel is not None:
        raise ValueError("Only call `create_subgraph` with named arguments',"
                         " (nodes_to_remove=...) or (nodes_to_keep=...).")
    if nodes_to_remove is None and nodes_to_keep is None:
        raise ValueError("Either nodes_to_remove or nodes_to_keep must be provided.")
    elif nodes_to_remove is not None and nodes_to_keep is not None:
        raise ValueError("Only one of nodes_to_remove or nodes_to_keep must be provided.")
    elif nodes_to_remove is not None:
        nodes_to_keep = [i for i in range(sparse_graph.num_nodes()) if i not in nodes_to_remove]
    elif nodes_to_keep is not None:
        nodes_to_keep = sorted(nodes_to_keep)
    else:
        raise RuntimeError("This should never happen.")

    sparse_graph.adj_matrix = sparse_graph.adj_matrix[nodes_to_keep][:, nodes_to_keep]
    if sparse_graph.attr_matrix is not None:
        sparse_graph.attr_matrix = sparse_graph.attr_matrix[nodes_to_keep]
    if sparse_graph.labels is not None:
        sparse_graph.labels = sparse_graph.labels[nodes_to_keep]
    if sparse_graph.node_names is not None:
        sparse_graph.node_names = sparse_graph.node_names[nodes_to_keep]
    return sparse_graph


def largest_connected_components(sparse_graph: SparseGraph, n_components: int = 1) -> SparseGraph:
    """Select the largest connected components in the graph.

    Changes are returned in a partially new SparseGraph.

    Parameters
    ----------
    sparse_graph
        Input graph.
    n_components
        Number of largest connected components to keep.

    Returns
    -------
    SparseGraph
        Subgraph of the input graph where only the nodes in largest n_components are kept.

    """
    _, component_indices = sp.csgraph.connected_components(sparse_graph.adj_matrix)
    component_sizes = np.bincount(component_indices)
    components_to_keep = np.argsort(component_sizes)[::-1][:n_components]  # reverse order to sort descending
    nodes_to_keep = [
        idx for (idx, component) in enumerate(component_indices) if component in components_to_keep
    ]
    return create_subgraph(sparse_graph, nodes_to_keep=nodes_to_keep)


def remove_self_loops(sparse_graph: SparseGraph) -> SparseGraph:
    """Remove self loops (diagonal entries in the adjacency matrix).

    Changes are returned in a partially new SparseGraph.

    """
    num_self_loops = (~np.isclose(sparse_graph.adj_matrix.diagonal(), 0)).sum()
    if num_self_loops > 0:
        sparse_graph.adj_matrix = sparse_graph.adj_matrix.tolil()
        sparse_graph.adj_matrix.setdiag(0)
        sparse_graph.adj_matrix = sparse_graph.adj_matrix.tocsr()
        warnings.warn("{0} self loops removed".format(num_self_loops))

    return sparse_graph


In [6]:
from numbers import Number
from typing import Union
from pathlib import Path
import numpy as np
import scipy.sparse as sp
from os.path import join

data_dir = os.path.join(os.getcwd(), "data")

def load_from_npz(file_name: str) -> SparseGraph:
    """Load a SparseGraph from a Numpy binary file.

    Parameters
    ----------
    file_name
        Name of the file to load.

    Returns
    -------
    SparseGraph
        Graph in sparse matrix format.

    """
    with np.load(file_name, allow_pickle=True) as loader:
        loader = dict(loader)
        dataset = SparseGraph.from_flat_dict(loader)
    return dataset


def load_dataset(name: str,
                 directory: Union[Path, str] = data_dir
                 ) -> SparseGraph:
    """Load a dataset.

    Parameters
    ----------
    name
        Name of the dataset to load.
    directory
        Path to the directory where the datasets are stored.

    Returns
    -------
    SparseGraph
        The requested dataset in sparse format.

    """
    if isinstance(directory, str):
        directory = Path(directory)
    if not name.endswith('.npz'):
        name += '.npz'
    path_to_file = directory / name
    if path_to_file.exists():
        return load_from_npz(path_to_file)
    else:
        raise ValueError("{} doesn't exist.".format(path_to_file))


def networkx_to_sparsegraph(
        nx_graph: Union['nx.Graph', 'nx.DiGraph'],
        label_name: str = None,
        sparse_node_attrs: bool = True,
        sparse_edge_attrs: bool = True
        ) -> 'SparseGraph':
    """Convert NetworkX graph to SparseGraph.

    Node attributes need to be numeric.
    Missing entries are interpreted as 0.
    Labels can be any object. If non-numeric they are interpreted as
    categorical and enumerated.

    This ignores all edge attributes except the edge weights.

    Parameters
    ----------
    nx_graph
        Graph to convert.

    Returns
    -------
    SparseGraph
        Converted graph.

    """
    import networkx as nx

    # Extract node names
    int_names = True
    for node in nx_graph.nodes:
        int_names &= isinstance(node, int)
    if int_names:
        node_names = None
    else:
        node_names = np.array(nx_graph.nodes)
        nx_graph = nx.convert_node_labels_to_integers(nx_graph)

    # Extract adjacency matrix
    adj_matrix = nx.adjacency_matrix(nx_graph)

    # Collect all node attribute names
    attrs = set()
    for _, node_data in nx_graph.nodes().data():
        attrs.update(node_data.keys())

    # Initialize labels and remove them from the attribute names
    if label_name is None:
        labels = None
    else:
        if label_name not in attrs:
            raise ValueError("No attribute with label name '{}' found.".format(label_name))
        attrs.remove(label_name)
        labels = [0 for _ in range(nx_graph.number_of_nodes())]

    if len(attrs) > 0:
        # Save attribute names if not integer
        all_integer = all((isinstance(attr, int) for attr in attrs))
        if all_integer:
            attr_names = None
            attr_mapping = None
        else:
            attr_names = np.array(list(attrs))
            attr_mapping = {k: i for i, k in enumerate(attr_names)}

        # Initialize attribute matrix
        if sparse_node_attrs:
            attr_matrix = sp.lil_matrix((nx_graph.number_of_nodes(), len(attr_names)), dtype=np.float32)
        else:
            attr_matrix = np.zeros((nx_graph.number_of_nodes(), len(attr_names)), dtype=np.float32)
    else:
        attr_matrix = None
        attr_names = None

    # Fill label and attribute matrices
    for inode, node_attrs in nx_graph.nodes.data():
        for key, val in node_attrs.items():
            if key == label_name:
                labels[inode] = val
            else:
                if not isinstance(val, Number):
                    if node_names is None:
                        raise ValueError("Node {} has attribute '{}' with value '{}', which is not a number."
                                         .format(inode, key, val))
                    else:
                        raise ValueError("Node '{}' has attribute '{}' with value '{}', which is not a number."
                                         .format(node_names[inode], key, val))
                if attr_mapping is None:
                    attr_matrix[inode, key] = val
                else:
                    attr_matrix[inode, attr_mapping[key]] = val
    if attr_matrix is not None and sparse_node_attrs:
        attr_matrix = attr_matrix.tocsr()

    # Convert labels to integers
    if labels is None:
        class_names = None
    else:
        try:
            labels = np.array(labels, dtype=np.float32)
            class_names = None
        except ValueError:
            class_names = np.unique(labels)
            class_mapping = {k: i for i, k in enumerate(class_names)}
            labels_int = np.empty(nx_graph.number_of_nodes(), dtype=np.float32)
            for inode, label in enumerate(labels):
                labels_int[inode] = class_mapping[label]
            labels = labels_int

    return SparseGraph(
            adj_matrix=adj_matrix, attr_matrix=attr_matrix, labels=labels,
            node_names=node_names, attr_names=attr_names, class_names=class_names,
            metadata=None)


In [7]:
__author__ = "Stefan Weißenberger and Johannes Klicpera"
__license__ = "MIT"

import os

import numpy as np
from scipy.linalg import expm

import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.datasets import Planetoid, Amazon, Coauthor
import scipy.sparse as sp
import scipy.sparse.linalg as spla

DATA_PATH = 'data'

def normalize_attributes(attr_matrix):
    epsilon = 1e-12
    if isinstance(attr_matrix, sp.csr_matrix):
        attr_norms = spla.norm(attr_matrix, ord=1, axis=1)
        attr_invnorms = 1 / np.maximum(attr_norms, epsilon)
        attr_mat_norm = attr_matrix.multiply(attr_invnorms[:, np.newaxis])
    else:
        attr_norms = np.linalg.norm(attr_matrix, ord=1, axis=1)
        attr_invnorms = 1 / np.maximum(attr_norms, epsilon)
        attr_mat_norm = attr_matrix * attr_invnorms[:, np.newaxis]
    return attr_mat_norm


def get_dataset(name: str, use_lcc: bool = True) -> InMemoryDataset:
    dataset = InMemoryDataset
    graph = load_dataset(name)
    graph.standardize(select_lcc=True)
    new_y = torch.LongTensor(graph.labels)
    data = Data(
            x=torch.FloatTensor(normalize_attributes(graph.attr_matrix).toarray()),
            edge_index=torch.LongTensor(graph.get_edgeid_to_idx_array().T),
            y=new_y,
            train_mask=torch.zeros(new_y.size(0), dtype=torch.bool),
            test_mask=torch.zeros(new_y.size(0), dtype=torch.bool),
            val_mask=torch.zeros(new_y.size(0), dtype=torch.bool)
        )
    dataset.data = data
    dataset.num_classes =len(np.unique(new_y))
    return dataset


def set_train_val_test_split(
        seed: int,
        data: Data,
        num_development: int = 1500,
        num_per_class: int = 20) -> Data:
    rnd_state = np.random.RandomState(development_seed)
    num_nodes = data.y.shape[0]
    development_idx = rnd_state.choice(num_nodes, num_development, replace=False)
    test_idx = [i for i in np.arange(num_nodes) if i not in development_idx]

    train_idx = []
    rnd_state = np.random.RandomState(seed)
    for c in range(data.y.max() + 1):
        class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]]
        train_idx.extend(rnd_state.choice(class_idx, num_per_class, replace=False))

    val_idx_tmp = [i for i in development_idx if i not in train_idx]

    val_idx = rnd_state.choice(val_idx_tmp, 500, replace=False)
    def get_mask(idx):
        mask = torch.zeros(num_nodes, dtype=torch.bool)
        mask[idx] = 1
        return mask

    data.train_mask = get_mask(train_idx)
    data.val_mask = get_mask(val_idx)
    data.test_mask = get_mask(test_idx)

    return data



Authors Utils:

In [8]:
import time
import yaml
import torch
import logging
import pickle
import matplotlib.pyplot as plt
import scipy.sparse as sp
import numpy as np
import seaborn as sns
import torch.nn.functional as F
import seaborn as sns

from tqdm.notebook import tqdm
from torch.optim import Adam, Optimizer
from collections import defaultdict
from torch_geometric.data import Data, InMemoryDataset


def save_obj(obj, name):
    with open('results/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name):
    with open('results/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

def summary(results):
    report={}
    for k, v in results.items():
        if k != 'steps' and k != 'probs':
            boots_series = sns.algorithms.bootstrap(results[k], func=np.mean, n_boot=1000)
            report[k] = np.mean(results[k])
            report[f'{k}_ci'] = np.max(np.abs(sns.utils.ci(boots_series, 95) - report[k]))
        else:
            array = np.array([k.mean().cpu().detach().numpy() for k in results['steps']])
            boots_series = sns.algorithms.bootstrap(array, func=np.mean, n_boot=1000)
            report[k] = np.mean(array)
            report[f'{k}_ci'] = np.max(np.abs(sns.utils.ci(boots_series, 95) - report[k]))
    return report

def plot_density(results):
    fig, ax = plt.subplots()

    z =[(x.cpu().numpy()).astype(int) for x in results['steps']]
    z = np.vstack(z)
    z = np.mean(z,axis=0)

    sns.distplot(z, hist = False, kde = True,
                 kde_kws = {'shade': True, 'linewidth': 3},
                 ax=ax)
    plt.xlabel('Number of Steps')
    plt.ylabel('Density')
    plt.tight_layout()
    plt.show()
    return

Training Functions:

In [9]:
def train(model: torch.nn.Module, optimizer: Optimizer, data: Data, train_halt, weight_decay: float):
    model.train()

    for param in model.prop.parameters():
        param.requires_grad = train_halt

    optimizer.zero_grad()
    logits, steps, reminders = model(data)

    loss = F.nll_loss(logits[data.train_mask], data.y[data.train_mask])
    l2_reg = sum((torch.sum(param ** 2) for param in model.reg_params))
    loss += weight_decay/2 * l2_reg + model.prop_penalty *(
            steps[data.train_mask] + reminders[data.train_mask]).mean()

    loss.backward()
    optimizer.step()
    return

def evaluate(model: torch.nn.Module, data: Data, test: bool, weight_decay: float):
    model.eval()

    with torch.no_grad():
        logits, steps, reminders = model(data)

        loss = F.nll_loss(logits[data.train_mask], data.y[data.train_mask])
        l2_reg = sum((torch.sum(param ** 2) for param in model.reg_params))
        loss += weight_decay/2 * l2_reg + model.prop_penalty *(
                steps[data.train_mask] + reminders[data.train_mask]).mean()

    eval_dict = {}
    keys = ['train','val']
    eval_dict['steps'] = steps
    for key in keys:
        mask = data[f'{key}_mask']
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        eval_dict[f'{key}_acc'] = acc
    return eval_dict, loss


def test_acc(model: torch.nn.Module, data: Data):
    model.eval()

    with torch.no_grad():
        logits, steps, reminders = model(data)
    mask = data['test_mask']
    pred = logits[mask].max(1)[1]
    acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
    return acc

In [10]:
def run(dataset: InMemoryDataset,
        model: torch.nn.Module,
        seeds: np.ndarray,
        test: bool = False,
        max_epochs: int = 10000,
        patience: int = 100,
        lr: float = 0.01,
        weight_decay: float = 0.01,
        num_development: int = 1500,
        device: str = 'cuda'):


    best_dict = defaultdict(list)

    for seed in tqdm(seeds):
        for _ in range(config['niter_per_seed']):
            torch_seed = gen_seeds()
            torch.manual_seed(seed=torch_seed)

            dataset.data = set_train_val_test_split(
                seed,
                dataset.data,
                num_development=num_development,
                num_per_class=20
                ).to(device)

            model.to(device).reset_parameters()
            optimizer = Adam(model.parameters(),lr=lr)

            patience_counter = 0
            best_loss = 999
            tmp_dict = {'val_acc': 0}

            start_time = time.perf_counter()
            for epoch in range(1, max_epochs + 1):
                if patience_counter == patience:
                    break

                train(model, optimizer, dataset.data, epoch%5==0, weight_decay)
                eval_dict, loss = evaluate(model, dataset.data, test, weight_decay)

                if(eval_dict['val_acc'] > tmp_dict['val_acc']) or (
                  (eval_dict['val_acc'] == tmp_dict['val_acc']) and loss < best_loss):
                    patience_counter = 0
                    tmp_dict['epoch'] = epoch
                    tmp_dict['runtime'] = time.perf_counter() - start_time

                    for k, v in eval_dict.items():
                        tmp_dict[k] = v

                    best_state = {key: value.cpu() for key, value
                                      in model.state_dict().items()}

                else:
                    patience_counter += 1

                if loss < best_loss:
                    best_loss = loss
                    patience_counter = 0

            model.load_state_dict(best_state)
            tmp_dict['test_acc'] = test_acc(model,dataset.data)
            print("Epoch: {:.1f}"" Train: {:.2f}"" Val: {:.2f}"" Test: {:.2f}".format(
                  tmp_dict['epoch'],
                  tmp_dict['train_acc'] * 100,
                  tmp_dict['val_acc'] * 100,
                  tmp_dict['test_acc'] * 100))

            for k, v in tmp_dict.items():
                best_dict[k].append(v)

    return dict(best_dict)

RUN:

In [12]:
device = 'cuda'

if torch.cuda.is_available():
    torch.cuda.synchronize()
#Datasets: 'citeseer', 'cora_ml' 'pubmed' 'ms_academic', 'amazon_electronics_computers', 'amazon_electronics_photo'
#Num Developent: 1500,1500,1500,5000,1500,1500
# weight_decay 0 for Amazon Datasets 8e-03 for the others
config = {'dataset_name': 'cora_ml',
          'test': True,
          'use_lcc': True,
          'num_development': 1500,
          'niter_per_seed': 5,
          'hidden_units': 64,
          'lr': 0.01,
          'dropout': 0.5,
          'weight_decay': 0
         }

dataset = get_dataset(
    name=config['dataset_name'],
    use_lcc=config['use_lcc']
    )

dataset.data = dataset.data.to(device)

In [13]:
dataset.data

Data(x=[2810, 2879], edge_index=[2, 15962], y=[2810], train_mask=[2810], test_mask=[2810], val_mask=[2810])

In [None]:
model = APGCN(dataset,10, prop_penalty=0.05)

total_params = sum(
  param.numel() for param in model.parameters()
)

print(total_params)

results = run(
    dataset,
    model,
    seeds=test_seeds if config['test'] else val_seeds,
    #seeds= quick_seeds,
    lr=config['lr'],
    weight_decay=config['weight_decay'],
    test=config['test'],
    num_development=config['num_development'],
    device=device
    )

#save_obj(results,'results_' + config['dataset_name'])
report = summary(results)

print("FINAL\n"
      "Train Accuracy: {:.2f} ± {:.2f}%\n"
      "Stopping Accuracy: {:.2f} ± {:.2f}%\n"
      "Test     Accuracy: {:.2f} ± {:.2f}%\n"
      "Steps: {:.2f} ± {:.2f}\n"
      "Epochs:  {:.2f} ± {:.2f}\n"
      "Runtime: {:.4f} ± {:.4f}\n"
      .format(
          report['train_acc'] * 100,
          report['train_acc_ci'] * 100,
          report['val_acc'] * 100,
          report['val_acc_ci'] * 100,
          report['test_acc']*100,
          report['test_acc_ci']*100,
          report['steps'],
          report['steps_ci'],
          report['epoch'],
          report['epoch_ci'],
          report['runtime'],
          report['runtime_ci']))

plot_density(results)

del model, dataset
torch.cuda.empty_cache()

184783


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 66.0 Train: 99.29 Val: 80.00 Test: 80.46
Epoch: 37.0 Train: 93.57 Val: 82.00 Test: 80.61
Epoch: 132.0 Train: 100.00 Val: 82.00 Test: 81.98
Epoch: 87.0 Train: 100.00 Val: 82.40 Test: 81.53
Epoch: 88.0 Train: 100.00 Val: 80.60 Test: 80.46
Epoch: 81.0 Train: 97.86 Val: 83.20 Test: 82.44
Epoch: 57.0 Train: 96.43 Val: 82.80 Test: 84.43
Epoch: 33.0 Train: 95.71 Val: 82.00 Test: 82.14
Epoch: 52.0 Train: 97.14 Val: 81.80 Test: 82.06
Epoch: 64.0 Train: 96.43 Val: 81.60 Test: 81.60
Epoch: 70.0 Train: 99.29 Val: 84.20 Test: 85.04
Epoch: 85.0 Train: 100.00 Val: 84.00 Test: 85.73
Epoch: 104.0 Train: 99.29 Val: 83.80 Test: 85.65
