This file allows for testing the GraphWorld setup with GNN implementations.
It is currently set up to test the SSL methods for the JL benchmarker.

Through this notebook you can attach a debugger.
Note that graph_tool does not work on windows, so we cannot use the graph generators.
Instead, we use the standard datasets from PyG.

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from graph_world.self_supervised_learning.benchmarker_jl import NNNodeBenchmarkerJL
from graph_world.models.basic_gnn import GCN, SuperGAT
from torch_geometric.datasets import Planetoid

from graph_world.self_supervised_learning.pretext_tasks.generation_based import *


In [2]:
# Parameter setup (for cora)
benchmark_params = {
    'epochs' : 50,
    'lr' : 0.01,
    'lambda' : 0.5
}

h_params = {
    'in_channels' : 1433,
    'hidden_channels' : 16,
    'num_layers' : 2,
    'dropout' : 0.5,
    'edge_mask_ratio' : 0.3
}

generator_config = {
    'num_clusters' : 7,
}

pretext_tasks = [EdgeMask]

In [3]:
# Get dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')[0]

In [4]:
# Training. You can attach a debugger to w/e is needed inside train
benchmarker = NNNodeBenchmarkerJL(generator_config=generator_config, model_class=GCN, 
                benchmark_params=benchmark_params, h_params=h_params, pretext_tasks=pretext_tasks)
benchmarker.SetMasks(train_mask=dataset.train_mask, val_mask=dataset.val_mask, test_mask=dataset.test_mask)
benchmarker.train(data=dataset, tuning_metric="rocauc_ovr", tuning_metric_is_loss=False)

GCN(1433, 16, num_layers=2)


([2.6537816524505615,
  2.624413251876831,
  2.5956201553344727,
  2.5812158584594727,
  2.5307974815368652,
  2.4700379371643066,
  2.4327988624572754,
  2.3743720054626465,
  2.3125715255737305,
  2.2486112117767334,
  2.1161229610443115,
  2.039182662963867,
  1.9125522375106812,
  1.836040735244751,
  1.733703851699829,
  1.6972777843475342,
  1.5582373142242432,
  1.5974316596984863,
  1.497349739074707,
  1.402721643447876,
  1.4374160766601562,
  1.4730026721954346,
  1.3112297058105469,
  1.2888453006744385,
  1.3025895357131958,
  1.2284963130950928,
  1.1566731929779053,
  1.1827781200408936,
  1.2713623046875,
  1.0194432735443115,
  1.0712332725524902,
  1.0471159219741821,
  1.0063486099243164,
  0.9605315923690796,
  0.99690842628479,
  0.9155166149139404,
  1.039700984954834,
  0.9176446199417114,
  0.8657811880111694,
  0.837931215763092,
  0.8789262175559998,
  0.9885526895523071,
  0.8928833603858948,
  0.9445092678070068,
  0.9821723699569702,
  0.7828702926635742,
 