# Exploring the Impact of Structured Representations in Scenario Generation Based on Large Language Models

## Preparation

In [None]:
import os

from google.colab import userdata


base_dir = "/content"
lctgen_extended_dir = f"{base_dir}/lctgen-extended"
lctgen_dir = f"{base_dir}/lctgen"

t = userdata.get("GITHUB_TOKEN")
u = "WilliamLiao2015"
r = "lctgen-extended"


os.chdir(base_dir)
!git clone https://{t}@github.com/{u}/{r}.git
os.chdir(lctgen_extended_dir)
os.chdir(base_dir)

In [None]:
import sys


# fetch codebase
CODE_DIR = "lctgen"
os.makedirs(f"./{CODE_DIR}", exist_ok=True)
!git clone https://github.com/Ariostgx/lctgen.git $CODE_DIR
os.chdir(lctgen_dir)
sys.path.append(lctgen_dir)

In [None]:
%pip install -q -r requirements_colab.txt --quiet
%pip install google-generativeai --quiet
%pip install ipympl --quiet
%pip install pymongo --quiet

In [None]:
!gdown https://drive.google.com/uc?id=17_TI-q4qkCOt988spWIZCqDLkZpMSptO -O data.zip
!unzip data.zip -d {lctgen_dir}/data/demo/waymo/

In [None]:
!gdown https://drive.google.com/uc?id=1_s_35QO6OiHHgDxHHAa7Djadm-_I7Usr -O example.ckpt
!mkdir {lctgen_extended_dir}/checkpoints
!mv example.ckpt {lctgen_extended_dir}/checkpoints

In [None]:
os.chdir(lctgen_extended_dir)

In [None]:
from google.colab import output

output.enable_custom_widget_manager()

## Setup

In [None]:
from scripts.colab import setup_colab


setup_colab()

## LLM

### OpenAI API

In [None]:
from llms.openai import get_openai_llm, inference_openai_llm


llm_name = "gpt-3.5-turbo" # @param ["gpt-3.5-turbo", "gpt-4"]
llm_model = get_openai_llm(llm_name)

inference_llm = lambda query: inference_openai_llm(llm_model, query)

### Google Generative AI

In [None]:
from imports.system import pprint
from imports.packages import genai


for llm_name in genai.list_models():
    pprint(llm_name)

In [None]:
from llms.google_generativeai import get_google_llm, inference_google_llm


llm_name = "gemini-1.0-pro-latest" # @param ["gemini-1.0-pro-latest", "gemini-1.5-pro-latest"]
llm_model = get_google_llm()

inference_llm = lambda query: inference_google_llm(llm_model, query)

### Llama 3 Together.AI API

In [None]:
from llms.llama3 import inference_llama3_llm


llm_name = "llama3"

inference_llm = inference_llama3_llm

### Llama 3 Local API

In [None]:
from llms.llama3_local import inference_llama3_llm


llm_name = "llama3 (local)"

inference_llm = inference_llama3_llm

### LLM Inference

In [None]:
import random

from prompts.structured import scenarios, generate_prompt


# query = 'V1 goes straight and collides with V2 while V2 turns left'  # @param {type:"string"}

query = generate_prompt(random.choice(scenarios))

print("Query:")
print(query)
print()

llm_result = inference_llm(query)

print("LLM inference result:")
print(llm_result)

### Predefined LLM Result

In [None]:
query = "Predefined result."
llm_result = """
Actor Vector:
- 'V1': [-1, 0, 0, 6, 4, 3, 3, 3]
- 'V2': [0, 0, 1, 1, 1, 1, 1, 1]
- 'V3': [2, 0, 0, 2, 4, 4, 3, 3]
- 'V4': [1, 0, 1, 2, 4, 4, 3, 3]
- 'V5': [0, 1, 2, 0, 0, 0, 0, 0]
- 'V6': [0, 1, 2, 0, 0, 0, 0, 0]
- 'V7': [0, 1, 2, 0, 0, 0, 0, 0]
- 'V8': [0, 1, 2, 0, 0, 0, 0, 0]
- 'V9': [0, 1, 1, 2, 4, 4, 4, 4]
- 'V10': [0, 1, 1, 2, 4, 4, 4, 4]
- 'V11': [0, 1, 1, 1, 4, 4, 4, 4]
- 'V12': [3, 1, 0, 2, 4, 4, 4, 4]
Map Vector:
- 'Map': [2, 2, 2, 2, 1, 2]
"""

## Batch

### Metrics

In [None]:
from imports.packages import np, plt, animation, HTML
from scripts.visualize import visualize
from utils.check_types import is_number

from metrics.overlapped_area_rate import evaluate_overlapped_area_rate, visualize_overlapped_area_rate
from metrics.road_collision_rate import evaluate_road_collision_rate, visualize_road_collision_rate
from metrics.car_collision_rate import evaluate_car_collision_rate, visualize_car_collision_rate
from metrics.minimum_speed_rate import evaluate_minimum_speed_rate, visualize_minimum_speed_rate


visualizations = [
    visualize_overlapped_area_rate,
    visualize_road_collision_rate,
    visualize_car_collision_rate,
    visualize_minimum_speed_rate
]
evaluations = [
    evaluate_overlapped_area_rate,
    evaluate_road_collision_rate,
    evaluate_car_collision_rate,
    evaluate_minimum_speed_rate
]


def visualize_all(data, agents, t, visualizations=visualizations):
    plt.gca().cla()
    visualize(data, agents, t)
    for i, method in enumerate(visualizations):
        method(data, agents, t, 55 - 5 * (i + 1))

def evaluate_all(data, agents, t, evaluations=evaluations):
    results = {}

    for evaluation in evaluations:
        method_name = evaluation.__name__
        print(f"Evaluating method: {method_name}")
        results[evaluation] = []
        for t in range(50):
            results[evaluation].append(evaluation(data, agents, t))
        print(evaluation.__doc__.format(result=np.mean(results[evaluation])))
        print()

    return results


def get_state_str(results, query=None, llm_result=None):
    state_str = ""
    lines = json.dumps({
        "query": query if query else "",
        "llm_result": llm_result,
        "results": {evaluation.__name__: values for evaluation, values in results.items()}
    }, indent=2).splitlines()
    for i, line in enumerate(lines):
        previous = is_number(lines[i - 1].strip().replace(",", "")) if i > 0 else False
        current = is_number(line.strip().replace(",", ""))
        next = is_number(lines[i + 1].strip().replace(",", "")) if i < len(lines) - 1 else False
        if current: state_str += line.strip()
        elif next: state_str += line.rstrip()
        elif previous: state_str += line.lstrip() + "\n"
        else: state_str += line + "\n"

    return state_str

def get_anim_html(data, agents):
    fig = plt.gcf()
    fig.set_size_inches(5, 5)
    fig.set_dpi(100)

    anim = animation.FuncAnimation(fig, lambda t: visualize_all(data, agents, t), frames=50, interval=100, repeat=False)
    anim_html = anim.to_jshtml()

    return anim_html

### Inference

In [None]:
import random
from imports.system import os, sys, json
from configs.paths import lctgen_dir
sys.path.append(lctgen_dir)

from imports.packages import tqdm
from configs.demo import cfg, model, map_vecs, map_ids
from imports.trafficgen import *
from scripts.colab import in_colab
from scripts.inference import inference
from prompts.structured import scenarios, generate_prompt

from IPython.display import clear_output
from pymongo import MongoClient


def get_database(uri: str, database_name: str) -> MongoClient:
    client = MongoClient(uri)
    return client[database_name]


if not in_colab:
    from dotenv import load_dotenv
    load_dotenv()

cfg.merge_from_list(["DATASET.DATA_LIST.ROOT", f"{lctgen_dir}/data/list"])
cfg.merge_from_list(["DATASET.DATA_PATH", f"{lctgen_dir}/data/demo/waymo"])


n = 1

user = "williamliao2015"
password = os.getenv("MONGODB_PASSWORD")

database = get_database(f"mongodb+srv://{user}:{password}@cluster-1.8eayemu.mongodb.net/?retryWrites=true&w=majority&appName=Cluster-1", "lctgen-extended")


for i in tqdm(range(n)):
    clear_output()

    query = generate_prompt(random.choice(scenarios))
    llm_result = inference_llm(query).replace("Note: ", "")
    print("Query:")
    print(query)
    print()
    print("LLM inference result:")
    print(llm_result)

    data, agents = inference(model, cfg, map_vecs, map_ids, llm_result)
    results = evaluate_all(data, agents, 0)

    state_str = get_state_str(results, query=query, llm_result=llm_result)
    anim_html = get_anim_html(data, agents)

    database["batch-test"].insert_one({
        "query": query if query else "",
        "llm_name": llm_name,
        "llm_result": llm_result,
        "results": {evaluation.__name__: values for evaluation, values in results.items()}
    })

### Visualize

In [None]:
%matplotlib widget

state_str = get_state_str(evaluate_all(data, agents, 0), query=query, llm_result=llm_result)

anim_html = get_anim_html(data, agents)
display(HTML(anim_html))

### Save

In [None]:
from configs.paths import base_dir
from imports.system import datetime, os


current_time = datetime.now().strftime("%Y-%m-%d %H-%M-%S")
save_to = f"{base_dir}/lctgen_records"

if not os.path.exists(save_to):
    os.mkdir(save_to)

state_dir = f"{save_to}/{current_time}.json"
print(f"Saving state to \"{state_dir}\"")
with open(state_dir, "w", encoding="utf-8") as fp:
    fp.write(state_str)

html_dir = f"{save_to}/{current_time}.html"
print(f"Saving HTML to \"{html_dir}\"")
with open(html_dir, "w", encoding="utf-8") as fp:
    fp.write(anim_html)

In [None]:
!zip -j {save_to}.zip {save_to}/*