In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib as plt

In [2]:
from NRI.experiments.config import ExperimentConfig


experiment_config = ExperimentConfig(
    obs_len=12,
    pred_len=18,
    encoder_loss_weight=1.0,
    decoder_loss_weight=1.0,
    signal_loss_weight=1.0,
    n_epoch=20,
    batch_size=64,
    checkpoint_prefix='dnri_18_baseline',
    checkpoint_interval=10,
    use_cuda=True,
    mask_traffic_signal=False
)

In [3]:
from SinD.config import get_dataset_path
from SinD.dataset.io import get_dataset_records
from NRI.dataset.utils import split_dataset

dataset_path = get_dataset_path()
dataset_files = get_dataset_records(dataset_path)
train_files, valid_files, test_files = split_dataset(dataset_files, 0.7, 0.2)

load datasets

In [4]:
from NRI.dataset import SignalizedIntersectionDatasetForNRI, SignalizedIntersectionDatasetConfig

dataset_config = SignalizedIntersectionDatasetConfig(
    obs_len=experiment_config.obs_len,
    pred_len=experiment_config.pred_len,
    stride=15,
    encode_traffic_signals=True,
    padding_value=0.0
)

train_set = SignalizedIntersectionDatasetForNRI(dataset_config)
train_set.load_records(dataset_path, train_files, verbose=True)

valid_set = SignalizedIntersectionDatasetForNRI(dataset_config)
valid_set.load_records(dataset_path, valid_files, verbose=True)


load_records: 100%|██████████| 16/16 [02:19<00:00,  8.71s/it]
load_records: 100%|██████████| 4/4 [00:41<00:00, 10.29s/it]


In [5]:
from NRI.models import DynamicNeuralRelationalInference
from NRI.experiments.main import train

model = DynamicNeuralRelationalInference(
    hid_dim=64,
    n_edges=4,
    dgvae=False
)

if experiment_config.use_cuda:
    model.cuda()

optimizer = torch.optim.Adam(model.parameters())

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

train(
    model,
    train_set=train_set,
    valid_set=valid_set,
    config=experiment_config,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler
)

  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 0: 100%|██████████| 172/172 [02:51<00:00,  1.00it/s]


Overall Loss: 23.12710672201112
Encoder KL Loss: 0.020016891099404296
Decoder NLL Loss: 21.974801623544014
Signal Cross-Entropy Loss: 1.1322880931371868


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 0: 100%|██████████| 48/48 [00:11<00:00,  4.03it/s]


Signal Prediction Accuracy: 0.7695481777191162
Final Displacement Error: 6.4911322593688965
Average Displacement Error: 3.0889740685621896


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 1: 100%|██████████| 172/172 [02:46<00:00,  1.03it/s]


Overall Loss: 14.95589488606121
Encoder KL Loss: 0.002560693173003275
Decoder NLL Loss: 14.205913954002908
Signal Cross-Entropy Loss: 0.7474202680033307


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 1: 100%|██████████| 48/48 [00:11<00:00,  4.30it/s]


Signal Prediction Accuracy: 0.7711758017539978
Final Displacement Error: 4.295049364368121
Average Displacement Error: 1.9779285366336505


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 2: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: 12.368139821429587
Encoder KL Loss: 0.017150129852843535
Decoder NLL Loss: 11.660533217496653
Signal Cross-Entropy Loss: 0.6904564915701406


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 2: 100%|██████████| 48/48 [00:11<00:00,  4.29it/s]


Signal Prediction Accuracy: 0.8165135979652405
Final Displacement Error: 4.4628248661756516
Average Displacement Error: 2.1192208727200827


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 3: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: 9.470548396886787
Encoder KL Loss: 0.07187633458958115
Decoder NLL Loss: 8.734563785930014
Signal Cross-Entropy Loss: 0.6641082729018009


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 3: 100%|██████████| 48/48 [00:11<00:00,  4.29it/s]


Signal Prediction Accuracy: 0.8105999231338501
Final Displacement Error: 4.702115344504516
Average Displacement Error: 2.0554223532478013


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 4: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: 6.805470156115155
Encoder KL Loss: 0.07913896246531676
Decoder NLL Loss: 6.081102110618768
Signal Cross-Entropy Loss: 0.6452290945967962


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 4: 100%|██████████| 48/48 [00:11<00:00,  4.29it/s]


Signal Prediction Accuracy: 0.8221017122268677
Final Displacement Error: 4.321256843705972
Average Displacement Error: 1.878059656669696


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 5: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: 3.0956604792628184
Encoder KL Loss: 0.07652512469957042
Decoder NLL Loss: 2.3865318535718805
Signal Cross-Entropy Loss: 0.632603492154631


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 5: 100%|██████████| 48/48 [00:11<00:00,  4.30it/s]


Signal Prediction Accuracy: 0.8229697346687317
Final Displacement Error: 4.408668627341588
Average Displacement Error: 1.8940510681519906


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 6: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: 2.3982636020627126
Encoder KL Loss: 0.07412184458659142
Decoder NLL Loss: 1.6963540580499974
Signal Cross-Entropy Loss: 0.627787684631902


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 6: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.8237111568450928
Final Displacement Error: 4.353103991597891
Average Displacement Error: 1.8944216805199783


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 7: 100%|██████████| 172/172 [02:43<00:00,  1.05it/s]


Overall Loss: 1.6997535627248672
Encoder KL Loss: 0.07966766346159372
Decoder NLL Loss: 0.9963810706651919
Signal Cross-Entropy Loss: 0.6237048312675124


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 7: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.8220836520195007
Final Displacement Error: 4.291820044318835
Average Displacement Error: 1.8754525346060593


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 8: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: 1.2047697870537297
Encoder KL Loss: 0.08176428958946878
Decoder NLL Loss: 0.5024934531362771
Signal Cross-Entropy Loss: 0.6205120384693146


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 8: 100%|██████████| 48/48 [00:11<00:00,  4.31it/s]


Signal Prediction Accuracy: 0.8234400153160095
Final Displacement Error: 4.267748329788446
Average Displacement Error: 1.8450594581663609


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 9: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: 1.1976522307756334
Encoder KL Loss: 0.08560741146983103
Decoder NLL Loss: 0.4938413768811798
Signal Cross-Entropy Loss: 0.6182034421105713


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 9: 100%|██████████| 48/48 [00:11<00:00,  4.29it/s]


Signal Prediction Accuracy: 0.8221560120582581
Final Displacement Error: 4.388752594590187
Average Displacement Error: 1.9322941762705643


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 10: 100%|██████████| 172/172 [02:40<00:00,  1.07it/s]


Overall Loss: -1.4878314255281937
Encoder KL Loss: 0.07958614163447256
Decoder NLL Loss: -2.1818205070963446
Signal Cross-Entropy Loss: 0.6144029439188712


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 10: 100%|██████████| 48/48 [00:11<00:00,  4.29it/s]


Signal Prediction Accuracy: 0.8237655758857727
Final Displacement Error: 4.181503318250179
Average Displacement Error: 1.8103979068497817


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 11: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: -2.0335316931785536
Encoder KL Loss: 0.07876834836463596
Decoder NLL Loss: -2.724729438593914
Signal Cross-Entropy Loss: 0.6124293963576473


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 11: 100%|██████████| 48/48 [00:11<00:00,  4.30it/s]


Signal Prediction Accuracy: 0.8244345784187317
Final Displacement Error: 4.104260917752981
Average Displacement Error: 1.7753301778187354


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 12: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: -2.287232042398564
Encoder KL Loss: 0.07669331856765021
Decoder NLL Loss: -2.9748728066076375
Signal Cross-Entropy Loss: 0.6109474530053693


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 12: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.8257186412811279
Final Displacement Error: 4.302691804865996
Average Displacement Error: 1.8505426365882158


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 13: 100%|██████████| 172/172 [02:41<00:00,  1.07it/s]


Overall Loss: -2.5564259746393487
Encoder KL Loss: 0.08127391901473664
Decoder NLL Loss: -3.2474859159786353
Signal Cross-Entropy Loss: 0.6097860247936361


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 13: 100%|██████████| 48/48 [00:11<00:00,  4.29it/s]


Signal Prediction Accuracy: 0.8241814374923706
Final Displacement Error: 3.870811148236195
Average Displacement Error: 1.6518649061520894


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 14: 100%|██████████| 172/172 [02:41<00:00,  1.06it/s]


Overall Loss: -2.6221268951546306
Encoder KL Loss: 0.08023827514329628
Decoder NLL Loss: -3.311211965263409
Signal Cross-Entropy Loss: 0.6088467914351194


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 14: 100%|██████████| 48/48 [00:11<00:00,  4.29it/s]


Signal Prediction Accuracy: 0.8237655162811279
Final Displacement Error: 4.222917859752973
Average Displacement Error: 1.8024086852868397


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 15: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: -4.329255209759226
Encoder KL Loss: 0.08034701628047365
Decoder NLL Loss: -5.016692378217094
Signal Cross-Entropy Loss: 0.607090174632017


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 15: 100%|██████████| 48/48 [00:11<00:00,  4.30it/s]


Signal Prediction Accuracy: 0.8244165182113647
Final Displacement Error: 4.055582170685132
Average Displacement Error: 1.7275462554146845


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 16: 100%|██████████| 172/172 [02:41<00:00,  1.06it/s]


Overall Loss: -4.849725003505861
Encoder KL Loss: 0.08351640663174695
Decoder NLL Loss: -5.539122332320654
Signal Cross-Entropy Loss: 0.6058809279009354


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 16: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.825664222240448
Final Displacement Error: 4.178926852842172
Average Displacement Error: 1.7828885490695636


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 17: 100%|██████████| 172/172 [02:27<00:00,  1.17it/s]


Overall Loss: -5.0537599858849545
Encoder KL Loss: 0.08709027873742028
Decoder NLL Loss: -5.74611193351014
Signal Cross-Entropy Loss: 0.6052616600726926


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 17: 100%|██████████| 48/48 [00:09<00:00,  5.19it/s]


Signal Prediction Accuracy: 0.8234398365020752
Final Displacement Error: 3.9940112804373107
Average Displacement Error: 1.703436965122819


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 18: 100%|██████████| 172/172 [02:10<00:00,  1.32it/s]


Overall Loss: -5.478701125743779
Encoder KL Loss: 0.09137270027814909
Decoder NLL Loss: -6.1741542391652295
Signal Cross-Entropy Loss: 0.6040804034402202


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 18: 100%|██████████| 48/48 [00:08<00:00,  5.44it/s]


Signal Prediction Accuracy: 0.8257908821105957
Final Displacement Error: 4.124460392942031
Average Displacement Error: 1.769045827910304


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 19: 100%|██████████| 172/172 [02:11<00:00,  1.31it/s]


Overall Loss: -5.776605830636134
Encoder KL Loss: 0.08771735997220807
Decoder NLL Loss: -6.467666645382725
Signal Cross-Entropy Loss: 0.603343469458957


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 19: 100%|██████████| 48/48 [00:08<00:00,  5.44it/s]

Signal Prediction Accuracy: 0.8241451978683472
Final Displacement Error: 4.068690974265337
Average Displacement Error: 1.7385124439994495





In [1]:
import os
import torch
from torch.utils.data import DataLoader

from NRI.dataset import SignalizedIntersectionDatasetForNRI
from NRI.models import DynamicNeuralRelationalInference
from NRI.experiments.config import ExperimentConfig
from NRI.experiments.test import generate_result
from SinD.config import get_dataset_path
from SinD.dataset.io import get_dataset_records
from NRI.dataset.utils import split_dataset
from NRI.dataset import SignalizedIntersectionDatasetForNRI, SignalizedIntersectionDatasetConfig

experiment_config = ExperimentConfig(
    obs_len=12,
    pred_len=12,
    encoder_loss_weight=1.0,
    decoder_loss_weight=1.0,
    signal_loss_weight=1.0,
    n_epoch=20,
    batch_size=64,
    checkpoint_prefix='dnri_18_baseline',
    checkpoint_interval=10,
    use_cuda=True,
    mask_traffic_signal=False
)

# Load the dataset
dataset_path = get_dataset_path()
dataset_files = get_dataset_records(dataset_path)
_, _, test_files = split_dataset(dataset_files, 0.7, 0.2)

dataset_config = SignalizedIntersectionDatasetConfig(
    obs_len=experiment_config.obs_len,
    pred_len=experiment_config.pred_len,
    stride=15,
    encode_traffic_signals=True,
    padding_value=0.0,
    include_incomplete_trajectories=False
)

test_set = SignalizedIntersectionDatasetForNRI(dataset_config)
test_set.load_records(dataset_path, test_files, verbose=True)

model = DynamicNeuralRelationalInference(
    hid_dim=64,
    n_edges=4,
    dgvae=False
)

# Load the saved model and configuration
checkpoint = torch.load('../checkpoints/dnri_18_baseline_best.pt')
model.load_state_dict(checkpoint['params'])
test_stats = generate_result(model, test_set, experiment_config)

# Report the results
test_stats.report()

load_records: 100%|██████████| 3/3 [00:44<00:00, 14.93s/it]
  return _nested.nested_tensor(
[test] generating result: 100%|██████████| 34/34 [00:06<00:00,  5.49it/s]

=== Overall Result ===
Signal Prediction Accuracy: 0.7773651480674744
Overall Final Displacement Error: 2.0950443744659424
Overall Average Displacement Error: 0.86806720495224
=== Displacement Error of Class `pedestrian` ===
FDE: 4.165600869019444
ADE: 2.2922888979774254
=== Displacement Error of Class `car` ===
FDE: 2.2438755868319338
ADE: 0.6204790487683736
=== Displacement Error of Class `truck` ===
FDE: 20.68140629799135
ADE: 5.773994162077865
=== Displacement Error of Class `bus` ===
FDE: 0.1488626353264396
ADE: 0.07469738022727695
=== Displacement Error of Class `motorcycle` ===
FDE: 2.3508817310053747
ADE: 0.8774539249190506
=== Displacement Error of Class `tricycle` ===
FDE: 0.8040027418429235
ADE: 0.3618348764101934
=== Displacement Error of Class `bicycle` ===
FDE: 3.7436919240275945
ADE: 1.668580631065446
=== Overall Result ===
Signal Prediction Accuracy: 0.7773651480674744
Overall Final Displacement Error: 2.0950443744659424
Overall Average Displacement Error: 0.86806720495


