In [None]:
import yaml
import wandb

In [None]:
# toolkit
import gTDR.utils.GTS as utils
from gTDR.trainers.GTS_trainer import Trainer
from gTDR.models import GTS

## Arguments & Parameters

Specify the setup in config, including:
* `base_dir`: (str) The path where to save the trained model and results.
* `dataset_dir`: (str) The path of the dataset.

In [None]:
config_filename = "../configs/GTS_METR_LA_parameters.yaml"
args = yaml.load(open(config_filename), Loader=yaml.SafeLoader)

Start `wandb` for monitoring experiment (train loss, validation loss).

In [None]:
run = wandb.init(project="GTS", name="METR-LA")

## Data (Part 1)

In this demo, we use the `METR-LA` dataset.

**First, unzip dataset located at [../data/METR-LA/](../data/METR-LA/):**

`gunzip -c ../data/METR-LA/metr-la.h5.zip > ../data/METR-LA/metr-la.h5`

**Then, run the script [./data_preparation/prepare_METR_LA.py](./data_preparation/prepare_METR_LA.py) to create data files `train.npz`, `val.npz`, and `test.npz` under [../data/METR-LA/](../data/METR-LA/).**

`python ./data_preparation/prepare_METR_LA.py --traffic_df_filename ../data/METR-LA/metr-la.h5 --output_dir ../data/METR-LA/`

## Data (Part 2)

Next, load data files and create data.

A dataloader class `DataLoader` has been built into the GTS codebase. This class contains a `get_iterator()` method.

In addition, a `load_dataset()` function has been built into the GTS codebase (see [../gTDR/utils/GTS/utils.py](../gTDR/utils/GTS/utils.py)), which returns a dictionary `data` that contains many contents:
* `x_train`, `x_val`, `x_test`: time series covariates
* `y_train`, `y_val`, `y_test`: time series targets
* `train_loader`, `val_loader`, `test_loader`: dataloaders
* `scaler`: a method to perform feature normalization

Due to the complexity of the implementation of `load_dataset()`, we do not show it here but directly call it. For a new dataset, one needs to modify [../gTDR/utils/GTS/utils.py](../gTDR/utils/GTS/utils.py).

In [None]:
data_kwargs = args.get('data_para')
data = utils.load_dataset(**data_kwargs)

## Model

You may specify these model parameters in config:

* `use_curriculum_learning`: (bool) Whether to use curriculum learning strategy. If True, the model gradually increases the difficulty of training samples.

* `cl_decay_steps`: (int) The number of steps for the curriculum learning decay. Curriculum learning is a type of training strategy that gradually increases the difficulty of training samples. The decay steps could determine how fast the "lessons" become more difficult.

* `filter_type`: (str) The type of graph convolutional filter to use. This parameter would affect how the graph convolution operation is computed. The provided options are 
    * `dual_random_walk`, 
    * `random_walk`, 
    * `laplacian`. 

* `horizon`: (int) The prediction horizon of the model. This is the number of future time steps the model is trained to predict.

* `input_dim`: (int) The dimensionality of the input data. This would be the number of features in the input data for each node at each time step.

* `output_dim`: (int) The dimensionality of the output data. This would be the number of features the model is trained to predict for each node at each time step.

* `l1_decay`: (float) The strength of L1 regularization applied during training. Regularization helps prevent overfitting by adding a penalty to the loss for large weights.

* `max_diffusion_step`: (int) The maximum number of diffusion steps in the graph convolution operation. This parameter affects how far information travels along the graph in each layer of the GCN.

* `num_nodes`: (int) The number of nodes in the graph. This is simply the number of different locations or sensors in your dataset.

* `num_rnn_layers`: (int) The number of recurrent layers in the model. More layers can capture more complex patterns but also increase the risk of overfitting and the computational cost.

* `rnn_units`: (int) The number of units in each recurrent layer. More units can capture more complex patterns but also increase the risk of overfitting and the computational cost.

* `seq_len`: (int) The length of the input sequences. This is the number of past time steps the model uses to make its predictions.

* `dim_fc`: (int) The size of the fully connected layer. This parameter affects the complexity and capacity of the model.

* `temperature`: (float) This parameter is used in the context of the Gumbel reparameterization trick, which allows the model to handle discrete graph structures. This parameter controls the "sharpness" of the distribution from which the adjacency matrix of the graph is sampled.

In the code below, the `temperature` is set to `0.5`. As the temperature approaches 0, the model's sampling of the adjacency matrix becomes more deterministic, tending to choose either 0 or 1 with higher probability. Conversely, a higher temperature value leads to a more uniform sampling, making the model's choices more exploratory.

In [None]:
model = GTS(temperature=0.5, **args)

## Training

You may specify these training parameters in config:

* `base_lr`: (float) The base learning rate for the optimizer.

* `lr_decay_ratio`: (float) The ratio for learning rate decay.

* `dropout`: (float) The dropout rate for regularization.

* `epoch`: (int) The starting epoch number for training. If you are continuing training from a saved model, this would be the epoch at which training stopped. If you are training from scratch, it should be `0`.

* `epochs`: (int) The total number of epochs (complete passes through the training dataset) to train the model.

* `epsilon`: (float) A small constant for numerical stability in the optimizer.

* `global_step`: (int) A counter for the total number of steps taken so far in training. This could be used for logging or for learning rate scheduling. If you are training from scratch, it should be `0`.

* `max_grad_norm`: (float) The maximum allowed norm for the gradient clipping. This is used to prevent the problem of exploding gradients in deep neural networks.

* `min_learning_rate`: (float) The minimum learning rate. This sets a lower bound on the learning rate, preventing it from going too low.

* `optimizer`: (str) The optimizer to use for training. Common options are `adam`, `sgd`. If no optimizer is specified, the default is `adam`.

* `patience`: (int) The number of epochs to wait for improvement before stopping training. This is used in early stopping.

* `steps`: (list of int) The epochs at which to decrease the learning rate. For example, if set to [20, 30, 40], the learning rate would be decreased after 20, 30, and 40 epochs.

* `test_every_n_epochs`: (int) The number of epochs after which to test the model.

* `knn_k`: (int) The number of nearest neighbors to consider in KNN graph construction. This is used when constructing the adjacency matrix for the graph-based model.

* `epoch_use_regularization`: (int) The epoch at which to start using regularization. This could be used to delay the use of regularization to allow the model to learn more freely in early epochs.

* `num_sample`: (int) The number of samples to use in each training step. This could be used in methods that involve sampling from the model or the data.

In [None]:
trainer = Trainer(model, data, **args)
trainer.train(use_wandb=True)

## Inference

Load the best check point and perform testing.

In [None]:
trainer.load_best_checkpoint()
trainer.test()

In [None]:
wandb.finish()