In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib as plt

In [2]:
from NRI.experiments.config import ExperimentConfig


experiment_config = ExperimentConfig(
    obs_len=12,
    pred_len=12,
    encoder_loss_weight=1.0,
    decoder_loss_weight=1.0,
    signal_loss_weight=1.0,
    n_epoch=20,
    batch_size=64,
    checkpoint_prefix='dnri_12_baseline',
    checkpoint_interval=10,
    use_cuda=True,
    mask_traffic_signal=False
)

In [3]:
from SinD.config import get_dataset_path
from SinD.dataset.io import get_dataset_records
from NRI.dataset.utils import split_dataset

dataset_path = get_dataset_path()
dataset_files = get_dataset_records(dataset_path)
train_files, valid_files, test_files = split_dataset(dataset_files, 0.7, 0.2)

load datasets

In [4]:
from NRI.dataset import SignalizedIntersectionDatasetForNRI, SignalizedIntersectionDatasetConfig

dataset_config = SignalizedIntersectionDatasetConfig(
    obs_len=experiment_config.obs_len,
    pred_len=experiment_config.pred_len,
    stride=15,
    encode_traffic_signals=True,
    padding_value=0.0
)

train_set = SignalizedIntersectionDatasetForNRI(dataset_config)
train_set.load_records(dataset_path, train_files, verbose=True)

valid_set = SignalizedIntersectionDatasetForNRI(dataset_config)
valid_set.load_records(dataset_path, valid_files, verbose=True)


load_records: 100%|██████████| 16/16 [01:58<00:00,  7.39s/it]
load_records: 100%|██████████| 4/4 [00:35<00:00,  8.86s/it]


In [5]:
from NRI.models import DynamicNeuralRelationalInference
from NRI.experiments.main import train

model = DynamicNeuralRelationalInference(
    hid_dim=64,
    n_edges=4,
    dgvae=False
)

if experiment_config.use_cuda:
    model.cuda()

optimizer = torch.optim.Adam(model.parameters())

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

train(
    model,
    train_set=train_set,
    valid_set=valid_set,
    config=experiment_config,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler
)

  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 0: 100%|██████████| 174/174 [02:12<00:00,  1.31it/s]


Overall Loss: 195.6832767629074
Encoder KL Loss: 0.04281331140055566
Decoder NLL Loss: 194.59245202733186
Signal Cross-Entropy Loss: 1.048010437652983


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 0: 100%|██████████| 49/49 [00:10<00:00,  4.70it/s]


Signal Prediction Accuracy: 0.7915444374084473
Final Displacement Error: 6.26243545571152
Average Displacement Error: 3.0850474737128435


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 1: 100%|██████████| 174/174 [01:53<00:00,  1.53it/s]


Overall Loss: 14.560713461075705
Encoder KL Loss: 0.016043799624232385
Decoder NLL Loss: 13.860887395924543
Signal Cross-Entropy Loss: 0.6837822628774861


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 1: 100%|██████████| 49/49 [00:07<00:00,  6.37it/s]


Signal Prediction Accuracy: 0.8158057332038879
Final Displacement Error: 4.955397985419449
Average Displacement Error: 2.3777888283437614


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 2: 100%|██████████| 174/174 [01:47<00:00,  1.62it/s]


Overall Loss: 11.092661660293052
Encoder KL Loss: 0.026587549018962632
Decoder NLL Loss: 10.437285261592645
Signal Cross-Entropy Loss: 0.6287888055560229


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 2: 100%|██████████| 49/49 [00:07<00:00,  6.40it/s]


Signal Prediction Accuracy: 0.8205623626708984
Final Displacement Error: 5.3923598454923045
Average Displacement Error: 2.591437962590432


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 3: 100%|██████████| 174/174 [01:59<00:00,  1.45it/s]


Overall Loss: 7.335311730702716
Encoder KL Loss: 0.04345318475930857
Decoder NLL Loss: 6.68618255922164
Signal Cross-Entropy Loss: 0.6056760023722709


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 3: 100%|██████████| 49/49 [00:07<00:00,  6.48it/s]


Signal Prediction Accuracy: 0.8177720904350281
Final Displacement Error: 5.022664634548888
Average Displacement Error: 2.3810494116374423


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 4: 100%|██████████| 174/174 [02:02<00:00,  1.42it/s]


Overall Loss: 5.8016365163627714
Encoder KL Loss: 0.057161021185503626
Decoder NLL Loss: 5.151540977516393
Signal Cross-Entropy Loss: 0.5929345284727798


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 4: 100%|██████████| 49/49 [00:07<00:00,  6.50it/s]


Signal Prediction Accuracy: 0.8227413296699524
Final Displacement Error: 4.967337905144205
Average Displacement Error: 2.3260544514169497


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 5: 100%|██████████| 174/174 [02:10<00:00,  1.33it/s]


Overall Loss: 1.999125733457762
Encoder KL Loss: 0.05991195653961309
Decoder NLL Loss: 1.3569840733390093
Signal Cross-Entropy Loss: 0.5822296956147268


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 5: 100%|██████████| 49/49 [00:07<00:00,  6.48it/s]


Signal Prediction Accuracy: 0.8215720057487488
Final Displacement Error: 5.079691414930383
Average Displacement Error: 2.3627530312051577


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 6: 100%|██████████| 174/174 [01:47<00:00,  1.62it/s]


Overall Loss: 0.709129320650265
Encoder KL Loss: 0.05663993567142679
Decoder NLL Loss: 0.07445639151351885
Signal Cross-Entropy Loss: 0.5780329949211802


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 6: 100%|██████████| 49/49 [00:07<00:00,  6.39it/s]


Signal Prediction Accuracy: 0.8219706416130066
Final Displacement Error: 4.983383835578452
Average Displacement Error: 2.327674941140778


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 7: 100%|██████████| 174/174 [01:48<00:00,  1.61it/s]


Overall Loss: 0.2694384474521395
Encoder KL Loss: 0.056546420244307345
Decoder NLL Loss: -0.3628031255218373
Signal Cross-Entropy Loss: 0.5756951556808646


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 7: 100%|██████████| 49/49 [00:07<00:00,  6.40it/s]


Signal Prediction Accuracy: 0.8187286853790283
Final Displacement Error: 4.871837679220706
Average Displacement Error: 2.2872565303530012


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 8: 100%|██████████| 174/174 [01:46<00:00,  1.63it/s]


Overall Loss: 1.7389418659196507
Encoder KL Loss: 0.06524618798545724
Decoder NLL Loss: 1.0991110144371716
Signal Cross-Entropy Loss: 0.5745846675387748


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 8: 100%|██████████| 49/49 [00:07<00:00,  6.55it/s]


Signal Prediction Accuracy: 0.8206421136856079
Final Displacement Error: 4.849023891955006
Average Displacement Error: 2.285345291604801


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 9: 100%|██████████| 174/174 [01:51<00:00,  1.56it/s]


Overall Loss: 1.2416942191535028
Encoder KL Loss: 0.07801404332035575
Decoder NLL Loss: 0.5916368021253737
Signal Cross-Entropy Loss: 0.5720433669871294


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 9: 100%|██████████| 49/49 [00:07<00:00,  6.53it/s]


Signal Prediction Accuracy: 0.8236979246139526
Final Displacement Error: 5.04387735833927
Average Displacement Error: 2.3634958899751


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 10: 100%|██████████| 174/174 [02:08<00:00,  1.35it/s]


Overall Loss: -0.7910515138472634
Encoder KL Loss: 0.08342284580756881
Decoder NLL Loss: -1.4434585425783293
Signal Cross-Entropy Loss: 0.568984184456968


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 10: 100%|██████████| 49/49 [00:07<00:00,  6.46it/s]


Signal Prediction Accuracy: 0.8247873187065125
Final Displacement Error: 4.8548590504393285
Average Displacement Error: 2.2788482417865676


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 11: 100%|██████████| 174/174 [02:10<00:00,  1.34it/s]


Overall Loss: -2.2896004629203643
Encoder KL Loss: 0.08278378309018304
Decoder NLL Loss: -2.93873739910537
Signal Cross-Entropy Loss: 0.5663531623009977


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 11: 100%|██████████| 49/49 [00:07<00:00,  6.47it/s]


Signal Prediction Accuracy: 0.8278964161872864
Final Displacement Error: 4.875974767062129
Average Displacement Error: 2.292960558618818


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 12: 100%|██████████| 174/174 [02:09<00:00,  1.35it/s]


Overall Loss: -2.856028769550652
Encoder KL Loss: 0.08244775171423775
Decoder NLL Loss: -3.503666663991995
Signal Cross-Entropy Loss: 0.5651901530465863


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 12: 100%|██████████| 49/49 [00:07<00:00,  6.46it/s]


Signal Prediction Accuracy: 0.8251593708992004
Final Displacement Error: 4.836557096364547
Average Displacement Error: 2.2698940379279002


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 13: 100%|██████████| 174/174 [01:47<00:00,  1.62it/s]


Overall Loss: -3.3665330118831562
Encoder KL Loss: 0.08063439331178006
Decoder NLL Loss: -4.012534117219092
Signal Cross-Entropy Loss: 0.5653667100544635


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 13: 100%|██████████| 49/49 [00:07<00:00,  6.52it/s]


Signal Prediction Accuracy: 0.8265411853790283
Final Displacement Error: 4.74970343648171
Average Displacement Error: 2.232592614329591


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 14: 100%|██████████| 174/174 [01:47<00:00,  1.62it/s]


Overall Loss: -3.5619557883547635
Encoder KL Loss: 0.07553391462598727
Decoder NLL Loss: -4.20144053505755
Signal Cross-Entropy Loss: 0.563950844879808


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 14: 100%|██████████| 49/49 [00:07<00:00,  6.52it/s]


Signal Prediction Accuracy: 0.823272705078125
Final Displacement Error: 4.755261727741787
Average Displacement Error: 2.2346281737697367


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 15: 100%|██████████| 174/174 [01:46<00:00,  1.63it/s]


Overall Loss: -4.918372796184717
Encoder KL Loss: 0.07370298664117682
Decoder NLL Loss: -5.555359233384844
Signal Cross-Entropy Loss: 0.5632834655457529


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 15: 100%|██████████| 49/49 [00:07<00:00,  6.55it/s]


Signal Prediction Accuracy: 0.8244419097900391
Final Displacement Error: 4.752256466417896
Average Displacement Error: 2.2324596716433156


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 16: 100%|██████████| 174/174 [01:44<00:00,  1.67it/s]


Overall Loss: -5.579170581938206
Encoder KL Loss: 0.07064785251672238
Decoder NLL Loss: -6.211874194528865
Signal Cross-Entropy Loss: 0.5620557686378217


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 16: 100%|██████████| 49/49 [00:07<00:00,  6.59it/s]


Signal Prediction Accuracy: 0.8245482444763184
Final Displacement Error: 4.741152505485379
Average Displacement Error: 2.22657509725921


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 17: 100%|██████████| 174/174 [01:46<00:00,  1.64it/s]


Overall Loss: -5.91279977148977
Encoder KL Loss: 0.06930524971464584
Decoder NLL Loss: -6.543342491676066
Signal Cross-Entropy Loss: 0.5612374614367539


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 17: 100%|██████████| 49/49 [00:07<00:00,  6.61it/s]


Signal Prediction Accuracy: 0.824787437915802
Final Displacement Error: 4.729623458823379
Average Displacement Error: 2.220778839928763


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 18: 100%|██████████| 174/174 [01:45<00:00,  1.64it/s]


Overall Loss: -6.141065962698269
Encoder KL Loss: 0.06923921419114898
Decoder NLL Loss: -6.7710318469453155
Signal Cross-Entropy Loss: 0.5607266727535204


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 18: 100%|██████████| 49/49 [00:07<00:00,  6.55it/s]


Signal Prediction Accuracy: 0.827949583530426
Final Displacement Error: 4.753188829032743
Average Displacement Error: 2.2289942746259728


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 19: 100%|██████████| 174/174 [01:44<00:00,  1.67it/s]


Overall Loss: -6.099895799982141
Encoder KL Loss: 0.06852024084963328
Decoder NLL Loss: -6.729240174951221
Signal Cross-Entropy Loss: 0.5608241475861646


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 19: 100%|██████████| 49/49 [00:07<00:00,  6.59it/s]

Signal Prediction Accuracy: 0.824415385723114
Final Displacement Error: 4.724471379299553
Average Displacement Error: 2.2147310436988366





Execute test set

In [1]:
import os
import torch
from torch.utils.data import DataLoader

from NRI.dataset import SignalizedIntersectionDatasetForNRI
from NRI.models import DynamicNeuralRelationalInference
from NRI.experiments.config import ExperimentConfig
from NRI.experiments.test import generate_result
from SinD.config import get_dataset_path
from SinD.dataset.io import get_dataset_records
from NRI.dataset.utils import split_dataset
from NRI.dataset import SignalizedIntersectionDatasetForNRI, SignalizedIntersectionDatasetConfig

experiment_config = ExperimentConfig(
    obs_len=12,
    pred_len=12,
    encoder_loss_weight=1.0,
    decoder_loss_weight=1.0,
    signal_loss_weight=1.0,
    n_epoch=20,
    batch_size=64,
    checkpoint_prefix='dnri_12_baseline',
    checkpoint_interval=10,
    use_cuda=True,
    mask_traffic_signal=False
)

# Load the dataset
dataset_path = get_dataset_path()
dataset_files = get_dataset_records(dataset_path)
_, _, test_files = split_dataset(dataset_files, 0.7, 0.2)

dataset_config = SignalizedIntersectionDatasetConfig(
    obs_len=experiment_config.obs_len,
    pred_len=experiment_config.pred_len,
    stride=15,
    encode_traffic_signals=True,
    padding_value=0.0,
    include_incomplete_trajectories=False
)

test_set = SignalizedIntersectionDatasetForNRI(dataset_config)
test_set.load_records(dataset_path, test_files, verbose=True)

model = DynamicNeuralRelationalInference(
    hid_dim=64,
    n_edges=4,
    dgvae=False
)


# Load the saved model and configuration
checkpoint = torch.load('../checkpoints/dnri_12_baseline_best.pt')
model.load_state_dict(checkpoint['params'])
test_stats = generate_result(model, test_set, experiment_config)

# Report the results
test_stats.report()

load_records: 100%|██████████| 3/3 [00:37<00:00, 12.60s/it]
  return _nested.nested_tensor(
[test] generating result: 100%|██████████| 34/34 [00:06<00:00,  5.66it/s]

=== Overall Result ===
Signal Prediction Accuracy: 0.7718290090560913
Overall Final Displacement Error: 1.9836796522140503
Overall Average Displacement Error: 0.7808113098144531
=== Displacement Error of Class `pedestrian` ===
FDE: 3.556249748331029
ADE: 1.9006368871441541
=== Displacement Error of Class `car` ===
FDE: 2.210488477220763
ADE: 0.5796978505316064
=== Displacement Error of Class `truck` ===
FDE: 20.928002699728935
ADE: 5.911805153854432
=== Displacement Error of Class `bus` ===
FDE: 0.05091768265484478
ADE: 0.02343659750122238
=== Displacement Error of Class `motorcycle` ===
FDE: 2.2600996758829237
ADE: 0.8208592312035741
=== Displacement Error of Class `tricycle` ===
FDE: 1.1623582486798372
ADE: 0.4981800039675467
=== Displacement Error of Class `bicycle` ===
FDE: 3.5434410093189834
ADE: 1.521978985106761
=== Overall Result ===
Signal Prediction Accuracy: 0.7718290090560913
Overall Final Displacement Error: 1.9836796522140503
Overall Average Displacement Error: 0.78081130


