In [None]:
from asim.training.models.sim_agent.smart.smart import SMART
from asim.training.models.sim_agent.smart.smart_config import SMARTConfig

from pathlib import Path
from asim.dataset.scene.arrow_scene import ArrowScene
from asim.common.visualization.matplotlib.plots import plot_scene_at_iteration

from asim.dataset.scene.scene_builder import ArrowSceneBuilder
from asim.dataset.scene.scene_filter import SceneFilter

from nuplan.planning.utils.multithreading.worker_sequential import Sequential

In [None]:
log_names = None
split = "nuplan_mini_val"
scene_filter = SceneFilter(
    split_names=[split],
    log_names=log_names,
    timestamp_threshold_s=8.0,
    duration_s=8.1,
    history_s=1.0,
)
scene_builder = ArrowSceneBuilder("/home/daniel/asim_workspace/data")
worker = Sequential()
# worker = RayDistributed()
scenes = scene_builder.get_scenes(scene_filter, worker)
scene: ArrowScene = scenes[100]
plot_scene_at_iteration(scene, iteration=0)
print(scene.get_number_of_iterations(), scene.get_number_of_history_iterations())

In [None]:
from asim.training.models.sim_agent.smart.smart import SMART
from asim.training.models.sim_agent.smart.smart_config import SMARTConfig


checkpoint_path = Path("/home/daniel/epoch_027.ckpt")
config = SMARTConfig(
    hidden_dim=64,
    num_freq_bands=64,
    num_heads=4,
    head_dim=8,
    dropout=0.1,
    hist_drop_prob=0.1,
    num_map_layers=2,
    num_agent_layers=4,
    pl2pl_radius=10,
    pl2a_radius=20,
    a2a_radius=20,
    time_span=20,
    num_historical_steps=11,
    num_future_steps=80,
)

smart_model = SMART.load_from_checkpoint(checkpoint_path, config=config, map_location="cpu")
smart_model.eval()
# print(smart_model.training)

In [None]:

import torch
from torch_geometric.data import HeteroData
from asim.training.feature_builder.smart_feature_builder import SMARTFeatureBuilder
from asim.training.models.sim_agent.smart.datamodules.target_builder import _numpy_dict_to_torch

feature_builder = SMARTFeatureBuilder()
features = feature_builder.build_features(scene)
# features["agent"]["position"][:, :40] = 0.0
_numpy_dict_to_torch(features)


torch_features = HeteroData(features)

from torch_geometric.loader import DataLoader

# If you have a dataset
dataset = [torch_features]  # List with single sample
loader = DataLoader(dataset, batch_size=1, shuffle=False)
with torch.no_grad():
    for batch in loader:
        pred_traj, pred_z, pred_head = smart_model.test_step(batch, 0)
        break


# features["agent"]["valid_mask"].sum(-1)

In [None]:
array = pred_traj.numpy()

array.shape

In [None]:
from matplotlib import pyplot as plt

from asim.common.geometry.transform.se2_array import convert_relative_to_absolute_point_2d_array


origin = scene.get_ego_state_at_iteration(0).bounding_box.center.state_se2
abs_array = convert_relative_to_absolute_point_2d_array(origin, array)


for roll_out in range(abs_array.shape[1]):
    # fig, ax = plt.subplots(figsize=(10, 10))
    fig, ax = plot_scene_at_iteration(scene, iteration=0)
    for i in range(abs_array.shape[0]):
        ax.plot(abs_array[i, roll_out, :, 0], abs_array[i, roll_out, :, 1], label=f"Agent {i}", zorder=15, linewidth=3, alpha=0.5)
    # ax.set_aspect('equal', adjustable='box')
    plt.show()
