This file allows for testing the GraphWorld setup with GNN implementations.
It is currently set up to test the SSL methods for the JL benchmarker.

Through this notebook you can attach a debugger.
Note that graph_tool does not work on windows, so we cannot use the graph generators.
Instead, we use the standard datasets from PyG.

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from graph_world.self_supervised_learning.benchmarker_jl import NNNodeBenchmarkerJL
from graph_world.models.basic_gnn import GCN
from torch_geometric.datasets import Planetoid

from graph_world.self_supervised_learning.pretext_tasks.auxiliary_property_based import *
from graph_world.self_supervised_learning.pretext_tasks.contrastive_based_different_scale import *
from graph_world.self_supervised_learning.pretext_tasks.hybrid import *

In [3]:
# Parameter setup (for cora)
benchmark_params = {
    'epochs' : 500,
    'lr' : 0.01,
    'lambda' : 1
}

h_params = {
    'in_channels' : 1433,
    'hidden_channels' : 128,
    'num_layers' : 2,
    'dropout' : 0.2,
    "embedding_corruption_ratio" : 0.1, 
    "partial_embedding_reconstruction" : True,
    'n_parts': 10,
    'shortest_path_cutoff': 6,
    'N_classes': 4,
    'k_largest': 20,
    'k': 20,
    'temperature': 1,
    'num_cluster_iter': 500,
    'alpha': 0.15,
    'n_clusters': 30,
    'num_iter': 10,
    'n_partitions': 8,
    'B_perc': 0.1,
    'k': 8,
    'P_perc': 1.3,
    'alpha': 0.15,
    'micro_meso_macro_weights': [1/3, 1/3, 1/3]
}

generator_config = {
    'num_clusters' : 7,
}

pretext_task = G_Zoom

In [4]:
# Get dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')[0]
# dataset = KarateClub()[0]


In [5]:
# Training. You can attach a debugger to w/e is needed inside train
benchmarker = NNNodeBenchmarkerJL(generator_config=generator_config, model_class=GCN, 
                benchmark_params=benchmark_params, h_params=h_params, pretext_task=pretext_task)
                
# benchmarker.SetMasks(train_mask=dataset.train_mask, val_mask=~dataset.train_mask, test_mask=~dataset.train_mask)
benchmarker.SetMasks(train_mask=dataset.train_mask, val_mask=dataset.val_mask, test_mask=dataset.test_mask)
benchmarker.train(data=dataset, tuning_metric="rocauc_ovr", tuning_metric_is_loss=False)

([4.927452087402344,
  4.674591541290283,
  4.182645797729492,
  3.441171646118164,
  3.0060007572174072,
  3.0839226245880127,
  9.174221992492676,
  6.384239673614502,
  4.529088497161865,
  6.852355003356934,
  3.2675297260284424,
  3.698986768722534,
  3.355661630630493,
  3.1942474842071533,
  3.2070717811584473,
  3.266439199447632,
  3.310626745223999,
  3.3232743740081787,
  3.3119874000549316,
  3.248483419418335,
  3.1848349571228027,
  3.113126516342163,
  3.043903350830078,
  2.973041534423828,
  2.9322633743286133,
  2.8919997215270996,
  2.891313314437866,
  2.849264144897461,
  2.8333539962768555,
  2.8138811588287354,
  2.805692672729492,
  2.7852783203125,
  2.780034303665161,
  2.7704477310180664,
  2.7731032371520996,
  2.762077808380127,
  2.738797187805176,
  2.744074583053589,
  2.7339065074920654,
  2.7136070728302,
  2.7322874069213867,
  2.7391233444213867,
  2.750556707382202,
  2.7241716384887695,
  2.721383810043335,
  2.7514407634735107,
  2.728906631469726