In [1]:
import dgl
import math
import torch
import numpy as np
import networkx as nx
from os import path
from pathlib import Path
from dgl.data import DGLDataset
from GraphDatasetInfo import (DistributionType, Split, Distribution, GraphSubdatasetInfo, GraphDatasetInfo)

datasetname = 'ToyDataset01'
outputFolder = path.join('/home/andrew/GNN_Sandbox/GraphDatasets', datasetname)
Path(outputFolder).mkdir(parents=True, exist_ok=True)

Using backend: pytorch


In [2]:
graphlabel = 0
name = f'GraphClass{graphlabel}'
description = 'Equal distribution of all node features'
graphCount = 10

nodesPerGraph = Distribution(
    distributionType=DistributionType.TruncNorm, 
    minimum=10, maximum=120, 
    mean=30, standardDeviation=10)

nodeFeatures = { 
    'P_t': Distribution(10, 100, DistributionType.Uniform),
    'Eta': Distribution(-10, 10, DistributionType.Uniform),
    'Phi': Distribution(0, 2 * math.pi, DistributionType.Uniform),
    'Mass': Distribution(0.001, 1, DistributionType.Uniform),
    'Type': Distribution(0, 2, DistributionType.Uniform, roundToNearestInt=True)
}

edgeFeatures = {
    'DeltaPhi': None,
    'DeltaEta': None,
    'RapiditySquared': None
}

graphsubdatasetInfo1 = GraphSubdatasetInfo(
    name, description, graphlabel, graphCount, nodesPerGraph, nodeFeatures, edgeFeatures)

In [3]:
graphlabel = 1
name = f'GraphClass{graphlabel}'

nodesPerGraph2 = Distribution(
    distributionType=DistributionType.TruncNorm, 
    minimum=10, maximum=120, 
    mean=65,standardDeviation=10)

from copy import deepcopy
nodeFeatures2 = deepcopy(nodeFeatures)
nodeFeatures2['P_t'] = Distribution(60, 80, DistributionType.Uniform)

graphsubdatasetInfo2 = GraphSubdatasetInfo(
    name, description, graphlabel, graphCount, nodesPerGraph2, nodeFeatures2, edgeFeatures)

In [4]:
subdatasets = []
subdatasets.append(graphsubdatasetInfo2)
subdatasets.append(graphsubdatasetInfo1)

graphdatasetInfo = GraphDatasetInfo(
    name=datasetname,
    description='Two equally sized sub datasets with only one difference \(number of nodes per graph\)',
    splitPercentages={'train': 0.6, 'valid': 0.2, 'test': 0.2},
    graphSubDatasetInfos=subdatasets
)

graphdatasetInfo.SaveToJsonfile(outputFolder, f'{graphdatasetInfo.Name}.json')

#can be loaded from json file like below:
graphdatasetInfo = GraphDatasetInfo.LoadFromJsonfile(path.join(outputFolder, f'{graphdatasetInfo.Name}.json'))

In [5]:
from ToyDGLDataset import ToyDGLDataset
dataset = ToyDGLDataset(name=datasetname, info=graphdatasetInfo, shuffleDataset=True, save_dir=outputFolder)

(1/2) Generating graphs from SubDataset GraphClass1: 100%|██████████| 10/10 [00:00<00:00, 47.76it/s]
(2/2) Generating graphs from SubDataset GraphClass0: 100%|██████████| 10/10 [00:00<00:00, 201.07it/s]


Calculating and saving histograms...
Num Graph classes: 2
Graph classes: [0, 1]
Number of graphs: 20
Number of all nodes in all graphs: 916
Number of all edges in all graphs: 48304
Dim node features: 5
Node feature keys: ['P_t', 'Eta', 'Phi', 'Mass', 'Type']
Dim edge features: 3
Edge feature keys: ['DeltaPhi', 'DeltaEta', 'RapiditySquared']
Done saving data into cached files.


<Figure size 720x504 with 0 Axes>

rng = np.random.default_rng(seed=42)
nxgraphs = graphdatasetInfo.ToNetworkxGraphList()

print(f'Edge Features \n nx: {list(nxgraphs[0].edges.data())[0]},\n dgl: {dglgraph.edata}')

print(f'Graphs in the dataset: {len(nxgraphs)}')
print('Node features of the first graph in the graph list: ')
for node in nxgraphs[0].nodes(data=True):
    print(node)

print("Fully connected graph with edge features: ")
print(nxgraphs[0].edges.data())

pos = nx.spring_layout(nxgraphs[0])
options = {
    "node_color": "#A0CBE2",
    "width": 0.5,
    "with_labels": True,
    "node_size": 600
}
plt.figure(1,figsize=(10,10)) 
nx.draw(nxgraphs[0], pos, **options)

dglgraph = dgl.from_networkx(
    nxgraphs[0], 
    node_attrs=nodeFeatures.keys(), 
    edge_attrs=edgeFeatures.keys())
print(f'Node count - nx: {nxgraphs[0].number_of_nodes()}, dgl: {dglgraph.num_nodes()}')
print(f'Edge count - nx: {nxgraphs[0].number_of_edges()}, dgl: {dglgraph.num_edges()}')

print(f'Node Features \n nx: {nxgraphs[0].nodes(data=True)[0]},\n dgl: {dglgraph.ndata}')

print(dglgraph)