In [1]:
import torch
import imp
import ppuu.costs

from ppuu.data.dataloader import DataStore, Dataset, Normalizer
from ppuu import configs

store = DataStore('/home/us441/nvidia-collab/vlad/traffic-data-5-small/state-action-cost/data_i80_v0/')
ds = Dataset(store, 'train', 20, 30, 40, shift=False, random_actions=False, state_diffs=True)
loader = torch.utils.data.DataLoader(ds, batch_size=2)
normalizer = Normalizer(store.stats)

In [89]:
from typing import Dict

class DisplacementErrorMetric:
    def __init__(self, normalizer: Normalizer, *, diffs: bool):
        self.normalizer = normalizer
        self.diffs = diffs
        assert diffs, "The metric only works for diffs representation for now"

    def _convert_states(self, states: torch.Tensor) -> torch.Tensor:
        unnormalized = self.normalizer.unnormalize_states(states)
        if self.diffs:
            cumulative = torch.cumsum(states[..., :2], dim=-2)
            return torch.cat([cumulative, states[..., 2:]], dim=-1)
        else:
            return unnormalized

    def build_log_dict(
        self,
        # [b_size, n_modes, out_horizon, state_dim]
        pred_states: torch.Tensor,
        # [b_size, n_modes, out_horizon, state_dim]
        target_states: torch.Tensor,
    ) -> Dict[str, float]:
        converted_pred = self._convert_states(pred_states)
        converted_targets = self._convert_states(target_states)
        # [batch_size, n_modes, out_horizon, state_dim]
        repeated_targets = converted_targets.unsqueeze(1).repeat_interleave(
            converted_pred.shape[1], dim=1
        )
        # [batch_size, n_modes, out_horizon, 2]
        pos_difference = converted_pred[..., :2] - repeated_targets[..., :2]
        # [batch_size, n_modes, out_horizon]
        difference_norm = torch.norm(pos_difference, dim=-1, p=2)
        # [batch_size, n_modes]
        final_difference_norm = difference_norm[..., -1]
        # [batch_size, n_modes]
        avg_difference_norm = difference_norm.mean(dim=-1)
        # ([batch_size], [batch_size])
        best_avg, best_avg_indices = torch.min(avg_difference_norm, dim=-1)
        # ([batch_size], [batch_size])
        best_final, best_final_indices = torch.min(
            final_difference_norm, dim=-1
        )
        breakpoint()
        return {
            "ADE": best_avg.mean().item(),
            "FDE": best_final.mean().item(),
            "ADE_idx_std": best_avg_indices.float().std().item(),
            "FDE_idx_std": best_final_indices.float().std().item(),
        }

# Test 1 - pred = target 

In [90]:
it = iter(loader)
batch = next(it)

In [91]:
target_states = batch.target_state_seq.states

In [92]:
target_states.shape

torch.Size([2, 30, 5])

In [93]:
mock_pred = target_states.unsqueeze(1).repeat_interleave(3, dim=1)
mock_pred[:, :3] = -100

In [94]:
metric = DisplacementErrorMetric(normalizer, diffs=True)

In [95]:
metric.build_log_dict(mock_pred, target_states)

> <ipython-input-89-aba23eefa74c>(46)build_log_dict()
-> "ADE": best_avg.mean().item(),
(Pdb) 
(Pdb) c


{'ADE': 2138.9619140625,
 'FDE': 4139.9228515625,
 'ADE_idx_std': 0.0,
 'FDE_idx_std': 0.0}

# Mock data with mock normalizer 

In [96]:
metric = DisplacementErrorMetric(normalizer=Normalizer.dummy(), diffs=True)

In [97]:
target_states = torch.zeros(10, 30, 5)
pred_states = torch.ones(10, 3, 30, 5)

In [98]:
metric.build_log_dict(pred_states, target_states)

> <ipython-input-89-aba23eefa74c>(46)build_log_dict()
-> "ADE": best_avg.mean().item(),
(Pdb) c


{'ADE': 21.920307159423828,
 'FDE': 42.4264030456543,
 'ADE_idx_std': 0.0,
 'FDE_idx_std': 0.0}