Import Libraries

In [1]:
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
import networkx as nx
import metis
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv
import matplotlib.pyplot as plt

Download Cora Dataset

In [2]:
dataset = DglNodePropPredDataset('ogbn-arxiv')
arxiv_dgl = dgl.data.AsNodePredDataset(dataset)[0]
arxiv_dgl = dgl.add_reverse_edges(arxiv_dgl)

Convert to a networkx graph

In [3]:
arxiv_nx = dgl.to_networkx(arxiv_dgl)

METIS Partitioning

In [9]:
k=4
_, parts = metis.part_graph(arxiv_nx,k,contig=True)

In [10]:
parts_tensor = torch.Tensor(parts)
sgs = []

# for each partition
for i in range(k):
    # get the nodes ids for the partition
    sg_nodes = (parts_tensor == i).nonzero()[:,0].tolist()

    # get the dgl subgraph from these ids
    sg = dgl.node_subgraph(arxiv_dgl, sg_nodes)
    sg.ndata['og_ids'] = torch.LongTensor(sg_nodes)
    sgs.append(sg)

get the og ids of the nodes that become disconnected

In [11]:
node_ids = (sgs[0].in_degrees() == 0).nonzero()
node_ids = sgs[0].ndata['og_ids'][node_ids]

In [12]:
node_ids

tensor([], size=(0, 1), dtype=torch.int64)

confirm that their neighbors have a different partition

proposed solution is to just reconnect the nodes to its old neighbors so will be part of the correct partition

In [None]:
# for sg in sgs:
#     node_ids = (sgs[0].in_degrees() == 0).nonzero()
#     node_ids = sgs[0].ndata['og_ids'][node_ids]
#     for node in node_ids:
#         # get the partition of its neighbor
#         neighbor = arxiv_nx.neighbors(node)
#         neighbor_part = parts_tensor(neighbor)

#         # move the node to this partition and add an edge
#         sg.remove_nodes(node)
#         sgs[neighbor_part].add_node(node)


In [57]:
parts_tensor[46080]

tensor(3.)

In [54]:
list(arxiv_nx.neighbors(763))

[46080]

In [50]:
arxiv_dgl.in_degrees()[node_ids]

tensor([[1],
        [1],
        [1],
        [1],
        [2],
        [1],
        [1],
        [1],
        [1],
        [1]])

In [45]:
parts_tensor[node_ids]

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]])

Create a dictionary that maps node ids to partition ids

In [27]:
part_dict = {}
for id,part_id in enumerate(parts):
    part_dict[id] = part_id

Label the nodes with their partition id and save to visualize in Gephi

In [28]:
nx.set_node_attributes(cora_c_nx,part_dict,name="partition")
nx.write_gexf(cora_c_nx,"dgl_cora_connected_parts.gexf")

Get node ids for each partition

In [29]:
parts_tensor = torch.Tensor(parts)
sg0_nodes = (parts_tensor == 0).nonzero()[:,0].tolist()
sg1_nodes = (parts_tensor == 1).nonzero()[:,0].tolist()

Create DGL subgraphs using partition list. This splits all the node features as well which is very conveniant

In [30]:
sg0 = dgl.node_subgraph(cora_dgl, sg0_nodes)
sg1 = dgl.node_subgraph(cora_dgl, sg1_nodes)

In [31]:
print("sg0 test:",sum(sg0.ndata['test_mask']==True))
print("sg0 val:",sum(sg0.ndata['val_mask']==True))
print("sg0 train:",sum(sg0.ndata['train_mask']==True))
print("sg1 test:",sum(sg1.ndata['test_mask']==True))
print("sg1 val:",sum(sg1.ndata['val_mask']==True))
print("sg1 train:",sum(sg1.ndata['train_mask']==True))

sg0 test: tensor(437)
sg0 val: tensor(213)
sg0 train: tensor(64)
sg1 test: tensor(478)
sg1 val: tensor(246)
sg1 train: tensor(58)


In [32]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

In [33]:
def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(100):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that you should only compute the losses of the nodes in the training set.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if e % 5 == 0:
            print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
                e, loss, val_acc, best_val_acc, test_acc, best_test_acc))

In [34]:
model = GCN(cora_dgl.ndata['feat'].shape[1], 16, dataset.num_classes)
model0 = GCN(sg0.ndata['feat'].shape[1], 16, dataset.num_classes)
model1 = GCN(sg1.ndata['feat'].shape[1], 16, dataset.num_classes)

In [35]:
train(cora_dgl,model)

In epoch 0, loss: 1.946, val acc: 0.118 (best 0.118), test acc: 0.131 (best 0.131)
In epoch 5, loss: 1.888, val acc: 0.268 (best 0.268), test acc: 0.293 (best 0.293)
In epoch 10, loss: 1.806, val acc: 0.314 (best 0.314), test acc: 0.342 (best 0.342)
In epoch 15, loss: 1.703, val acc: 0.436 (best 0.436), test acc: 0.478 (best 0.478)
In epoch 20, loss: 1.579, val acc: 0.584 (best 0.584), test acc: 0.593 (best 0.593)
In epoch 25, loss: 1.434, val acc: 0.675 (best 0.675), test acc: 0.674 (best 0.674)
In epoch 30, loss: 1.274, val acc: 0.717 (best 0.717), test acc: 0.718 (best 0.718)
In epoch 35, loss: 1.104, val acc: 0.732 (best 0.732), test acc: 0.739 (best 0.739)
In epoch 40, loss: 0.935, val acc: 0.752 (best 0.752), test acc: 0.741 (best 0.741)
In epoch 45, loss: 0.774, val acc: 0.760 (best 0.763), test acc: 0.751 (best 0.750)
In epoch 50, loss: 0.631, val acc: 0.769 (best 0.769), test acc: 0.757 (best 0.755)
In epoch 55, loss: 0.508, val acc: 0.769 (best 0.771), test acc: 0.756 (best 0

In [36]:
train(sg0, model0)

In epoch 0, loss: 1.945, val acc: 0.155 (best 0.155), test acc: 0.140 (best 0.140)
In epoch 5, loss: 1.810, val acc: 0.385 (best 0.394), test acc: 0.398 (best 0.405)
In epoch 10, loss: 1.647, val acc: 0.347 (best 0.394), test acc: 0.366 (best 0.405)
In epoch 15, loss: 1.488, val acc: 0.347 (best 0.394), test acc: 0.362 (best 0.405)
In epoch 20, loss: 1.333, val acc: 0.371 (best 0.394), test acc: 0.378 (best 0.405)
In epoch 25, loss: 1.172, val acc: 0.394 (best 0.394), test acc: 0.437 (best 0.405)
In epoch 30, loss: 1.009, val acc: 0.521 (best 0.521), test acc: 0.503 (best 0.503)
In epoch 35, loss: 0.851, val acc: 0.563 (best 0.563), test acc: 0.542 (best 0.542)
In epoch 40, loss: 0.705, val acc: 0.596 (best 0.596), test acc: 0.565 (best 0.563)
In epoch 45, loss: 0.577, val acc: 0.606 (best 0.606), test acc: 0.616 (best 0.586)
In epoch 50, loss: 0.467, val acc: 0.629 (best 0.629), test acc: 0.638 (best 0.636)
In epoch 55, loss: 0.376, val acc: 0.648 (best 0.648), test acc: 0.638 (best 0

In [37]:
train(sg1, model1)

In epoch 0, loss: 1.948, val acc: 0.041 (best 0.041), test acc: 0.033 (best 0.033)
In epoch 5, loss: 1.871, val acc: 0.366 (best 0.435), test acc: 0.431 (best 0.458)
In epoch 10, loss: 1.781, val acc: 0.402 (best 0.435), test acc: 0.458 (best 0.458)
In epoch 15, loss: 1.676, val acc: 0.423 (best 0.435), test acc: 0.464 (best 0.458)
In epoch 20, loss: 1.563, val acc: 0.244 (best 0.435), test acc: 0.318 (best 0.458)
In epoch 25, loss: 1.450, val acc: 0.171 (best 0.435), test acc: 0.243 (best 0.458)
In epoch 30, loss: 1.342, val acc: 0.175 (best 0.435), test acc: 0.243 (best 0.458)
In epoch 35, loss: 1.239, val acc: 0.244 (best 0.435), test acc: 0.318 (best 0.458)
In epoch 40, loss: 1.143, val acc: 0.333 (best 0.435), test acc: 0.372 (best 0.458)
In epoch 45, loss: 1.053, val acc: 0.398 (best 0.435), test acc: 0.444 (best 0.458)
In epoch 50, loss: 0.971, val acc: 0.476 (best 0.476), test acc: 0.498 (best 0.498)
In epoch 55, loss: 0.894, val acc: 0.545 (best 0.545), test acc: 0.559 (best 0