In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib as plt

In [2]:
from NRI.experiments.config import ExperimentConfig


experiment_config = ExperimentConfig(
    obs_len=12,
    pred_len=12,
    encoder_loss_weight=1.0,
    decoder_loss_weight=1.0,
    signal_loss_weight=1.0,
    n_epoch=20,
    batch_size=64,
    checkpoint_prefix='dnri_12_NoTraff',
    checkpoint_interval=10,
    use_cuda=True,
    mask_traffic_signal=True
)

In [3]:
from SinD.config import get_dataset_path
from SinD.dataset.io import get_dataset_records
from NRI.dataset.utils import split_dataset

dataset_path = get_dataset_path()
dataset_files = get_dataset_records(dataset_path)
train_files, valid_files, test_files = split_dataset(dataset_files, 0.7, 0.2)

load datasets

In [4]:
from NRI.dataset import SignalizedIntersectionDatasetForNRI, SignalizedIntersectionDatasetConfig

dataset_config = SignalizedIntersectionDatasetConfig(
    obs_len=experiment_config.obs_len,
    pred_len=experiment_config.pred_len,
    stride=15,
    encode_traffic_signals=True,
    padding_value=0.0
)

train_set = SignalizedIntersectionDatasetForNRI(dataset_config)
train_set.load_records(dataset_path, train_files, verbose=True)

valid_set = SignalizedIntersectionDatasetForNRI(dataset_config)
valid_set.load_records(dataset_path, valid_files, verbose=True)


load_records: 100%|██████████| 16/16 [01:59<00:00,  7.45s/it]
load_records: 100%|██████████| 4/4 [00:35<00:00,  8.84s/it]


In [5]:
from NRI.models import DynamicNeuralRelationalInference
from NRI.experiments.main import train

model = DynamicNeuralRelationalInference(
    hid_dim=64,
    n_edges=4,
    dgvae=False
)

if experiment_config.use_cuda:
    model.cuda()

optimizer = torch.optim.Adam(model.parameters())

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

train(
    model,
    train_set=train_set,
    valid_set=valid_set,
    config=experiment_config,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler
)

  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 0: 100%|██████████| 174/174 [01:49<00:00,  1.58it/s]


Overall Loss: 23.142181308790192
Encoder KL Loss: 0.018519533337789714
Decoder NLL Loss: 23.12366173185152
Signal Cross-Entropy Loss: 1.5064185741304013


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 0: 100%|██████████| 49/49 [00:09<00:00,  5.18it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 6.328040911226856
Average Displacement Error: 3.118809262100531


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 1: 100%|██████████| 174/174 [02:09<00:00,  1.34it/s]


Overall Loss: 14.251021434520846
Encoder KL Loss: 0.06333459031115148
Decoder NLL Loss: 14.187686870838032
Signal Cross-Entropy Loss: 1.5064111475286817


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 1: 100%|██████████| 49/49 [00:09<00:00,  5.19it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.716417108263288
Average Displacement Error: 2.796819879084217


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 2: 100%|██████████| 174/174 [02:09<00:00,  1.35it/s]


Overall Loss: 9.284223419496383
Encoder KL Loss: 0.0820855401756092
Decoder NLL Loss: 9.202137886792766
Signal Cross-Entropy Loss: 1.5063997116582155


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 2: 100%|██████████| 49/49 [00:09<00:00,  5.22it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 4.980822967023266
Average Displacement Error: 2.403894310094872


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 3: 100%|██████████| 174/174 [02:10<00:00,  1.34it/s]


Overall Loss: 5.478097881393868
Encoder KL Loss: 0.07417793723957977
Decoder NLL Loss: 5.4039199420775486
Signal Cross-Entropy Loss: 1.5063937518788506


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 3: 100%|██████████| 49/49 [00:09<00:00,  5.21it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.194051742553711
Average Displacement Error: 2.4717269333041445


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 4: 100%|██████████| 174/174 [02:08<00:00,  1.35it/s]


Overall Loss: 3.493428972603259
Encoder KL Loss: 0.07348520948883445
Decoder NLL Loss: 3.419943764805796
Signal Cross-Entropy Loss: 1.5063966834682163


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 4: 100%|██████████| 49/49 [00:09<00:00,  5.19it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.021263064170371
Average Displacement Error: 2.4490429041336994


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 5: 100%|██████████| 174/174 [02:09<00:00,  1.34it/s]


Overall Loss: -0.23727273367259688
Encoder KL Loss: 0.07307265103719701
Decoder NLL Loss: -0.31034539029772934
Signal Cross-Entropy Loss: 1.5063962066310606


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 5: 100%|██████████| 49/49 [00:09<00:00,  5.22it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.233740928221722
Average Displacement Error: 2.4865577585843144


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 6: 100%|██████████| 174/174 [02:10<00:00,  1.33it/s]


Overall Loss: -1.5544995747763533
Encoder KL Loss: 0.06590410660224394
Decoder NLL Loss: -1.620403680490094
Signal Cross-Entropy Loss: 1.5063964251814237


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 6: 100%|██████████| 49/49 [00:09<00:00,  5.23it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.32903275197866
Average Displacement Error: 2.572858988022318


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 7: 100%|██████████| 174/174 [02:10<00:00,  1.33it/s]


Overall Loss: -1.839049376397469
Encoder KL Loss: 0.06728506936081523
Decoder NLL Loss: -1.906334451276638
Signal Cross-Entropy Loss: 1.5063965656291503


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 7: 100%|██████████| 49/49 [00:09<00:00,  5.19it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.1060015036135304
Average Displacement Error: 2.459836441643384


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 8: 100%|██████████| 174/174 [02:09<00:00,  1.34it/s]


Overall Loss: -1.3134046833876574
Encoder KL Loss: 0.07786154179651847
Decoder NLL Loss: -1.391266220238532
Signal Cross-Entropy Loss: 1.5063921692727658


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 8: 100%|██████████| 49/49 [00:09<00:00,  5.19it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.080618337709076
Average Displacement Error: 2.4233307205900854


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 9: 100%|██████████| 174/174 [02:12<00:00,  1.32it/s]


Overall Loss: -2.2980745962296396
Encoder KL Loss: 0.06909586714002587
Decoder NLL Loss: -2.36717046441874
Signal Cross-Entropy Loss: 1.5063926378885903


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 9: 100%|██████████| 49/49 [00:09<00:00,  4.93it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.114655509286997
Average Displacement Error: 2.4316402235809638


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 10: 100%|██████████| 174/174 [02:16<00:00,  1.28it/s]


Overall Loss: -4.479540699686126
Encoder KL Loss: 0.07306102110907953
Decoder NLL Loss: -4.552601719724717
Signal Cross-Entropy Loss: 1.5063915807625292


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 10: 100%|██████████| 49/49 [00:10<00:00,  4.74it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.133960811459288
Average Displacement Error: 2.426927388930807


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 11: 100%|██████████| 174/174 [02:16<00:00,  1.27it/s]


Overall Loss: -5.580272174943453
Encoder KL Loss: 0.06655949613229298
Decoder NLL Loss: -5.646831667885698
Signal Cross-Entropy Loss: 1.506392045952814


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 11: 100%|██████████| 49/49 [00:09<00:00,  5.16it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.108576584835442
Average Displacement Error: 2.403720113695884


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 12: 100%|██████████| 174/174 [02:09<00:00,  1.34it/s]


Overall Loss: -4.725893364298618
Encoder KL Loss: 0.07760545692738453
Decoder NLL Loss: -4.803498831117293
Signal Cross-Entropy Loss: 1.5063894733615264


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 12: 100%|██████████| 49/49 [00:09<00:00,  5.15it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.105430442459729
Average Displacement Error: 2.4379357683415317


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 13: 100%|██████████| 174/174 [02:14<00:00,  1.29it/s]


Overall Loss: -5.6153616514699225
Encoder KL Loss: 0.07442930105260052
Decoder NLL Loss: -5.689790952822257
Signal Cross-Entropy Loss: 1.5063883299115035


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 13: 100%|██████████| 49/49 [00:09<00:00,  5.19it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.133644575975379
Average Displacement Error: 2.4296983918365167


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 14: 100%|██████████| 174/174 [02:10<00:00,  1.34it/s]


Overall Loss: -5.9064051757684375
Encoder KL Loss: 0.07298451263842916
Decoder NLL Loss: -5.979389676288974
Signal Cross-Entropy Loss: 1.506388325115726


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 14: 100%|██████████| 49/49 [00:09<00:00,  5.17it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.118647667826439
Average Displacement Error: 2.4159703279028135


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 15: 100%|██████████| 174/174 [01:51<00:00,  1.57it/s]


Overall Loss: -7.9837959374504495
Encoder KL Loss: 0.06733631980658951
Decoder NLL Loss: -8.05113223229332
Signal Cross-Entropy Loss: 1.506387773601485


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 15: 100%|██████████| 49/49 [00:07<00:00,  6.71it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.164518487696745
Average Displacement Error: 2.4460692746298656


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 16: 100%|██████████| 174/174 [01:47<00:00,  1.61it/s]


Overall Loss: -8.540729007501708
Encoder KL Loss: 0.06510324818992065
Decoder NLL Loss: -8.605832239677172
Signal Cross-Entropy Loss: 1.5063885450363121


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 16: 100%|██████████| 49/49 [00:07<00:00,  6.85it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.1959773861632055
Average Displacement Error: 2.4568872670738067


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 17: 100%|██████████| 174/174 [01:47<00:00,  1.63it/s]


Overall Loss: -8.687009134511834
Encoder KL Loss: 0.06458741625578236
Decoder NLL Loss: -8.751596554942514
Signal Cross-Entropy Loss: 1.506388930068614


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 17: 100%|██████████| 49/49 [00:07<00:00,  6.83it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.238011102287137
Average Displacement Error: 2.464673918120715


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 18: 100%|██████████| 174/174 [01:45<00:00,  1.65it/s]


Overall Loss: -8.71250295639038
Encoder KL Loss: 0.06418771230369462
Decoder NLL Loss: -8.776690694107407
Signal Cross-Entropy Loss: 1.5063890300947995


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 18: 100%|██████████| 49/49 [00:07<00:00,  6.83it/s]


Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.245896120460666
Average Displacement Error: 2.4919705780185


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 19: 100%|██████████| 174/174 [01:45<00:00,  1.65it/s]


Overall Loss: -9.009849474347869
Encoder KL Loss: 0.06421803334064186
Decoder NLL Loss: -9.07406749396489
Signal Cross-Entropy Loss: 1.5063890218734692


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 19: 100%|██████████| 49/49 [00:07<00:00,  6.61it/s]

Signal Prediction Accuracy: 0.0
Final Displacement Error: 5.243369944241582
Average Displacement Error: 2.487721010130279





In [3]:
import os
import torch
from torch.utils.data import DataLoader

from NRI.dataset import SignalizedIntersectionDatasetForNRI
from NRI.models import DynamicNeuralRelationalInference
from NRI.experiments.config import ExperimentConfig
from NRI.experiments.test import generate_result
from SinD.config import get_dataset_path
from SinD.dataset.io import get_dataset_records
from NRI.dataset.utils import split_dataset
from NRI.dataset import SignalizedIntersectionDatasetForNRI, SignalizedIntersectionDatasetConfig

experiment_config = ExperimentConfig(
    obs_len=12,
    pred_len=12,
    encoder_loss_weight=1.0,
    decoder_loss_weight=1.0,
    signal_loss_weight=1.0,
    n_epoch=20,
    batch_size=64,
    checkpoint_prefix='dnri_12_NoTraff',
    checkpoint_interval=10,
    use_cuda=True,
    mask_traffic_signal=True
)

# Load the dataset
dataset_path = get_dataset_path()
dataset_files = get_dataset_records(dataset_path)
_, _, test_files = split_dataset(dataset_files, 0.7, 0.2)

dataset_config = SignalizedIntersectionDatasetConfig(
    obs_len=experiment_config.obs_len,
    pred_len=experiment_config.pred_len,
    stride=15,
    encode_traffic_signals=True,
    padding_value=0.0,
    include_incomplete_trajectories=False
)

test_set = SignalizedIntersectionDatasetForNRI(dataset_config)
test_set.load_records(dataset_path, test_files, verbose=True)

model = DynamicNeuralRelationalInference(
    hid_dim=64,
    n_edges=4,
    dgvae=False
)

# Load the saved model and configuration
checkpoint = torch.load('../checkpoints/dnri_12_NoTraff_best.pt')
model.load_state_dict(checkpoint['params'])
test_stats = generate_result(model, test_set, experiment_config)

# Report the results
test_stats.report()

load_records: 100%|██████████| 3/3 [00:36<00:00, 12.30s/it]
[test] generating result: 100%|██████████| 34/34 [00:06<00:00,  5.55it/s]

=== Overall Result ===
Signal Prediction Accuracy: 0.0
Overall Final Displacement Error: 2.4803903102874756
Overall Average Displacement Error: 1.0032001733779907
=== Displacement Error of Class `pedestrian` ===
FDE: 5.000961668067253
ADE: 2.6214960854557843
=== Displacement Error of Class `car` ===
FDE: 2.4817559997956034
ADE: 0.704604554504484
=== Displacement Error of Class `truck` ===
FDE: 21.160104630454892
ADE: 5.962149904620263
=== Displacement Error of Class `bus` ===
FDE: 0.5421742046517986
ADE: 0.24542376993049264
=== Displacement Error of Class `motorcycle` ===
FDE: 2.5717481467376886
ADE: 1.0075953586331956
=== Displacement Error of Class `tricycle` ===
FDE: 1.2659166227550704
ADE: 0.5014884084203437
=== Displacement Error of Class `bicycle` ===
FDE: 3.7331809061536547
ADE: 1.558264241219535
=== Overall Result ===
Signal Prediction Accuracy: 0.0
Overall Final Displacement Error: 2.4803903102874756
Overall Average Displacement Error: 1.0032001733779907
=== Displacement Error


