This file allows for testing the GraphWorld setup with GNN implementations.
It is currently set up to test the SSL methods for the JL benchmarker.

Through this notebook you can attach a debugger.
Note that graph_tool does not work on windows, so we cannot use the graph generators.
Instead, we use the standard datasets from PyG.

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from graph_world.self_supervised_learning.benchmarker_jl import NNNodeBenchmarkerJL
from graph_world.models.basic_gnn import GCN
from torch_geometric.datasets import Planetoid

from graph_world.self_supervised_learning.pretext_tasks.auxiliary_property_based import *
from graph_world.self_supervised_learning.pretext_tasks.contrastive_based_different_scale import *
from graph_world.self_supervised_learning.tensor_utils import get_top_k_indices

In [105]:
import torch
from torch_ppr import personalized_page_rank
edge_index = torch.as_tensor(data=[(0, 1), (1, 2), (1, 3), (2, 4)]).t()
device = torch.device('cpu')
S = personalized_page_rank(edge_index=edge_index, alpha=0.1, device=device)
I = S.topk(2, dim=1).indices
I, S

Using maximize_memory_utilization on non-CUDA tensors. This may lead to undocumented crashes due to CPU OOM killer.


(tensor([[1, 0],
         [1, 2],
         [2, 1],
         [1, 3],
         [2, 1]]),
 tensor([[0.2158, 0.3861, 0.1946, 0.1158, 0.0876],
         [0.1287, 0.4289, 0.2163, 0.1287, 0.0973],
         [0.0973, 0.3245, 0.3316, 0.0973, 0.1493],
         [0.1158, 0.3861, 0.1946, 0.2158, 0.0876],
         [0.0876, 0.2920, 0.2985, 0.0876, 0.2343]]))

In [124]:
X = torch.tensor([
    [1, 1, 1],
    [1, 1, 1],
    [2, 2, 2],
    [3, 3, 3],
    [4, 4, 4]
])

indices = torch.tensor([True] * 5)
i = [2]

indices[i] = False
X[indices, :] = 0
indices.clone()

tensor([ True,  True, False,  True,  True])

In [125]:
torch.ones(5, dtype=torch.bool)

tensor([True, True, True, True, True])

In [None]:
import torch_geometric



In [128]:
from torch_geometric.utils import subgraph

K = 5
for i in range(5):
    idx, _, mask = subgraph([4, 2, 1], edge_index=edge_index, num_nodes=K, return_edge_mask=True)
    
    print(mask)
    break

tensor([False,  True, False,  True])


In [79]:
torch.select(X; )

IndexError: index 2 is out of bounds for dimension 0 with size 2

In [70]:
X = S.repeat(2, 1, 1)
print(X)

mask = I.t()  # or dtype=torch.ByteTensor
print('Mask: ', mask)

# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)

# Select based on the new expanded mask
Y = torch.masked_select(X, (mask_ == 1)) # does not preserve the dims
print(Y).

tensor([[[0.2158, 0.3861, 0.1946, 0.1158, 0.0876],
         [0.1287, 0.4289, 0.2163, 0.1287, 0.0973],
         [0.0973, 0.3245, 0.3316, 0.0973, 0.1493],
         [0.1158, 0.3861, 0.1946, 0.2158, 0.0876],
         [0.0876, 0.2920, 0.2985, 0.0876, 0.2343]],

        [[0.2158, 0.3861, 0.1946, 0.1158, 0.0876],
         [0.1287, 0.4289, 0.2163, 0.1287, 0.0973],
         [0.0973, 0.3245, 0.3316, 0.0973, 0.1493],
         [0.1158, 0.3861, 0.1946, 0.2158, 0.0876],
         [0.0876, 0.2920, 0.2985, 0.0876, 0.2343]]])
Mask:  tensor([[1, 1, 2, 1, 2],
        [0, 2, 1, 3, 1]])
tensor([[[1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2],
         [1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2]],

        [[0, 0, 0, 0, 0],
         [2, 2, 2, 2, 2],
         [1, 1, 1, 1, 1],
         [3, 3, 3, 3, 3],
         [1, 1, 1, 1, 1]]])
tensor([0.2158, 0.3861, 0.1946, 0.1158, 0.0876, 0.1287, 0.4289, 0.2163, 0.1287,
        0.0973, 0.1158, 0.3861, 0.1946, 0.2158, 0.0876, 0.0973, 0.3245, 0.3316,
    

In [62]:
mask_

tensor([[[1, 1],
         [0, 0],
         [0, 0]],

        [[0, 0],
         [1, 1],
         [0, 0]],

        [[0, 0],
         [0, 0],
         [0, 0]],

        [[1, 1],
         [0, 0],
         [0, 0]]])

In [60]:
S.repeat(2, 1, 1)

tensor([[[0.2158, 0.3861, 0.1946, 0.1158, 0.0876],
         [0.1287, 0.4289, 0.2163, 0.1287, 0.0973],
         [0.0973, 0.3245, 0.3316, 0.0973, 0.1493],
         [0.1158, 0.3861, 0.1946, 0.2158, 0.0876],
         [0.0876, 0.2920, 0.2985, 0.0876, 0.2343]],

        [[0.2158, 0.3861, 0.1946, 0.1158, 0.0876],
         [0.1287, 0.4289, 0.2163, 0.1287, 0.0973],
         [0.0973, 0.3245, 0.3316, 0.0973, 0.1493],
         [0.1158, 0.3861, 0.1946, 0.2158, 0.0876],
         [0.0876, 0.2920, 0.2985, 0.0876, 0.2343]]])

In [2]:
# Parameter setup (for cora)
benchmark_params = {
    'epochs' : 50,
    'lr' : 0.01,
    'lambda' : 1
}

h_params = {
    'in_channels' : 1433,
    'hidden_channels' : 16,
    'num_layers' : 2,
    'dropout' : 0.5,
    "embedding_corruption_ratio" : 0.1, 
    "partial_embedding_reconstruction" : True,
    'n_parts': 10,
    'shortest_path_cutoff': 6,
    'N_classes': 4,
    'k_largest': 20,
    'k': 30,
    'temperature': 1,
    'num_cluster_iter': 500,
    'alpha': 1/2,
}

generator_config = {
    'num_clusters' : 7,
}

pretext_tasks = [DeepGraphInfomax]

In [3]:
# Get dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')[0]
# dataset = KarateClub()[0]


In [4]:
# Training. You can attach a debugger to w/e is needed inside train
benchmarker = NNNodeBenchmarkerJL(generator_config=generator_config, model_class=GCN, 
                benchmark_params=benchmark_params, h_params=h_params, pretext_tasks=pretext_tasks)
# benchmarker.SetMasks(train_mask=dataset.train_mask, val_mask=~dataset.train_mask, test_mask=~dataset.train_mask)
benchmarker.SetMasks(train_mask=dataset.train_mask, val_mask=dataset.val_mask, test_mask=dataset.test_mask)
benchmarker.train(data=dataset, tuning_metric="rocauc_ovr", tuning_metric_is_loss=False)

GCN(1433, 16, num_layers=2)


([3.3405697345733643,
  3.3243794441223145,
  3.306535005569458,
  3.2765321731567383,
  3.236759662628174,
  3.2011375427246094,
  3.120988607406616,
  3.0833921432495117,
  2.955915927886963,
  2.941879987716675,
  2.8654351234436035,
  2.795201539993286,
  2.671541690826416,
  2.683580160140991,
  2.4859378337860107,
  2.4457130432128906,
  2.467940330505371,
  2.412355899810791,
  2.308687210083008,
  2.186960458755493,
  2.2213480472564697,
  2.1361312866210938,
  2.1299047470092773,
  2.1119649410247803,
  2.090444564819336,
  2.0715558528900146,
  1.9986580610275269,
  2.0037200450897217,
  1.8863670825958252,
  1.953214406967163,
  1.7313849925994873,
  1.8875255584716797,
  1.7715727090835571,
  1.828615427017212,
  1.7530622482299805,
  1.6012623310089111,
  1.6210243701934814,
  1.6686795949935913,
  1.6426074504852295,
  1.6730304956436157,
  1.596615195274353,
  1.5263845920562744,
  1.5755417346954346,
  1.439550518989563,
  1.3799687623977661,
  1.3692924976348877,
  1.3