# Partition-Based Active Learning for Graph Neural Networks - Demo Notebook


## Library Installation


In [None]:
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric 

In [None]:
!pip install dgl
!pip install ogb

In [14]:
from __future__ import division
from __future__ import print_function

import os
import argparse
import json
import codecs
from timeit import default_timer as timer
from copy import deepcopy

import random
import numpy as np

from sklearn import cluster
import sklearn.metrics as metrics

import dgl
import networkx as nx
from networkx.algorithms.link_analysis.pagerank_alg import pagerank
from networkx.algorithms.community.quality import modularity
from networkx.utils.mapped_queue import MappedQueue

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.datasets import Planetoid, CoraFull, Coauthor
from ogb.nodeproppred import PygNodePropPredDataset

## Queries

In [4]:
class ActiveLearning:
    """
    An active learning framework that...
    * queries from an oracle;
    * updates its known set,
    * trains the GNN model, and
    * evaluate the Macro F-1 score.
    """

    def __init__(self, data, model, seed, args):
        self.round = 0
        self.data = data
        self.model = model
        self.seed = seed
        self.args = args
        self.retrain = args.retrain
        self.clf = None
        self.aggregated = None
        self.num_centers = args.num_centers
        self.num_parts = -1

    def query(self, b):
        pass

    def update(self, train_mask):
        self.data.train_mask = train_mask
        self.round += 1

    def train(self):
        if self.retrain:
            self.clf = deepcopy(self.model).to(self.args.device)
        else:
            self.clf = self.model.to(self.args.device)
        optimizer = optim.Adam(
            self.clf.parameters(), lr=self.args.lr,
            weight_decay=self.args.weight_decay)
        for epoch in range(self.args.epochs):
            self.clf.train()
            optimizer.zero_grad()
            out = self.clf(self.data.x, self.data.adj_t)
            true = self.data.y
            if len(true.shape) > 1:
                true = true.squeeze(1)
            loss = F.cross_entropy(
                out[self.data.train_mask],
                true[self.data.train_mask])
            if self.args.verbose == 2:
                print('Epoch {:03d}: Training loss: {:.4f}'.format(epoch, loss))
            loss.backward()
            optimizer.step()

    def evaluate(self):
        self.clf.eval()
        logits = self.clf(self.data.x, self.data.adj_t)
        y_pred = logits.max(1)[1].cpu()
        y_true = self.data.y.cpu()
        f1 = metrics.f1_score(y_true, y_pred, average='macro')
        acc = metrics.f1_score(y_true, y_pred, average='micro')
        if self.args.verbose == 2:
            print('Macro-f1 score: {:.4f}'.format(f1))
            print('Micro-f1 score: {:.4f}'.format(acc))
        return f1, acc

    def get_node_representation(self, rep='aggregation', encoder='gcn'):

        if rep == 'aggregation':
            if self.aggregated is None:

                # Dense Implementation

                # A_ = self.data.adj_t.to_dense() + torch.eye(self.data.num_nodes, device=self.args.device)  # A + I
                # D_ = torch.diag(A_.sum(dim=0) ** (-1 / 2))  # (A + I)^(-1/2)
                # A_norm = torch.sparse.mm(torch.sparse.mm(D_, A_), D_)
                # self.aggregated = A_norm
                # for i in range(self.clf.num_layers - 1):
                #     self.aggregated = torch.sparse.mm(self.aggregated, A_norm)
                # self.aggregated = torch.sparse.mm(self.aggregated, self.data.x)

                # Sparse Implementation

                # nnz = len(self.data.adj_t.storage._row)
                # indice = torch.cat([self.data.adj_t.storage._row.unsqueeze(dim=0),
                #                     self.data.adj_t.storage._col.unsqueeze(dim=0)], dim=0)
                # values = torch.ones(nnz, device=self.args.device)
                # A = torch.sparse_coo_tensor(indice, values, [self.data.num_nodes, self.data.num_nodes])
                # diag = torch.tensor(range(self.data.num_nodes), device=self.args.device).unsqueeze(dim=0)
                # indice = torch.cat([diag, diag], dim=0)
                # values = torch.ones(self.data.num_nodes, device=self.args.device)
                # I = torch.sparse_coo_tensor(indice, values, [self.data.num_nodes, self.data.num_nodes])
                # A = A + I
                # value = torch.sparse.sum(A, dim=0) ** (-1 / 2)
                # D = torch.sparse_coo_tensor(indice, value.to_dense(), [self.data.num_nodes, self.data.num_nodes])
                # A_norm = torch.sparse.mm(torch.sparse.mm(D, A), D)
                # self.aggregated = self.data.x.to_sparse()
                # for i in range(self.clf.num_layers):
                #     self.aggregated = torch.sparse.mm(A_norm, self.aggregated)
                # self.aggregated = self.aggregated.to_dense()

                feat_dim = self.data.x.size(1)
                if encoder == 'sage':
                    conv = SAGEConv(feat_dim, feat_dim, bias=False)
                    conv.lin_l.weight = torch.nn.Parameter(torch.eye(feat_dim))
                    conv.lin_r.weight = torch.nn.Parameter(torch.eye(feat_dim))
                else:
                    conv = GCNConv(feat_dim, feat_dim, cached=True, bias=False)
                    conv.lin.weight = torch.nn.Parameter(torch.eye(feat_dim))
                conv.to(self.args.device)
                with torch.no_grad():
                    self.aggregated = conv(self.data.x, self.data.adj_t)
                    self.aggregated = conv(self.aggregated, self.data.adj_t)
            return self.aggregated

        elif rep == 'embedding':
            with torch.no_grad():
                embed = self.clf.embed(self.data.x, self.data.adj_t)
            return embed

        else:
            return self.data.x

    def split_cluster(self, b, partitions, x_embed=None, method='default'):

        if method == 'inertia':
            part_size = []
            for i in range(self.num_parts):
                part_id = np.where(partitions == i)[0]
                x = x_embed[part_id]
                kmeans = Cluster(n_clusters=1, n_dim=x_embed.shape[1], seed=self.seed, device=self.args.device)
                kmeans.train(x.cpu())
                inertia = kmeans.get_inertia()
                part_size.append(inertia)

            part_size = np.rint(b * np.array(part_size) / sum(part_size)).astype(int)
            part_size = np.maximum(self.num_centers, part_size)
            i = 0
            while part_size.sum() - b != 0:
                if part_size.sum() - b > 0:
                    i = self.num_parts - 1 if i <= 0 else i
                    while part_size[i] <= 1:
                        i -= 1
                    part_size[i] -= 1
                    i -= 1
                else:
                    i = 0 if i >= self.num_parts else i
                    part_size[i] += 1
                    i += 1

        elif method == 'size':
            part_size = []
            for i in range(self.num_parts):
                part_size.append(len(np.where(partitions == i)[0]))
            part_size = np.rint(b * np.array(part_size) / sum(part_size)).astype(int)
            part_size = np.maximum(self.num_centers, part_size)
            i = 0
            while part_size.sum() - b != 0:
                if part_size.sum() - b > 0:
                    i = self.num_parts - 1 if i <= 0 else i
                    while part_size[i] <= 1:
                        i -= 1
                    part_size[i] -= 1
                    i -= 1
                else:
                    i = 0 if i >= self.num_parts else i
                    part_size[i] += 1
                    i += 1

        else:
            part_size = [b // self.num_parts for _ in range(self.num_parts)]
            for i in range(b % self.num_parts):
                part_size[i] += 1

        return part_size

    def __str__(self):
        return "Active Learning Agent (uninitialized)"


class Random(ActiveLearning):
    """
    Random:
    The Random Sampling method chooses nodes uniformly at random,
    similarly as the commonly used semi-supervised learning experiment setting for GCN.
    """

    def __init__(self, data, model, seed, args):
        super(Random, self).__init__(data, model, seed, args)

    def query(self, b):
        indice = np.random.choice(
            np.where(self.data.train_mask == 0)[0], b, replace=False
        )
        return torch.tensor(indice)

    def __str__(self):
        return "Random"


class Density(ActiveLearning):
    """
    Density:
    The Density method first performs a clustering algorithm on the hidden representations of the nodes,
    and then chooses nodes with maximum density score, which is (approximately) inversely proportional to
    the L2-distance between each node and its cluster center.
    """

    def __init__(self, data, model, seed, args):
        super(Density, self).__init__(data, model, seed, args)

    def query(self, b):
        # Get propagated nodes
        x_embed = self.get_node_representation('embedding').cpu()

        # Perform K-Means as approximation
        kmeans = Cluster(n_clusters=b, n_dim=x_embed.shape[1], seed=self.seed, device=self.args.device)
        kmeans.train(x_embed)

        # Calculate density
        centers = kmeans.get_centroids()
        label = kmeans.predict(x_embed)
        centers = centers[label]
        dist_map = torch.linalg.norm(x_embed - centers, dim=1)
        density = 1 / (1 + dist_map)

        density[np.where(self.data.train_mask != 0)[0]] = 0
        _, indices = torch.topk(density, k=b)

        return indices

    def __str__(self):
        return "Density"


class Uncertainty(ActiveLearning):
    """
    Uncertainty:
    The Uncertainty method chooses the nodes with maximum entropy on the predicted class distribution.
    """

    def __init__(self, data, model, seed, args):
        super(Uncertainty, self).__init__(data, model, seed, args)

    def query(self, b):
        logits = self.clf(self.data.x, self.data.adj_t)
        entropy = -torch.sum(F.softmax(logits, dim=1) * F.log_softmax(logits, dim=1), dim=1)
        entropy[np.where(self.data.train_mask != 0)[0]] = 0
        _, indices = torch.topk(entropy, k=b)
        return indices

    def __str__(self):
        return "Uncertainty"


class CoreSetGreedy(ActiveLearning):
    """
    CoreSet:
    The CoreSet method performs a K-Center clustering over the hidden representations of nodes.
    A time-efficient greedy approximation version by choosing node closest to the cluster centers.
    """

    def __init__(self, data, model, seed, args):
        super(CoreSetGreedy, self).__init__(data, model, seed, args)

    def query(self, b):

        embed = self.get_node_representation('embedding').cpu()
        indices = list(np.where(self.data.train_mask != 0)[0])

        for i in range(b):
            dist = metrics.pairwise_distances(embed, embed[indices], metric='euclidean')
            min_distances = torch.min(torch.tensor(dist), dim=1)[0]
            new_index = min_distances.argmax()
            indices.append(int(new_index))
        return indices

    def __str__(self):
        return "Core Set (Greedy)"


class CoreSetMIP(ActiveLearning):
    """
    CoreSet:
    The CoreSet method performs a K-Center clustering over the hidden representations of nodes.
    Optimized by gurobipy MIP.
    """

    def __init__(self, data, model, seed, args):
        super(CoreSetMIP, self).__init__(data, model, seed, args)

    def query(self, b):
        import gurobipy

        # Get distance matrix
        embed = self.get_node_representation('embedding')
        dist_mat = embed.matmul(embed.t())
        sq = dist_mat.diagonal().reshape(self.data.num_nodes, 1)
        dist_mat = torch.sqrt(-dist_mat * 2 + sq + sq.t())

        # Perform greedy K-center
        mask = self.data.train_mask.copy()
        mat = dist_mat[~mask, :][:, mask]
        _, indices = mat.min(dim=1)[0].topk(k=b)
        indices = torch.arange(self.data.num_nodes)[~mask][indices]
        mask[indices] = True

        # Robust approximation
        opt = mat.min(dim=1)[0].max()
        ub = opt
        lb = opt / 2.0
        xx, yy = np.where(dist_mat <= opt)
        dd = dist_mat[xx, yy]

        flag = self.data.train_mask.copy()
        subset = np.where(flag == 0)[0].tolist()

        # Solve MIP for fac_loc
        x = {}
        y = {}
        z = {}
        n = self.data.num_nodes
        m = len(xx)

        model = gurobipy.Model("k-center")
        for i in range(n):
            z[i] = model.addVar(
                obj=1, ub=0.0, vtype="B", name="z_{}".format(i))

        for i in range(m):
            _x = xx[i]
            _y = yy[i]
            if _y not in y:
                if _y in subset:
                    y[_y] = model.addVar(
                        obj=0, ub=1.0, lb=1.0, vtype="B", name="y_{}".format(_y))
                else:
                    y[_y] = model.addVar(
                        obj=0, vtype="B", name="y_{}".format(_y))
            x[_x, _y] = model.addVar(
                obj=0, vtype="B", name="x_{},{}".format(_x, _y))
        model.update()

        coef = [1 for j in range(n)]
        var = [y[j] for j in range(n)]
        model.addConstr(
            gurobipy.LinExpr(coef, var), "=", rhs=b + len(subset), name="k_center")

        for i in range(m):
            _x = xx[i]
            _y = yy[i]
            model.addConstr(
                x[_x, _y], "<", y[_y], name="Strong_{},{}".format(_x, _y))

        yyy = {}
        for v in range(m):
            _x = xx[v]
            _y = yy[v]
            if _x not in yyy:
                yyy[_x] = []
            if _y not in yyy[_x]:
                yyy[_x].append(_y)

        for _x in yyy:
            coef = []
            var = []
            for _y in yyy[_x]:
                coef.append(1)
                var.append(x[_x, _y])
            coef.append(1)
            var.append(z[_x])
            model.addConstr(
                gurobipy.LinExpr(coef, var), "=", 1, name="Assign{}".format(_x))

        # Approximate
        delta = 1e-7
        sol_file = None
        while ub - lb > delta:
            cur_r = (ub + lb) / 2.0
            viol = np.where(dd > cur_r)
            new_max_d = torch.min(dd[dd >= cur_r])
            new_min_d = torch.max(dd[dd <= cur_r])
            for v in viol[0]:
                x[xx[v], yy[v]].UB = 0

            model.update()
            r = model.optimize()
            if model.getAttr(gurobipy.GRB.Attr.Status) == gurobipy.GRB.INFEASIBLE:
                failed = True
                print("Infeasible")
            elif sum([z[i].X for i in range(len(z))]) > 0:
                failed = True
                print("Failed")
            else:
                failed = False
            if failed:
                lb = max(cur_r, new_max_d)
                for v in viol[0]:
                    x[xx[v], yy[v]].UB = 1
            else:
                print("sol founded", cur_r, lb, ub)
                ub = min(cur_r, new_min_d)
                sol_file = "s_{}_solution_{}.sol".format(b, cur_r)
                model.write(sol_file)

        # Process results
        if sol_file is not None:
            results = open(sol_file).read().split('\n')
            results_nodes = filter(lambda x1: 'y' in x1,
                                   filter(lambda x1: '#' not in x1, results))
            string_to_id = lambda x1: (
                int(x1.split(' ')[0].split('_')[1]),
                int(x1.split(' ')[1]))
            result_node_ids = map(string_to_id, results_nodes)
            centers = []
            for node_result in result_node_ids:
                if node_result[1] > 0:
                    centers.append(node_result[0])
            return torch.tensor(centers)
        else:
            return None

    def __str__(self):
        return "Core Set (MIP)"


class Degree(ActiveLearning):
    """
    Centrality:
    The Centrality method chooses nodes with the largest graph centrality metric value.
    This framework chooses node degree as the metric.
    """

    def __init__(self, data, model, seed, args):
        super(Degree, self).__init__(data, model, seed, args)

    def query(self, b):

        if hasattr(self.data.adj_t.storage, '_row'):
            degree = self.data.adj_t.sum(dim=0)
        else:
            indice = torch.cat([self.data.adj_t[0].unsqueeze(dim=0),
                                self.data.adj_t[1].unsqueeze(dim=0)], dim=0)
            values = torch.ones(self.data.adj_t.shape[1], device=self.args.device)
            adj = torch.sparse_coo_tensor(indice, values, [self.data.num_nodes, self.data.num_nodes]).to_dense()
            degree = adj.sum(dim=0)

        degree[np.where(self.data.train_mask != 0)[0]] = 0
        _, indices = torch.topk(degree, k=b)
        return indices

    def __str__(self):
        return "Centrality (Degree)"


class PageRank(ActiveLearning):
    """
    PageRank:
    The Centrality method chooses nodes with the largest graph centrality metric value.
    This framework chooses node degree as the metric.
    """

    def __init__(self, data, model, seed, args):
        super(PageRank, self).__init__(data, model, seed, args)

    def query(self, b):
        page = torch.tensor(list(pagerank(self.data.g).values()))
        page[np.where(self.data.train_mask != 0)[0]] = 0
        _, indices = torch.topk(page, k=b)
        return indices

    def __str__(self):
        return "Centrality (PageRank)"


class AGE(ActiveLearning):
    """
    AGE:
    AGE defines the informativeness of nodes by linearly combining three metrics:
    centrality, density and uncertainty.
    It further chooses nodes with the highest scores.
    """

    def __init__(self, data, model, seed, args):
        super(AGE, self).__init__(data, model, seed, args)

    def query(self, b):

        # Get entropy
        logits = self.clf(self.data.x, self.data.adj_t)
        entropy = -torch.sum(F.softmax(logits, dim=1) * F.log_softmax(logits, dim=1), dim=1)

        # Get centrality
        page = torch.tensor(list(pagerank(self.data.g).values()),
                            dtype=logits.dtype, device=self.args.device)

        # Get density
        x = self.get_node_representation('embedding').cpu()
        N = x.shape[0]

        kmeans = Cluster(n_clusters=b, n_dim=x.shape[1], seed=self.seed, device=self.args.device)
        kmeans.train(x)
        centers = kmeans.get_centroids()
        label = kmeans.predict(x)

        x = x.to(logits.device)
        centers = torch.tensor(centers[label], dtype=x.dtype, device=x.device)
        dist_map = torch.linalg.norm(x - centers, dim=1).to(logits.dtype)
        density = 1 / (1 + dist_map)

        # Get percentile
        percentile = (torch.arange(N, dtype=logits.dtype, device=self.args.device) / N)
        id_sorted = density.argsort(descending=False)
        density[id_sorted] = percentile
        id_sorted = entropy.argsort(descending=False)
        entropy[id_sorted] = percentile
        id_sorted = page.argsort(descending=False)
        page[id_sorted] = percentile

        # Get linear combination
        alpha, beta, gamma = self.data.params['age']
        age_score = alpha * entropy + beta * density + gamma * page
        age_score[np.where(self.data.train_mask != 0)[0]] = 0
        _, indices = torch.topk(age_score, k=b)
        return indices


class ClusterBased(ActiveLearning):
    """
    Cluster:
    The cluster method first performs clustering (K-Means as approximation of K-Medoids)
    on the aggregated node features and then choose the nodes closest to the K-means centers.

    rep {'feature', 'embedding', 'aggregation'}
    init {‘k-means++’, ‘random’}
    """

    def __init__(self, data, model, seed, args,
                 representation='aggregation',
                 encoder='gcn',
                 initialization='k-means++'):
        super(ClusterBased, self).__init__(data, model, seed, args)
        self.representation = representation
        self.encoder = encoder
        self.initialization = None if initialization != 'k-means++' else initialization

    def query(self, b):

        # Get node representations
        x = self.get_node_representation(self.representation, self.encoder)

        # Perform K-Means clustering:
        kmeans = Cluster(n_clusters=b, n_dim=x.shape[1], seed=self.seed, device=self.args.device)
        kmeans.train(x.cpu().numpy())
        centers = torch.tensor(kmeans.get_centroids(), dtype=x.dtype, device=x.device)

        # Obtain the centers
        indices = list(np.where(self.data.train_mask != 0)[0])
        for center in centers:
            center = center.to(dtype=x.dtype, device=x.device)
            dist_map = torch.linalg.norm(x - center, dim=1)
            dist_map[indices] = torch.tensor(np.infty, dtype=dist_map.dtype, device=dist_map.device)
            idx = int(torch.argmin(dist_map))
            indices.append(idx)

        return torch.tensor(indices)


class PartitionBased(ActiveLearning):
    """
    Partition:
    Our method, which first partitions the graph into communities, and
    performs clustering over each graph community on the aggregated node features.

    rep {'none', 'embed', 'prop'}
    init {‘k-means++’, ‘random’}
    compensation {float: 0 - 1}
    """

    def __init__(self, data, model, seed, args,
                 representation='aggregation',
                 encoder='gcn',
                 initialization='k-means++',
                 compensation=1):
        super(PartitionBased, self).__init__(data, model, seed, args)
        self.representation = representation
        self.encoder = encoder
        self.initialization = None if initialization != 'k-means++' else initialization
        self.compensation = compensation

    def query(self, b):

        # Perform graph partition (preprocessed)
        self.num_parts = int(np.ceil(b / self.num_centers))
        compensation = 0
        if self.num_parts > self.data.max_part:
            self.num_parts = self.data.max_part
            compensation = self.compensation
        partitions = np.array(self.data.partitions[self.num_parts].cpu())

        # Get node representations
        x = self.get_node_representation(self.representation, self.encoder)

        # Determine the number of partitions and number of centers
        part_size = self.split_cluster(b, partitions, x)

        # Iterate over each partition
        indices = list(np.where(self.data.train_mask != 0)[0])
        for i in range(self.num_parts):
            part_id = np.where(partitions == i)[0]
            masked_id = [i for i, x in enumerate(part_id) if x in indices]
            xi = x[part_id]

            n_clusters = part_size[i]
            if n_clusters <= 0:
                continue

            # Perform K-Means clustering:
            kmeans = Cluster(n_clusters=n_clusters, n_dim=xi.shape[1], seed=self.seed, device=self.args.device)
            kmeans.train(xi.cpu().numpy())
            centers = kmeans.get_centroids()

            # Compensating for the interference across partitions
            dist = None
            if self.compensation > 0:
                dist_to_center = torch.ones(x.shape[0], dtype=x.dtype, device=x.device) * np.infty
                for idx in indices:
                    dist_to_center = torch.minimum(dist_to_center, torch.linalg.norm(x - x[idx], dim=1))
                dist = dist_to_center[part_id]

            # Obtain the centers
            for center in centers:
                center = torch.tensor(center, dtype=x.dtype, device=x.device)
                dist_map = torch.linalg.norm(xi - center, dim=1)
                if self.compensation > 0:
                    dist_map -= dist * compensation
                dist_map[masked_id] = torch.tensor(np.infty, dtype=dist_map.dtype, device=dist_map.device)
                idx = int(torch.argmin(dist_map))
                masked_id.append(idx)
                indices.append(part_id[idx])

        return torch.tensor(indices)


## Models

In [5]:
class Cluster:
    """
    Kmeans Clustering
    """
    def __init__(self, n_clusters, n_dim, seed,
                 implementation='sklearn',
                 init='k-means++',
                 device=torch.cuda.is_available()):

        assert implementation in ['sklearn', 'faiss', 'cuml']
        assert init in ['k-means++', 'random']

        self.n_clusters = n_clusters
        self.n_dim = n_dim
        self.implementation = implementation
        self.initialization = init
        self.model = None

        if implementation == 'sklearn':
            self.model = cluster.KMeans(n_clusters=n_clusters, init=init, random_state=seed)
        elif implementation == 'faiss':
            import faiss
            self.model = faiss.Kmeans(n_dim, n_clusters, niter=20, nredo=10, seed=seed, gpu=device != 'cpu')
        elif implementation == 'cuml':
            import cuml
            if init == 'k-means++':
                init = 'scalable-kmeans++'
            self.model = cuml.KMeans(n_dim, n_clusters, random_state=seed, init=init, output_type='numpy')
        else:
            raise NotImplemented

    def train(self, x):
        if self.implementation == 'sklearn':
            self.model.fit(x)
        elif self.implementation == 'faiss':
            if self.initialization == 'kmeans++':
                init_centroids = self._kmeans_plusplus(x, self.n_clusters).cpu().numpy()
            else:
                init_centroids = None
            self.model.train(x, init_centroids=init_centroids)
        elif self.implementation == 'cuml':
            self.model.fit(x)
        else:
            raise NotImplemented

    def predict(self, x):
        if self.implementation == 'sklearn':
            return self.model.predict(x)
        elif self.implementation == 'faiss':
            _, labels = self.model.index.search(x, 1)
            return labels
        else:
            raise NotImplemented

    def get_centroids(self):
        if self.implementation == 'sklearn':
            return self.model.cluster_centers_
        elif self.implementation == 'faiss':
            return self.model.centroids
        elif self.implementation == 'cuml':
            return self.model.cluster_centers_
        else:
            raise NotImplemented

    def get_inertia(self):
        if self.implementation == 'sklearn':
            return self.model.inertia_
        else:
            raise NotImplemented

    @staticmethod
    def _kmeans_plusplus(X, n_clusters):
        """
        K-means++ initialization in PyTorch for Faiss.

        Modified from sklearn version of implementation.
        https://github.com/scikit-learn/scikit-learn/blob/2beed5584/sklearn/cluster/_kmeans.py
        """

        n_samples, n_features = X.shape

        # Set the number of local seeding trials if none is given
        n_local_trials = 2 + int(np.log(n_clusters))

        # Pick first center randomly and track index of point
        center_id = torch.randint(n_samples, (1,)).item()
        centers = [X[center_id]]

        # Initialize list of closest distances and calculate current potential
        closest_dist_sq = torch.cdist(X, X[center_id].unsqueeze(dim=0)).pow(2).squeeze()
        current_pot = closest_dist_sq.sum()

        # Pick the remaining n_clusters-1 points
        for c in range(1, n_clusters):
            # Choose center candidates by sampling with probability proportional
            # to the squared distance to the closest existing center
            rand_vals = torch.rand(n_local_trials).to(current_pot.device) * current_pot
            candidate_ids = torch.searchsorted(torch.cumsum(closest_dist_sq.flatten(), dim=0), rand_vals)

            # Numerical imprecision can result in a candidate_id out of range
            torch.clip(candidate_ids, min=None, max=closest_dist_sq.shape[0] - 1, out=candidate_ids)

            # Compute distances to center candidates
            distance_to_candidates = torch.cdist(X[candidate_ids].unsqueeze(dim=0), X).pow(2).squeeze()

            # update closest distances squared and potential for each candidate
            torch.minimum(closest_dist_sq, distance_to_candidates, out=distance_to_candidates)
            candidates_pot = distance_to_candidates.sum(dim=1)

            # Decide which candidate is the best
            best_candidate = torch.argmin(candidates_pot)
            current_pot = candidates_pot[best_candidate]
            closest_dist_sq = distance_to_candidates[best_candidate]
            best_candidate = candidate_ids[best_candidate]

            # Permanently add best center candidate found in local tries
            centers.append(X[best_candidate])

        centers = torch.stack(centers, dim=0).to(dtype=X.dtype)
        return centers


class GCN(torch.nn.Module):
    """
    Graph Convolutional Network
    """
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_layers=2, dropout=0.5, batchnorm=False, activation="relu"):
        super(GCN, self).__init__()

        assert activation in ["relu", "elu"]

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        self.bns = torch.nn.ModuleList()
        if batchnorm:
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))

        self.num_layers = num_layers
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=True))
            if batchnorm:
                self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))

        self.dropout = dropout
        self.activation = getattr(F, activation)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        x = self.embed(x, adj_t)
        x = self.convs[-1](x, adj_t)
        return x

    def embed(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            if len(self.bns) > 0:
                x = self.bns[i](x)
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return x


class SAGE(torch.nn.Module):
    """
    GraphSAGE
    """
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_layers=2, dropout=0.5, batchnorm=False, activation="relu"):
        super(SAGE, self).__init__()

        assert activation in ["relu", "elu"]

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.bns = torch.nn.ModuleList()
        if batchnorm:
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))

        self.num_layers = num_layers
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            if batchnorm:
                self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

        self.dropout = dropout
        self.activation = getattr(F, activation)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        x = self.embed(x, adj_t)
        x = self.convs[-1](x, adj_t)
        return x

    def embed(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            if len(self.bns) > 0:
                x = self.bns[i](x)
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return x


class GAT(torch.nn.Module):
    """
    Graph Attention Network
    """
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads,
                 num_layers=2, dropout=0.5, batchnorm=False, activation="relu"):
        super(GAT, self).__init__()

        assert activation in ["relu", "elu"]

        self.convs = torch.nn.ModuleList()
        self.convs.append(
            GATConv(in_channels, hidden_channels, heads=num_heads, bias=False))
        self.bns = torch.nn.ModuleList()
        if batchnorm:
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))

        self.num_layers = num_layers
        for _ in range(num_layers - 2):
            self.convs.append(
                GATConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, bias=False))
            if batchnorm:
                self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(
            GATConv(hidden_channels * num_heads, out_channels, heads=num_heads, bias=False))

        self.dropout = dropout
        self.activation = getattr(F, activation)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        x = self.embed(x, adj_t)
        x = self.convs[-1](x, adj_t)
        return x

    def embed(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            if len(self.bns) > 0:
                x = self.bns[i](x)
            x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return x

## Partition

In [6]:
class GraphPartition:

    def __init__(self, graph, x, num_classes):

        self.graph = graph
        self.x = x
        self.n_cluster = num_classes
        self.costs = []

    def clauset_newman_moore(self, num_part=-1, weight=None, q_break=0):
        """
        Find communities in graph using Clauset-Newman-Moore greedy modularity maximization.

        Greedy modularity maximization begins with each node in its own community
        and joins the pair of communities that most increases (least decrease) modularity
        until q_break.

        Modified from
        https://networkx.org/documentation/stable/_modules/networkx/algorithms/community/modularity_max.html#greedy_modularity_communities
        """

        # Count nodes and edges
        N = len(self.graph.nodes())
        m = sum([d.get("weight", 1) for u, v, d in self.graph.edges(data=True)])
        q0 = 1.0 / (2.0 * m)

        # Map node labels to contiguous integers
        label_for_node = {i: v for i, v in enumerate(self.graph.nodes())}
        node_for_label = {label_for_node[i]: i for i in range(N)}

        # Calculate edge weight
        if weight is not None:
            edge_weight = []
            for edge in self.graph.edges:
                edge_weight.append(torch.linalg.norm(self.x[edge[0]] - self.x[edge[1]]).item())
            edge_weight = torch.tensor(edge_weight)
            edge_weight -= edge_weight.min()
            edge_weight /= edge_weight.max()
            attrs = {}
            for edge, distance in zip(self.graph.edges, list(edge_weight)):
                attrs[edge] = {'distance': distance}
            weight = 'distance'

        # Calculate degrees
        k_for_label = self.graph.degree(self.graph.nodes(), weight=weight)
        k = [k_for_label[label_for_node[i]] for i in range(N)]

        # Initialize community and merge lists
        communities = {i: frozenset([i]) for i in range(N)}

        # Initial modularity and homophily
        partition = [[label_for_node[x] for x in c] for c in communities.values()]
        q_cnm = modularity(self.graph, partition)

        # Initialize data structures
        # CNM Eq 8-9 (Eq 8 was missing a factor of 2 (from A_ij + A_ji)
        # a[i]: fraction of edges within community i
        # dq_dict[i][j]: dQ for merging community i, j
        # dq_heap[i][n] : (-dq, i, j) for communitiy i nth largest dQ
        # H[n]: (-dq, i, j) for community with nth largest max_j(dQ_ij)
        a = [k[i] * q0 for i in range(N)]
        dq_dict = {
            i: {
                j: 2 * q0 - 2 * k[i] * k[j] * q0 * q0
                for j in [node_for_label[u] for u in self.graph.neighbors(label_for_node[i])]
                if j != i
            }
            for i in range(N)
        }
        dq_heap = [
            MappedQueue([(-dq, i, j) for j, dq in dq_dict[i].items()]) for i in range(N)
        ]
        H = MappedQueue([dq_heap[i].heap[0] for i in range(N) if len(dq_heap[i]) > 0])

        # Merge communities until we can't improve modularity
        while len(H) > 1:
            # Find best merge
            # Remove from heap of row maxes
            # Ties will be broken by choosing the pair with lowest min community id
            try:
                dq, i, j = H.pop()
            except IndexError:
                break
            dq = -dq

            # Remove best merge from row i heap
            dq_heap[i].pop()

            # Push new row max onto H
            if len(dq_heap[i]) > 0:
                H.push(dq_heap[i].heap[0])

            # If this element was also at the root of row j, we need to remove the
            # duplicate entry from H
            if dq_heap[j].heap[0] == (-dq, j, i):
                H.remove((-dq, j, i))
                # Remove best merge from row j heap
                dq_heap[j].remove((-dq, j, i))
                # Push new row max onto H
                if len(dq_heap[j]) > 0:
                    H.push(dq_heap[j].heap[0])
            else:
                # Duplicate wasn't in H, just remove from row j heap
                dq_heap[j].remove((-dq, j, i))

            # Stop when change is non-positive 0
            if 0 < num_part == len(communities):
                break
            elif dq <= q_break:
                break

            # New modularity and homophily
            q_cnm += dq

            # Perform merge
            communities[j] = frozenset(communities[i] | communities[j])
            del communities[i]

            # Get list of communities connected to merged communities
            i_set = set(dq_dict[i].keys())
            j_set = set(dq_dict[j].keys())
            all_set = (i_set | j_set) - {i, j}
            both_set = i_set & j_set

            # Merge i into j and update dQ
            for k in all_set:

                # Calculate new dq value
                if k in both_set:
                    dq_jk = dq_dict[j][k] + dq_dict[i][k]
                elif k in j_set:
                    dq_jk = dq_dict[j][k] - 2.0 * a[i] * a[k]
                else:
                    # k in i_set
                    dq_jk = dq_dict[i][k] - 2.0 * a[j] * a[k]

                # Update rows j and k
                for row, col in [(j, k), (k, j)]:
                    # Save old value for finding heap index
                    if k in j_set:
                        d_old = (-dq_dict[row][col], row, col)
                    else:
                        d_old = None
                    # Update dict for j,k only (i is removed below)
                    dq_dict[row][col] = dq_jk
                    # Save old max of per-row heap
                    if len(dq_heap[row]) > 0:
                        d_oldmax = dq_heap[row].heap[0]
                    else:
                        d_oldmax = None
                    # Add/update heaps
                    d = (-dq_jk, row, col)
                    if d_old is None:
                        # We're creating a new nonzero element, add to heap
                        dq_heap[row].push(d)
                    else:
                        # Update existing element in per-row heap
                        dq_heap[row].update(d_old, d)
                    # Update heap of row maxes if necessary
                    if d_oldmax is None:
                        # No entries previously in this row, push new max
                        H.push(d)
                    else:
                        # We've updated an entry in this row, has the max changed?
                        if dq_heap[row].heap[0] != d_oldmax:
                            H.update(d_oldmax, dq_heap[row].heap[0])

            # Remove row/col i from matrix
            i_neighbors = dq_dict[i].keys()
            for k in i_neighbors:
                # Remove from dict
                dq_old = dq_dict[k][i]
                del dq_dict[k][i]
                # Remove from heaps if we haven't already
                if k != j:
                    # Remove both row and column
                    for row, col in [(k, i), (i, k)]:
                        # Check if replaced dq is row max
                        d_old = (-dq_old, row, col)
                        if dq_heap[row].heap[0] == d_old:
                            # Update per-row heap and heap of row maxes
                            dq_heap[row].remove(d_old)
                            H.remove(d_old)
                            # Update row max
                            if len(dq_heap[row]) > 0:
                                H.push(dq_heap[row].heap[0])
                        else:
                            # Only update per-row heap
                            dq_heap[row].remove(d_old)

            del dq_dict[i]
            # Mark row i as deleted, but keep placeholder
            dq_heap[i] = MappedQueue()
            # Merge i into j and update a
            a[j] += a[i]
            a[i] = 0

        communities = [
            [label_for_node[i] for i in c] for c in communities.values()
        ]
        return sorted(communities, key=len, reverse=True)

    def agglomerative_clustering(self, communities, min_clusters=2):
        """
        Agglomerative Clustering: Ward's Linkage Method
        """

        n_clusters = list(range(min_clusters, len(communities)))
        n_clusters.reverse()
        partitions = {}

        dist, x_com = self.community_linkage(communities, full=True)

        num_clusters = len(communities)
        while num_clusters > min(n_clusters):

            sorted_communities = sorted(communities, key=lambda c: len(c), reverse=True)
            partitions[num_clusters] = torch.zeros(self.x.shape[0], dtype=torch.int)
            for i, com in enumerate(sorted_communities):
                partitions[num_clusters][com] = i

            merge_cost, closest_idx = torch.min(dist, dim=1)
            j = torch.argmin(merge_cost).item()
            i = closest_idx[j].item()
            assert i > j

            communities[j].extend(communities[i])
            del communities[i]
            x_com = torch.cat((x_com[0:i], x_com[i + 1:]), dim=0)
            x_com[j] = self.x[communities[j]].mean(axis=0)

            dist = torch.cat((dist[0:i], dist[i + 1:]), dim=0)
            dist = torch.cat((dist[:, 0:i], dist[:, i + 1:]), dim=1)
            num_clusters -= 1

            for k in range(len(communities)):
                if k == j:
                    continue
                nk, nj = len(communities[k]), len(communities[j])
                n = nk * nj / (nk + nj)
                d = torch.linalg.norm(x_com[j] - x_com[k])
                dist[k, j] = d * n
                dist[j, k] = d * n

            cost = merge_cost.min().item()
            self.costs.append(cost)

        return partitions

    def community_linkage(self, communities, full=True):

        n = self.x.shape[1]
        x_com = []
        for com in communities:
            x_com.append(self.x[com].mean(axis=0))
        x_com = torch.stack(x_com, dim=0)

        linkage = torch.linalg.norm(
            x_com.reshape(1, -1, n) - x_com.reshape(-1, 1, n), dim=2
        )
        for i in range(len(communities)):
            for j in range(i + 1, len(communities)):
                ni, nj = len(communities[i]), len(communities[j])
                n = ni * nj / (ni + nj)
                linkage[i, j] *= n
                linkage[j, i] *= n
        linkage += torch.diag(torch.ones(linkage.shape[0]) * float("Inf"))

        if full:
            return linkage, x_com
        return linkage


## Experiment Setup

In [16]:
def run(data, args):

    gnn = args.model
    baseline = args.baselines
    budget = int(args.budget)
    seed = int(args.seed)

    # Choose model
    model_args = {
        "in_channels": data.num_features,
        "out_channels": data.num_classes,
        "hidden_channels": args.hidden,
        "num_layers": args.num_layers,
        "dropout": args.dropout,
        "activation": args.activation,
        "batchnorm": args.batchnorm
    }

    # Initialize models
    if gnn == "gat":
        model_args["num_heads"] = args.num_heads
        model_args["hidden_channels"] = int(args.hidden / args.num_heads)
        model = GAT(**model_args)
    elif gnn == "gcn":
        model = GCN(**model_args)
    elif gnn == "sage":
        model = SAGE(**model_args)
    else:
        raise NotImplemented

    model = model.to(args.device)

    # General-Purpose Methods
    if baseline == "random":
        agent = Random(data, model, seed, args)
    elif baseline == "density":
        agent = Density(data, model, seed, args)
    elif baseline == "uncertainty":
        agent = Uncertainty(data, model, seed, args)
    elif baseline == "coreset":
        agent = CoreSetGreedy(data, model, seed, args)

    # Graph-specific Methods
    elif baseline == "degree":
        agent = Degree(data, model, seed, args)
    elif baseline == "pagerank":
        agent = PageRank(data, model, seed, args)
    elif baseline == "age":
        agent = AGE(data, model, seed, args)
    elif baseline == "featprop":
        agent = ClusterBased(data, model, seed, args,
                              representation='aggregation',
                              encoder='gcn')

    # Our Methods
    elif baseline == "graphpart":
        agent = PartitionBased(data, model, seed, args,
                                representation='aggregation',
                                encoder='gcn',
                                compensation=0)
    elif baseline == "graphpartfar":
        agent = PartitionBased(data, model, seed, args,
                                representation='aggregation',
                                encoder='gcn',
                                compensation=1)

    # Ablation Studies
    elif 'part' in baseline:
        agent = PartitionBased(data, model, seed, args,
                                representation=args.representation,
                                compensation=0)
    else:
        agent = ClusterBased(data, model, seed, args,
                              representation=args.representation)

    # Initialization
    training_mask = np.zeros(data.num_nodes, dtype=bool)
    initial_mask = np.arange(data.num_nodes)
    np.random.shuffle(initial_mask)
    init = args.init
    if baseline in ['density', 'uncertainty', 'coreset', 'age']:
        init = budget // 3
    training_mask[initial_mask[:init]] = True

    training_mask = torch.tensor(training_mask)
    agent.update(training_mask)
    agent.train()

    if args.verbose > 0:
        print('Round {:03d}: Labelled: {:d}, Prediction macro-f1 score {:.4f}'
              .format(0, init, agent.evaluate()))
        
    # Query
    start = timer()
    indices = agent.query(budget - init)
    end = timer()
    print('Total Query Runtime [s]:', end - start)

    # Update
    training_mask[indices] = True
    agent.update(training_mask)

    # Training
    agent.train()

    # Evaluate
    f1, acc = agent.evaluate()
    labelled = len(np.where(agent.data.train_mask != 0)[0])

    if args.verbose > 0:
        print('Round {:03d}: # Labelled nodes: {:d}, Prediction macro-f1 score {:.4f}'
              .format(rd, labelled, f1))
    else:
        print("{},{},{},{},{},{}"
              .format(gnn, baseline, seed,
                      labelled, f1, acc))

    return f1, acc, indices, agent

In [18]:
def test(data, gnns, budgets, baselines, seed=0):

    # Set seeds
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    # Training settings
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--verbose", type=int, default=0, help="Verbose: 0, 1 or 2")
    parser.add_argument(
        "--device", default=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    # General configs
    parser.add_argument(
        "--baselines", type=str, default=baselines)
    parser.add_argument(
        "--model", default=gnns)
    parser.add_argument(
        "--partition", type=str, default='greedy')

    # Active Learning parameters
    parser.add_argument(
        "--budget", type=int, default=budgets,
        help="Number of rounds to run the agent.")
    parser.add_argument(
        "--retrain", type=bool, default=True)
    parser.add_argument(
        "--num_centers", type=int, default=1)
    parser.add_argument(
        "--representation", type=str, default='aggregation')
    parser.add_argument(
        "--compensation", type=float, default=1.0)
    parser.add_argument(
        "--init", type=float, default=0, help="Number of initially labelled nodes.")
    parser.add_argument(
        "--epochs", type=int, default=300, help="Number of epochs to train.")
    parser.add_argument(
        "--steps", type=int, default=4, help="Number of steps of random walk.")

    # GNN parameters
    parser.add_argument(
        "--seed", type=int, default=seed, help="Number of random seeds.")
    parser.add_argument(
        "--lr", type=float, default=0.01, help="Initial learning rate.")
    parser.add_argument(
        "--weight_decay", type=float, default=5e-4,
        help="Weight decay (L2 loss on parameters).")
    parser.add_argument(
        "--hidden", type=int, default=16, help="Number of hidden units.")
    parser.add_argument(
        "--num_layers", type=int, default=2, help="Number of layers.")
    parser.add_argument(
        "--dropout", type=float, default=0,
        help="Dropout rate (1 - keep probability).")
    parser.add_argument(
        "--batchnorm", type=bool, default=False,
        help="Perform batch normalization")
    parser.add_argument(
        "--activation", default="relu")

    # GAT hyper-parameters
    parser.add_argument(
        "--num_heads", type=int, default=8, help="Number of heads.")

    args, _ = parser.parse_known_args()

    f1_all, acc_all, queried, agent = run(data, args)

    agg = agent.get_node_representation('aggregation', 'gcn').cpu().numpy()
    agg_distance = {}
    train_idx = list(queried.numpy())
    test_idx = list(range(data.num_nodes))
    train_agg = agg[train_idx]

    for i in test_idx:
        distance_tmp = train_agg - agg[i]
        agg_distance[int(i)] = float(min(np.linalg.norm(distance_tmp, axis=1)))
    
    group_num = 10
    sort_res = list(
        map(lambda x: x[0], sorted(agg_distance.items(), key=lambda x: x[1])))
    node_num_group = len(sort_res) // group_num
    res = [
        sort_res[i:i + node_num_group + 1]
        for i in range(0, len(sort_res), node_num_group + 1)
    ]

    acc_list = []
    f1_list = []
    agent.clf.eval()
    logits = agent.clf(agent.data.x, agent.data.adj_t)
    y_pred = logits.max(1)[1].cpu()
    y_true = agent.data.y.cpu()
    for test_set in res:
        acc = metrics.accuracy_score(y_true[test_set], y_pred[test_set])
        f1 = metrics.f1_score(y_true[test_set], y_pred[test_set], average='macro')
        acc_list.append(acc)
        f1_list.append(f1)

    print("{},{},{},{},{},\n{},{},{},{},{}"
    .format(acc_list[0], acc_list[1], acc_list[2], acc_list[3], acc_list[4],
            acc_list[5], acc_list[6], acc_list[7], acc_list[8], acc_list[9]))

## Experiment

In [12]:
name = 'cora'

path = os.path.join("data", name)
dataset = Planetoid(root=path, name=name, transform=T.ToSparseTensor())

data = dataset[0]
data.max_part = 7
data.num_classes = dataset.num_classes
data.params = {'age': [0.05, 0.05, 0.9]}

print(data.num_nodes)
print(data.num_edges)
print(data.num_classes)
print(data.x.shape[1])

data.printname = name
data.adj_t = data.adj_t.to_symmetric() if not isinstance(data.adj_t, torch.Tensor) else data.adj_t
edges = [(int(i), int(j)) for i, j in zip(data.adj_t.storage._row,
                                          data.adj_t.storage._col)]

data.g = nx.Graph()
data.g.add_edges_from(edges)
graph = data.g.to_undirected()
edges = [(int(i), int(i)) for i in range(data.num_nodes)]
data.g.add_edges_from(edges)

feat_dim = data.x.size(1)
conv = GCNConv(feat_dim, feat_dim, cached=True, bias=False)
conv.lin.weight = torch.nn.Parameter(torch.eye(feat_dim))
with torch.no_grad():
    data.aggregated = conv(data.x, data.adj_t)
    data.aggregated = conv(data.aggregated, data.adj_t)
    
filename = "data/partitions.json"
if os.path.exists(filename):
    data.partitions = {}
    obj_text = codecs.open(filename, 'r', encoding='utf-8').read()
    part_dict = json.loads(obj_text)
    data.max_part = part_dict[name]['num_part']
    data.partitions[data.max_part] = torch.tensor(part_dict[name]['partition'])
else:
    print('Partition file not found!')
    raise NotImplemented

2708
10556
7
1433


In [19]:
test(data, 'gcn', 40, 'graphpart', seed=0)
test(data, 'gcn', 40, 'graphpartfar', seed=0)
test(data, 'gcn', 40, 'random', seed=0)

Total Query Runtime [s]: 3.956631758999947
gcn,graphpart,0,40,0.7985027144005709,0.8142540620384048
0.915129151291513,0.8523985239852399,0.8228782287822878,0.8302583025830258,0.8339483394833949,
0.8339483394833949,0.7712177121771218,0.8118081180811808,0.7749077490774908,0.6951672862453532
Total Query Runtime [s]: 5.6262782779999725
gcn,graphpartfar,0,40,0.7702959697159066,0.7961595273264401
0.9261992619926199,0.8302583025830258,0.7601476014760148,0.8228782287822878,0.7859778597785978,
0.7970479704797048,0.7970479704797048,0.7601476014760148,0.7527675276752768,0.7286245353159851
Total Query Runtime [s]: 0.0003246920000492537
gcn,random,0,40,0.5655700544372413,0.6063515509601182
0.8302583025830258,0.7121771217712177,0.6715867158671587,0.6826568265682657,0.6346863468634686,
0.4981549815498155,0.5239852398523985,0.5202952029520295,0.45387453874538747,0.5353159851301115
