In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib as plt

In [2]:
from NRI.experiments.config import ExperimentConfig


experiment_config = ExperimentConfig(
    obs_len=12,
    pred_len=24,
    encoder_loss_weight=1.0,
    decoder_loss_weight=1.0,
    signal_loss_weight=1.0,
    n_epoch=20,
    batch_size=32,
    checkpoint_prefix='dnri',
    checkpoint_interval=10,
    use_cuda=True,
    mask_traffic_signal=False
)

In [3]:
from SinD.config import get_dataset_path
from SinD.dataset.io import get_dataset_records
from NRI.dataset.utils import split_dataset

dataset_path = get_dataset_path()
dataset_files = get_dataset_records(dataset_path)
train_files, valid_files, test_files = split_dataset(dataset_files, 0.7, 0.2)

load datasets

In [4]:
from NRI.dataset import SignalizedIntersectionDatasetForNRI, SignalizedIntersectionDatasetConfig

dataset_config = SignalizedIntersectionDatasetConfig(
    obs_len=experiment_config.obs_len,
    pred_len=experiment_config.pred_len,
    stride=15,
    encode_traffic_signals=True,
    padding_value=0.0
)

train_set = SignalizedIntersectionDatasetForNRI(dataset_config)
train_set.load_records(dataset_path, train_files, verbose=True)

valid_set = SignalizedIntersectionDatasetForNRI(dataset_config)
valid_set.load_records(dataset_path, valid_files, verbose=True)


load_records: 100%|██████████| 16/16 [02:28<00:00,  9.30s/it]
load_records: 100%|██████████| 4/4 [00:42<00:00, 10.63s/it]


In [5]:
from NRI.models import DynamicNeuralRelationalInference
from NRI.experiments.main import train

model = DynamicNeuralRelationalInference(
    hid_dim=64,
    n_edges=4,
    dgvae=False
)

if experiment_config.use_cuda:
    model.cuda()

optimizer = torch.optim.Adam(model.parameters())

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

train(
    model,
    train_set=train_set,
    valid_set=valid_set,
    config=experiment_config,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler
)

  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 0: 100%|██████████| 341/341 [02:40<00:00,  2.12it/s]


Overall Loss: 24.984617171749026
Encoder KL Loss: 0.012981241496547483
Decoder NLL Loss: 23.97270452661596
Signal Cross-Entropy Loss: 0.9989312467686954


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 0: 100%|██████████| 95/95 [00:11<00:00,  8.23it/s]


Signal Prediction Accuracy: 0.8040881156921387
Final Displacement Error: 5.993442038485878
Average Displacement Error: 2.8577335834503175


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 1: 100%|██████████| 341/341 [02:42<00:00,  2.10it/s]


Overall Loss: 13.178670424520094
Encoder KL Loss: 0.04967885112657579
Decoder NLL Loss: 12.400249655994855
Signal Cross-Entropy Loss: 0.7287419193650967


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 1: 100%|██████████| 95/95 [00:11<00:00,  8.21it/s]


Signal Prediction Accuracy: 0.8152310252189636
Final Displacement Error: 5.926944873207494
Average Displacement Error: 2.7077833865818226


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 2: 100%|██████████| 341/341 [02:42<00:00,  2.10it/s]


Overall Loss: 7.871178869627779
Encoder KL Loss: 0.08032908748365562
Decoder NLL Loss: 7.1001665787962445
Signal Cross-Entropy Loss: 0.6906832072042648


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 2: 100%|██████████| 95/95 [00:11<00:00,  8.24it/s]


Signal Prediction Accuracy: 0.8089811205863953
Final Displacement Error: 5.53472799501921
Average Displacement Error: 2.4075523338819806


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 3: 100%|██████████| 341/341 [02:40<00:00,  2.12it/s]


Overall Loss: 3.9379097190071035
Encoder KL Loss: 0.06290524358973128
Decoder NLL Loss: 3.204007385544415
Signal Cross-Entropy Loss: 0.6709970868577702


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 3: 100%|██████████| 95/95 [00:11<00:00,  8.19it/s]


Signal Prediction Accuracy: 0.8243866562843323
Final Displacement Error: 6.0350471346001875
Average Displacement Error: 2.441641720972563


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 4: 100%|██████████| 341/341 [02:41<00:00,  2.11it/s]


Overall Loss: 2.9000645549066597
Encoder KL Loss: 0.07848546436586344
Decoder NLL Loss: 2.160635249727457
Signal Cross-Entropy Loss: 0.66094384427644


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 4: 100%|██████████| 95/95 [00:11<00:00,  8.21it/s]


Signal Prediction Accuracy: 0.8192196488380432
Final Displacement Error: 6.462026927345677
Average Displacement Error: 2.420266295734205


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 5: 100%|██████████| 341/341 [02:40<00:00,  2.12it/s]


Overall Loss: -1.592664258483567
Encoder KL Loss: 0.06891603113488019
Decoder NLL Loss: -2.306043069307142
Signal Cross-Entropy Loss: 0.6444627775474727


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 5: 100%|██████████| 95/95 [00:11<00:00,  8.22it/s]


Signal Prediction Accuracy: 0.8265798091888428
Final Displacement Error: 5.35760047310277
Average Displacement Error: 2.0139231593985305


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 6: 100%|██████████| 341/341 [02:40<00:00,  2.13it/s]


Overall Loss: -1.6200154553061004
Encoder KL Loss: 0.08955437427829087
Decoder NLL Loss: -2.3505976463277483
Signal Cross-Entropy Loss: 0.6410278067910424


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 6: 100%|██████████| 95/95 [00:11<00:00,  8.23it/s]


Signal Prediction Accuracy: 0.8216593861579895
Final Displacement Error: 6.025309876391762
Average Displacement Error: 2.3500540381983708


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 7: 100%|██████████| 341/341 [02:40<00:00,  2.13it/s]


Overall Loss: -3.2018361997919014
Encoder KL Loss: 0.06895935059118131
Decoder NLL Loss: -3.905784693046791
Signal Cross-Entropy Loss: 0.6349891508080046


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 7: 100%|██████████| 95/95 [00:11<00:00,  8.21it/s]


Signal Prediction Accuracy: 0.823879599571228
Final Displacement Error: 7.164597797393799
Average Displacement Error: 2.8406059139653257


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 8: 100%|██████████| 341/341 [02:40<00:00,  2.13it/s]


Overall Loss: -3.3826267079523946
Encoder KL Loss: 0.06698742310060435
Decoder NLL Loss: -4.081637404924612
Signal Cross-Entropy Loss: 0.6320232763842755


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 8: 100%|██████████| 95/95 [00:11<00:00,  8.23it/s]


Signal Prediction Accuracy: 0.8247568607330322
Final Displacement Error: 6.5399296057851695
Average Displacement Error: 2.539578702575282


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 9: 100%|██████████| 341/341 [02:40<00:00,  2.12it/s]


Overall Loss: -3.8111030730334203
Encoder KL Loss: 0.06377414024033504
Decoder NLL Loss: -4.505260211543952
Signal Cross-Entropy Loss: 0.6303829923927614


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[valid] epoch 9: 100%|██████████| 95/95 [00:11<00:00,  8.22it/s]


Signal Prediction Accuracy: 0.8186302185058594
Final Displacement Error: 7.130697024495978
Average Displacement Error: 2.7560066072564378


  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
  return _nested.nested_tensor(
[train] epoch 10:  93%|█████████▎| 316/341 [02:28<00:11,  2.13it/s]


AssertionError: 

In [None]:
import torch.distributions as td

corr = 9.6457e-01
sigma_x = 4.1535e-01
sigma_y = 1.0788e-01

cov = torch.tensor([
    [sigma_x ** 2, sigma_y * corr * sigma_x],
    [corr * sigma_x * sigma_y, sigma_y ** 2],
])

chol = torch.tensor([
    [sigma_x, 0.0],
    [corr * sigma_y, torch.sqrt(1 - torch.tensor(corr) ** 2) * sigma_y]
])

print(td.MultivariateNormal(
    loc=torch.tensor([-1.3988e-02,  1.6305e-02]),
    covariance_matrix=cov
).log_prob(torch.tensor([0.0, 0.0])))

print(td.MultivariateNormal(
    loc=torch.tensor([-1.3988e-02,  1.6305e-02]),
    scale_tril=chol
).log_prob(torch.tensor([0.0, 0.0])))