In [1]:
import dgl
import math
import torch
import numpy as np
from os import path
from pathlib import Path
from dgl.data import DGLDataset
from ToyDGLDataset import ToyDGLDataset, GetNodeFeatureVec, GetEdgeFeatureVec

Using backend: pytorch


In [2]:
datasetName = 'ToyDataset01'
datasetDir = path.join('/home/andrew/GNN_Sandbox/GraphDatasets', datasetName)
dataset = ToyDGLDataset(datasetName, datasetDir)

Done loading data from cached files.


In [3]:
dataset.printProperties()

Num Graph classes: 2
Graph classes: [0, 1]
Number of graphs: 20
Number of all nodes in all graphs: 982
Number of all edges in all graphs: 55454
Dim node features: 5
Node feature keys: ['P_t', 'Eta', 'Phi', 'Mass', 'Type']
Dim edge features: 3
Edge feature keys: ['DeltaPhi', 'DeltaEta', 'RapiditySquared']


In [5]:
print(len(dataset))
graph, label = dataset[0]
print(graph)
print(f'Label: {label}')
print(GetNodeFeatureVec(graph))

20
Graph(num_nodes=27, num_edges=702,
      ndata_schemes={'Type': Scheme(shape=(), dtype=torch.float64), 'Mass': Scheme(shape=(), dtype=torch.float64), 'Eta': Scheme(shape=(), dtype=torch.float64), 'Phi': Scheme(shape=(), dtype=torch.float64), 'P_t': Scheme(shape=(), dtype=torch.float64)}
      edata_schemes={'RapiditySquared': Scheme(shape=(), dtype=torch.float64), 'DeltaEta': Scheme(shape=(), dtype=torch.float64), 'DeltaPhi': Scheme(shape=(), dtype=torch.float64)})
Label: 0
tensor([[ 0.0000,  0.9674, -1.3320,  3.7294, 39.5653],
        [ 1.0000,  0.3934, -1.3661,  1.9948, 73.1044],
        [ 1.0000,  0.5667,  8.0511,  5.2543, 12.7182],
        [ 1.0000,  0.1026,  1.7799,  6.2018, 27.1879],
        [ 1.0000,  0.7496,  8.1134,  4.1509, 15.2236],
        [ 2.0000,  0.5874, -0.2570,  0.3350, 76.8397],
        [ 2.0000,  0.9794,  2.1959,  3.6444, 13.5918],
        [ 1.0000,  0.8795, -3.0623,  4.4351, 27.9040],
        [ 0.0000,  0.3838, -8.1866,  2.3626, 49.4673],
        [ 0.0000,  0.86

In [6]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
splitIndices = dataset.get_split_indices()

train_sampler = SubsetRandomSampler(splitIndices['train'])
test_sampler = SubsetRandomSampler(splitIndices['test'])

train_dataloader = GraphDataLoader(dataset, sampler=train_sampler, batch_size=32, drop_last=False)
test_dataloader = GraphDataLoader(dataset, sampler=test_sampler, batch_size=32, drop_last=False)

In [7]:
it = iter(train_dataloader)
batch = next(it)
print(batch)

[Graph(num_nodes=554, num_edges=31144,
      ndata_schemes={'Type': Scheme(shape=(), dtype=torch.float64), 'Mass': Scheme(shape=(), dtype=torch.float64), 'Eta': Scheme(shape=(), dtype=torch.float64), 'Phi': Scheme(shape=(), dtype=torch.float64), 'P_t': Scheme(shape=(), dtype=torch.float64)}
      edata_schemes={'RapiditySquared': Scheme(shape=(), dtype=torch.float64), 'DeltaEta': Scheme(shape=(), dtype=torch.float64), 'DeltaPhi': Scheme(shape=(), dtype=torch.float64)}), tensor([1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0])]


In [8]:
batched_graph, labels = batch
print('Number of nodes for each graph element in the batch:', batched_graph.batch_num_nodes())
print('Number of edges for each graph element in the batch:', batched_graph.batch_num_edges())
print(labels)
# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
print('The original graphs in the minibatch:')
print(graphs)

Number of nodes for each graph element in the batch: tensor([59, 25, 70, 68, 27, 63, 30, 56, 26, 88, 17, 25])
Number of edges for each graph element in the batch: tensor([3422,  600, 4830, 4556,  702, 3906,  870, 3080,  650, 7656,  272,  600])
tensor([1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0])
The original graphs in the minibatch:
[Graph(num_nodes=59, num_edges=3422,
      ndata_schemes={'Type': Scheme(shape=(), dtype=torch.float64), 'Mass': Scheme(shape=(), dtype=torch.float64), 'Eta': Scheme(shape=(), dtype=torch.float64), 'Phi': Scheme(shape=(), dtype=torch.float64), 'P_t': Scheme(shape=(), dtype=torch.float64)}
      edata_schemes={'RapiditySquared': Scheme(shape=(), dtype=torch.float64), 'DeltaEta': Scheme(shape=(), dtype=torch.float64), 'DeltaPhi': Scheme(shape=(), dtype=torch.float64)}), Graph(num_nodes=25, num_edges=600,
      ndata_schemes={'Type': Scheme(shape=(), dtype=torch.float64), 'Mass': Scheme(shape=(), dtype=torch.float64), 'Eta': Scheme(shape=(), dtype=torch.float64), 'Phi

In [9]:
from dgl.nn import GraphConv
import torch.nn as nn
import torch.nn.functional as F

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')

In [10]:
# Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 16, dataset.num_graph_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(50):
    for batched_graph, labels in train_dataloader:
        nodeFeatVec = GetNodeFeatureVec(batched_graph)
        pred = model(batched_graph, nodeFeatVec)
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #if epoch % 5 == 0:
    print(f'Epoch: {epoch}, Loss: {loss:.3f}')

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    nodeFeatVec = GetNodeFeatureVec(batched_graph)
    pred = model(batched_graph, nodeFeatVec)
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print('Test accuracy:', num_correct / num_tests)

Epoch: 0, Loss: 5.563
Epoch: 1, Loss: 4.493
Epoch: 2, Loss: 3.434
Epoch: 3, Loss: 2.392
Epoch: 4, Loss: 1.431
Epoch: 5, Loss: 0.827
Epoch: 6, Loss: 0.943
Epoch: 7, Loss: 1.346
Epoch: 8, Loss: 1.631
Epoch: 9, Loss: 1.753
Epoch: 10, Loss: 1.739
Epoch: 11, Loss: 1.623
Epoch: 12, Loss: 1.435
Epoch: 13, Loss: 1.208
Epoch: 14, Loss: 0.985
Epoch: 15, Loss: 0.821
Epoch: 16, Loss: 0.766
Epoch: 17, Loss: 0.819
Epoch: 18, Loss: 0.921
Epoch: 19, Loss: 1.007
Epoch: 20, Loss: 1.043
Epoch: 21, Loss: 1.027
Epoch: 22, Loss: 0.967
Epoch: 23, Loss: 0.885
Epoch: 24, Loss: 0.809
Epoch: 25, Loss: 0.761
Epoch: 26, Loss: 0.752
Epoch: 27, Loss: 0.774
Epoch: 28, Loss: 0.809
Epoch: 29, Loss: 0.837
Epoch: 30, Loss: 0.849
Epoch: 31, Loss: 0.841
Epoch: 32, Loss: 0.817
Epoch: 33, Loss: 0.786
Epoch: 34, Loss: 0.758
Epoch: 35, Loss: 0.742
Epoch: 36, Loss: 0.740
Epoch: 37, Loss: 0.749
Epoch: 38, Loss: 0.763
Epoch: 39, Loss: 0.772
Epoch: 40, Loss: 0.773
Epoch: 41, Loss: 0.765
Epoch: 42, Loss: 0.752
Epoch: 43, Loss: 0.73

In [11]:
print(graph.nodes())print(GetEdges)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26])


In [14]:
print(graph.dstnodes[0])

NodeSpace(data={'Type': tensor([0.], dtype=torch.float64), 'Mass': tensor([0.9674], dtype=torch.float64), 'Eta': tensor([-1.3320], dtype=torch.float64), 'Phi': tensor([3.7294], dtype=torch.float64), 'P_t': tensor([39.5653], dtype=torch.float64)})


In [18]:
print(torch.dstack(graph.edges()).squeeze())
print(graph.nodes[[0,1]])


tensor([[ 0,  1],
        [ 0,  2],
        [ 0,  3],
        ...,
        [26, 23],
        [26, 24],
        [26, 25]])
NodeSpace(data={'Type': tensor([0., 1.], dtype=torch.float64), 'Mass': tensor([0.9674, 0.3934], dtype=torch.float64), 'Eta': tensor([-1.3320, -1.3661], dtype=torch.float64), 'Phi': tensor([3.7294, 1.9948], dtype=torch.float64), 'P_t': tensor([39.5653, 73.1044], dtype=torch.float64)})


In [20]:
print(graph.out_edges(0))

(tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0]), tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24, 25, 26]))


In [64]:
def GetEdgeListOfNode(edgeList, nodeLabel: int):
    """
    returns tensor of [srcNodeID, dstNodeId, edgeID]
    """
    return torch.dstack(graph.out_edges(nodeLabel)).squeeze()

def GetEdgeList(graph):
    """
    returns a tensor of [srcNodeID, dstNodeId, edgeID] of the whole graph
    """
    #return torch.dstack(graph.edges('all', order='srcdst')).squeeze()
    return graph.edges('all', order='srcdst')

u, v, eid = GetEdgeList(graph)
sourceNodeID = 0
print((u == sourceNodeID).nonzero(as_tuple=True)[0])
    
print(GetEdgeListOfNode(graph, 0))
u, v = graph.out_edges(0)
print(graph.edge_ids(u, v))
#print(graph.edges('all', order='srcdst'))
print(GetEdgeList(graph))

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25])
tensor([[ 0,  1],
        [ 0,  2],
        [ 0,  3],
        [ 0,  4],
        [ 0,  5],
        [ 0,  6],
        [ 0,  7],
        [ 0,  8],
        [ 0,  9],
        [ 0, 10],
        [ 0, 11],
        [ 0, 12],
        [ 0, 13],
        [ 0, 14],
        [ 0, 15],
        [ 0, 16],
        [ 0, 17],
        [ 0, 18],
        [ 0, 19],
        [ 0, 20],
        [ 0, 21],
        [ 0, 22],
        [ 0, 23],
        [ 0, 24],
        [ 0, 25],
        [ 0, 26]])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25])
(tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2, 

In [53]:
def ConcatNodeAndEdgeFeatures(graph, nodeLabel: int):
    pass

print(graph.edge_ids([0, 26], [26, 0]))

edgeFeatureVec = GetEdgeFeatureVec(graph)
print(edgeFeatureVec[25])
print(edgeFeatureVec[676])
#print(len(edgeFeatureVec))
#print(edgeFeatureVec)

tensor([ 25, 676])
tensor([113.6136,  10.3400,   2.5879])
tensor([113.6136,  10.3400,   2.5879])


In [51]:
print(graph.edata['DeltaPhi'])
print(graph.dstdata['P_t'])

tensor([1.7346, 1.5249, 2.4724, 0.4215, 3.3944, 0.0850, 0.7057, 1.3668, 2.3474,
        1.4097, 1.1981, 1.8204, 2.5296, 2.4507, 0.3479, 0.7056, 1.6235, 1.3469,
        1.3588, 2.2284, 0.4408, 1.8273, 1.8325, 1.0572, 0.6779, 2.5879, 1.7346,
        3.2595, 4.2070, 2.1561, 1.6598, 1.6496, 2.4403, 0.3678, 4.0820, 3.1443,
        2.9327, 0.0858, 0.7950, 4.1853, 2.0825, 1.0290, 0.1111, 3.0815, 3.0934,
        0.4938, 1.2938, 0.0927, 3.5671, 0.6774, 2.4125, 0.8533, 1.5249, 3.2595,
        0.9475, 1.1034, 4.9193, 1.6099, 0.8192, 2.8917, 0.8224, 0.1152, 0.3268,
        3.3453, 4.0545, 0.9258, 1.1770, 2.2305, 3.1484, 0.1780, 0.1661, 3.7533,
        1.9657, 3.3522, 0.3076, 2.5821, 0.8470, 4.1128, 2.4724, 4.2070, 0.9475,
        2.0509, 5.8668, 2.5574, 1.7667, 3.8392, 0.1250, 1.0627, 1.2743, 4.2928,
        5.0020, 0.0217, 2.1245, 3.1780, 4.0959, 1.1255, 1.1136, 4.7008, 2.9132,
        4.2997, 0.6399, 3.5296, 1.7945, 5.0602, 0.4215, 2.1561, 1.1034, 2.0509,
        3.8159, 0.5065, 0.2842, 1.7883, 