# Exploring the Impact of Structured Representations in Scenario Generation Based on Large Language Models

## Preparation

In [None]:
import os
os.chdir('/content')

# fetch codebase
CODE_DIR = 'lctgen'
os.makedirs(f'./{CODE_DIR}', exist_ok=True)
!git clone https://github.com/Ariostgx/lctgen.git $CODE_DIR
os.chdir('/content/lctgen')

In [None]:
%pip install -q -r requirements_colab.txt --quiet
%pip install google-generativeai --quiet
%pip install ipympl --quiet

In [None]:
!gdown https://drive.google.com/uc?id=17_TI-q4qkCOt988spWIZCqDLkZpMSptO -O data.zip
!unzip data.zip -d /content/lctgen/data/demo/waymo/

In [None]:
!gdown https://drive.google.com/uc?id=1_s_35QO6OiHHgDxHHAa7Djadm-_I7Usr -O example.ckpt
!mkdir /content/lctgen/checkpoints
!mv example.ckpt /content/lctgen/checkpoints

In [None]:
from google.colab import output

output.enable_custom_widget_manager()

## Setup

In [None]:
import itertools
import json
import math
import sys

import torch
import ipywidgets

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image
import openai
import google.generativeai as genai

from google.colab import userdata
from datetime import datetime
from IPython.display import HTML, Javascript
from matplotlib.patches import Wedge
from pprint import pprint
from shapely.geometry import Point, Polygon, LineString

# LCTGen
from lctgen.config.default import get_config
from lctgen.core.registry import registry
from lctgen.datasets.waymo_open_motion import WaymoOpenMotionDataset
from lctgen.inference.utils import output_formating_cot, map_retrival, get_map_data_batch, load_all_map_vectors
from lctgen.models.utils import visualize_input_seq, visualize_output_seq, transform_traj_output_to_waymo_agent

# TrafficGen
from trafficgen.utils.typedef import *
from trafficgen.utils.data_process.agent_process import WaymoAgent

In [None]:
in_colab = "google.colab" in sys.modules

In [None]:
cfg_file = 'cfgs/demo_inference.yaml'
cfg = get_config(cfg_file)

model_cls = registry.get_model(cfg.MODEL.TYPE)
model = model_cls.load_from_checkpoint(cfg.LOAD_CHECKPOINT_PATH, config=cfg, metrics=[], strict=False)
model.eval()

map_data_file = 'data/demo/waymo/demo_map_vec.npy'
map_vecs, map_ids = load_all_map_vectors(map_data_file)

## LLM

### OpenAI API

In [None]:
openai.organization = userdata.get("OPENAI_ORGANIZATION")
openai.api_key = userdata.get("OPENAI_API_KEY")

In [None]:
llm_model_name = "gpt-3.5-turbo" # @param ["gpt-3.5-turbo", "gpt-4"]

llm_cfg = get_config("/content/lctgen/lctgen/gpt/cfgs/attr_ind_motion/non_api_cot_attr_20m.yaml")
llm_cfg.merge_from_list(["LLM.CODEX.MODEL", llm_model_name])

llm_model = registry.get_llm('codex')(llm_cfg)

In [None]:
query = 'V1 speeds up and crashes V2 from back'  # @param {type:"string"}
llm_result = llm_model.forward(query)

print('LLM inference result:')
print(llm_result)

### Google AI Studio API

In [None]:
genai.configure(api_key=userdata.get("GOOGLE_API_KEY"))

for llm_model in genai.list_models():
    pprint(llm_model)

In [None]:
genai.configure(api_key=userdata.get("GOOGLE_API_KEY"))

# Set up the model
generation_config = {
    "temperature": 0.9,
    "top_p": 1,
    "top_k": 1,
    "max_output_tokens": 2048,
}

safety_settings = [
    {
        "category": "HARM_CATEGORY_HARASSMENT",
        "threshold": "BLOCK_MEDIUM_AND_ABOVE"
    },
    {
        "category": "HARM_CATEGORY_HATE_SPEECH",
        "threshold": "BLOCK_MEDIUM_AND_ABOVE"
    },
    {
        "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
        "threshold": "BLOCK_MEDIUM_AND_ABOVE"
    },
    {
        "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
        "threshold": "BLOCK_MEDIUM_AND_ABOVE"
    },
]

llm_model = genai.GenerativeModel(model_name="gemini-1.0-pro-latest", generation_config=generation_config, safety_settings=safety_settings)

In [None]:
query = 'V1 goes straight and collides with V2 while V2 turns left'  # @param {type:"string"}


prompt = "## INSTRUCTIONS\n"

with open("/content/lctgen/lctgen/gpt/prompts/attr_ind_motion/sys_non_api_cot_attr_20m.prompt") as fp:
    prompt += fp.read()

prompt += "\n---\n## QUERIES\n"

with open("/content/lctgen/lctgen/gpt/prompts/attr_ind_motion/non_api_cot_attr_20m.prompt") as fp:
    prompt += fp.read()

prompt = prompt.replace("INSERT_QUERY_HERE", query)

print("Prompt:")
print(prompt)
print("\n\n\n")


result = llm_model.generate_content(prompt)
llm_result = result.text


print("LLM inference result:")
print(llm_result)

### Predefined Result

In [None]:
llm_result = """
Actor Vector:
- 'V1': [-1, 0, 0, 6, 4, 3, 3, 3]
- 'V2': [0, 0, 1, 1, 1, 1, 1, 1]
- 'V3': [2, 0, 0, 2, 4, 4, 3, 3]
- 'V4': [1, 0, 1, 2, 4, 4, 3, 3]
- 'V5': [0, 1, 2, 0, 0, 0, 0, 0]
- 'V6': [0, 1, 2, 0, 0, 0, 0, 0]
- 'V7': [0, 1, 2, 0, 0, 0, 0, 0]
- 'V8': [0, 1, 2, 0, 0, 0, 0, 0]
- 'V9': [0, 1, 1, 2, 4, 4, 4, 4]
- 'V10': [0, 1, 1, 2, 4, 4, 4, 4]
- 'V11': [0, 1, 1, 1, 4, 4, 4, 4]
- 'V12': [3, 1, 0, 2, 4, 4, 4, 4]
Map Vector:
- 'Map': [2, 2, 2, 2, 1, 2]
"""

## Utils

### Type

In [None]:
def is_number(x):
    try: return float(x) != None
    except ValueError: return False

### Geometry

In [None]:
def compute_magnitude(x, y):
    return (x ** 2 + y ** 2) ** 0.5

### Display

- V1: red (ego car)
- V2: blue
- V3: orange
- V4: green
- V5: purple
- V6: brown
- V7: pink
- V8: gray
- V9: olive
- V10: cyan
- V11 (V2): blue
- V12 (V3): orange
- V13 (V4): green
- ...

In [None]:
def draw(center, agents, traj=None, other=None, edge=None, heat_map=False):
    ax = plt.gca()
    plt.axis('equal')

    colors = ['tab:red', 'tab:blue', 'tab:orange', 'tab:green', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
    lane_color = 'black'
    alpha = 0.12
    linewidth = 3

    if heat_map:
        lane_color = 'white'
        alpha = 0.2
        linewidth = 6

    ax.axis('off')

    for j in range(center.shape[0]):
        traf_state = center[j, -1]

        x0, y0, x1, y1, = center[j, :4]

        if x0 == 0: break
        ax.plot((x0, x1), (y0, y1), '--', color=lane_color, linewidth=1, alpha=0.2)

        if traf_state == 1:
            color = 'red'
            ax.plot((x0, x1), (y0, y1), color=color, alpha=alpha, linewidth=linewidth, zorder=5000)
        elif traf_state == 2:
            color = 'yellow'
            ax.plot((x0, x1), (y0, y1), color=color, alpha=alpha, linewidth=linewidth, zorder=5000)
        elif traf_state == 3:
            color = 'green'
            ax.plot((x0, x1), (y0, y1), color=color, alpha=alpha, linewidth=linewidth, zorder=5000)

    if edge is not None:
        for j in range(len(edge)):

            # if lane[j, k, -1] == 0: continue
            x0, y0, x1, y1, = edge[j, :4]
            if x0 == 0: break
            ax.plot((x0, x1), (y0, y1), lane_color, linewidth=1.5)
            # ax.arrow(x0, y0, x1-x0, y1-y0,head_width=1.5,head_length=0.75,width = 0.1)

    if other is not None:
        for j in range(len(other)):

            # if lane[j, k, -1] == 0: continue
            x0, y0, x1, y1, = other[j, :4]
            if x0 == 0: break
            ax.plot((x0, x1), (y0, y1), lane_color, linewidth=0.7, alpha=0.9)

    for i in range(len(agents)):
        agent_position = agents[i].position[0]
        if abs(agent_position[0]) > 45 or abs(agent_position[1]) > 45: continue

        # if i in collide: continue
        if i == 0:
            col = colors[0]
        else:
            ind = (i-1) % 9 + 1
            col = colors[ind]
            if traj is not None:
                traj_i = traj[:, i]
                len_t = traj_i.shape[0] - 1
                for j in range(len_t):
                    x0, y0 = traj_i[j]
                    x1, y1 = traj_i[j + 1]

                    if abs(x0) < 60 and abs(y0) < 60 and abs(x1) < 60 and abs(y1) < 60:
                        ax.plot((x0, x1), (y0, y1), '-', color=col, linewidth=1.8, marker='.', markersize=3)

        agent = agents[i]
        rect = agent.get_rect()[0]
        rect = plt.Polygon(rect, edgecolor='black', facecolor=col, linewidth=0.5, zorder=10000)
        ax.add_patch(rect)

    # ax.set_facecolor('black')
    plt.autoscale()
    plt.xlim([-60, 60])
    plt.ylim([-60, 60])

In [None]:
def visualize(data, agents, t = 0):
    center = data["center"][0].cpu().numpy()
    bound = data["bound"][0].cpu().numpy()
    rest = data["rest"][0].cpu().numpy()

    draw(center, agents[t], other=rest, edge=bound)

    ax = plt.gca()
    ax.text(-60, 60, f"Timestep {t + 1}", fontsize=12, color="black", weight="bold")

    return ax

In [None]:
def copy_text_button(text: str) -> ipywidgets.Widget:
	button = ipywidgets.Button(description="Copy", icon="copy")
	output = ipywidgets.Output(layout=ipywidgets.Layout(display="none"))
	copy_js = Javascript(f"navigator.clipboard.writeText({json.dumps(text)})")

	def on_click(_: ipywidgets.Button) -> None:
		output.clear_output()
		output.append_display_data(copy_js)
	button.on_click(on_click)

	return ipywidgets.Box((button, output))

## Evaluation

### Inference

In [None]:
%matplotlib widget


# format LLM output to Structured Representation (agent and map vectors)
MIN_LENGTH = 4.0
MIN_WIDTH = 1.5

MAX_TIMESTEPS = 50
MAX_AGENT_NUM = 32

agent_vector, map_vector = output_formating_cot(llm_result)

agent_num = len(agent_vector)
vector_dim = len(agent_vector[0])
agent_vector = agent_vector + [[-1]*vector_dim] * (MAX_AGENT_NUM - agent_num)

# retrive map from map dataset
sorted_idx = map_retrival(map_vector, map_vecs)[:1]
map_id = map_ids[sorted_idx[0]]

# load map data
data = get_map_data_batch(map_id, cfg)

# inference with LLM-output Structured Representation
data['text'] = torch.tensor(agent_vector, dtype=data['text'].dtype)[None, ...]
data['agent_mask'] = torch.tensor([1]*agent_num + [0]*(MAX_AGENT_NUM - agent_num), dtype=data['agent_mask'].dtype)[None, ...]

model_output = model.forward(data, 'val')['text_decode_output']
output_scene = model.process(model_output, data, num_limit=1, with_attribute=True, pred_ego=True, pred_motion=True)

agents = transform_traj_output_to_waymo_agent(output_scene[0])

for t in range(MAX_TIMESTEPS):
    for agent in agents[t]:
        agent.length_width = np.clip(agent.length_width, [MIN_LENGTH, MIN_WIDTH], [10.0, 5.0])

### Visualize

In [None]:
visualizations = [
    visualize_overlapped_area,
    visualize_road_collisions,
    visualize_car_collisions,
    visualize_minimum_speed
]
evaluations = [
    evaluate_overlapped_area,
    evaluate_road_collisions,
    evaluate_car_collisions,
    evaluate_minimum_speed
]

results = {}


def visualize_all(data, agents, t, visualizations=visualizations):
    plt.gca().cla()
    visualize(data, agents, t)
    for i, method in enumerate(visualizations):
        method(data, agents, t, 55 - 5 * (i + 1))


for evaluation in evaluations:
    method_name = evaluation.__name__
    print(f"Evaluating method: {method_name}")
    results[evaluation] = []
    for t in range(50):
        results[evaluation].append(evaluation(data, agents, t))
    print(evaluation.__doc__.format(result=np.mean(results[evaluation])))
    print()


print("Copy state as text:")
state_str = ""
lines = json.dumps({
    "query": query if query else "",
    "llm_result": llm_result,
    "results": {evaluation.__name__: values for evaluation, values in results.items()}
}, indent=2).splitlines()
for i, line in enumerate(lines):
    previous = is_number(lines[i - 1].strip().replace(",", "")) if i > 0 else False
    current = is_number(line.strip().replace(",", ""))
    next = is_number(lines[i + 1].strip().replace(",", "")) if i < len(lines) - 1 else False
    if current: state_str += line.strip()
    elif next: state_str += line.rstrip()
    elif previous: state_str += line.lstrip() + "\n"
    else: state_str += line + "\n"
display(copy_text_button(state_str))
print()


fig = plt.gcf()
fig.set_size_inches(5, 5)
fig.set_dpi(100)

anim = animation.FuncAnimation(fig, lambda t: visualize_all(data, agents, t), frames=50, interval=100, repeat=False)
anim_html = anim.to_jshtml()
display(HTML(anim_html))

### Save

In [None]:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
save_to = "/content/lctgen_records"

if not os.path.exists(save_to):
    os.mkdir(save_to)

state_dir = f"{save_to}/{current_time}.json"
print(f"Saving state to \"{state_dir}\"")
with open(state_dir, "w", encoding="utf-8") as fp:
    fp.write(state_str)

html_dir = f"{save_to}/{current_time}.html"
print(f"Saving HTML to \"{html_dir}\"")
with open(html_dir, "w", encoding="utf-8") as fp:
    fp.write(anim_html)

In [None]:
!zip -j /content/lctgen_records.zip /content/lctgen_records/*