# ANC Databricks: Wikipedia + FAISS + Databricks + StateGraph

This notebook combines two Databricks workflows in one place:

- Baseline notebook pipeline: Wikipedia -> chunking -> FAISS -> Databricks-hosted Q&A with `ChatDatabricks`
- Canonical orchestration diagram + run flow via shared `naturalist_companion.stategraph_shared`

Provider mapping used here:
- `ChatOllama` -> `ChatDatabricks`
- `OllamaEmbeddings` -> `DatabricksEmbeddings`
- Same `WikipediaLoader` + `FAISS` retrieval flow


## Setup (Install -> Restart -> Check)

Run this setup section top-to-bottom once. It is intentionally split into small steps so it is obvious which libraries are being installed.

1. Setup Step 1/7: Install core libraries (`backoff`, `databricks-langchain`, `mlflow-skinny[databricks]`, `langchain*`, `langgraph`, `faiss-cpu`, `wikipedia`).
2. Setup Step 2/7: Restart kernel.
3. Setup Step 3/7: Verify core imports.
4. Setup Step 4/7: Install widget libraries (`ipywidgets`, `jupyterlab_widgets`).
5. Setup Step 5/7: Restart kernel.
6. Setup Step 6/7: Verify widget imports.
7. Setup Step 7/7: Run final preflight (dependencies + repo path checks).

### Databricks auth prerequisites

- Create/sign in to your Databricks Free Edition workspace.
- Create a Personal Access Token (PAT) in Databricks.
- For local notebook runs, set:
  - `export DATABRICKS_HOST="https://<your-workspace-host>"`
  - `export DATABRICKS_TOKEN="<your-pat>"`
- Ensure your user/token has `Can Query` permission for the selected model endpoints.

### Foundation model endpoints

This notebook defaults to Databricks-hosted Foundation Model API endpoints:
- Embeddings: `databricks-bge-large-en`
- Chat: `databricks-meta-llama-3-1-8b-instruct`

If your workspace exposes different endpoint names, override with:
- `DATABRICKS_EMBEDDING_ENDPOINT`
- `DATABRICKS_LLM_ENDPOINT`


In [None]:
#######################################################################################################

###### Setup Step 1/7: Install Core Libraries                                                       ######

#######################################################################################################


import importlib.util
import shlex
from pathlib import Path
from IPython import get_ipython


core_packages = [
    "backoff==2.2.1",
    "databricks-langchain",
    "mlflow-skinny[databricks]",
    "langchain==0.3.27",
    "langchain-core==0.3.83",
    "langchain-community==0.3.31",
    "langchain-text-splitters==0.3.11",
    "langgraph==0.5.3",
    "faiss-cpu==1.13.2",
    "wikipedia==1.4.0",
]
core_module_checks = [
    "faiss",
    "langgraph",
    "databricks_langchain",
    "langchain_community",
    "langchain_text_splitters",
    "wikipedia",
]

requirements_file = Path("../requirements-dbrx-dev.txt")
use_requirements_file_for_core = False  # Set True to install from ../requirements-dbrx-dev.txt.
force_reinstall_core = False


def _flatten_requirements(path: Path, seen=None):
    seen = set() if seen is None else seen
    resolved = path.resolve()
    key = str(resolved)
    if key in seen:
        return []
    seen.add(key)

    if not resolved.exists():
        return []

    packages = []
    for raw_line in resolved.read_text().splitlines():
        line = raw_line.strip()
        if not line or line.startswith("#"):
            continue
        if line.startswith("-r "):
            nested = (resolved.parent / line.split(maxsplit=1)[1]).resolve()
            packages.extend(_flatten_requirements(nested, seen))
            continue
        packages.append(line)
    return packages


missing_core_modules = [name for name in core_module_checks if importlib.util.find_spec(name) is None]

if not missing_core_modules and not force_reinstall_core:
    print("[setup step 1/7] Core libraries are already available in this kernel. Skipping install.")
else:
    if missing_core_modules:
        print(f"[setup step 1/7] Missing core modules: {', '.join(missing_core_modules)}")
    if force_reinstall_core:
        print("[setup step 1/7] force_reinstall_core=True, running install anyway.")

    if use_requirements_file_for_core and requirements_file.exists():
        packages_to_show = _flatten_requirements(requirements_file)
        print(f"[setup step 1/7] Installing from requirements file: {requirements_file}")
        pip_cmd = f"install -q -U -r {requirements_file}"
    else:
        packages_to_show = core_packages
        print("[setup step 1/7] Installing from explicit package list.")
        pip_cmd = "install -q -U " + " ".join(shlex.quote(pkg) for pkg in core_packages)

    print("[setup step 1/7] Libraries that will be installed:")
    for pkg in packages_to_show:
        print(f"  - {pkg}")

    ip = get_ipython()
    if ip is None:
        raise RuntimeError("IPython kernel not found; cannot run %pip install.")

    ip.run_line_magic("pip", pip_cmd)
    print("[setup step 1/7] Core install command finished.")


In [None]:
#######################################################################################################

###### Setup Step 2/7: Restart Kernel After Core Install                                            ######

#######################################################################################################


RESTART_NOW = False  # Set True to trigger restart in Databricks.

if RESTART_NOW:
    try:
        dbutils.library.restartPython()
    except NameError as exc:
        raise RuntimeError(
            "Automatic restart is available in Databricks only. Use your IDE/Jupyter restart action."
        ) from exc
    except Exception as exc:
        raise RuntimeError(f"Automatic restart failed: {type(exc).__name__}: {exc}") from exc
else:
    print("[setup step 2/7] Restart the kernel now.")
    print("  - Databricks: set RESTART_NOW=True and rerun this cell (or restart from the UI).")
    print("  - VS Code/Jupyter: use Restart Kernel.")
    print("After restarting, run Setup Step 3/7.")


In [None]:
#######################################################################################################

###### Setup Step 3/7: Verify Core Imports                                                         ######

#######################################################################################################


import importlib


checks = [
    ("FAISS backend", ["faiss"]),
    ("LangGraph runtime", ["langgraph"]),
    ("Databricks integration", ["databricks_langchain"]),
    ("Wikipedia loader import", ["langchain_community.document_loaders", "langchain.document_loaders"]),
    ("Text splitter import", ["langchain_text_splitters", "langchain.text_splitter"]),
    ("Vectorstore import", ["langchain_community.vectorstores"]),
    ("In-memory docstore import", ["langchain_community.docstore.in_memory"]),
]

resolved = {}
missing = []

for label, module_candidates in checks:
    matched = None
    for module_name in module_candidates:
        try:
            importlib.import_module(module_name)
            matched = module_name
            break
        except Exception:
            pass

    if matched is None:
        missing.append((label, module_candidates))
    else:
        resolved[label] = matched

if missing:
    lines = ["[setup step 3/7] Core import checks failed:"]
    for label, module_candidates in missing:
        lines.append(f"- {label}: expected one of {', '.join(module_candidates)}")
    lines.append("")
    lines.append("Re-run setup step 1/7, then restart kernel (step 2/7), then run this check again.")
    raise ModuleNotFoundError("\n".join(lines))

print("[setup step 3/7] Core imports verified.")
for label, module_name in resolved.items():
    print(f"  - {label}: {module_name}")


In [None]:
#######################################################################################################

###### Setup Step 4/7: Install Widget Libraries                                                     ######

#######################################################################################################


import importlib.util
from IPython import get_ipython
import shlex


widget_packages = ["ipywidgets", "jupyterlab_widgets"]
widget_module_checks = ["ipywidgets", "jupyterlab_widgets"]

missing = [name for name in widget_module_checks if importlib.util.find_spec(name) is None]
if not missing:
    print("[setup step 4/7] Widget libraries are already available in this kernel.")
else:
    print("[setup step 4/7] Installing widget libraries:")
    for pkg in widget_packages:
        print(f"  - {pkg}")

    pip_cmd = "install -q -U " + " ".join(shlex.quote(pkg) for pkg in widget_packages)
    ip = get_ipython()
    if ip is None:
        raise RuntimeError("IPython kernel not found; cannot run %pip install.")

    ip.run_line_magic("pip", pip_cmd)
    print("[setup step 4/7] Widget install command finished.")


In [None]:
#######################################################################################################

###### Setup Step 5/7: Restart Kernel After Widget Install                                          ######

#######################################################################################################


RESTART_NOW = False  # Set True to trigger restart in Databricks.

if RESTART_NOW:
    try:
        dbutils.library.restartPython()
    except NameError as exc:
        raise RuntimeError(
            "Automatic restart is available in Databricks only. Use your IDE/Jupyter restart action."
        ) from exc
    except Exception as exc:
        raise RuntimeError(f"Automatic restart failed: {type(exc).__name__}: {exc}") from exc
else:
    print("[setup step 5/7] Restart the kernel now.")
    print("  - Databricks: set RESTART_NOW=True and rerun this cell (or restart from the UI).")
    print("  - VS Code/Jupyter: use Restart Kernel.")
    print("After restarting, run Setup Step 6/7.")


In [None]:
#######################################################################################################

###### Setup Step 6/7: Verify Widget Imports                                                       ######

#######################################################################################################


import importlib


checks = [
    ("ipywidgets", ["ipywidgets"]),
    ("jupyterlab_widgets", ["jupyterlab_widgets"]),
]

resolved = {}
missing = []

for label, module_candidates in checks:
    matched = None
    for module_name in module_candidates:
        try:
            importlib.import_module(module_name)
            matched = module_name
            break
        except Exception:
            pass

    if matched is None:
        missing.append((label, module_candidates))
    else:
        resolved[label] = matched

if missing:
    lines = ["[setup step 6/7] Widget import checks failed:"]
    for label, module_candidates in missing:
        lines.append(f"- {label}: expected one of {', '.join(module_candidates)}")
    lines.append("")
    lines.append("Re-run setup step 4/7, restart kernel (step 5/7), then run this check again.")
    raise ModuleNotFoundError("\n".join(lines))

print("[setup step 6/7] Widget imports verified.")
for label, module_name in resolved.items():
    print(f"  - {label}: {module_name}")


In [None]:
#######################################################################################################

###### Setup Step 7/7: Final Preflight (Dependencies + Repo Path)                                  ######

#######################################################################################################


import importlib
import os
import sys
from pathlib import Path


def _candidate_src_paths():
    candidates = []

    # Optional explicit override.
    env_src = os.environ.get("NATURALIST_COMPANION_SRC", "").strip()
    if env_src:
        candidates.append(Path(env_src))

    # Databricks notebook context path (when available).
    try:
        notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
        if "/notebooks/" in notebook_path:
            repo_workspace_path = "/Workspace" + notebook_path.split("/notebooks/", 1)[0]
            candidates.append(Path(repo_workspace_path) / "src")
    except Exception:
        pass

    cwd = Path.cwd()
    candidates.extend([
        cwd / "src",
        cwd.parent / "src",
        cwd.parent.parent / "src",
    ])

    repos_root = Path("/Workspace/Repos")
    if repos_root.exists():
        for pkg_dir in repos_root.glob("*/*/src/naturalist_companion"):
            candidates.append(pkg_dir.parent)

    deduped = []
    seen = set()
    for item in candidates:
        key = str(item)
        if key in seen:
            continue
        seen.add(key)
        deduped.append(item)
    return deduped


for src_path in _candidate_src_paths():
    if (src_path / "naturalist_companion").exists() and str(src_path) not in sys.path:
        sys.path.insert(0, str(src_path))
        break


checks = [
    ("FAISS backend", ["faiss"]),
    ("LangGraph runtime", ["langgraph"]),
    ("Wikipedia loader module", ["langchain_community.document_loaders", "langchain.document_loaders"]),
    ("Text splitter module", ["langchain_text_splitters", "langchain.text_splitter"]),
    ("LangChain vectorstore module", ["langchain_community.vectorstores"]),
    ("LangChain in-memory docstore module", ["langchain_community.docstore.in_memory"]),
    ("Databricks integration", ["databricks_langchain"]),
    ("Naturalist stategraph module", ["naturalist_companion.stategraph_shared"]),
]

resolved = {}
missing = []

for label, module_candidates in checks:
    matched = None
    last_error = None
    for module_name in module_candidates:
        try:
            importlib.import_module(module_name)
            matched = module_name
            break
        except Exception as exc:
            last_error = f"{type(exc).__name__}: {exc}"

    if matched is not None:
        resolved[label] = matched
    else:
        missing.append((label, module_candidates, last_error))

if missing:
    non_stategraph_missing = [m for m in missing if m[0] != "Naturalist stategraph module"]
    if non_stategraph_missing:
        missing = non_stategraph_missing

    lines = ["[setup step 7/7] Missing required notebook dependencies:"]
    for label, module_candidates, last_error in missing:
        lines.append(f"- {label}: expected one of {', '.join(module_candidates)}")
        if last_error:
            lines.append(f"  last error: {last_error}")

    lines.append("")
    lines.append("Re-run setup steps 1/7 -> 3/7 for core dependencies.")
    lines.append("If widgets are missing, also re-run setup steps 4/7 -> 6/7.")
    lines.append("If this repo is synced in Databricks, set NATURALIST_COMPANION_SRC to your repo src path if needed.")
    raise ModuleNotFoundError("\n".join(lines))

print("[setup step 7/7] Dependency preflight passed.")
for label, module_name in resolved.items():
    print(f"  - {label}: {module_name}")


In [None]:
#######################################################################################################

###### Notebook Imports + Runtime Config                                                           ######

#######################################################################################################


import json
import os
import warnings
from threading import Event, Thread
from urllib.parse import quote, unquote, urlparse
from urllib.request import Request, urlopen

from IPython.display import Image, Markdown, display

# Mitigate common macOS OpenMP duplicate-library crashes in notebook kernels.
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")

# Silence noisy tqdm widget warning in IDE notebooks when rich progress widgets are unavailable.
warnings.filterwarnings("ignore", message=".*IProgress not found.*")


# LangChain moved WikipediaLoader in newer releases; keep backward compatibility.
try:
    from langchain_community.document_loaders import WikipediaLoader
except ImportError:
    from langchain.document_loaders import WikipediaLoader

try:
    from langchain_text_splitters import RecursiveCharacterTextSplitter
except ImportError:
    from langchain.text_splitter import RecursiveCharacterTextSplitter

import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from databricks_langchain import ChatDatabricks, DatabricksEmbeddings


def _start_heartbeat(label: str, every_s: float = 8.0):
    stop = Event()

    def _run():
        elapsed = 0.0
        while not stop.wait(every_s):
            elapsed += every_s
            print(f"[{label}] still running... {elapsed:.0f}s elapsed")

    thread = Thread(target=_run, daemon=True)
    thread.start()
    return stop


WIKIPEDIA_API_ENDPOINT = "https://en.wikipedia.org/w/api.php"


def _wiki_api_get(params):
    query = "&".join(
        f"{quote(str(key), safe='')}={quote(str(value), safe='')}" for key, value in params.items()
    )
    url = f"{WIKIPEDIA_API_ENDPOINT}?{query}"
    req = Request(
        url,
        headers={"User-Agent": "naturalist-companion/0.1 (notebook image preview)"},
    )
    try:
        with urlopen(req, timeout=10) as resp:
            return json.loads(resp.read().decode("utf-8"))
    except Exception:
        return {}


def _title_from_wikipedia_url(url):
    parsed = urlparse(str(url or ""))
    marker = "/wiki/"
    if marker not in parsed.path:
        return None
    title = unquote(parsed.path.split(marker, 1)[1]).replace("_", " ").strip()
    return title or None



def _wiki_title_from_search(query):
    payload = _wiki_api_get(
        {
            "action": "query",
            "list": "search",
            "format": "json",
            "formatversion": 2,
            "srlimit": 1,
            "srsearch": query,
        }
    )
    results = (payload.get("query") or {}).get("search") or []
    if not results:
        return None
    title = str(results[0].get("title") or "").strip()
    return title or None


def _iter_page_refs(items):
    for item in items or []:
        if isinstance(item, str):
            raw = item.strip()
            if not raw:
                continue
            title = _title_from_wikipedia_url(raw)
            if not title:
                title = _wiki_title_from_search(raw)
            if title:
                yield {"title": title, "url": f"https://en.wikipedia.org/wiki/{title.replace(' ', '_')}"}
            continue

        if isinstance(item, dict):
            title = str(item.get("title") or "").strip()
            url = str(item.get("url") or item.get("source") or "").strip()
            if not title and url:
                title = _title_from_wikipedia_url(url) or ""
            if title:
                yield {"title": title, "url": url}
            continue

        metadata = getattr(item, "metadata", None) or {}
        title = str(metadata.get("title") or "").strip()
        url = str(metadata.get("source") or "").strip()
        if not title and url:
            title = _title_from_wikipedia_url(url) or ""
        if title:
            yield {"title": title, "url": url}


def _wiki_thumbnail_for_title(title, thumb_px=640):
    payload = _wiki_api_get(
        {
            "action": "query",
            "prop": "pageimages",
            "format": "json",
            "formatversion": 2,
            "redirects": 1,
            "piprop": "thumbnail|original",
            "pithumbsize": int(thumb_px),
            "titles": title,
        }
    )
    pages = (payload.get("query") or {}).get("pages") or []
    for page in pages:
        if not isinstance(page, dict):
            continue
        thumb = page.get("thumbnail") or {}
        original = page.get("original") or {}
        source = thumb.get("source") or original.get("source")
        if source:
            return str(source)
    return None


def display_wikipedia_images_for_pages(items, max_images=4, thumb_px=640):
    seen = set()
    shown = 0
    for ref in _iter_page_refs(items):
        title = ref["title"]
        if title in seen:
            continue
        seen.add(title)

        image_url = _wiki_thumbnail_for_title(title, thumb_px=thumb_px)
        if not image_url:
            continue

        shown += 1
        page_url = ref.get("url") or f"https://en.wikipedia.org/wiki/{title.replace(' ', '_')}"
        display(Markdown(f"**Wikipedia image preview: {title}**"))
        display(Image(url=image_url, width=min(int(thumb_px), 720)))
        display(Markdown(f"[Open page]({page_url})"))

        if shown >= int(max_images):
            break

    if shown == 0:
        print("[wiki-images] No thumbnail images found for the selected pages.")


def _show_databricks_auth_status() -> None:
    host = os.environ.get("DATABRICKS_HOST", "").strip()
    has_token = bool(os.environ.get("DATABRICKS_TOKEN", "").strip())
    in_runtime = bool(os.environ.get("DATABRICKS_RUNTIME_VERSION", "").strip())

    if in_runtime:
        print("[env] Databricks runtime detected; workspace auth should be available via runtime context.")
        return

    if host and has_token:
        print(f"[env] Databricks auth env detected: DATABRICKS_HOST={host}")
    else:
        print("[env] Databricks auth env not fully set. For local runs set DATABRICKS_HOST and DATABRICKS_TOKEN.")


def _raise_databricks_hint(stage: str, endpoint: str, exc: Exception) -> None:
    message = str(exc)
    hints = [
        f"[{stage}] Databricks call failed for endpoint={endpoint!r}.",
        f"Underlying error: {message}",
    ]

    lower_msg = message.lower()
    if "invalid access token" in lower_msg or "403" in lower_msg:
        hints.append("Check that DATABRICKS_TOKEN is valid and matches DATABRICKS_HOST.")
        hints.append("In Free Edition, regenerate a PAT and retry if the old token is expired.")
    if "resource_does_not_exist" in lower_msg or "404" in lower_msg:
        hints.append("Verify the endpoint name in Databricks Serving and update notebook env overrides.")
    if "permission" in lower_msg or "can query" in lower_msg:
        hints.append("Ensure your user/token has Can Query permission on the endpoint.")

    raise RuntimeError("\n".join(hints)) from exc


_show_databricks_auth_status()


#######################################################################################################

###### Config (Define LLMs, Embeddings, Vector Store, Data Loader specs)                          ######

#######################################################################################################


# DataLoader Config
query_terms = [
    "roadcut",
    "geology",
    "sedimentary rock",
    "stratigraphy",
]
max_docs = 3  # Fast local iteration setting.

# Stage 2 chunking + batching controls (keep small for interactive runs).
chunk_size = 1200
chunk_overlap = 150
embedding_batch_size = 8


# Retriever Config
k = 1
EMBEDDING_MODEL_ENDPOINT = os.environ.get("DATABRICKS_EMBEDDING_ENDPOINT", "databricks-bge-large-en")


# LLM Config
LLM_ENDPOINT_NAME = os.environ.get("DATABRICKS_LLM_ENDPOINT", "databricks-meta-llama-3-1-8b-instruct")
TEMPERATURE = 0.0


# Response style controls (Roadside Geology audience: curious drivers, practical field learners).
RESPONSE_TONE = "field-guide"
MAX_BULLETS_PER_SECTION = 4


## Query Prompt (Edit This Cell)

Use the next code cell to set the active question(s).

Question types this notebook is designed for:
- Detour geology: legal pull-offs or short walks near a route segment
- Safety-first prompts: where to stop and what to avoid roadside
- Route constraints: city/exit anchors plus max detour minutes
- Beginner field interpretation: what visual clues to look for and why they matter

Tip: Include your nearest city or exit and your max detour time to improve stop recommendations.


In [None]:
example_question = "I am on I-81 near Hagerstown with a 30-minute detour. Where can I safely stop to observe folded Valley-and-Ridge strata, and what exactly should I look for?"

example_questions = [
    "I am driving I-81 near Bristol, TN. Give me two legal pull-off stops where I can see clear sedimentary layering, and tell me exactly what to look for.",
    "Near I-81 between Winchester and Strasburg, where can I safely stop to see Valley-and-Ridge structure, and what field clues confirm folding?",
    "I have 45 minutes near Hagerstown, MD. What roadside geology stop gives the best payoff for a beginner, and what story does the outcrop tell?",
    "Along I-81 in the Shenandoah Valley, point me to a short-walk stop to compare rock type and landform, then explain why that match matters.",
    "On an I-81 drive day, suggest one stop where I can observe evidence of ancient seas or sediment transport, with specific visual clues.",
]

# StateGraph run can use the same prompt by default; edit independently if desired.
stategraph_question = example_question

place_image_queries = [
    "Hagerstown, Maryland",
    "Bristol, Tennessee",
    "Winchester, Virginia",
    "Strasburg, Virginia",
    "Shenandoah Valley",
]


In [None]:
#######################################################################################################

###### Stage 1/3: Wikipedia Data Load                                                            ######

#######################################################################################################


from time import perf_counter

print("[stage 1/3] Starting Wikipedia document load...")
query = " ".join(query_terms) if isinstance(query_terms, list) else query_terms
print(f"[stage 1/3] query={query!r}, max_docs={max_docs}")

heartbeat = _start_heartbeat("stage 1/3 wikipedia load", every_s=8.0)
t0 = perf_counter()
try:
    docs = WikipediaLoader(query=query, load_max_docs=max_docs).load()
finally:
    heartbeat.set()
t1 = perf_counter()

print(f"[stage 1/3] Loaded {len(docs)} document(s) in {t1 - t0:.2f}s")
if not docs:
    raise RuntimeError("No documents loaded from Wikipedia. Adjust query_terms/max_docs and re-run stage 1.")

print("[stage 1/3] Sample titles:")
for i, doc in enumerate(docs[:3], start=1):
    title = str((doc.metadata or {}).get("title") or f"doc_{i}")
    source = str((doc.metadata or {}).get("source") or "n/a")
    print(f"  {i}. {title} ({source})")

print("[stage 1/3] Wikipedia image previews from loaded pages...")
display_wikipedia_images_for_pages(docs, max_images=min(3, len(docs)))

print("[stage 1/3] Wikipedia image previews near place queries...")
display_wikipedia_images_for_pages(place_image_queries, max_images=min(5, len(place_image_queries)))


In [None]:
#######################################################################################################

###### Stage 2/3: Build + Save FAISS Index, Then Retrieve                                        ######

#######################################################################################################


import os
from pathlib import Path
from time import perf_counter

if "docs" not in globals() or not docs:
    raise RuntimeError("`docs` not found. Run Stage 1/3 first.")

print(f"[stage 2/3] Building embeddings with endpoint={EMBEDDING_MODEL_ENDPOINT!r}...")
print(
    f"[stage 2/3] Chunking docs with chunk_size={chunk_size}, chunk_overlap={chunk_overlap}, "
    f"embedding_batch_size={embedding_batch_size}"
)

splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
split_docs = splitter.split_documents(docs)
if not split_docs:
    raise RuntimeError("Chunking produced 0 documents. Adjust chunk_size/chunk_overlap and retry.")

total_chars = sum(len(str(d.page_content or "")) for d in split_docs)
print(f"[stage 2/3] Prepared {len(split_docs)} chunk(s), total_chars={total_chars}")

embeddings = DatabricksEmbeddings(endpoint=EMBEDDING_MODEL_ENDPOINT)

batch_size = max(1, int(embedding_batch_size))
vector_store = None

heartbeat = _start_heartbeat("stage 2/3 embedding/index", every_s=8.0)
t2 = perf_counter()
try:
    try:
        for start in range(0, len(split_docs), batch_size):
            batch = split_docs[start : start + batch_size]
            b0 = perf_counter()
            if vector_store is None:
                vector_store = FAISS.from_documents(batch, embeddings)
            else:
                vector_store.add_documents(batch)
            b1 = perf_counter()

            done = min(start + batch_size, len(split_docs))
            pct = (100.0 * done) / len(split_docs)
            print(
                f"[stage 2/3] Embedded batch {start // batch_size + 1}: "
                f"{done}/{len(split_docs)} chunks ({pct:.1f}%) in {b1 - b0:.2f}s"
            )
    except Exception as exc:
        _raise_databricks_hint("stage 2/3 embedding/index", EMBEDDING_MODEL_ENDPOINT, exc)
finally:
    heartbeat.set()

t3 = perf_counter()
if vector_store is None:
    raise RuntimeError("Vector store was not created.")

print(f"[stage 2/3] Built FAISS index in {t3 - t2:.2f}s")


faiss_base = os.environ.get("ANC_FAISS_DIR", "").strip()
if faiss_base:
    faiss_dir = (Path(faiss_base).expanduser() / "anc_dbrx").resolve()
else:
    faiss_dir = (Path.home() / "DATA" / "naturalist-companion" / "faiss" / "anc_dbrx").resolve()

faiss_dir.mkdir(parents=True, exist_ok=True)
vector_store.save_local(str(faiss_dir))
print(f"[stage 2/3] Saved FAISS index to: {faiss_dir}")


print(f"[stage 2/3] Running similarity search for question={example_question!r}, k={k}...")
results = vector_store.similarity_search(example_question, k=k)
print(f"[stage 2/3] Retrieved {len(results)} result(s)")

for i, res in enumerate(results, start=1):
    title = str((res.metadata or {}).get("title") or f"result_{i}")
    source = str((res.metadata or {}).get("source") or "n/a")
    snippet = str(res.page_content or "")[:220].replace("\n", " ")
    print(f"  {i}. {title} ({source})")
    print(f"     {snippet}...")

print("[stage 2/3] Wikipedia image previews from retrieved pages...")
display_wikipedia_images_for_pages(results, max_images=min(4, len(results)))


In [None]:
#######################################################################################################

###### Stage 3/3: Generate Answer with ChatDatabricks                                            ######

#######################################################################################################


from time import perf_counter

if "vector_store" not in globals():
    raise RuntimeError("`vector_store` not found. Run Stage 2/3 first.")

print(f"[stage 3/3] Generating answer with endpoint={LLM_ENDPOINT_NAME!r}...")
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME, temperature=TEMPERATURE)

voice_instructions = f"""
You are writing in a concise Roadside Geology field-guide voice for curious drivers.
Tone:
- Plainspoken, observant, and practical (not academic).
- Emphasize what can be seen from legal/safe pull-offs or short walks.
- Explain key geology in everyday language, then add one precise term when useful.
- Include safety and access realism (do not suggest unsafe roadside behavior).
Output format:
1) "Where to stop" (up to {MAX_BULLETS_PER_SECTION} bullets)
2) "What to look for" (up to {MAX_BULLETS_PER_SECTION} bullets)
3) "Why it matters" (2-4 sentences)
4) "Citations" (Wikipedia URLs only)
""".strip()


def _context_for_question(question: str, top_k: int = 2) -> str:
    local_results = vector_store.similarity_search(question, k=max(1, top_k))
    lines = []
    for i, res in enumerate(local_results, start=1):
        title = str((res.metadata or {}).get("title") or f"result_{i}")
        source = str((res.metadata or {}).get("source") or "n/a")
        snippet = str(res.page_content or "")[:450].replace("\n", " ")
        lines.append(f"[{i}] {title} ({source}) :: {snippet}")
    return "\n".join(lines)


def answer_question(question: str) -> str:
    context_block = _context_for_question(question, top_k=max(1, k))
    prompt = (
        f"Use only the provided Wikipedia-grounded context when you can.\n\n"
        f"Question: {question}\n\n"
        f"Context:\n{context_block}\n\n"
        f"Style requirements:\n{voice_instructions}"
    )

    heartbeat = _start_heartbeat("stage 3/3 llm", every_s=8.0)
    t0 = perf_counter()
    try:
        try:
            response = llm.invoke(prompt)
        except Exception as exc:
            _raise_databricks_hint("stage 3/3 llm", LLM_ENDPOINT_NAME, exc)
    finally:
        heartbeat.set()
    dt = perf_counter() - t0
    print(f"[stage 3/3] LLM response received in {dt:.2f}s")
    return str(response.content)


print(f"[stage 3/3] Primary question:\n- {example_question}")
primary_answer = answer_question(example_question)
print("\nAnswer:\n")
print(primary_answer)


In [None]:
#######################################################################################################

###### Stage 3b/3: Run All Example Questions                                                     ######

#######################################################################################################


if "answer_question" not in globals():
    raise RuntimeError("`answer_question` not found. Run Stage 3/3 first.")

if "example_questions" not in globals() or not example_questions:
    raise RuntimeError("`example_questions` is empty. Check config cell.")

all_answers = []
print(f"[stage 3b/3] Running {len(example_questions)} example question(s)...")

for i, q in enumerate(example_questions, start=1):
    print("\n" + "=" * 110)
    print(f"[stage 3b/3] Question {i}/{len(example_questions)}")
    print(q)
    print("=" * 110)

    answer = answer_question(q)
    all_answers.append({"question": q, "answer": answer})

    print("\nResponse:\n")
    print(answer)

print(f"\n[stage 3b/3] Completed {len(all_answers)} question(s).")


## Canonical Workflow Diagram (StateGraph)

StateGraph is the canonical workflow diagram for this notebook because it reflects the shared orchestration logic and is less likely to drift than a separate static diagram.


In [None]:
# Dependencies should already be installed from the setup steps above.

from pathlib import Path
import os
import sys

from IPython.display import Image, Markdown, display


def _candidate_src_paths() -> list[Path]:
    candidates: list[Path] = []

    # Optional explicit override.
    env_src = os.environ.get("NATURALIST_COMPANION_SRC", "").strip()
    if env_src:
        candidates.append(Path(env_src))

    # If running inside Databricks, derive repo path from notebook context.
    try:
        notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
        # Example notebook_path: /Repos/user/repo/notebooks/anc_dbrx
        if "/notebooks/" in notebook_path:
            repo_workspace_path = "/Workspace" + notebook_path.split("/notebooks/", 1)[0]
            candidates.append(Path(repo_workspace_path) / "src")
    except Exception:
        pass

    # Relative paths (works in local and some Databricks repo executions).
    cwd = Path.cwd()
    candidates.extend([
        cwd / "src",
        cwd.parent / "src",
        cwd.parent.parent / "src",
    ])

    # Databricks Repos filesystem fallback scan.
    repos_root = Path("/Workspace/Repos")
    if repos_root.exists():
        for pkg_dir in repos_root.glob("*/*/src/naturalist_companion"):
            candidates.append(pkg_dir.parent)

    # De-duplicate while preserving order.
    deduped: list[Path] = []
    seen: set[str] = set()
    for item in candidates:
        key = str(item)
        if key in seen:
            continue
        seen.add(key)
        deduped.append(item)
    return deduped


STATEGRAPH_AVAILABLE = False
STATEGRAPH_SRC_PATH = None
_stategraph_import_error = None

for src_path in _candidate_src_paths():
    if not (src_path / "naturalist_companion").exists():
        continue
    if str(src_path) not in sys.path:
        sys.path.insert(0, str(src_path))
    STATEGRAPH_SRC_PATH = src_path
    break

try:
    from naturalist_companion.stategraph_shared import (
        build_stategraph_app,
        run_i81_eval_harness,
        run_stategraph,
    )
    STATEGRAPH_AVAILABLE = True
    if STATEGRAPH_SRC_PATH is not None:
        print(f"[stategraph] Loaded naturalist_companion from: {STATEGRAPH_SRC_PATH}")
    else:
        print("[stategraph] Loaded naturalist_companion from current Python path")
except Exception as exc:
    _stategraph_import_error = exc
    STATEGRAPH_AVAILABLE = False
    display(
        Markdown(
            "**StateGraph module is unavailable in this workspace.**\n\n"
            "Baseline Databricks stages (Wikipedia -> FAISS -> ChatDatabricks) still run.\n\n"
            "If this repo is synced in Databricks Repos, open the notebook from the Repo path "
            "(not a copied workspace file), or set `NATURALIST_COMPANION_SRC` to your repo `src` path."
        )
    )
    print(f"[stategraph] import error: {type(exc).__name__}: {exc}")


In [None]:
if not STATEGRAPH_AVAILABLE:
    print("[stategraph] Skipping canonical diagram: naturalist_companion not available.")
else:
    provider = 'databricks'
    app = build_stategraph_app(provider=provider)
    print('Compiled StateGraph successfully for provider:', provider)

    # Render a real image (PNG bytes) instead of plain Mermaid text.
    try:
        png_bytes = app.get_graph().draw_mermaid_png()
        display(Image(data=png_bytes))
    except Exception as exc:
        display(Markdown(f'Graph render fallback (text). Error: `{type(exc).__name__}: {exc}`'))
        print(app.get_graph().draw_mermaid())


In [None]:
if not STATEGRAPH_AVAILABLE:
    print("[stategraph] Skipping run_stategraph: naturalist_companion not available.")
else:
    result = run_stategraph(
        stategraph_question,
        provider='databricks',
        config={'artifact_root': 'out/stategraph/notebook_runs', 'max_retrieval_attempts': 3, 'citation_coverage_threshold': 0.80},
    )
    final_output = result['final_output']
    print('Question:', stategraph_question)
    print('Provider:', final_output['provider'])
    print('Route:', final_output['route_decision']['decision'])
    print('Quality passed:', final_output['quality']['passed'])
    print('Attempts:', final_output['retrieval_attempts'])
    print('Artifact dir:', result['artifact_dir'])
    print('Response:')
    print(final_output['answer']['response'])
    print('Citation image previews:')
    display_wikipedia_images_for_pages(final_output['answer'].get('citations', []), max_images=4)


In [None]:
if not STATEGRAPH_AVAILABLE:
    print("[stategraph] Skipping eval harness: naturalist_companion not available.")
else:
    report = run_i81_eval_harness(
        provider='databricks',
        config={'artifact_root': 'out/stategraph/notebook_eval', 'max_retrieval_attempts': 3, 'citation_coverage_threshold': 0.80},
    )
    print(report['summary'])
    print(report['artifact_root'])
