In [1]:
import argparse
import traceback
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
import models
import tasks
import utils.callbacks
import utils.data
import utils.logging
import pandas as pd
from torch.utils.data import DataLoader

In [2]:
DATA_PATHS = {
    "tailing": {"feat": "data/3611817550_feature_matrix_X.csv", "adj": "data/3611817550_adj.csv"},
}


def get_model(args, dm):
    model = None
    if args.model_name == "GCN":
        model = models.GCN(adj=dm.adj, input_dim=args.seq_len, output_dim=args.hidden_dim)#input_dim=args.seq_len*5+args.pre_len*4+1
    if args.model_name == "GRU":
        model = models.GRU(input_dim=dm.adj.shape[0], hidden_dim=args.hidden_dim)#input_dim=dm.adj.shape[0]
    if args.model_name == "TGCN":
        model = models.TGCN(adj=dm.adj, hidden_dim=args.hidden_dim)
    return model

def get_task(args, model, dm):
    task = getattr(tasks, args.settings.capitalize() + "ForecastTask")(
        model=model, feat_max_val=dm.feat_max_val, **vars(args)
    )
    return task


def get_callbacks(args):
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="train_loss")
    plot_validation_predictions_callback = utils.callbacks.PlotValidationPredictionsCallback(monitor="train_loss")
    callbacks = [
        checkpoint_callback,
        plot_validation_predictions_callback,
    ]
    return callbacks


def main_supervised(args):
    dm = utils.data.SpatioTemporalCSVDataModule(
        feat_path=DATA_PATHS[args.data]["feat"], adj_path=DATA_PATHS[args.data]["adj"], **vars(args)
    )
    model = get_model(args, dm)
    task = get_task(args, model, dm)
    callbacks = get_callbacks(args)
    trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, gpus=1, max_epochs=10)#gpus=1,
    trainer.fit(task, dm)
    results = trainer.validate(datamodule=dm)
    return results


def main(args):
    rank_zero_info(vars(args))
    results = globals()["main_" + args.settings](args)
    return results
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)

    parser.add_argument(
        "--data", type=str, help="The name of the dataset", choices=("tailing"), default="tailing"
    )
    parser.add_argument(
        "--model_name",
        type=str,
        help="The name of the model for spatiotemporal prediction",
        choices=("GCN", "GRU", "TGCN"),
        default="TGCN",
    )
    parser.add_argument(
        "--settings",
        type=str,
        help="The type of settings, e.g. supervised learning",
        choices=("supervised",),
        default="supervised",
    )
    parser.add_argument("--log_path", type=str, default=None, help="Path to the output console log file")
    parser.add_argument("--send_email", "--email", action="store_true", help="Send email when finished")
    temp_args, _ = parser.parse_known_args()

    parser = getattr(utils.data, temp_args.settings.capitalize() + "DataModule").add_data_specific_arguments(parser)
    parser = getattr(models, temp_args.model_name).add_model_specific_arguments(parser)
    parser = getattr(tasks, temp_args.settings.capitalize() + "ForecastTask").add_task_specific_arguments(parser)
    
    args = parser.parse_args(args=[])
    # utils.logging.format_logger(pl._logger)
    if args.log_path is not None:
        utils.logging.output_logger_to_file(pl._logger, args.log_path)

    try:
        results = main(args)
    except:  # noqa: E722
        traceback.print_exc()

{'logger': True, 'enable_checkpointing': True, 'default_root_dir': None, 'gradient_clip_val': None, 'gradient_clip_algorithm': None, 'num_nodes': 1, 'num_processes': None, 'devices': None, 'gpus': None, 'auto_select_gpus': None, 'tpu_cores': None, 'ipus': None, 'enable_progress_bar': True, 'overfit_batches': 0.0, 'track_grad_norm': -1, 'check_val_every_n_epoch': 1, 'fast_dev_run': False, 'accumulate_grad_batches': None, 'max_epochs': None, 'min_epochs': None, 'max_steps': -1, 'min_steps': None, 'max_time': None, 'limit_train_batches': None, 'limit_val_batches': None, 'limit_test_batches': None, 'limit_predict_batches': None, 'val_check_interval': None, 'log_every_n_steps': 50, 'accelerator': None, 'strategy': None, 'sync_batchnorm': False, 'precision': 32, 'enable_model_summary': True, 'num_sanity_val_steps': 2, 'resume_from_checkpoint': None, 'profiler': None, 'benchmark': None, 'reload_dataloaders_every_n_epochs': 0, 'auto_lr_find': False, 'replace_sampler_ddp': True, 'detect_anomaly

(16812, 38, 30)
(16812, 3, 30)
(4198, 38, 30)
(4198, 3, 30)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


{'val_loss': tensor(80490.1875, device='cuda:0'), 'RMSE': tensor(0.6527, device='cuda:0'), 'MAE': tensor(0.6272, device='cuda:0'), 'accuracy': tensor(-0.1786, device='cuda:0'), 'R2': tensor(-0.0064, device='cuda:0'), 'ExplainedVar': tensor(-0.0908, device='cuda:0')}


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

{'val_loss': tensor(5318.9902, device='cuda:0'), 'RMSE': tensor(0.1678, device='cuda:0'), 'MAE': tensor(0.1220, device='cuda:0'), 'accuracy': tensor(0.6970, device='cuda:0'), 'R2': tensor(0.0600, device='cuda:0'), 'ExplainedVar': tensor(0.0600, device='cuda:0')}


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
  rank_zero_warn(
You are using a CUDA device ('NVIDIA GeForce RTX 4060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Restoring states from the checkpoint path at D:\T-GCN\T-GCN-PyTorch\lightning_logs\version_9\checkpoints\epoch=0-step=263.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


(16812, 38, 30)
(16812, 3, 30)
(4198, 38, 30)
(4198, 3, 30)


Loaded model weights from checkpoint at D:\T-GCN\T-GCN-PyTorch\lightning_logs\version_9\checkpoints\epoch=0-step=263.ckpt
  rank_zero_warn(


Validation: 0it [00:00, ?it/s]

{'val_loss': tensor(5318.9902, device='cuda:0'), 'RMSE': tensor(0.1678, device='cuda:0'), 'MAE': tensor(0.1220, device='cuda:0'), 'accuracy': tensor(0.6970, device='cuda:0'), 'R2': tensor(0.0600, device='cuda:0'), 'ExplainedVar': tensor(0.0600, device='cuda:0')}
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      ExplainedVar          0.06001746654510498
           MAE              0.12197772413492203
           R2              0.059998393058776855
          RMSE              0.16779471933841705
        accuracy            0.6970181465148926
        val_loss              5318.990234375
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
