In [None]:
import types
import yaml
import numpy as np
import torch 
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
import wandb

In [None]:
# toolkit
from gTDR.models import FastGCN
from gTDR.trainers.FastGCN_trainer import Trainer

## Arguments & Parameters

Specify the setup in config, including:
* `dataset_root`: (str) The path where to save the downloaded dataset.
* `fast`: (bool) Whether to use the FastGCN method.
* `save_results`: (bool) Whether to save the training and testing results.
* `save_path`: (str) The path where to save the trained model and results.
* `use_cuda`: (bool) Whether to use CUDA for GPU acceleration.
* `seed`: (int) The random seed for reproducibility.

In [None]:
config_filename = "../configs/FastGCN_Cora_parameters.yaml"
with open(config_filename) as f:
    configs = yaml.load(f, Loader=yaml.SafeLoader)
args = types.SimpleNamespace(**configs)

Use GPU.

In [None]:
args.use_cuda = (torch.cuda.is_available() and args.use_cuda)
if args.use_cuda:
    args.device = 'cuda'
else:
    args.device = 'cpu'
print ("use CUDA:", args.use_cuda, "- device:", args.device)

Set seed for reproducibility.

In [None]:
seed = args.seed
np.random.seed(seed)
torch.manual_seed(seed)
if args.use_cuda:
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

Start `wandb` for monitoring experiment (train loss, validation loss, test accuracy).

In [None]:
run = wandb.init(project="FastGCN_ParameterTuning", name="Cora")

## Hyperparameter Search (see [wandb's documentation](https://docs.wandb.ai/guides/sweeps/define-sweep-configuration))

To do parameter search, you need to specify your method, metric to base on, as well as range of (which) parameters to search for. In this example, we use the random method to maximize test accuracy by searching for learning rate in the range from 0.0001 to 0.1.

In [None]:
sweep_config = {
    'method': 'random',  # bayes, grid
    'metric': {
      'name': 'Test Accuracy',
      'goal': 'maximize'   
    },
    'parameters': {
        'lr': {
            'min': 0.0001,
            'max': 0.1
        },
    }
}

sweep_id = wandb.sweep(sweep_config, project="FastGCN_ParameterTuning")

## Data

In this demo, we use `Cora` from the PyG `Planetoid` collection. For customized datasets, please refer to PyG dataset [documentation](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_dataset.html).

In [None]:
dataset = Planetoid(root=args.dataset_root, name=args.dataset, 
                    transform=T.ToSparseTensor(remove_edge_index=False), 
                    split='full')

## Model

You may specify these model parameters in config:

* `hidden_dims`: (int) The dimensionality of the hidden layers in the graph convolutional network (GCN).

* `num_layers`: (int) The number of layers in the GCN.

* `dropout`: (float) The dropout rate for regularization.

* `batch_norm`: (bool) Whether to use batch normalization in the GCN.

In [None]:
model = FastGCN(args, dataset)

## Training

You may specify these training parameters in config:

* `normalize_features`: (bool) If set to True, the input features to the model will be normalized (mean-centered and divided by their standard deviation). If set to False, the features will be used as is.

* `init_batch`: (int) Specifies the number of samples to be included in the initial batch of training data. This is the batch size at the input layer of the FastGCN model during training.

* `sample_size`: (int) Specifies the number of samples to be used in each training step after the initial batch. This controls the sample size for the hidden layers in the FastGCN model during training.

* `scale_factor`: (float) A factor by which the batch size is increased at each layer of the Graph Convolutional Network (GCN). This can be used to progressively increase the batch size at deeper layers.

* `epochs`: (int) The total number of training epochs. An epoch is a complete pass through the entire training dataset.

* `lr`: (float) The learning rate for the optimizer. This controls how much the model parameters are updated in response to the estimated error each time the model weights are updated.

* `early_stop`: (int) The number of epochs with no improvement in validation loss after which training will be stopped. This is a form of early stopping, which can prevent overfitting.

* `weight_decay`: (float) The weight decay (L2 penalty) for the optimizer. This adds a regularization term to the loss function, which can help prevent overfitting.

* `samp_inference`: (bool) If set to `True`, importance sampling will be used during the inference phase (i.e., when making predictions on unseen data). If set to `False`, all instances will be used.

* `use_val`: (bool) If set to `True`, a validation set will be used during training. 

* `num_samp_inference`: (int) The number of samples to use for each inference step. This is relevant only if `samp_inference` is set to True.

* `inference_init_batch`: (int) The number of samples to be included in the initial batch of inference data. This is the batch size at the input layer of the FastGCN model during inference.

* `inference_sample_size`: (int) Specifies the number of samples to be used in each inference step after the initial batch. This controls the batch size for the hidden layers in the FastGCN model during inference.

* `report`: (int) The number of training epochs after which a report (validation loss and test accuracy) will be printed.

In [None]:
trainer = Trainer(model=model, args=args)

## Hyperparameter Search

To tune hyperparameters, you need to write your own `parameter_search(self)` function to customize which hyperparameters to tune. Below is an example of searching for `lr` and consequently updating all trainer attributes that use `lr` for it to take effect.

In [None]:
def parameter_search(self):
    run = wandb.init() # must have
    self.run = run
    config = run.config # must have

    # parameter search customization
    self.args.lr = config.lr
    self.optimizer = torch.optim.Adam(params=self.model.parameters(), 
                                      lr=self.args.lr, 
                                      weight_decay=self.args.weight_decay)

Then, we call the wandb agent to start hyperparameter search.

In [None]:
wandb.agent(sweep_id, lambda: trainer.train(parameter_search=parameter_search), count=10)