Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 129 additions & 13 deletions openkb/agent/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
# Prompt templates
# ---------------------------------------------------------------------------

# DeepSeek/Qwen require the prompt itself to mention "json" when this kwarg
# is set; the templates below already do.
_JSON_RESPONSE_FORMAT = {"type": "json_object"}

_SYSTEM_TEMPLATE = """\
You are OpenKB's wiki compilation agent for a personal knowledge base.

Expand Down Expand Up @@ -258,6 +262,7 @@ def _llm_call(model: str, messages: list[dict], step_name: str, **kwargs) -> str

response = litellm.completion(model=model, messages=messages, **kwargs)
content = response.choices[0].message.content or ""
_warn_if_truncated(response, step_name, kwargs.get("max_tokens"))

spinner.stop(_format_usage(time.time() - t0, response.usage))
logger.debug("LLM response [%s]:\n%s", step_name, content[:500] + ("..." if len(content) > 500 else ""))
Expand All @@ -274,6 +279,7 @@ async def _llm_call_async(model: str, messages: list[dict], step_name: str, **kw

response = await litellm.acompletion(model=model, messages=messages, **kwargs)
content = response.choices[0].message.content or ""
_warn_if_truncated(response, step_name, kwargs.get("max_tokens"))

elapsed = time.time() - t0
sys.stdout.write(f" {step_name}... {_format_usage(elapsed, response.usage)}\n")
Expand All @@ -282,6 +288,25 @@ async def _llm_call_async(model: str, messages: list[dict], step_name: str, **kw
return content.strip()


def _warn_if_truncated(response, step_name: str, max_tokens: int | None) -> None:
"""Emit a warning when the LLM hit the max_tokens cap.

``json_repair`` will silently salvage the truncated prefix, so without
this the caller can't tell a short response from a cut-off one.
"""
try:
finish_reason = response.choices[0].finish_reason
except (AttributeError, IndexError):
return
if finish_reason != "length":
return
cap = f" (max_tokens={max_tokens})" if max_tokens else ""
logger.warning("LLM [%s] hit length limit%s — output may be truncated.",
step_name, cap)
sys.stdout.write(f" [WARN] {step_name} hit length limit{cap} — output may be truncated.\n")
sys.stdout.flush()


def _parse_json(text: str) -> list | dict:
"""Parse JSON from LLM response, handling fences, prose, and malformed JSON."""
from json_repair import repair_json
Expand All @@ -297,6 +322,49 @@ def _parse_json(text: str) -> list | dict:
return result


def _filter_concept_items(items: list, label: str) -> list[dict]:
"""Keep only dicts that carry a non-empty ``name``; warn about anything else."""
if not isinstance(items, list):
logger.warning("concepts plan: %s was %s, expected list — dropping",
label, type(items).__name__)
return []
valid = [c for c in items if isinstance(c, dict) and isinstance(c.get("name"), str) and c["name"].strip()]
if len(valid) < len(items):
reasons: list[str] = []
for c in items:
if not isinstance(c, dict):
reasons.append(type(c).__name__)
elif not isinstance(c.get("name"), str) or not c["name"].strip():
reasons.append("dict-missing-name")
logger.warning(
"concepts plan: dropped %d malformed %s item(s) (reasons: %s)",
len(items) - len(valid), label, ", ".join(sorted(set(reasons))),
)
return valid


def _require_nonempty_content(content, name: str) -> None:
"""Raise if a concept body is missing or whitespace-only."""
if not isinstance(content, str) or not content.strip():
raise ValueError(f"LLM returned empty content for concept {name!r}")


def _filter_related_slugs(items: list) -> list[str]:
"""Keep only non-empty string slugs; warn about anything else."""
if not isinstance(items, list):
logger.warning("concepts plan: related was %s, expected list — dropping",
type(items).__name__)
return []
valid = [s for s in items if isinstance(s, str) and s.strip()]
if len(valid) < len(items):
bad_types = sorted({type(s).__name__ for s in items if not (isinstance(s, str) and s.strip())})
logger.warning(
"concepts plan: dropped %d malformed related item(s) (types: %s)",
len(items) - len(valid), ", ".join(bad_types),
)
return valid


# ---------------------------------------------------------------------------
# File I/O helpers
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -911,7 +979,7 @@ async def _compile_concepts(
{"role": "user", "content": _CONCEPTS_PLAN_USER.format(
concept_briefs=concept_briefs,
)},
], "concepts-plan", max_tokens=1024)
], "concepts-plan", max_tokens=2048, response_format=_JSON_RESPONSE_FORMAT)

def _write_v1_summary_stripped() -> None:
"""Fallback writer for the v1 summary on early-return paths.
Expand All @@ -935,27 +1003,54 @@ def _write_v1_summary_stripped() -> None:
try:
parsed = _parse_json(plan_raw)
except (json.JSONDecodeError, ValueError) as exc:
logger.warning("Failed to parse concepts plan: %s", exc)
logger.debug("Raw: %s", plan_raw)
preview = plan_raw[:500] + ("..." if len(plan_raw) > 500 else "")
logger.warning(
"Failed to parse concepts plan: %s. Raw output (first 500 chars): %r",
exc, preview,
)
logger.debug("Concepts plan raw output (full, %d chars): %s",
len(plan_raw), plan_raw)
sys.stdout.write(
f" [WARN] concepts plan unparseable for {doc_name} — "
f"no concept pages generated. See log (stderr) for details.\n"
)
sys.stdout.flush()
if rewrite_summary:
_write_v1_summary_stripped()
_update_index(wiki_dir, doc_name, [], doc_brief=doc_brief, doc_type=doc_type)
return

# Fallback: if LLM returns a flat list, treat all items as "create"
# Fallback: if LLM returns a flat list, treat all items as "create".
if isinstance(parsed, list):
plan = {"create": parsed, "update": [], "related": []}
plan = {"create": _filter_concept_items(parsed, "list"),
"update": [], "related": []}
else:
plan = {
"create": parsed.get("create", []),
"update": parsed.get("update", []),
"related": parsed.get("related", []),
"create": _filter_concept_items(parsed.get("create", []), "create"),
"update": _filter_concept_items(parsed.get("update", []), "update"),
"related": _filter_related_slugs(parsed.get("related", [])),
}

create_items = plan["create"]
update_items = plan["update"]
related_items = plan["related"]

# Distinguish "filters dropped everything" from "LLM emitted an empty plan".
if isinstance(parsed, list):
original_total = len(parsed)
else:
original_total = sum(
len(parsed.get(k, [])) if isinstance(parsed.get(k), list) else 0
for k in ("create", "update", "related")
)
post_filter_total = len(create_items) + len(update_items) + len(related_items)
if original_total > 0 and post_filter_total == 0:
sys.stdout.write(
f" [WARN] concepts plan for {doc_name} had {original_total} "
f"item(s), all dropped as malformed — see log (stderr).\n"
)
sys.stdout.flush()

if not create_items and not update_items and not related_items:
if rewrite_summary:
_write_v1_summary_stripped()
Expand Down Expand Up @@ -1010,13 +1105,16 @@ async def _gen_create(concept: dict) -> tuple[str, str, bool, str]:
title=title, doc_name=doc_name,
update_instruction="",
)},
], f"concept: {name}")
], f"concept: {name}", response_format=_JSON_RESPONSE_FORMAT)
try:
parsed = _parse_json(raw)
brief = parsed.get("brief", "")
content = parsed.get("content", raw)
# ``or raw``: ``.get("content", raw)`` returns None for
# ``{"content": null}`` (legal under json_object mode).
content = parsed.get("content") or raw
except (json.JSONDecodeError, ValueError):
brief, content = "", raw
_require_nonempty_content(content, name)
return name, content, False, brief

async def _gen_update(concept: dict) -> tuple[str, str, bool, str]:
Expand All @@ -1042,13 +1140,14 @@ async def _gen_update(concept: dict) -> tuple[str, str, bool, str]:
title=title, doc_name=doc_name,
existing_content=existing_content,
)},
], f"update: {name}")
], f"update: {name}", response_format=_JSON_RESPONSE_FORMAT)
try:
parsed = _parse_json(raw)
brief = parsed.get("brief", "")
content = parsed.get("content", raw)
content = parsed.get("content") or raw
except (json.JSONDecodeError, ValueError):
brief, content = "", raw
_require_nonempty_content(content, name)
return name, content, True, brief

tasks = []
Expand All @@ -1066,9 +1165,11 @@ async def _gen_update(concept: dict) -> tuple[str, str, bool, str]:

results = await asyncio.gather(*tasks, return_exceptions=True)

failure_types: list[str] = []
for r in results:
if isinstance(r, Exception):
logger.warning("Concept generation failed: %s", r)
failure_types.append(type(r).__name__)
continue
name, page_content, is_update, brief = r
pending_writes.append((name, page_content, is_update, brief))
Expand All @@ -1077,6 +1178,20 @@ async def _gen_update(concept: dict) -> tuple[str, str, bool, str]:
if brief:
concept_briefs_map[safe_name] = brief

# Include exception type names inline so the stdout line is
# self-contained — per-failure WARNINGs go to stderr.
written = len(pending_writes)
if written < total:
reason = (
", ".join(sorted(set(failure_types)))
if failure_types else "see log (stderr)"
)
sys.stdout.write(
f" [WARN] {total} concept(s) planned but only {written} written "
f"for {doc_name} ({reason}).\n"
)
sys.stdout.flush()

# Strip unresolved wikilinks from concept bodies before writing. The
# whitelist includes existing files + this round's planned slugs +
# the summary for this document.
Expand Down Expand Up @@ -1212,7 +1327,8 @@ async def compile_short_doc(
# for the plan + concept-generation calls, then rewritten into a final
# v2 (with a whitelist of known wikilink targets) inside
# _compile_concepts before being written to disk.
summary_raw = _llm_call(model, [system_msg, doc_msg], "summary")
summary_raw = _llm_call(model, [system_msg, doc_msg], "summary",
response_format=_JSON_RESPONSE_FORMAT)
try:
summary_parsed = _parse_json(summary_raw)
doc_brief = summary_parsed.get("brief", "")
Expand Down