In [None]:
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from pathlib import Path
import os
import yaml
import types
import collections
from itertools import repeat
from typing import List, Dict, Any
import wandb

In [None]:
# toolkit
from gTDR.datasets import FastDataset
import gTDR.utils.SALIENT as utils 
from gTDR.utils.SALIENT.fast_trainer.utils import Timer
from gTDR.trainers.SALIENT_trainer import Trainer

Specify model. Opther options: `GAT`, `GIN`, `SAGEResInception`, `SAGEClassic`, `JKNet`, `GCN`, `ARMA`. Should be consistent with that specified in `../configs/SALIENT_ogbn_arxiv_single_machine_parameters.yaml`

In [None]:
from gTDR.models.SALIENT import SAGE

## Arguments & Parameters

Specify the setup in config, including:
* `dataset_root`: (str) The path where to save the downloaded dataset.
* `output_root`: (str) The path where to save the trained model and results.
* `save_results`: (bool) Whether to save the training and testing results.

In [None]:
config_filename = "../configs/SALIENT_ogbn_arxiv_single_machine_parameters.yaml"
with open(config_filename) as f:
    configs = yaml.load(f, Loader=yaml.SafeLoader)
args = types.SimpleNamespace(**configs)

Complete the specification of `args` and set up multi-GPU training.

In [None]:
args = utils.setup(args)

Start `wandb` for monitoring experiment.

In [None]:
run = wandb.init(project="SALIENT", name="ogbn-arxiv")

## Data 

In this demo, we use an `ogb` dataset.

SALIENT defines a dataset class `FastDataset`. It supports loading `ogb` datasets. To use custom data, one needs to modify `FastDataset` (see [../gTDR/datasets/SALIENT_Dataset.py](../gTDR/datasets/SALIENT_Dataset.py)).

`FastDataset` includes these properties:

* `name`: (str) The name of the dataset.

* `x`: (torch.Tensor) A tensor containing the feature vectors of the nodes in the graph.

* `y`: (torch.Tensor) A tensor containing the labels of the nodes in the graph.

* `rowptr`: (torch.Tensor) A tensor containing the row pointers of the adjacency matrix of the graph in Compressed Sparse Row (CSR) format.

* `col`: (torch.Tensor) A tensor containing the column indices of the adjacency matrix of the graph in CSR format.

* `split_idx`: (Mapping[str, torch.Tensor]) A dictionary containing the indices for splitting the dataset into training, validation, and testing sets. The keys are strings ("train", "valid", "test") and the values are tensors of indices.

* `meta_info`: (Mapping[str, Any]) A dictionary containing additional metadata about the dataset. The keys are strings and the values can be of any type.

`FastDataset` includes these methods:

* `adj_t(self)`: This method constructs and returns a SparseTensor representing the adjacency matrix of the graph. The adjacency matrix is constructed using the rowptr and col attributes, which contain the row pointers and column indices of the adjacency matrix in Compressed Sparse Row (CSR) format. The `num_nodes` argument is the number of nodes in the graph. The `is_sorted=True` and `trust_data=True` arguments indicate that the data is already sorted and the data can be trusted to be in the correct format without further checks.

* `share_memory_(self)`: This method moves the data of the tensor to shared memory using PyTorch's `share_memory_()` function. This is useful when you want to share data across multiple processes, like when using data parallelism in PyTorch. This is done for each attribute of the dataset object that is a tensor `(self.x, self.y, self.rowptr, self.col)`, as well as for each tensor in the `split_idx` dictionary.

* `save(self, _path, name)` (optional): This method saves the fields of the dataset object to disk. It does this by looping over each field (like x, y, rowptr, etc.) and saving each field as a separate file in the processed data directory. The `_path` argument is the base directory where the data should be saved, and `name` is the name of the dataset. The processed data will be saved in a subdirectory named after the dataset under the `_path` directory.

In [None]:
with Timer('Loading dataset'):
    dataset = FastDataset.from_path(args.dataset_root, args.dataset_name)

## Model

You may specify these model parameters in config:

* `hidden_features`(int): This parameter specifies the number of hidden units in each layer of the model. This parameter defines the number of output channels, i.e., the dimension of the output feature vectors produced by each layer. The higher the number of hidden features, the more complex patterns the model can capture, but it may also increase the risk of overfitting and require more computational resources.

* `num_layers` (int): This parameter specifies the number of layers in the model. The deeper the model (i.e., the more layers it has), the more complex the patterns it can theoretically learn from the data. However, as with hidden_features, increasing this parameter may also increase the risk of overfitting and the computational resources required.

In [None]:
model = SAGE

## Training

You may specify these training parameters in config:

* `lr` (float): The learning rate for the model's optimizer.

* `epochs` (int): The number of complete passes through the entire training dataset.

* `train_batch_size` (int): The number of training examples utilized in one iteration.

* `test_batch_size` (int): The number of test examples utilized in one iteration.

* `test_epoch_frequency` (int): The frequency, in epochs, at which the test evaluation should occur.

* `test_max_num_batches` (int): The maximum number of batches to use during testing.

In [None]:
trainer = Trainer(model, args)
trainer.train(dataset, use_wandb=True)

## Inference

Load the best check point and perform testing.

In [None]:
trainer.load_best_checkpoint()
trainer.test(dataset)

In [None]:
wandb.finish()