In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib as plt

In [2]:
from NRI.experiments.config import ExperimentConfig


experiment_config = ExperimentConfig(
    obs_len=12,
    pred_len=18,
    encoder_loss_weight=1.0,
    decoder_loss_weight=1.0,
    signal_loss_weight=1.0,
    n_epoch=20,
    batch_size=64,
    checkpoint_prefix='dnri_18_NoTraff',
    checkpoint_interval=10,
    use_cuda=True,
    mask_traffic_signal=True
)

In [3]:
from SinD.config import get_dataset_path
from SinD.dataset.io import get_dataset_records
from NRI.dataset.utils import split_dataset

dataset_path = get_dataset_path()
dataset_files = get_dataset_records(dataset_path)
train_files, valid_files, test_files = split_dataset(dataset_files, 0.7, 0.2)

load datasets

In [4]:
from NRI.dataset import SignalizedIntersectionDatasetForNRI, SignalizedIntersectionDatasetConfig

dataset_config = SignalizedIntersectionDatasetConfig(
    obs_len=experiment_config.obs_len,
    pred_len=experiment_config.pred_len,
    stride=15,
    encode_traffic_signals=True,
    padding_value=0.0
)

train_set = SignalizedIntersectionDatasetForNRI(dataset_config)
train_set.load_records(dataset_path, train_files, verbose=True)

valid_set = SignalizedIntersectionDatasetForNRI(dataset_config)
valid_set.load_records(dataset_path, valid_files, verbose=True)


load_records: 100%|██████████| 16/16 [02:18<00:00,  8.68s/it]
load_records: 100%|██████████| 4/4 [00:40<00:00, 10.19s/it]


In [5]:
from NRI.models import DynamicNeuralRelationalInference
from NRI.experiments.main import train

model = DynamicNeuralRelationalInference(
    hid_dim=64,
    n_edges=4,
    dgvae=False
)

if experiment_config.use_cuda:
    model.cuda()

optimizer = torch.optim.Adam(model.parameters())

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

train(
    model,
    train_set=train_set,
    valid_set=valid_set,
    config=experiment_config,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler
)

  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 0: 100%|██████████| 172/172 [02:25<00:00,  1.19it/s]


Overall Loss: 33.77419683545136
Encoder KL Loss: 0.018318692989089686
Decoder NLL Loss: 33.75587829323702
Signal Cross-Entropy Loss: 1.7816878862159162


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 0: 100%|██████████| 48/48 [00:10<00:00,  4.49it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.775298660000165
Average Displacement Error: 2.228523169954618


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 1: 100%|██████████| 172/172 [02:24<00:00,  1.19it/s]


Overall Loss: 16.171535963235897
Encoder KL Loss: 0.02128282363450717
Decoder NLL Loss: 16.15025315173835
Signal Cross-Entropy Loss: 1.7816823873408987


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 1: 100%|██████████| 48/48 [00:09<00:00,  5.04it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 6.3463621238867445
Average Displacement Error: 3.003972848256429


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 2: 100%|██████████| 172/172 [02:25<00:00,  1.18it/s]


Overall Loss: 12.80350357432698
Encoder KL Loss: 0.05278921690444614
Decoder NLL Loss: 12.750714357509175
Signal Cross-Entropy Loss: 1.7816871487817112


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 2: 100%|██████████| 48/48 [00:09<00:00,  4.84it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.204722245534261
Average Displacement Error: 2.3704915270209312


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 3: 100%|██████████| 172/172 [02:40<00:00,  1.07it/s]


Overall Loss: 11.347978680632837
Encoder KL Loss: 0.0651669574165067
Decoder NLL Loss: 11.282811719317776
Signal Cross-Entropy Loss: 1.7816794556240698


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 3: 100%|██████████| 48/48 [00:11<00:00,  4.24it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.893272812167804
Average Displacement Error: 2.3205504591266313


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 4: 100%|██████████| 172/172 [02:43<00:00,  1.05it/s]


Overall Loss: 10.437753508257314
Encoder KL Loss: 0.08506814137014537
Decoder NLL Loss: 10.352685382199835
Signal Cross-Entropy Loss: 1.7816733246625862


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 4: 100%|██████████| 48/48 [00:11<00:00,  4.29it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.275467470288277
Average Displacement Error: 2.464212921758493


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 5: 100%|██████████| 172/172 [02:39<00:00,  1.08it/s]


Overall Loss: 5.524356749168659
Encoder KL Loss: 0.08319096509800403
Decoder NLL Loss: 5.441165760506032
Signal Cross-Entropy Loss: 1.781665510216425


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 5: 100%|██████████| 48/48 [00:11<00:00,  4.26it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 6.638623724381129
Average Displacement Error: 3.0379897554715476


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 6: 100%|██████████| 172/172 [02:41<00:00,  1.07it/s]


Overall Loss: 4.472943908946459
Encoder KL Loss: 0.08661034023172644
Decoder NLL Loss: 4.386333558448524
Signal Cross-Entropy Loss: 1.7816625556280428


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 6: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 6.571618651350339
Average Displacement Error: 3.0649593199292817


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 7: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: 3.8693483686724357
Encoder KL Loss: 0.08833271836818649
Decoder NLL Loss: 3.7810156518636764
Signal Cross-Entropy Loss: 1.781661240860477


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 7: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.302382508913676
Average Displacement Error: 2.2771383126576743


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 8: 100%|██████████| 172/172 [02:41<00:00,  1.06it/s]


Overall Loss: 3.0357480412998874
Encoder KL Loss: 0.08621484591344061
Decoder NLL Loss: 2.9495331925708177
Signal Cross-Entropy Loss: 1.7816578087418595


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 8: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.305331180493037
Average Displacement Error: 2.2375955606500306


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 9: 100%|██████████| 172/172 [02:41<00:00,  1.07it/s]


Overall Loss: 1.6483239611870688
Encoder KL Loss: 0.07862482822045337
Decoder NLL Loss: 1.569699125543106
Signal Cross-Entropy Loss: 1.781653355720429


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 9: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.169546927014987
Average Displacement Error: 2.206044060488542


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 10: 100%|██████████| 172/172 [02:41<00:00,  1.07it/s]


Overall Loss: -0.4253471418998615
Encoder KL Loss: 0.07788238037637506
Decoder NLL Loss: -0.503229523207559
Signal Cross-Entropy Loss: 1.7816499762756894


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 10: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.960537309447925
Average Displacement Error: 2.1147676731149354


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 11: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: -1.5454262095871716
Encoder KL Loss: 0.0729361366653858
Decoder NLL Loss: -1.6183623486837415
Signal Cross-Entropy Loss: 1.781648079323209


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 11: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.015959978103638
Average Displacement Error: 2.1378600547711053


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 12: 100%|██████████| 172/172 [02:41<00:00,  1.07it/s]


Overall Loss: -1.405685569852763
Encoder KL Loss: 0.06992789808400843
Decoder NLL Loss: -1.4756134675577455
Signal Cross-Entropy Loss: 1.7816457755343826


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 12: 100%|██████████| 48/48 [00:11<00:00,  4.29it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.867833033204079
Average Displacement Error: 2.1059751585125923


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 13: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: -1.6432933446972864
Encoder KL Loss: 0.0674537768544153
Decoder NLL Loss: -1.7107471164619157
Signal Cross-Entropy Loss: 1.781644961861674


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 13: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.907699808478355
Average Displacement Error: 2.0974437644084296


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 14: 100%|██████████| 172/172 [02:40<00:00,  1.07it/s]


Overall Loss: -2.4134062367902938
Encoder KL Loss: 0.0652201458551856
Decoder NLL Loss: -2.4786263752409217
Signal Cross-Entropy Loss: 1.7816429491652945


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 14: 100%|██████████| 48/48 [00:11<00:00,  4.27it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.939689626296361
Average Displacement Error: 2.1402989526589713


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 15: 100%|██████████| 172/172 [02:41<00:00,  1.07it/s]


Overall Loss: -3.777455963887446
Encoder KL Loss: 0.06487926541892595
Decoder NLL Loss: -3.8423352306437955
Signal Cross-Entropy Loss: 1.7816409323104583


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 15: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.998148669799169
Average Displacement Error: 2.1764689683914185


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 16: 100%|██████████| 172/172 [02:40<00:00,  1.07it/s]


Overall Loss: -4.49195310022942
Encoder KL Loss: 0.06291799019848887
Decoder NLL Loss: -4.55487106671167
Signal Cross-Entropy Loss: 1.7816397215044835


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 16: 100%|██████████| 48/48 [00:11<00:00,  4.27it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.938395082950592
Average Displacement Error: 2.1451578016082444


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 17: 100%|██████████| 172/172 [02:41<00:00,  1.06it/s]


Overall Loss: -4.126851330662881
Encoder KL Loss: 0.06142083995130867
Decoder NLL Loss: -4.188272158159783
Signal Cross-Entropy Loss: 1.7816387525824635


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 17: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.844655081629753
Average Displacement Error: 2.1042445624868074


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 18: 100%|██████████| 172/172 [02:42<00:00,  1.06it/s]


Overall Loss: -4.720964643151261
Encoder KL Loss: 0.06076704249392416
Decoder NLL Loss: -4.781731674837512
Signal Cross-Entropy Loss: 1.7816376187080518


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 18: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.911520704627037
Average Displacement Error: 2.1511467744906745


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 19: 100%|██████████| 172/172 [02:40<00:00,  1.07it/s]


Overall Loss: -4.940957611383396
Encoder KL Loss: 0.06015496475752011
Decoder NLL Loss: -5.001112581165724
Signal Cross-Entropy Loss: 1.7816367897876466


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 19: 100%|██████████| 48/48 [00:11<00:00,  4.28it/s]

Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.926220695177714
Average Displacement Error: 2.172109251221021





In [1]:
import os
import torch
from torch.utils.data import DataLoader

from NRI.dataset import SignalizedIntersectionDatasetForNRI
from NRI.models import DynamicNeuralRelationalInference
from NRI.experiments.config import ExperimentConfig
from NRI.experiments.test import generate_result
from SinD.config import get_dataset_path
from SinD.dataset.io import get_dataset_records
from NRI.dataset.utils import split_dataset
from NRI.dataset import SignalizedIntersectionDatasetForNRI, SignalizedIntersectionDatasetConfig

experiment_config = ExperimentConfig(
    obs_len=12,
    pred_len=12,
    encoder_loss_weight=1.0,
    decoder_loss_weight=1.0,
    signal_loss_weight=1.0,
    n_epoch=20,
    batch_size=64,
    checkpoint_prefix='dnri_18_NoTraff',
    checkpoint_interval=10,
    use_cuda=True,
    mask_traffic_signal=True
)

# Load the dataset
dataset_path = get_dataset_path()
dataset_files = get_dataset_records(dataset_path)
_, _, test_files = split_dataset(dataset_files, 0.7, 0.2)

dataset_config = SignalizedIntersectionDatasetConfig(
    obs_len=experiment_config.obs_len,
    pred_len=experiment_config.pred_len,
    stride=15,
    encode_traffic_signals=True,
    padding_value=0.0,
    include_incomplete_trajectories=False
)

test_set = SignalizedIntersectionDatasetForNRI(dataset_config)
test_set.load_records(dataset_path, test_files, verbose=True)

model = DynamicNeuralRelationalInference(
    hid_dim=64,
    n_edges=4,
    dgvae=False
)

# Load the saved model and configuration
checkpoint = torch.load('../checkpoints/dnri_18_NoTraff_best.pt')
model.load_state_dict(checkpoint['params'])
test_stats = generate_result(model, test_set, experiment_config)

# Report the results
test_stats.report()

load_records: 100%|██████████| 3/3 [00:37<00:00, 12.44s/it]
  return _nested.nested_tensor(
[test] generating result: 100%|██████████| 34/34 [00:06<00:00,  5.39it/s]

=== Overall Result ===
Signal Prediction Accuracy: 0.0
Overall Final Displacement Error: 2.8654088973999023
Overall Average Displacement Error: 1.1738905906677246
=== Displacement Error of Class `pedestrian` ===
FDE: 5.389487702640204
ADE: 2.7717833351974304
=== Displacement Error of Class `car` ===
FDE: 2.6083696736347872
ADE: 0.7664632424754578
=== Displacement Error of Class `truck` ===
FDE: 20.44562187887007
ADE: 6.0214136608185305
=== Displacement Error of Class `bus` ===
FDE: 0.7366150116608537
ADE: 0.262057336327435
=== Displacement Error of Class `motorcycle` ===
FDE: 3.0535265687537145
ADE: 1.1208451778813462
=== Displacement Error of Class `tricycle` ===
FDE: 2.2579938850605688
ADE: 0.977609463123894
=== Displacement Error of Class `bicycle` ===
FDE: 4.7598493628458005
ADE: 2.132578789074484
=== Overall Result ===
Signal Prediction Accuracy: 0.0
Overall Final Displacement Error: 2.8654088973999023
Overall Average Displacement Error: 1.1738905906677246
=== Displacement Error o


