This file allows for testing the GraphWorld setup with GNN implementations.
It is currently set up to test the SSL methods for the JL benchmarker.

Through this notebook you can attach a debugger.
Note that graph_tool does not work on windows, so we cannot use the graph generators.
Instead, we use the standard datasets from PyG.

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from graph_world.self_supervised_learning.benchmarker import NNNodeBenchmarkerSSL
from graph_world.models.basic_gnn import GCN, SuperGAT
from torch_geometric.datasets import Planetoid, KarateClub, FakeDataset
from torch_geometric.transforms import RandomNodeSplit

from graph_world.self_supervised_learning.pretext_tasks.auxiliary_property_based import *


In [2]:
# Get dataset
dataset = FakeDataset(num_graphs = 1, avg_num_nodes = 100, num_channels = 16, num_classes = 4)[0]
dataset = RandomNodeSplit(split = "random", num_test = 20, num_val = 20, num_train_per_class=2)(dataset)

In [3]:
# Parameter setup (for cora)
benchmark_params = {
    'downstream_epochs' : 200,
    'pretext_epochs' : 200,
    'downstream_lr' : 3e-4,
    'pretext_lr' : 3e-4,
    'patience': 100
}

h_params = {
    'in_channels' : dataset.x.shape[1],
    'hidden_channels' : 128,
    'num_layers' : 2,
    'dropout' : 0.5,
}

pretext_params = {
    'k_largest': 3,
}

generator_config = {
    'num_clusters' : 4,
}

pretext_task = PairwiseAttrSim
training_scheme = 'JL'

In [4]:
# Training. You can attach a debugger to w/e is needed inside train
benchmarker = NNNodeBenchmarkerSSL(generator_config=generator_config, model_class=GCN, 
                benchmark_params=benchmark_params, h_params=h_params,
                pretext_task=pretext_task, pretext_params = pretext_params, training_scheme=training_scheme)
benchmarker.SetMasks(train_mask=dataset.train_mask, val_mask=dataset.val_mask, test_mask=dataset.test_mask)
benchmarker.train(data=dataset, tuning_metric="rocauc_ovr", tuning_metric_is_loss=False)

([],
 [2.1767492294311523,
  2.1490089893341064,
  2.17777419090271,
  2.1470038890838623,
  2.1398282051086426,
  2.192577362060547,
  2.149956226348877,
  2.1515071392059326,
  2.161991596221924,
  2.1279027462005615,
  2.1137962341308594,
  2.193523406982422,
  2.1907527446746826,
  2.159590721130371,
  2.165247917175293,
  2.1893229484558105,
  2.1340012550354004,
  2.1140055656433105,
  2.15274715423584,
  2.114549398422241,
  2.0716285705566406,
  2.1964821815490723,
  2.2144992351531982,
  2.1985256671905518,
  2.2189998626708984,
  2.14143443107605,
  2.2199175357818604,
  2.124197483062744,
  2.1487441062927246,
  2.2451322078704834,
  2.176652431488037,
  2.150010108947754,
  2.2175333499908447,
  2.1868739128112793,
  2.2042508125305176,
  2.1452369689941406,
  2.1880154609680176,
  2.156446695327759,
  2.2171664237976074,
  2.147845983505249,
  2.0891478061676025,
  2.117432117462158,
  2.1924142837524414,
  2.166722059249878,
  2.151304244995117,
  2.1071767807006836,
  2.