# Phase 3: Memory Augmented RAG on nuPlan
**Goal:** 
- Build a continual memory on top of the static FAISS index from Phase 2
- Use a small LLM (GPT-2) to produce high level driving decisions
- Log interactions so Phase 4 can evaluate performance

**Team:** Karina Shah, Dhruvina Gujarati, Nilay Kumar, Nishanth Krishna Churchmal

**Course:** CSE 475 - Fall 2025

Phase 3: Memory-Augmented RAG Planner

We start from the FAISS index + metadata built in Phase 2 (static scenario memory).

We create a second FAISS index (index_memory) that starts empty and stores newly encountered scenarios.

For each query (text description of a nuPlan scenario):

Encode the query with SentenceTransformer into an embedding.

Retrieve nearest neighbors from both the static index and memory index.

Concatenate and sort by similarity, tagging rows as source=static or source=memory.

Format the top-k retrieved scenarios into a prompt and pass them to GPT-2, which generates a 2–3 sentence high-level driving plan.

Log this new “experience” into the memory index so future queries can retrieve it.

We show that over multiple queries, the memory index grows and starts to dominate retrieval for similar situations (see source=memory in the top retrieved rows).

In [None]:
import os
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Any

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

print("Libraries imported.")

# Paths consistent with Phase 2
NUPLAN_DATA_ROOT = Path(os.environ["NUPLAN_DATA_ROOT"])
NUPLAN_EXP_ROOT = Path(os.environ["NUPLAN_EXP_ROOT"])

INDEX_DIR = NUPLAN_EXP_ROOT / "rag_index"
INDEX_PATH = INDEX_DIR / "faiss_index.bin"
METADATA_PATH = INDEX_DIR / "metadata.parquet"

# New: memory-augmented paths
MEMORY_INDEX_PATH = INDEX_DIR / "faiss_index_memory.bin"
MEMORY_METADATA_PATH = INDEX_DIR / "metadata_memory.parquet"

print("INDEX_DIR:", INDEX_DIR)
print("INDEX_PATH:", INDEX_PATH)
print("METADATA_PATH:", METADATA_PATH)
print("MEMORY_INDEX_PATH:", MEMORY_INDEX_PATH)
print("MEMORY_METADATA_PATH:", MEMORY_METADATA_PATH)


In [None]:
# Load static FAISS index and metadata from Phase 2
print("Loading static FAISS index and metadata...")

index_static = faiss.read_index(str(INDEX_PATH))
metadata_static = pd.read_parquet(METADATA_PATH, engine="fastparquet")

print("Static index size:", index_static.ntotal)
print("Metadata rows:", len(metadata_static))

# Load the same embedding model used in Phase 2
embedding_model_name = "all-MiniLM-L6-v2"
embed_model = SentenceTransformer(embedding_model_name)

print("Embedding model loaded:", embedding_model_name)

In [None]:
# Initialize or load memory index + metadata

def init_or_load_memory_index(static_index: faiss.Index):
    d = static_index.d  # embedding dimension

    if MEMORY_INDEX_PATH.exists() and MEMORY_METADATA_PATH.exists():
        print("Loading existing memory index...")
        index_mem = faiss.read_index(str(MEMORY_INDEX_PATH))
        metadata_mem = pd.read_parquet(MEMORY_METADATA_PATH, engine="pyarrow")
    else:
        print("No memory index found. Creating an empty one...")
        index_mem = faiss.IndexFlatIP(d)
        metadata_mem = pd.DataFrame(
            columns=["scenario_id", "scenario_type", "lidar_pc_token", "text", "source"]
        )

    print("Memory index size:", index_mem.ntotal)
    print("Memory metadata rows:", len(metadata_mem))
    return index_mem, metadata_mem


index_memory, metadata_memory = init_or_load_memory_index(index_static)


In [None]:
def embed_query(query: str) -> np.ndarray:
    """Embed a text query into the same vector space as scenarios."""
    q_emb = embed_model.encode(
        [query],
        normalize_embeddings=True,
        convert_to_numpy=True,
    ).astype("float32")
    return q_emb


def search_static(query: str, k: int = 5) -> pd.DataFrame:
    """Retrieve from the fixed Phase 2 index."""
    q_emb = embed_query(query)
    scores, idxs = index_static.search(q_emb, k)
    idxs = idxs[0]
    scores = scores[0]

    results = metadata_static.iloc[idxs].copy()
    results["score"] = scores
    results["source"] = "static"
    return results


def search_memory(query: str, k: int = 5) -> pd.DataFrame:
    """Retrieve from the growing memory index. If empty, return empty df."""
    if index_memory.ntotal == 0:
        return pd.DataFrame(
            columns=["scenario_id", "scenario_type", "lidar_pc_token", "text", "score", "source"]
        )

    q_emb = embed_query(query)
    scores, idxs = index_memory.search(q_emb, k)
    idxs = idxs[0]
    scores = scores[0]

    results = metadata_memory.iloc[idxs].copy()
    results["score"] = scores
    results["source"] = "memory"
    return results


def search_with_memory(query: str, k_static: int = 5, k_memory: int = 5) -> pd.DataFrame:
    """Combine results from static and memory indexes."""
    static_results = search_static(query, k=k_static)
    memory_results = search_memory(query, k=k_memory)

    combined = pd.concat([static_results, memory_results], ignore_index=True)

    # Sort by score descending
    combined = combined.sort_values("score", ascending=False).reset_index(drop=True)
    return combined


In [None]:
search_with_memory("hard braking scenario", k_static=3, k_memory=3).head()


In [None]:
# Load a small causal LLM (GPT-2)
llm_name = "gpt2"  # could also use "distilgpt2" if you want even smaller

tokenizer = AutoTokenizer.from_pretrained(llm_name)
model_llm = AutoModelForCausalLM.from_pretrained(llm_name)

# GPT-2 has no pad token by default; set pad = eos
tokenizer.pad_token = tokenizer.eos_token
model_llm.config.pad_token_id = tokenizer.eos_token_id

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_llm = model_llm.to(device)

print("Loaded LLM:", llm_name, "on", device)


In [None]:
def build_context_block(retrieved: pd.DataFrame, top_m: int = 3) -> str:
    """Format top-m retrieved rows as a context string for the LLM."""
    rows = retrieved.head(top_m)
    lines = []
    for i, row in rows.iterrows():
        line = (
            f"{i+1}. [source={row['source']}] "
            f"type={row['scenario_type']} | text={row['text']}"
        )
        lines.append(line)
    return "\n".join(lines)


def generate_plan(query: str, retrieved: pd.DataFrame, top_m: int = 3) -> str:
    """Use GPT-2 to generate a high-level driving plan."""
    context = build_context_block(retrieved, top_m=top_m)

    prompt = (
        "You are an autonomous driving planner that reasons over past scenarios.\n\n"
        "Retrieved past scenarios:\n"
        f"{context}\n\n"
        f"Current query / situation:\n{query}\n\n"
        "Based on the retrieved scenarios, describe a safe high-level plan "
        "for the ego vehicle in 2-3 sentences.\n\n"
        "Plan:"
    )

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=512,
    ).to(device)

    with torch.no_grad():
        outputs = model_llm.generate(
            **inputs,
            max_new_tokens=80,
            do_sample=True,
            top_p=0.9,
            temperature=0.8,
        )

    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Return only the part after "Plan:"
    if "Plan:" in generated:
        return generated.split("Plan:", 1)[-1].strip()
    else:
        return generated.strip()


In [None]:
def add_memory_entry(
    query: str,
    plan: str,
    top_retrieved: pd.Series,
):
    """
    Create a memory text combining query, plan, and retrieved scenario metadata.
    Append to memory index + metadata and save to disk.
    """
    global index_memory, metadata_memory

    memory_text = (
        f"Memory note | query: {query} | "
        f"plan: {plan} | "
        f"base_scenario_type: {top_retrieved['scenario_type']} | "
        f"base_scenario_id: {top_retrieved['scenario_id']}"
    )

    # Embed and add to FAISS
    emb = embed_model.encode(
        [memory_text],
        normalize_embeddings=True,
        convert_to_numpy=True,
    ).astype("float32")

    index_memory.add(emb)

    new_row = {
        "scenario_id": str(top_retrieved["scenario_id"]),
        "scenario_type": str(top_retrieved["scenario_type"]),
        "lidar_pc_token": str(top_retrieved["lidar_pc_token"]),
        "text": memory_text,
        "source": "memory",
    }
    metadata_memory = pd.concat(
        [metadata_memory, pd.DataFrame([new_row])],
        ignore_index=True,
    )

    # Persist to disk
    faiss.write_index(index_memory, str(MEMORY_INDEX_PATH))
    metadata_memory.to_parquet(MEMORY_METADATA_PATH, index=False, engine="fastparquet")

    print("✅ Memory updated. New memory index size:", index_memory.ntotal)


In [None]:
def plan_with_memory(
    query: str,
    k_static: int = 5,
    k_memory: int = 5,
    top_m_for_llm: int = 3,
    update_memory_flag: bool = True,
):
    """
    Full Phase 3 pipeline:
    1. Retrieve from static + memory indexes
    2. Generate a high-level plan with GPT-2
    3. Optionally update memory with this new experience
    """
    # 1) Retrieval
    retrieved = search_with_memory(query, k_static=k_static, k_memory=k_memory)
    if retrieved.empty:
        print("No retrieved scenarios. Something is wrong with the index.")
        return None

    # 2) Use LLM to generate plan
    plan = generate_plan(query, retrieved, top_m=top_m_for_llm)

    # 3) Optional memory update (use the single best retrieved item as anchor)
    if update_memory_flag:
        top_row = retrieved.iloc[0]
        add_memory_entry(query, plan, top_row)

    return {
        "query": query,
        "retrieved": retrieved,
        "plan": plan,
    }


In [None]:
example_queries = [
    "The ego vehicle is approaching a red light with a slow lead car ahead.",
    "The ego vehicle is entering a roundabout with multiple vehicles already inside.",
    "A pedestrian starts crossing unexpectedly at a crosswalk while the ego car is braking.",
]

results = []
for q in example_queries:
    print("=" * 80)
    print("QUERY:", q)
    out = plan_with_memory(q, k_static=5, k_memory=5, top_m_for_llm=3, update_memory_flag=True)
    results.append(out)
    print("\nGenerated plan:\n", out["plan"])
    print("\nTop retrieved sources:\n", out["retrieved"][["scenario_type", "source", "score"]].head())
