This file allows for testing the GraphWorld setup with GNN implementations.
It is currently set up to test the SSL methods for the JL benchmarker.

Through this notebook you can attach a debugger.
Note that graph_tool does not work on windows, so we cannot use the graph generators.
Instead, we use the standard datasets from PyG.

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from graph_world.self_supervised_learning.benchmarker_jl import NNNodeBenchmarkerJL
from graph_world.models.basic_gnn import GCN
from torch_geometric.datasets import Planetoid

from graph_world.self_supervised_learning.pretext_tasks import *


In [2]:
# Parameter setup (for cora)
benchmark_params = {
    'epochs' : 50,
    'lr' : 0.01,
    'lambda' : 1
}

h_params = {
    'in_channels' : 1433,
    'hidden_channels' : 16,
    'num_layers' : 2,
    'dropout' : 0.5,
    "embedding_corruption_ratio" : 0.1, 
    "partial_embedding_reconstruction" : True,
    'n_clusters': 200
}

generator_config = {
    'num_clusters' : 7,
}

pretext_tasks = [NodeClusteringWithAlignment]

In [3]:
# Get dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')[0]

In [47]:
from sklearn.cluster import KMeans

# Step 0: Setup
train_mask = dataset.train_mask
X = dataset.x
y = dataset.y
num_classes = y.unique().shape[0]
feat_dim = X.shape[1]
centroids_labeled = torch.zeros((num_classes, feat_dim))
n_clusters=200
random_state=1

# Step 1: Compute centroids in each cluster by the mean in each class
for cn in range(num_classes):
    lf = X[train_mask]
    ll = y[train_mask]
    centroids_labeled[cn] = lf[ll == cn].mean(axis=0)

# Step 2: Set cluster labels for each node
cluster_labels = torch.ones(y.shape, dtype=torch.int64) * -1
cluster_labels[train_mask] = y[train_mask]

# Step 3: Train KMeans on all points
kmeans = KMeans(n_clusters=n_clusters, random_state=random_state).fit(X)

# Step 4: Perform alignment mechanism
# 1) Compute its centroids
# 2) Find cluster closest to the centroid computed in step 1
# 3) Assign all unlabeled nodes to that closest cluster.

for cn in range(n_clusters):
    # v_l
    centroids_unlabeled = X[kmeans.labels_ == cn].mean(axis=0)

    # Equation 5
    label_for_cluster = np.linalg.norm(centroids_labeled - centroids_unlabeled, axis=1).argmin()
    for node in np.where(kmeans.labels_ == cn)[0]:
        if not train_mask[node]:
            cluster_labels[node] = label_for_cluster
        
pseudo_labels = cluster_labels

In [4]:
# Training. You can attach a debugger to w/e is needed inside train
benchmarker = NNNodeBenchmarkerJL(generator_config=generator_config, model_class=GCN, 
                benchmark_params=benchmark_params, h_params=h_params, pretext_tasks=pretext_tasks)
benchmarker.SetMasks(train_mask=dataset.train_mask, val_mask=dataset.val_mask, test_mask=dataset.test_mask)
benchmarker.train(data=dataset, tuning_metric="rocauc_ovr", tuning_metric_is_loss=False)

GCN(1433, 16, num_layers=2)


([3.8831515312194824,
  3.8361587524414062,
  3.781806468963623,
  3.708414077758789,
  3.6395370960235596,
  3.586761713027954,
  3.5007543563842773,
  3.4798481464385986,
  3.3896591663360596,
  3.314657688140869,
  3.2314326763153076,
  3.1372103691101074,
  3.0012292861938477,
  2.9711737632751465,
  2.9290904998779297,
  2.8051087856292725,
  2.792025089263916,
  2.691007137298584,
  2.6372809410095215,
  2.6427688598632812,
  2.4474406242370605,
  2.467951536178589,
  2.513339042663574,
  2.289475202560425,
  2.383141040802002,
  2.192713499069214,
  2.195323944091797,
  2.0936620235443115,
  2.22123384475708,
  2.1294121742248535,
  2.0268070697784424,
  2.0474939346313477,
  2.025908946990967,
  2.055022716522217,
  1.9677447080612183,
  1.7958502769470215,
  1.9264920949935913,
  1.825606346130371,
  1.859749674797058,
  1.857128620147705,
  1.7803820371627808,
  1.7667732238769531,
  1.6591300964355469,
  1.813990831375122,
  1.6925104856491089,
  1.6272449493408203,
  1.6607