In [None]:
# lambda_function.py
# Public-repo safe template (no client defaults, no secret hardcoding)

import os, json, re, base64, boto3
from urllib.parse import parse_qs, urlparse
from typing import Any, Dict, List, Optional, Tuple
from botocore.config import Config
from botocore.exceptions import ClientError

# ========= env =========
REGION              = os.getenv("REGION", "us-east-1").strip()
MODEL_ARN           = os.getenv("MODEL_ARN", "").strip()

NUM_RESULTS_DEFAULT = int(os.getenv("NUMBER_OF_RESULTS", "20"))
SEARCH_TYPE_DEFAULT = os.getenv("OVERRIDE_SEARCH_TYPE", "SEMANTIC").strip().upper()

BRAND               = os.getenv("BRAND_NAME", "YourBrand").strip()
USE_CASE            = os.getenv("USE_CASE_TITLE", "AskAnswer").strip()

# ⚠️ No client-specific default in public repo
SN_INSTANCE         = os.getenv("SN_INSTANCE", "").strip()

CORS_ORIGIN         = os.getenv("CORS_ORIGIN", "*").strip()
DEBUG_RAW           = (os.getenv("DEBUG_RAW", "false").lower() == "true")

# Role-based KB IDs (no default KB)
CUSTOMERS_KB_ID     = os.getenv("CUSTOMERS_KB_ID", "").strip()
RESERVATION_KB_ID   = os.getenv("RESERVATION_KB_ID", "").strip()
ONESOURCE_KB_ID     = os.getenv("ONESOURCE_KB_ID", "").strip()

# Generation defaults
DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.6"))
DEFAULT_TOPP        = float(os.getenv("DEFAULT_TOPP", "0.9"))
DEFAULT_MAX_TOKENS  = int(os.getenv("DEFAULT_MAX_TOKENS", "900"))

# URL behavior
DISABLE_URL_FALLBACK = (os.getenv("DISABLE_URL_FALLBACK", "false").lower() == "true")
URL_FIELD_NAME = os.getenv("URL_FIELD_NAME", "sn_url").strip()
ALT_URL_KEYS = [k.strip() for k in os.getenv("ALT_URL_KEYS", "").split(",") if k.strip()]
INLINE_URLS = (os.getenv("INLINE_URLS", "true").lower() == "true")

# Numbered superscript inline citations (HTML-safe)
INLINE_SUPERSCRIPT = (os.getenv("INLINE_SUPERSCRIPT", "false").lower() == "true")

# How much HTML to scan when extracting a KB id from source HTML
HTML_READ_BYTES = int(os.getenv("HTML_READ_BYTES", str(256 * 1024)))

# Optional verified mapping KBB… -> KB… (ground truth provided by you)
try:
    KB_REWRITE = json.loads(os.getenv("KB_REWRITE_JSON", "{}"))
    if not isinstance(KB_REWRITE, dict):
        KB_REWRITE = {}
except Exception:
    KB_REWRITE = {}

# Default prompt (generic; do NOT hardcode client policies here)
PROMPT_TMPL = os.getenv(
    "PROMPT_TEMPLATE",
    """You are a grounded question-answering assistant.

USER QUESTION:
{query}

Rules:
- Use ONLY information retrieved by the Knowledge Base.
- Do not invent facts. If results do not contain enough information, say so.
- Keep the answer concise and clear.

(The system may append Sources from retrieved metadata.)"""
)

client = boto3.client(
    "bedrock-agent-runtime",
    region_name=REGION,
    config=Config(connect_timeout=5, read_timeout=25, retries={"max_attempts": 3, "mode": "standard"}),
)
s3 = boto3.client("s3", region_name=REGION)

# ========= validation =========
def _config_error(message: str, missing: Optional[List[str]] = None) -> Dict[str, Any]:
    out = {"error": message}
    if missing:
        out["missing"] = missing
    return out

def _validate_runtime_config() -> Optional[Dict[str, Any]]:
    missing = []
    if not SN_INSTANCE:
        missing.append("SN_INSTANCE")
    if not MODEL_ARN:
        missing.append("MODEL_ARN")
    if not CUSTOMERS_KB_ID:
        missing.append("CUSTOMERS_KB_ID")
    if not RESERVATION_KB_ID:
        missing.append("RESERVATION_KB_ID")
    if not ONESOURCE_KB_ID:
        missing.append("ONESOURCE_KB_ID")

    if missing:
        return _config_error("Missing required configuration.", missing)

    return None

# ========= helpers =========
def _first(d: Dict, *paths, default=None):
    for p in paths:
        cur = d
        ok = True
        for k in p:
            if isinstance(cur, dict) and k in cur:
                cur = cur[k]
            else:
                ok = False
                break
        if ok:
            return cur
    return default

def _clamp_float(x: Any, lo: float, hi: float, default: float) -> float:
    try:
        v = float(x)
    except Exception:
        return default
    if v < lo:
        return lo
    if v > hi:
        return hi
    return v

def _clamp_int(x: Any, lo: int, hi: int, default: int) -> int:
    try:
        v = int(x)
    except Exception:
        return default
    if v < lo:
        return lo
    if v > hi:
        return hi
    return v

def _get_meta_field(md: Dict, key: str) -> Optional[str]:
    if key in md and isinstance(md[key], (str, int, float)):
        return str(md[key])
    val = _first(md, ("metadataAttributes", key, "value", "stringValue")) \
          or _first(md, ("metadataAttributes", key, "value", "numberValue"))
    return str(val) if val is not None else None

def _build_sn_url_from_kb(kb_number: str) -> str:
    # Canonical SN KB URL pattern (customize via env if needed)
    return f"https://{SN_INSTANCE}.service-now.com/kb_view.do?sysparm_article={kb_number}"

# ===== Permissive validation + canonicalization =====
KB_RE  = re.compile(r"^KB\d+$", re.IGNORECASE)
KBB_RE = re.compile(r"^KBB\d+$", re.IGNORECASE)

def _is_valid_sn_kb_url(url: Optional[str], kb_number: Optional[str]) -> bool:
    """Accept any ServiceNow path on the right instance as long as sysparm_article=KB#### exists."""
    if not url or not isinstance(url, str):
        return False
    expected_host = f"{SN_INSTANCE}.service-now.com"
    try:
        parsed = urlparse(url)
        if parsed.scheme != "https" or parsed.netloc != expected_host:
            return False
        qs = parse_qs(parsed.query)
        got = (qs.get("sysparm_article", [""])[0] or "").upper()
        if kb_number:
            return got == kb_number.upper()
        return bool(KB_RE.match(got))
    except Exception:
        return False

def _canonicalize_sn_url(url: str) -> Optional[str]:
    """Normalize any valid SN URL to canonical kb_view.do URL."""
    try:
        parsed = urlparse(url)
        qs = parse_qs(parsed.query)
        kb = (qs.get("sysparm_article", [""])[0] or "").upper()
        if not KB_RE.match(kb):
            return None
        return _build_sn_url_from_kb(kb)
    except Exception:
        return None

# -------- KB extraction from S3 key path (safe fallback) --------
KB_IN_KEY_RE = re.compile(r"(KB\d+)", re.IGNORECASE)

def _parse_s3_uri(uri: str) -> Tuple[str, str]:
    if not uri.startswith("s3://"):
        raise ValueError("not an s3:// uri")
    without = uri[len("s3://"):]
    bucket, _, key = without.partition("/")
    if not bucket or not key:
        raise ValueError("bad s3 uri")
    return bucket, key

def _kb_from_s3_key(key: str) -> Optional[str]:
    m = KB_IN_KEY_RE.search(key or "")
    return m.group(1).upper() if m else None

# Accept id param as optional in HTML page (some pages omit it)
KB_URL_RE   = re.compile(
    r"https://[a-z0-9.-]+/[^ \t\r\n]*\?[^ \t\r\n]*\bsysparm_article=(KB\d+)\b[^ \t\r\n]*",
    re.IGNORECASE,
)
KB_TOKEN_RE = re.compile(r"\b(KB\d+)\b", re.IGNORECASE)

def _scan_html_for_kb(s3_uri: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
    """
    Read up to HTML_READ_BYTES from the HTML and try to extract a canonical KB id or URL.
    Returns (kb_number, kb_url). Only KB (not KBB) is accepted. Never throws.
    """
    if not s3_uri:
        return None, None
    try:
        bkt, key = _parse_s3_uri(s3_uri)
    except Exception:
        return None, None

    try:
        obj = s3.get_object(Bucket=bkt, Key=key, Range=f"bytes=0-{max(0, HTML_READ_BYTES-1)}")
        body = obj.get("Body")
        data = body.read().decode("utf-8", "ignore") if body else ""
    except ClientError as e:
        code = e.response.get("Error", {}).get("Code", "")
        if code in ("InvalidRange", "416"):
            try:
                obj = s3.get_object(Bucket=bkt, Key=key)
                body = obj.get("Body")
                data = body.read(HTML_READ_BYTES).decode("utf-8", "ignore") if body else ""
            except Exception:
                return None, None
        else:
            return None, None
    except Exception:
        return None, None

    if not data:
        return None, None

    m = KB_URL_RE.search(data)
    if m:
        kb = m.group(1).upper()
        return kb, _build_sn_url_from_kb(kb)

    m = KB_TOKEN_RE.search(data)
    if m:
        kb = m.group(1).upper()
        return kb, _build_sn_url_from_kb(kb)

    return None, None

def _rewrite_to_kb(number: Optional[str], url: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
    """
    If we have KBB and a verified mapping in KB_REWRITE, rewrite to KB and rebuild URL.
    """
    candidate = None
    if number and KBB_RE.match(str(number)):
        candidate = str(number).upper()
    elif url:
        try:
            qs = parse_qs(urlparse(url).query)
            art = (qs.get("sysparm_article", [""])[0] or "").upper()
            if KBB_RE.match(art):
                candidate = art
        except Exception:
            pass
    if not candidate:
        return number, url
    kb = KB_REWRITE.get(candidate)
    if kb and KB_RE.match(str(kb).upper()):
        kb = str(kb).upper()
        return kb, _build_sn_url_from_kb(kb)
    return number, url

def _get_kb_number(md: Dict) -> Optional[str]:
    """
    Prefer metadataAttributes.filterable.number, then metadataAttributes.number, then flat metadata.number.
    As a last resort, extract KB#### from html_key.
    """
    ma_num_filterable = _first(md, ("metadataAttributes", "filterable", "number"))
    if isinstance(ma_num_filterable, dict):
        v = _first(ma_num_filterable, ("value", "stringValue"))
        if isinstance(v, str) and KB_RE.match(v):
            return v.upper()

    ma_num = _first(md, ("metadataAttributes", "number", "value", "stringValue"))
    if isinstance(ma_num, str) and KB_RE.match(ma_num):
        return ma_num.upper()

    flat_num = md.get("number")
    if isinstance(flat_num, str) and KB_RE.match(flat_num):
        return flat_num.upper()

    html_key = md.get("html_key") or _get_meta_field(md, "html_key")
    if isinstance(html_key, str):
        m = re.search(r"(KB\d+)", html_key, re.IGNORECASE)
        if m:
            return m.group(1).upper()

    return None

def _get_url_from_metadata(md: Dict) -> Optional[str]:
    """
    Return the first URL we can find in metadata.
    Looks in:
      - configured top-level keys (e.g., 'sn_url')
      - metadataAttributes.nonFilterable / filterable / attributes
      - nested {value:{stringValue|numberValue}} shapes (depth-limited)
    """
    for k in [URL_FIELD_NAME] + ALT_URL_KEYS:
        u = _get_meta_field(md, k)
        if isinstance(u, str) and u.startswith("http"):
            return u

    attrs = md.get("metadataAttributes") or {}
    buckets: List[Dict[str, Any]] = []
    if isinstance(attrs, dict):
        for sub in ("nonFilterable", "filterable", "attributes"):
            v = attrs.get(sub)
            if isinstance(v, dict):
                buckets.append(v)
        if not buckets:
            buckets = [attrs]

    def _val_to_url(val) -> Optional[str]:
        if isinstance(val, str) and val.startswith("http"):
            return val
        if isinstance(val, dict):
            s = _first(val, ("value", "stringValue")) or _first(val, ("value", "numberValue"))
            if isinstance(s, str) and s.startswith("http"):
                return s
        return None

    for b in buckets:
        for v in b.values():
            u = _val_to_url(v)
            if u:
                return u
            if isinstance(v, dict):
                for vv in v.values():
                    uu = _val_to_url(vv)
                    if uu:
                        return uu

    def _find_url(obj, depth=0):
        if depth > 3:
            return None
        if isinstance(obj, str) and obj.startswith("http"):
            return obj
        if isinstance(obj, dict):
            for vv in obj.values():
                r = _find_url(vv, depth + 1)
                if r:
                    return r
        if isinstance(obj, list):
            for vv in obj:
                r = _find_url(vv, depth + 1)
                if r:
                    return r
        return None

    return _find_url(attrs)

def _extract_ref_info(ref: Dict) -> Dict[str, Any]:
    """
    Build {number, url, score} for a retrieved reference.
    Only return a URL if we can prove a KB id (KB####).
    Never throws.
    """
    md = ref.get("metadata") or {}
    url = _get_url_from_metadata(md)

    score = ref.get("score") or ref.get("relevanceScore")
    try:
        score = float(score) if score is not None else None
    except Exception:
        score = None

    # 1) If metadata already has a ServiceNow URL, normalize/validate it
    try:
        if url:
            canon = _canonicalize_sn_url(url)
            if canon and _is_valid_sn_kb_url(canon, None):
                kb_qs = parse_qs(urlparse(canon).query).get("sysparm_article", [""])[0].upper()
                return {"number": kb_qs, "url": canon, "score": score}
            if _is_valid_sn_kb_url(url, None):
                kb_qs = parse_qs(urlparse(url).query).get("sysparm_article", [""])[0].upper()
                if KB_RE.match(kb_qs):
                    return {"number": kb_qs, "url": url, "score": score}
    except Exception:
        pass

    # 2) If metadata provides KB#### number, build canonical URL
    try:
        kb_from_meta = _get_kb_number(md)
        if kb_from_meta and KB_RE.match(kb_from_meta):
            kb_url = _build_sn_url_from_kb(kb_from_meta)
            if _is_valid_sn_kb_url(kb_url, kb_from_meta):
                return {"number": kb_from_meta, "url": kb_url, "score": score}
    except Exception:
        pass

    # 3) Optional fallback: inspect S3 HTML / key path (can be disabled)
    if not DISABLE_URL_FALLBACK:
        try:
            loc_uri = (
                _first(ref, ("location", "s3Location", "uri"))
                or ref.get("location", {}).get("s3Uri")
                or _get_meta_field(md, "x-amz-bedrock-kb-source-uri")
            )
            if isinstance(loc_uri, str) and loc_uri.startswith("s3://"):
                kb_from_html, url_from_html = _scan_html_for_kb(loc_uri)
                if kb_from_html and url_from_html:
                    return {"number": kb_from_html, "url": url_from_html, "score": score}

                bkt, key = _parse_s3_uri(loc_uri)
                kb_from_key = _kb_from_s3_key(key)
                if kb_from_key and KB_RE.match(kb_from_key):
                    kb_url = _build_sn_url_from_kb(kb_from_key)
                    if _is_valid_sn_kb_url(kb_url, kb_from_key):
                        return {"number": kb_from_key, "url": kb_url, "score": score}
        except Exception:
            pass

    # 4) Optional last resort: verified KBB -> KB mapping
    try:
        num_raw = _get_meta_field(md, "number")
        number, mapped_url = _rewrite_to_kb(num_raw, url)
        if number and mapped_url and _is_valid_sn_kb_url(mapped_url, number):
            return {"number": number, "url": mapped_url, "score": score}
    except Exception:
        pass

    return {"number": None, "url": None, "score": score}

def _dedupe_refs(refs: List[Dict]) -> List[Dict[str, Any]]:
    best: Dict[str, Dict[str, Any]] = {}
    for r in refs:
        n, u, s = r.get("number"), r.get("url"), r.get("score")
        if not u:
            continue
        key = (n or u)
        prev = best.get(key)
        if prev is None or (s is not None and (prev.get("score") is None or s > prev["score"])):
            best[key] = {"number": n, "url": u, "score": s}
    return sorted(best.values(), key=lambda x: (-(x["score"] or 0.0), x.get("number") or x.get("url")))

def _corsify(body: Dict, status: int = 200) -> Dict:
    return {
        "statusCode": status,
        "headers": {
            "Content-Type": "application/json",
            "Access-Control-Allow-Origin": CORS_ORIGIN,
            "Access-Control-Allow-Headers": "Content-Type, Authorization",
            "Access-Control-Allow-Methods": "POST, OPTIONS",
        },
        "body": json.dumps(body),
    }

def _parse_body(event) -> Dict[str, Any]:
    if isinstance(event, dict) and "body" not in event:
        return event
    raw = event.get("body", "")
    if not isinstance(raw, str):
        return raw if isinstance(raw, dict) else {}
    if event.get("isBase64Encoded"):
        try:
            raw = base64.b64decode(raw).decode("utf-8", "ignore")
        except Exception:
            return {}
    s = raw.strip()
    if not s:
        return {}
    try:
        return json.loads(s)
    except Exception:
        try:
            qs = parse_qs(s)
            return {k: (v[0] if isinstance(v, list) and v else v) for k, v in qs.items()}
        except Exception:
            return {}

def _try_rag(kwargs: Dict[str, Any], search_type: str):
    try:
        return client.retrieve_and_generate(**kwargs)
    except Exception as e:
        msg = str(e)
        if ("HYBRID search type is not supported" in msg) and (search_type == "HYBRID"):
            kwargs_retry = dict(kwargs)
            cfg = kwargs_retry["retrieveAndGenerateConfiguration"]["knowledgeBaseConfiguration"]
            cfg["retrievalConfiguration"]["vectorSearchConfiguration"]["overrideSearchType"] = "SEMANTIC"
            return client.retrieve_and_generate(**kwargs_retry)
        raise

def _pick_best_url_from_refs(refs: List[Dict[str, Any]]) -> Optional[str]:
    try:
        for r in refs or []:
            try:
                info = _extract_ref_info(r)
                if info.get("url"):
                    return info["url"]
            except Exception:
                pass
    except Exception:
        pass
    return None

def _inject_inline_sources(answer_text: str, citations: List[Dict[str, Any]], superscript: bool = False) -> str:
    if not answer_text or not citations:
        return answer_text

    url_index: Dict[str, int] = {}
    next_idx = 1
    insert_points = []

    for cit in citations:
        grp = cit.get("generatedResponsePart", {}).get("textResponsePart", {})
        span = grp.get("span") or {}
        start = span.get("start")
        end = span.get("end")
        refs = cit.get("retrievedReferences") or []
        url = _pick_best_url_from_refs(refs)
        if not url:
            continue

        if superscript:
            if url not in url_index:
                url_index[url] = next_idx
                next_idx += 1
            # HTML-safe superscript link
            token = f'<sup><a href="{url}" target="_blank" rel="noopener noreferrer">[{url_index[url]}]</a></sup>'
        else:
            token = f" (Source: {url})"

        if isinstance(start, int) and isinstance(end, int) and 0 <= start <= end <= len(answer_text):
            insert_points.append((end, token))
        else:
            insert_points.append((len(answer_text), token))

    if not insert_points:
        return answer_text
    insert_points.sort(key=lambda x: x[0], reverse=True)
    out = answer_text
    for pos, s in insert_points:
        out = out[:pos] + s + out[pos:]
    return out

def _collect_urls_from_response(resp: Dict[str, Any]) -> List[str]:
    raw: List[Dict[str, Any]] = []

    for cit in (resp.get("citations") or []):
        for ref in (cit.get("retrievedReferences") or []):
            try:
                raw.append(_extract_ref_info(ref))
            except Exception:
                pass

    for ref in (resp.get("retrievedReferences") or []):
        try:
            raw.append(_extract_ref_info(ref))
        except Exception:
            pass

    chunks = _dedupe_refs(raw)
    urls: List[str] = []
    for c in chunks:
        u = c.get("url")
        if u and u not in urls:
            urls.append(u)
    return urls

# Role -> KB mapping
ROLE_TO_KB_ENV: Dict[str, str] = {
    "customers":   CUSTOMERS_KB_ID,
    "reservation": RESERVATION_KB_ID,
    "onesource":   ONESOURCE_KB_ID,
}

def _resolve_kb_id(user_role: Optional[str]) -> Optional[str]:
    if not user_role or not isinstance(user_role, str):
        return None
    kb = ROLE_TO_KB_ENV.get(user_role.strip().lower())
    return kb if kb else None

def _scrub_raw(resp: Any) -> Any:
    """
    Remove high-risk fields from raw debug output (e.g., S3 URIs).
    Keep minimal structure for troubleshooting.
    """
    try:
        if not isinstance(resp, dict):
            return {"type": str(type(resp))}
        safe = {}
        for k in ("output", "citations", "retrievedReferences"):
            if k in resp:
                safe[k] = resp[k]
        # Remove S3 uris if present
        def _walk(o, depth=0):
            if depth > 4:
                return None
            if isinstance(o, dict):
                out = {}
                for kk, vv in o.items():
                    if isinstance(vv, str) and vv.startswith("s3://"):
                        out[kk] = "[redacted]"
                    else:
                        out[kk] = _walk(vv, depth + 1) if isinstance(vv, (dict, list)) else vv
                return out
            if isinstance(o, list):
                return [_walk(x, depth + 1) if isinstance(x, (dict, list)) else x for x in o[:50]]
            return o
        return _walk(safe, 0)
    except Exception:
        return {"error": "scrub_failed"}

# ========= main =========
def lambda_handler(event, context):
    if isinstance(event, dict) and event.get("httpMethod") == "OPTIONS":
        return _corsify({"ok": True})

    cfg_err = _validate_runtime_config()
    if cfg_err:
        return _corsify(cfg_err, 500)

    try:
        body = _parse_body(event)
        query = (body.get("query") or body.get("inputText") or "").strip()
        if not query:
            return _corsify({"error": "Missing 'query' or 'inputText'."}, 400)

        user_role_raw = body.get("userRole") or body.get("role") or body.get("user_role")
        user_role = user_role_raw.strip().lower() if isinstance(user_role_raw, str) else None
        kb_id = _resolve_kb_id(user_role)

        session_id   = (body.get("sessionId") or "").strip()
        # Namespace session by role to avoid cross-role mixing
        if session_id and user_role:
            session_id = f"{user_role}:{session_id}"

        model_arn    = (body.get("modelArn") or MODEL_ARN).strip()
        num_results  = _clamp_int(body.get("numberOfResults") or NUM_RESULTS_DEFAULT, 1, 35, NUM_RESULTS_DEFAULT)

        search_type  = (body.get("overrideSearchType") or SEARCH_TYPE_DEFAULT or "SEMANTIC").strip().upper()
        if search_type in ("VECTOR", "KEYWORD", "FULLTEXT"):
            search_type = "SEMANTIC"
        if search_type not in ("SEMANTIC", "HYBRID"):
            search_type = "SEMANTIC"

        brand        = (body.get("brandName") or BRAND).strip()
        use_case     = (body.get("useCaseTitle") or USE_CASE).strip()
        prompt_tmpl  = body.get("promptTemplate") or PROMPT_TMPL

        temperature = _clamp_float(body.get("temperature", DEFAULT_TEMPERATURE), 0.0, 1.0, DEFAULT_TEMPERATURE)
        top_p       = _clamp_float(body.get("topP", DEFAULT_TOPP), 0.0, 1.0, DEFAULT_TOPP)
        max_tokens  = _clamp_int(body.get("maxTokens", DEFAULT_MAX_TOKENS), 50, 2048, DEFAULT_MAX_TOKENS)

        if not kb_id:
            return _corsify(
                {
                    "error": "Invalid or missing userRole. Must be one of: customers, reservation, onesource.",
                    "details": {
                        "receivedUserRole": user_role_raw,
                        "allowedRoles": list(ROLE_TO_KB_ENV.keys()),
                    },
                },
                400,
            )

        if not model_arn:
            return _corsify({"error": "MODEL_ARN must be set."}, 500)

        prompt = prompt_tmpl.format(
            brand=brand,
            use_case=use_case,
            query=query,
            sn_instance=SN_INSTANCE,
        )

        rag_cfg = {
            "type": "KNOWLEDGE_BASE",
            "knowledgeBaseConfiguration": {
                "knowledgeBaseId": kb_id,
                "modelArn": model_arn,
                "generationConfiguration": {
                    "inferenceConfig": {
                        "textInferenceConfig": {
                            "maxTokens": max_tokens,
                            "temperature": temperature,
                            "topP": top_p,
                        }
                    }
                },
                "retrievalConfiguration": {
                    "vectorSearchConfiguration": {
                        "numberOfResults": num_results,
                        "overrideSearchType": search_type,
                    }
                },
            },
        }

        kwargs: Dict[str, Any] = {
            "input": {"text": prompt},
            "retrieveAndGenerateConfiguration": rag_cfg,
        }
        if session_id:
            kwargs["sessionConfiguration"] = {"sessionId": session_id}

        resp = _try_rag(kwargs, search_type)

        answer_text = _first(resp, ("output", "text"), default="").strip()
        citations_list = resp.get("citations") or []

        # Optional inline injection (still safe; only provable SN URLs are used)
        if citations_list:
            try:
                if INLINE_SUPERSCRIPT:
                    answer_text = _inject_inline_sources(answer_text, citations_list, superscript=True)
                elif INLINE_URLS:
                    answer_text = _inject_inline_sources(answer_text, citations_list, superscript=False)
            except Exception:
                pass

        urls = _collect_urls_from_response(resp)

        raw_cites: List[Dict[str, Any]] = []
        for cit in citations_list:
            for ref in (cit.get("retrievedReferences") or []):
                try:
                    raw_cites.append(_extract_ref_info(ref))
                except Exception:
                    pass
        for ref in (resp.get("retrievedReferences") or []):
            try:
                raw_cites.append(_extract_ref_info(ref))
            except Exception:
                pass
        chunks = _dedupe_refs(raw_cites)

        final_answer = answer_text if answer_text else "No exact answer found in the provided results."

        if urls:
            # Always include Sources section (simple + predictable for clients)
            final_answer += "\n\n## Sources\n" + "\n".join(urls)
        else:
            final_answer += "\n\n*Note: No provable source URLs were available from retrieved references.*"

        out: Dict[str, Any] = {
            "answer": final_answer,
            "chunks": chunks,  # always provide structured refs
        }

        if DEBUG_RAW:
            out["raw"] = _scrub_raw(resp)
            out["debug"] = {
                "has_citations": bool(citations_list),
                "num_urls": len(urls),
                "inline_superscript": INLINE_SUPERSCRIPT,
                "disable_url_fallback": DISABLE_URL_FALLBACK,
            }

        return _corsify(out)

    except Exception as e:
        # Keep errors generic; don't accidentally leak internals in public template
        return _corsify({"error": f"{type(e).__name__}: {str(e)}"}, 500)
