This file allows for testing the GraphWorld setup with GNN implementations.
It is currently set up to test the SSL methods for the JL benchmarker.

Through this notebook you can attach a debugger.
Note that graph_tool does not work on windows, so we cannot use the graph generators.
Instead, we use the standard datasets from PyG.

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from graph_world.self_supervised_learning.benchmarker_jl import NNNodeBenchmarkerJL
from graph_world.models.basic_gnn import GCN
from torch_geometric.datasets import Planetoid

from graph_world.self_supervised_learning.pretext_tasks.auxiliary_property_based import *
from graph_world.self_supervised_learning.pretext_tasks.contrastive_based_different_scale import *

In [2]:
# Parameter setup (for cora)
benchmark_params = {
    'epochs' : 10,
    'lr' : 0.01,
    'lambda' : 1
}

h_params = {
    'in_channels' : 1433,
    'hidden_channels' : 16,
    'num_layers' : 2,
    'dropout' : 0.5,
    "embedding_corruption_ratio" : 0.1, 
    "partial_embedding_reconstruction" : True,
    'n_parts': 10,
    'shortest_path_cutoff': 6,
    'N_classes': 4,
    'k_largest': 20,
    'k': 15,
    'temperature': 1,
    'num_cluster_iter': 500,
    'alpha': 0.15,
    'n_clusters': 30
}

generator_config = {
    'num_clusters' : 7,
}

pretext_task = DeepGraphInfomax

In [3]:
# Get dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')[0]
# dataset = KarateClub()[0]


In [4]:
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

class AvgNeighbor(nn.Module):
    def __init__(self):
        super(AvgNeighbor, self).__init__()

    def forward(self, seq, adj_ori):
        if torch.cuda.is_available():
            adj_ori = adj_ori.cuda()
        return torch.unsqueeze(torch.spmm(adj_ori, torch.squeeze(seq, 0)), 0)

In [8]:
dataset.x[[0, 0, 0], :]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [6]:
c = AvgNeighbor()






TypeError: new() received an invalid combination of arguments - got (Tensor), but expected one of:
 * (*, torch.device device)
 * (Tensor indices, Tensor values, *, torch.device device)
 * (Tensor indices, Tensor values, tuple of ints size, *, torch.device device)
 * (tuple of ints size, *, torch.device device)


In [4]:
# Training. You can attach a debugger to w/e is needed inside train
benchmarker = NNNodeBenchmarkerJL(generator_config=generator_config, model_class=GCN, 
                benchmark_params=benchmark_params, h_params=h_params, pretext_task=pretext_task)
                
# benchmarker.SetMasks(train_mask=dataset.train_mask, val_mask=~dataset.train_mask, test_mask=~dataset.train_mask)
benchmarker.SetMasks(train_mask=dataset.train_mask, val_mask=dataset.val_mask, test_mask=dataset.test_mask)
benchmarker.train(data=dataset, tuning_metric="rocauc_ovr", tuning_metric_is_loss=False)

GCN(1433, 16, num_layers=2)


  super(Adam, self).__init__(params, defaults)
  warn("CUDA is not available, disabling CUDA profiling")


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::zeros         0.88%       2.311ms         1.17%       3.078ms      66.913us            46  
                                            aten::empty         0.54%       1.421ms         0.54%       1.421ms       3.116us           456  
                                            aten::zero_         0.10%     251.000us         0.42%       1.096ms      11.913us            92  
                                          ProfilerStep*        39.58%     104.010ms        99.20%     260.723ms     130.362ms             2  
      

  warn("CUDA is not available, disabling CUDA profiling")


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                            aten::zeros         0.13%     450.000us         0.46%       1.572ms      24.562us            64  
                                            aten::empty         0.67%       2.265ms         0.67%       2.265ms       3.528us           642  
                                            aten::zero_         0.09%     313.000us         0.46%       1.547ms      11.632us           133  
                                          ProfilerStep*        35.10%     119.067ms        99.97%     339.130ms     169.565ms             2  
      

  warn("CUDA is not available, disabling CUDA profiling")


([3.360574245452881,
  3.327719211578369,
  3.276233196258545,
  3.229010581970215,
  3.1075940132141113,
  2.993231773376465,
  2.928018808364868,
  2.830142021179199,
  2.616269588470459,
  2.595198154449463],
 {'accuracy': 0.642,
  'f1_micro': 0.642,
  'f1_macro': 0.6047545919642572,
  'rocauc_ovr': 0.7926955713042793,
  'rocauc_ovo': 0.7926955713042793,
  'logloss': 2.7923250497430563},
 {'accuracy': 0.59,
  'f1_micro': 0.59,
  'f1_macro': 0.5592747802505944,
  'rocauc_ovr': 0.7696136704594043,
  'rocauc_ovo': 0.7696136704594043,
  'logloss': 3.346689196899533})