In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from prettytable import PrettyTable

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset

from l5kit.dataset import EgoDatasetVectorized
from l5kit.vectorization.vectorizer_builder import build_vectorizer

from l5kit.simulation.dataset import SimulationConfig
from l5kit.simulation.unroll import ClosedLoopSimulator
from l5kit.cle.closed_loop_evaluator import ClosedLoopEvaluator, EvaluationPlan
from l5kit.cle.metrics import (CollisionFrontMetric, CollisionRearMetric, CollisionSideMetric,
                               DisplacementErrorL2Metric, DistanceToRefTrajectoryMetric)
from l5kit.cle.validators import RangeValidator, ValidationCountingAggregator
from bokeh.models import Button

from l5kit.visualization.visualizer.zarr_utils import simulation_out_to_visualizer_scene
from l5kit.visualization.visualizer.visualizer import visualize, visualize2, visualize3, visualize4
from bokeh.io import output_notebook, show
from l5kit.data import MapAPI

from collections import defaultdict
import os
from stable_baselines3 import SAC
# set env variable for data
os.environ["L5KIT_DATA_FOLDER"] = "/home/pronton/rl/l5kit_dataset/"
dm = LocalDataManager(None)
# get config
cfg = load_config_data("/home/pronton/rl/l5kit/examples/urban_driver/config.yaml")
model_path = "/home/pronton/rl/l5kit/examples/urban_driver/BPTT.pt"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load(model_path).to(device)
# model = SAC.load("/home/pronton/rl/l5kit/examples/RL/gg colabs/logs/SAC_640000_steps.zip")
model = model.eval()
torch.set_grad_enabled(False)
# ===== INIT DATASET
eval_cfg = cfg["val_data_loader"]
eval_zarr = ChunkedDataset(dm.require(eval_cfg["key"])).open()
vectorizer = build_vectorizer(cfg, dm)
eval_dataset = EgoDatasetVectorized(cfg, eval_zarr, vectorizer)
print(eval_dataset)
num_scenes_to_unroll = 2
num_simulation_steps = 10
# ==== DEFINE CLOSED-LOOP SIMULATION
sim_cfg = SimulationConfig(use_ego_gt=False, use_agents_gt=True, disable_new_agents=True,
                           distance_th_far=500, distance_th_close=50, num_simulation_steps=num_simulation_steps,
                           start_frame_index=0, show_info=True)

sim_loop = ClosedLoopSimulator(sim_cfg, eval_dataset, device, model_ego=model, model_agents=None)
# ==== UNROLL
idx1 = 0
scenes_to_unroll = list(range(0, len(eval_zarr.scenes), len(eval_zarr.scenes)//num_scenes_to_unroll))
sim_outs = sim_loop.unroll(scenes_to_unroll)
mapAPI = MapAPI.from_cfg(dm, cfg)
from bokeh.layouts import column, LayoutDOM, row, gridplot

# fs = []
# for sim_out in sim_outs[:2]: # for each scene
#     vis_in = simulation_out_to_visualizer_scene(sim_out, mapAPI)
#     fs.append(visualize4(sim_out.scene_id, vis_in))
#     # show(f)
# show(column(fs))
cols = []
a = ''
output_notebook()
def left_button_callback():
    global a
    a = 'left'
    print("Left")

def right_button_callback():
    global a
    a = 'right'
    print("Right")

def cannot_tell_button_callback():
    global a
    a = 'cannottell'
    print("Cannot tell")

def same_button_callback():
    global a
    a = 'same'
    print("Same")


# Define the buttons
left_button = Button(label="Left", button_type="success")
right_button = Button(label="Right", button_type="success")
cannot_tell_button = Button(label="Cannot tell", button_type="warning")
same_button = Button(label="Same", button_type="danger")

# Attach the callbacks to the buttons
left_button.on_click(left_button_callback)
right_button.on_click(right_button_callback)
cannot_tell_button.on_click(cannot_tell_button_callback)
same_button.on_click(same_button_callback)

button = Button(label="Click me")
from bokeh.models import CustomJS

# define the JavaScript callback function
callback = CustomJS(code="alert('Button clicked!');")

# add the callback to the button
button.js_on_click(callback)

for i,sim_out in enumerate(sim_outs): # for each scene
    vis_in = simulation_out_to_visualizer_scene(sim_out, mapAPI)
    f, buttons = visualize3(sim_out.scene_id, vis_in)
    cols.append(column(buttons, f)) # [column(buttons, f), column(buttons, f)]
# cols.append(row(left_button, column(cannot_tell_button, same_button), right_button))
# grid = gridplot(cols, ncols=2, plot_width=250, plot_height=250)
demo = row(cols)
# pref = row(left_button, column(cannot_tell_button, same_button), right_button, button)
pref = button
f = column(demo,pref )

show(f)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from prettytable import PrettyTable

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset

from l5kit.dataset import EgoDatasetVectorized
from l5kit.vectorization.vectorizer_builder import build_vectorizer

from l5kit.simulation.dataset import SimulationConfig
from l5kit.simulation.unroll import ClosedLoopSimulator
from l5kit.cle.closed_loop_evaluator import ClosedLoopEvaluator, EvaluationPlan
from l5kit.cle.metrics import (CollisionFrontMetric, CollisionRearMetric, CollisionSideMetric,
                               DisplacementErrorL2Metric, DistanceToRefTrajectoryMetric)
from l5kit.cle.validators import RangeValidator, ValidationCountingAggregator
from bokeh.models import Button

from l5kit.visualization.visualizer.zarr_utils import simulation_out_to_visualizer_scene
from l5kit.visualization.visualizer.visualizer import visualize, visualize2, visualize3, visualize4
from bokeh.io import output_notebook, show
from l5kit.data import MapAPI

from collections import defaultdict
import os
from stable_baselines3 import SAC
# set env variable for data
os.environ["L5KIT_DATA_FOLDER"] = "/home/pronton/rl/l5kit_dataset/"
dm = LocalDataManager(None)
# get config
cfg = load_config_data("/home/pronton/rl/l5kit/examples/urban_driver/config.yaml")
model_path = "/home/pronton/rl/l5kit/examples/urban_driver/BPTT.pt"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load(model_path).to(device)
# model = SAC.load("/home/pronton/rl/l5kit/examples/RL/gg colabs/logs/SAC_640000_steps.zip")
model = model.eval()
torch.set_grad_enabled(False)
# ===== INIT DATASET
eval_cfg = cfg["val_data_loader"]
eval_zarr = ChunkedDataset(dm.require(eval_cfg["key"])).open()
vectorizer = build_vectorizer(cfg, dm)
eval_dataset = EgoDatasetVectorized(cfg, eval_zarr, vectorizer)
print(eval_dataset)
num_scenes_to_unroll = 2
num_simulation_steps = 4
# ==== DEFINE CLOSED-LOOP SIMULATION
sim_cfg = SimulationConfig(use_ego_gt=False, use_agents_gt=True, disable_new_agents=True,
                           distance_th_far=500, distance_th_close=50, num_simulation_steps=num_simulation_steps,
                           start_frame_index=0, show_info=True)

sim_loop = ClosedLoopSimulator(sim_cfg, eval_dataset, device, model_ego=model, model_agents=None)
# ==== UNROLL
idx1 = 0
scenes_to_unroll = list(range(0, len(eval_zarr.scenes), len(eval_zarr.scenes)//num_scenes_to_unroll))
sim_outs = sim_loop.unroll(scenes_to_unroll)
mapAPI = MapAPI.from_cfg(dm, cfg)
from bokeh.layouts import column, LayoutDOM, row, gridplot

# fs = []
# for sim_out in sim_outs[:2]: # for each scene
#     vis_in = simulation_out_to_visualizer_scene(sim_out, mapAPI)
#     fs.append(visualize4(sim_out.scene_id, vis_in))
#     # show(f)
# show(column(fs))
cols = []
a = ''
output_notebook()
def left_button_callback():
    global a
    a = 'left'
    print("Left")

def right_button_callback():
    global a
    a = 'right'
    print("Right")

def cannot_tell_button_callback():
    global a
    a = 'cannottell'
    print("Cannot tell")

def same_button_callback():
    global a
    a = 'same'
    print("Same")


# Define the buttons
left_button = Button(label="Left", button_type="success")
right_button = Button(label="Right", button_type="success")
cannot_tell_button = Button(label="Cannot tell", button_type="warning")
same_button = Button(label="Same", button_type="danger")

# Attach the callbacks to the buttons
left_button.on_click(left_button_callback)
right_button.on_click(right_button_callback)
cannot_tell_button.on_click(cannot_tell_button_callback)
same_button.on_click(same_button_callback)

button = Button(label="Click me")
from bokeh.models import CustomJS

# define the JavaScript callback function
callback = CustomJS(code="alert('Button clicked!');")

# add the callback to the button
button.js_on_click(callback)

for i,sim_out in enumerate(sim_outs): # for each scene
    vis_in = simulation_out_to_visualizer_scene(sim_out, mapAPI)
    visualize4(sim_out.scene_id, vis_in)
    cols.append(column(buttons, f)) # [column(buttons, f), column(buttons, f)]
# cols.append(row(left_button, column(cannot_tell_button, same_button), right_button))
# grid = gridplot(cols, ncols=2, plot_width=250, plot_height=250)
demo = row(cols)
# pref = row(left_button, column(cannot_tell_button, same_button), right_button, button)
pref = button
f = column(demo,pref )
