In [None]:
import os
import argparse
import torch
import numpy as np
import pandas as pd
import random
import yaml
import types
import wandb
from torch.utils.data import DataLoader

In [None]:
# toolkit
from gTDR.datasets import GANF_Dataset
from gTDR.models import GANF
from gTDR.trainers.GANF_trainer import Trainer

## Arguments & Parameters

Specify the setup in config, including:
* `data_dir`: (str) Path to the dataset to be used for training the model.
* `name`: (str) Name of the model run. This can be used to identify different runs or configurations.
* `seed`: (int) Random seed for reproducibility.
* `use_cuda`: (bool) Whether to use CUDA for training. If True and a GPU is available, the model will be trained on the GPU.
* `save_results`: (bool) Whether to save the training checkpoints.
* `output_dir`: (str) Directory to save the trainer's checkpoints.

In [None]:
config_filename = "../configs/GANF_METR_LA_parameters.yaml"
with open(config_filename) as f:
    configs = yaml.load(f, Loader=yaml.SafeLoader)
args = types.SimpleNamespace(**configs)

Use GPU.

In [None]:
args.use_cuda = (torch.cuda.is_available() and args.use_cuda)
if args.use_cuda:
    args.device = 'cuda'
else:
    args.device = 'cpu'
print ("use CUDA:", args.use_cuda, "- device:", args.device)

Set seed for reproducibility.

In [None]:
seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if args.use_cuda:
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

Start `wandb` for monitoring experiment (train loss, validation loss, test loss).

In [None]:
run = wandb.init(project="GANF", name="METR-LA")

## Data (Part 1)

In this demo, we use the `METR-LA` dataset.

**First, unzip dataset located at [../data/METR-LA/](../data/METR-LA/):**

`gunzip -c ../data/METR-LA/metr-la.h5.zip > ../data/METR-LA/metr-la.h5`

## Data (Part 2)

Next, create dataloaders.

GANF defines a dataset class `GANF_Dataset` inherited from `torch.utils.data.Dataset`. This class takes a DataFrame of sensor readings (`df`) and optional labels (`label`), as well as a window size and stride size for the sliding windows. It preprocesses the data to create sliding windows of sensor readings and keeps track of the labels associated with each window.

Here, we define a function `load_traffic()` to load the dataset file `metr-la.h5`, perform feature normalization, conduct data split, and create dataloaders by taking `GANF_Dataset` objects for training, validation, and testing.

In [None]:
def load_traffic(root, batch_size):
    """
    Load traffic dataset
    return train_loader, val_loader, test_loader
    """
    df = pd.read_hdf(root)
    df = df.reset_index()
    df = df.rename(columns={"index":"utc"})
    df["utc"] = pd.to_datetime(df["utc"], unit="s")
    df = df.set_index("utc")
    n_sensor = len(df.columns)

    mean = df.values.flatten().mean()
    std = df.values.flatten().std()

    df = (df - mean)/std
    df = df.sort_index()
    # split the dataset
    train_df = df.iloc[:int(0.75*len(df))]
    val_df = df.iloc[int(0.75*len(df)):int(0.875*len(df))]
    test_df = df.iloc[int(0.75*len(df)):]

    train_loader = DataLoader(GANF_Dataset(train_df), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(GANF_Dataset(val_df), batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(GANF_Dataset(test_df), batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader, n_sensor  

Create dataloaders.

In [None]:
train_loader, val_loader, test_loader, n_sensor = load_traffic("{}/metr-la.h5".format(args.data_dir), \
                                                                args.batch_size)

## Model

You may specify these model parameters in config:

* `input_size`: (int) Size of the input data for the GNN and NF models.

* `hidden_size`: (int) Size of the hidden layers in the GNN.

* `n_hidden`: (int) Number of hidden layers in the NF model.

* `dropout`: (float) Dropout rate for the LSTM layer in the GANF model.

* `n_blocks`: (int) Number of flow blocks in the NF model.

* `batch_norm`: (bool) Whether to use batch normalization in the NF model.

In [None]:
model = GANF(args)

## Training

You may specify these training parameters in config:

* `batch_size`: (int) Number of samples per batch during training. This is the number of samples that the model will process at one time.

* `weight_decay`: (float) Weight decay (L2 penalty) for the optimizer. This can help to prevent overfitting.

* `n_epochs`: (int) Number of training epochs. An epoch is one complete pass through the entire training dataset.

* `additional_iter`: (int) Number of additional iterations for the second stage training.

* `lr`: (float) Learning rate for the optimizer. This controls the size of the updates to the model's parameters during training.

* `max_iter`: (int) Maximum number of iterations for the training.

* `h_tol`: (float) Tolerance for the stopping criterion.

* `rho`: (float) This is a regularization parameter in the Alternating Direction Method of Multipliers (ADMM) algorithm. It balances the trade-off between the loss function and the constraints in the optimization problem. A larger value of `rho` can make the constraints more strictly enforced, while a smaller value can make the solution focus more on minimizing the loss function.

* `rho_max`: (float) This is the maximum value that `rho` can take. In ADMM, `rho` is allowed to increase adaptively during the training process to enforce the constraints more strictly. `rho_max` sets an upper limit to prevent `rho` from becoming too large.

* `alpha`: (float) This is a parameter that controls the over-relaxation in the ADMM algorithm. It is used to accelerate the convergence of the algorithm. The value of `alpha` is typically between 1 and 2. A value of 1 corresponds to no over-relaxation, while values greater than 1 correspond to increasing degrees of over-relaxation.

* `graph_dir`: (str) Directory containing the graph structure for the model. If provided, the model will initialize its graph structure from this file. If not provided, the model will initialize its graph structure randomly.

In [None]:
trainer = Trainer(args, model, n_sensor, train_loader, val_loader, test_loader, has_label=False) 
trainer.train(use_wandb=True)

## Inference

Load the best check point and perform testing.

In [None]:
trainer.load_best_checkpoint()
trainer.test()

In [None]:
wandb.finish()