In [None]:
from asim.training.models.sim_agent.smart.smart import SMART
from asim.training.models.sim_agent.smart.smart_config import SMARTConfig

from asim.common.visualization.color.color import TAB_10

import torch

import matplotlib.pyplot as plt
import numpy as np

In [None]:


config = SMARTConfig()
smart_model = SMART(config)
smart_model

In [None]:
from pathlib import Path
import pickle

training_path = Path("/home/daniel/nuplan_cache/training")
pickle_paths = list(training_path.iterdir())

idx = 1

with open(pickle_paths[idx], "rb") as f:
    data = pickle.load(f)
data


In [None]:
for key in data.keys():
    print(f"{key}:")
    try:
        for part_key, part_data  in data[key].items():
            if isinstance(part_data, (torch.Tensor, np.ndarray)):
                print(f"  {part_key}:")
                print(f"    Tensor: shape: {list(part_data.shape)}, dtype: {part_data.dtype}")
            else:
                print(f"  {part_key}: {type(part_data)} - {part_data}")
                print(f"    {type(part_data)} - {part_data}")

    except:
        print(f"    {data[key]}")

In [None]:
"""
map_save:
  traj_pos:
    Tensor: shape: [3013, 3, 2], dtype: torch.float32
  traj_theta:
    Tensor: shape: [3013], dtype: torch.float32
pt_token:
  type:
    Tensor: shape: [3013], dtype: torch.uint8
  pl_type:
    Tensor: shape: [3013], dtype: torch.uint8
  light_type:
    Tensor: shape: [3013], dtype: torch.uint8
  num_nodes: <class 'int'> - 3013
    <class 'int'> - 3013
agent:
  num_nodes: <class 'int'> - 48
    <class 'int'> - 48
  valid_mask:
    Tensor: shape: [48, 91], dtype: torch.bool
  role:
    Tensor: shape: [48, 3], dtype: torch.bool
  id:
    Tensor: shape: [48], dtype: torch.int64
  type:
    Tensor: shape: [48], dtype: torch.uint8
  position:
    Tensor: shape: [48, 91, 3], dtype: torch.float32
  heading:
    Tensor: shape: [48, 91], dtype: torch.float32
  velocity:
    Tensor: shape: [48, 91, 2], dtype: torch.float32
  shape:
    Tensor: shape: [48, 3], dtype: torch.float32
scenario_id:
    5e1ba6c841ae6ccd
"""

In [None]:

# 1. map_save:
#   traj_pos:
#     Tensor: shape: [3013, 3, 2], dtype: torch.float32
#   traj_theta:
#     Tensor: shape: [3013], dtype: torch.float32



fig, ax = plt.subplots(figsize=(10, 10))

traj_pos = data['map_save']['traj_pos']
for i in range(traj_pos.shape[0], ):
    ax.plot(traj_pos[i, :, 0], traj_pos[i, :, 1])

ax.set_title("map_save/traj_pos")


In [None]:
distance = np.linalg.norm(traj_pos[:, :-1] - traj_pos[:, 1:], axis=-1)

# min_x, min_y = np.min(traj_pos[:, :, 0]), np.min(traj_pos[:, :, 1])
# max_x, max_y = np.max(traj_pos[:, :, 0]), np.max(traj_pos[:, :, 1])
# print(np.abs(min_y-max_y), np.abs(min_x-max_x))


plt.hist(distance)

In [None]:
# pt_token:
#   type:
#     Tensor: shape: [3013], dtype: torch.uint8
#   pl_type:
#     Tensor: shape: [3013], dtype: torch.uint8
#   light_type:
#     Tensor: shape: [3013], dtype: torch.uint8
#   num_nodes: <class 'int'> - 3013
#     <class 'int'> - 3013

from asim.common.visualization.matplotlib.utils import add_non_repeating_legend_to_ax


fig, ax = plt.subplots(figsize=(10, 10))
traj_pos = data["map_save"]["traj_pos"]
type = data["pt_token"]["type"]
for i in range(traj_pos.shape[0]):
    ax.plot(traj_pos[i, :, 0], traj_pos[i, :, 1], color=TAB_10[type[i] % len(TAB_10)].hex, label=f"type: {type[i]}")
ax.set_title(f"map_save/traj_pos with type {set(type.tolist())}")
add_non_repeating_legend_to_ax(ax)
plt.show()


fig, ax = plt.subplots(figsize=(10, 10))
traj_pos = data["map_save"]["traj_pos"]
pl_type = data["pt_token"]["pl_type"]
for i in range(traj_pos.shape[0]):
    ax.plot(traj_pos[i, :, 0], traj_pos[i, :, 1], color=TAB_10[pl_type[i] % len(TAB_10)].hex, label=f"pl_type: {pl_type[i]}")
ax.set_title(f"map_save/traj_pos with pl_type {set(pl_type.tolist())}")
add_non_repeating_legend_to_ax(ax)
plt.show()

fig, ax = plt.subplots(figsize=(10, 10))
traj_pos = data["map_save"]["traj_pos"]
light_type = data["pt_token"]["light_type"]
for i in range(traj_pos.shape[0]):
    ax.plot(traj_pos[i, :, 0], traj_pos[i, :, 1], color=TAB_10[light_type[i] % len(TAB_10)].hex, label=f"light_type: {light_type[i]}")
ax.set_title(f"map_save/traj_pos with light_type {set(light_type.tolist())}")
add_non_repeating_legend_to_ax(ax)
plt.show()

In [None]:
# agent:
#   num_nodes: <class 'int'> - 48
#     <class 'int'> - 48
#   valid_mask:
#     Tensor: shape: [48, 91], dtype: torch.bool
#   role:
#     Tensor: shape: [48, 3], dtype: torch.bool
#   id:
#     Tensor: shape: [48], dtype: torch.int64
#   type:
#     Tensor: shape: [48], dtype: torch.uint8
#   position:
#     Tensor: shape: [48, 91, 3], dtype: torch.float32
#   heading:
#     Tensor: shape: [48, 91], dtype: torch.float32
#   velocity:
#     Tensor: shape: [48, 91, 2], dtype: torch.float32
#   shape:
#     Tensor: shape: [48, 3], dtype: torch.float32

num_nodes = data["agent"]["num_nodes"]
valid_mask = data["agent"]["valid_mask"]
role = data["agent"]["role"].argmax(axis=-1)
id = data["agent"]["id"]
type = data["agent"]["type"]
position = data["agent"]["position"]
heading = data["agent"]["heading"]
velocity = data["agent"]["velocity"]
shape = data["agent"]["shape"]


fig, ax = plt.subplots(figsize=(10, 10))
for i in range(num_nodes):
    if type[i] == 1:
        continue

    position_mask = valid_mask[i]
    ax.plot(
        position[i, position_mask, 0],
        position[i, position_mask, 1],
        # color=TAB_10[type[i] % len(TAB_10)].hex,
        label=f"type: {type[i]}, id: {id[i]}, role: {role[i]}",
    )

# ax.legend()
ax.set_aspect('equal', adjustable='box')