In [1]:
import dgl
import math
import torch
import numpy as np
import networkx as nx
from os import path
from pathlib import Path
from copy import deepcopy
from dgl.data import DGLDataset
from GraphDatasetInfo import (DistributionType, Distribution, GraphSubdatasetInfo, GraphDatasetInfo)


Using backend: pytorch


In [2]:
datasetname = 'ToyDataset01'
outputFolder = path.join('/home/andrew/GNN_Sandbox/GraphToyDatasets/00', datasetname)
Path(outputFolder).mkdir(parents=True, exist_ok=True)

graphlabel = 0
name = f'GraphClass{graphlabel}'
description = 'Equal distribution of all node features'
graphCount = 2000

nodesPerGraph = Distribution(
    distributionType=DistributionType.truncnorm, 
    minimum=10, maximum=120, 
    mean=30, standardDeviation=10)

nFeatMapping = {'P_t': 0, 'Eta': 1, 'Phi': 2, 'Mass': 3, 'Type': 4}
nodeFeat = [
    Distribution(10, 100, DistributionType.uniform), # index 0 -> P_t
    Distribution(-10, 10, DistributionType.uniform), # index 1 -> Eta
    Distribution(0, 2 * math.pi, DistributionType.uniform), # index 2 -> Phi
    Distribution(0.001, 1, DistributionType.uniform), # index 3 -> Mass
    Distribution(0, 2, DistributionType.uniform, roundToNearestInt=True) # index 4 -> Type
]

eFeatMapping = {'DeltaEta': 0, 'DeltaPhi': 1, 'RapiditySquared': 2}
edgeFeat = [
    None, # index 0 -> DeltaEta
    None, # index 1 -> DeltaPhi
    None # index 2 -> RapiditySquared
]

graphsubdatasetInfo1 = GraphSubdatasetInfo(
    name=name, description=description, label=graphlabel, 
    graphCount=graphCount, nodesPerGraph=nodesPerGraph, 
    nodeFeatMapping=nFeatMapping, nodeFeat=nodeFeat, 
    edgeFeatMapping=eFeatMapping, edgeFeat=edgeFeat)

graphlabel = 1
name = f'GraphClass{graphlabel}'

nodesPerGraph2 = deepcopy(nodesPerGraph)
nodesPerGraph2.mean = 65

nodeFeat2 = deepcopy(nodeFeat)
nodeFeat2[nFeatMapping['P_t']] = Distribution(60, 80, DistributionType.uniform)

graphsubdatasetInfo2 = GraphSubdatasetInfo(
    name=name, description=description, label=graphlabel, 
    graphCount=graphCount, nodesPerGraph=nodesPerGraph2, 
    nodeFeatMapping=nFeatMapping, nodeFeat=nodeFeat2, 
    edgeFeatMapping=eFeatMapping, edgeFeat=edgeFeat)

graphsubdatasetInfo2 = deepcopy(graphsubdatasetInfo1)
graphsubdatasetInfo2.nodesPerGraph

subdatasets = []
subdatasets.append(graphsubdatasetInfo2)
subdatasets.append(graphsubdatasetInfo1)

graphdatasetInfo = GraphDatasetInfo(
    name=datasetname,
    description='Two equally sized sub datasets with only one difference \(number of nodes per graph\)',
    splitPercentages={'train': 0.6, 'valid': 0.2, 'test': 0.2},
    graphSubDatasetInfos=subdatasets
)

graphdatasetInfo.SaveToJsonfile(outputFolder, f'{graphdatasetInfo.name}.json')

#can be loaded from json file like below:
#graphdatasetInfo = GraphDatasetInfo.LoadFromJsonfile(path.join(outputFolder, f'{graphdatasetInfo.name}.json'))

File ToyDataset01.json already exists! Do you want to overwrite the file? (y/n)
y
File ToyDataset01.json overwritten.


In [6]:
from ToyDGLDataset import ToyDGLDataset
dataset = ToyDGLDataset(name=datasetname, info=graphdatasetInfo, shuffleDataset=True, save_dir=outputFolder)

(1/2) Generating graphs from SubDataset GraphClass1: 100%|██████████| 2000/2000 [00:10<00:00, 197.53it/s]
(2/2) Generating graphs from SubDataset GraphClass0: 100%|██████████| 2000/2000 [00:02<00:00, 827.83it/s]


Calculating and saving histograms...
Num Graph classes: 2
Graph classes: [0, 1]
Number of graphs: 4000
Number of all nodes in all graphs: 188767
Number of all edges in all graphs: 10284774
Dim node features: 5
Node feature keys: ['P_t', 'Eta', 'Phi', 'Mass', 'Type']
Dim edge features: 3
Edge feature keys: ['DeltaEta', 'DeltaPhi', 'RapiditySquared']
Done saving data into cached files.


<Figure size 720x504 with 0 Axes>

rng = np.random.default_rng(seed=42)
nxgraphs = graphdatasetInfo.ToNetworkxGraphList()

print(f'Edge Features \n nx: {list(nxgraphs[0].edges.data())[0]},\n dgl: {dglgraph.edata}')

print(f'Graphs in the dataset: {len(nxgraphs)}')
print('Node features of the first graph in the graph list: ')
for node in nxgraphs[0].nodes(data=True):
    print(node)

print("Fully connected graph with edge features: ")
print(nxgraphs[0].edges.data())

pos = nx.spring_layout(nxgraphs[0])
options = {
    "node_color": "#A0CBE2",
    "width": 0.5,
    "with_labels": True,
    "node_size": 600
}
plt.figure(1,figsize=(10,10)) 
nx.draw(nxgraphs[0], pos, **options)

dglgraph = dgl.from_networkx(
    nxgraphs[0], 
    node_attrs=nodeFeatures.keys(), 
    edge_attrs=edgeFeatures.keys())
print(f'Node count - nx: {nxgraphs[0].number_of_nodes()}, dgl: {dglgraph.num_nodes()}')
print(f'Edge count - nx: {nxgraphs[0].number_of_edges()}, dgl: {dglgraph.num_edges()}')

print(f'Node Features \n nx: {nxgraphs[0].nodes(data=True)[0]},\n dgl: {dglgraph.ndata}')

print(dglgraph)