In [None]:
%pip install "sentence-transformers>=2.2.0,<4.0.0" --upgrade


In [2]:
import os
import sqlite3
from pathlib import Path
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
from tqdm.auto import tqdm

NUPLAN_DATA_ROOT = Path(os.environ["NUPLAN_DATA_ROOT"])
NUPLAN_EXP_ROOT = Path(os.environ["NUPLAN_EXP_ROOT"])

DB_DIR = NUPLAN_DATA_ROOT / "nuplan-v1.1" / "splits" / "mini"
print("DB dir:", DB_DIR)

db_files = sorted(DB_DIR.glob("*.db"))
print(f"Found {len(db_files)} DB files")
db_file = db_files[0]
print("Using DB:", db_file)

INDEX_DIR = NUPLAN_EXP_ROOT / "rag_index"
INDEX_DIR.mkdir(parents=True, exist_ok=True)
INDEX_PATH = INDEX_DIR / "faiss_index.bin"
METADATA_PATH = INDEX_DIR / "metadata.parquet"
INDEX_DIR

DB dir: /home/kshah26_/nuplan/dataset/nuplan-v1.1/splits/mini
Found 64 DB files
Using DB: /home/kshah26_/nuplan/dataset/nuplan-v1.1/splits/mini/2021.05.12.22.00.38_veh-35_01008_01518.db


PosixPath('/home/kshah26_/nuplan/exp/rag_index')

In [3]:
conn = sqlite3.connect(str(db_file))
query = """
SELECT 
    token               AS scenario_id,
    type                AS scenario_type,
    lidar_pc_token      AS lidar_pc_token
FROM scenario_tag
"""
df = pd.read_sql_query(query, conn)
conn.close()
print("Num scenarios:", len(df))
df.head()

Num scenarios: 13812


Unnamed: 0,scenario_id,scenario_type,lidar_pc_token
0,"b'\x18,\nF\x81\x1d[='",low_magnitude_speed,"b'&*Kz\x1f,Z\xb6'"
1,b'\x06]\x86_\xf3C^X',low_magnitude_speed,b'\\YL\x7fU>_\xb9'
2,b'\xc8gu\xa2uKPv',low_magnitude_speed,b'\x92\xb7\xf6\x9d\x95DXc'
3,b'{\xa9\x16;\x95yW\xbf',low_magnitude_speed,b'\x90\x0f\x14\xe9\xcf\xfdY\xdf'
4,b'g\xc5pY\xcdXTs',low_magnitude_speed,b'.y\x9d\xee\x12\xda]<'


In [4]:
def build_text(row):
    return (
        f"Scenario type: {row['scenario_type']} | "
        f"Lidar token: {row['lidar_pc_token']} | "
        f"Scenario id: {row['scenario_id']}"
    )
df["text"] = df.apply(build_text, axis=1)
df[["scenario_id", "scenario_type", "text"]].head()

Unnamed: 0,scenario_id,scenario_type,text
0,"b'\x18,\nF\x81\x1d[='",low_magnitude_speed,Scenario type: low_magnitude_speed | Lidar tok...
1,b'\x06]\x86_\xf3C^X',low_magnitude_speed,Scenario type: low_magnitude_speed | Lidar tok...
2,b'\xc8gu\xa2uKPv',low_magnitude_speed,Scenario type: low_magnitude_speed | Lidar tok...
3,b'{\xa9\x16;\x95yW\xbf',low_magnitude_speed,Scenario type: low_magnitude_speed | Lidar tok...
4,b'g\xc5pY\xcdXTs',low_magnitude_speed,Scenario type: low_magnitude_speed | Lidar tok...


In [5]:
model = SentenceTransformer("all-MiniLM-L6-v2")
texts = df["text"].tolist()
embeddings = model.encode(
    texts,
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True,
)
embeddings = embeddings.astype("float32")
embeddings.shape

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/216 [00:00<?, ?it/s]

(13812, 384)

In [7]:
d = embeddings.shape[1]
index = faiss.IndexFlatIP(d)
index.add(embeddings)
print("Vectors in index:", index.ntotal)
faiss.write_index(index, str(INDEX_PATH))
metadata = df[["scenario_id", "scenario_type", "lidar_pc_token", "text"]].copy()
metadata.to_parquet(METADATA_PATH, index=False)
INDEX_PATH, METADATA_PATH

Vectors in index: 13812


(PosixPath('/home/kshah26_/nuplan/exp/rag_index/faiss_index.bin'),
 PosixPath('/home/kshah26_/nuplan/exp/rag_index/metadata.parquet'))

In [8]:
index = faiss.read_index(str(INDEX_PATH))
metadata = pd.read_parquet(METADATA_PATH)

def embed_query(query: str):
    q_emb = model.encode(
        [query],
        normalize_embeddings=True,
        convert_to_numpy=True,
    ).astype("float32")
    return q_emb

def search(query: str, k: int = 5):
    q_emb = embed_query(query)
    scores, idxs = index.search(q_emb, k)
    idxs = idxs[0]
    scores = scores[0]
    results = metadata.iloc[idxs].copy()
    results["score"] = scores
    return results

search("hard braking scenario", k=5) # Sanity check

Unnamed: 0,scenario_id,scenario_type,lidar_pc_token,text,score
12945,b'*mpH\x94\xb6V^',stationary_in_traffic,b'\xd5\x881\x87\xca\xbeX\\',Scenario type: stationary_in_traffic | Lidar t...,0.260185
12217,b'\xa9<\xfb\x1b\xa7\xab[l',stationary_in_traffic,b'\x94\x1a\xdc\x13\xb5AU\x0b',Scenario type: stationary_in_traffic | Lidar t...,0.238049
13547,b'\x7f\xa4\x9b<I\xfd^m',stationary_in_traffic,b'\xe1\xc7A\xa2s\xffY\xb1',Scenario type: stationary_in_traffic | Lidar t...,0.23331
2120,b'\xae\xf4]H\x95\xd8X\xc4',near_high_speed_vehicle,b'9\xa2\xb7\xad\xc1\x1b^f',Scenario type: near_high_speed_vehicle | Lidar...,0.232866
2239,b'\xb0\x91!\x0b\xaf\x13Y1',near_high_speed_vehicle,b'q\xd5\x03\xf3I\x00\\6',Scenario type: near_high_speed_vehicle | Lidar...,0.232294
