# 02 - Hooks and Exit Conditions

This notebook demonstrates runtime controls you can adapt from this repo:
- `session_start` hook injects context
- `pre_tool_use` can allow/deny a tool call
- `post_tool_use` can append corrective context
- `stop` hook can block completion and force another loop turn


In [None]:
import json
import random
import re
from typing import Any

import requests

OLLAMA_URL = "http://localhost:11434/api/chat"
MODEL = "llama3.1:8b-instruct"


In [None]:
world_state = {
    "location": "Moonfall Keep",
    "time": "20:00",
    "threat_level": "medium",
}

def get_world_state() -> dict[str, Any]:
    return world_state

def advance_clock(hours: int) -> dict[str, Any]:
    # simple toy clock update
    h, m = map(int, world_state["time"].split(":"))
    h = (h + hours) % 24
    world_state["time"] = f"{h:02d}:{m:02d}"
    return world_state

def roll_check(dc: int, modifier: int = 0) -> dict[str, Any]:
    roll = random.randint(1, 20)
    total = roll + modifier
    return {"roll": roll, "modifier": modifier, "total": total, "dc": dc, "success": total >= dc}

TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "get_world_state",
            "description": "Read current world state.",
            "parameters": {"type": "object", "properties": {}, "required": []},
        },
    },
    {
        "type": "function",
        "function": {
            "name": "advance_clock",
            "description": "Advance world clock by N hours.",
            "parameters": {
                "type": "object",
                "properties": {"hours": {"type": "integer", "minimum": 1, "maximum": 24}},
                "required": ["hours"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "roll_check",
            "description": "Roll 1d20 plus modifier against a DC.",
            "parameters": {
                "type": "object",
                "properties": {
                    "dc": {"type": "integer"},
                    "modifier": {"type": "integer", "default": 0}
                },
                "required": ["dc"],
            },
        },
    },
]

TOOL_IMPL = {
    "get_world_state": get_world_state,
    "advance_clock": advance_clock,
    "roll_check": roll_check,
}


In [None]:
class DMHooks:
    def session_start(self, state: dict[str, Any]) -> str | None:
        return (
            "Hook context: keep continuity with prior events and keep danger proportional to threat_level. "
            f"Current threat_level is {state.get('threat_level')}."
        )

    def pre_tool_use(self, tool_name: str, args: dict[str, Any], state: dict[str, Any]) -> dict[str, Any]:
        # Example hard guard: disallow huge time skips in one move
        if tool_name == "advance_clock" and int(args.get("hours", 0)) > 8:
            return {"allow": False, "reason": "Do not skip more than 8 hours in a single turn."}
        return {"allow": True}

    def post_tool_use(self, tool_name: str, tool_result: dict[str, Any], state: dict[str, Any]) -> str | None:
        if not tool_result.get("ok", True):
            return "A tool failed. Recover by explaining the constraint and offering alternatives."
        return None

    def stop(self, assistant_text: str, state: dict[str, Any], stop_hook_active: bool) -> str | None:
        # Require clear user choices before ending turn
        text = (assistant_text or "").lower()
        has_choices = ("what do you do" in text) or ("choice" in text) or ("options" in text)
        if not has_choices:
            return "Before completing, provide 2-3 concrete player choices."

        # Optional second-pass stricter rule
        if stop_hook_active and len(assistant_text.strip()) < 80:
            return "Your response is too brief. Provide richer scene detail before ending."

        return None


hooks = DMHooks()


In [None]:
def ollama_chat(messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str = MODEL) -> dict[str, Any]:
    payload: dict[str, Any] = {
        "model": model,
        "messages": messages,
        "stream": False,
    }
    if tools:
        payload["tools"] = tools

    r = requests.post(OLLAMA_URL, json=payload, timeout=120)
    r.raise_for_status()
    return r.json()


def extract_tool_calls(message: dict[str, Any]) -> list[dict[str, Any]]:
    calls: list[dict[str, Any]] = []
    for i, call in enumerate(message.get("tool_calls") or []):
        fn = call.get("function", {})
        args = fn.get("arguments", {})
        if isinstance(args, str):
            try:
                args = json.loads(args)
            except json.JSONDecodeError:
                args = {}
        calls.append({
            "id": call.get("id") or f"call_{i}",
            "name": fn.get("name") or call.get("name"),
            "arguments": args,
        })
    return calls


def execute_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]:
    fn = TOOL_IMPL.get(name)
    if not fn:
        return {"ok": False, "error": f"Unknown tool: {name}"}
    try:
        return {"ok": True, "result": fn(**arguments)}
    except Exception as e:
        return {"ok": False, "error": str(e)}


In [None]:
SYSTEM_PROMPT = (
    "You are a tactical and cinematic DnD DM. Use tools for world checks and world updates. "
    "Do not fabricate tool results."
)


def run_loop_with_hooks(user_prompt: str, max_iterations: int = 12) -> dict[str, Any]:
    messages: list[dict[str, Any]] = [
        {"role": "system", "content": SYSTEM_PROMPT},
    ]

    start_ctx = hooks.session_start(world_state)
    if start_ctx:
        messages.append({"role": "user", "content": start_ctx})

    messages.append({"role": "user", "content": user_prompt})

    rounds: list[dict[str, Any]] = []
    stop_hook_active = False

    for i in range(max_iterations):
        data = ollama_chat(messages, TOOLS)
        message = data["message"]
        assistant_text = message.get("content", "")
        tool_calls = extract_tool_calls(message)

        rounds.append({
            "iteration": i + 1,
            "assistant_text": assistant_text,
            "tool_calls": tool_calls,
            "stop_hook_active": stop_hook_active,
        })

        if not tool_calls:
            reason = hooks.stop(assistant_text, world_state, stop_hook_active)
            if reason:
                stop_hook_active = True
                messages.append({
                    "role": "user",
                    "content": f"You were about to finish, but a stop hook blocked completion: {reason}",
                })
                continue

            return {
                "status": "completed",
                "final_answer": assistant_text,
                "rounds": rounds,
                "messages": messages,
            }

        messages.append({
            "role": "assistant",
            "content": assistant_text,
            "tool_calls": message.get("tool_calls", []),
        })

        for call in tool_calls:
            pre = hooks.pre_tool_use(call["name"], call["arguments"], world_state)
            if not pre.get("allow", True):
                tool_result = {"ok": False, "error": pre.get("reason", "Blocked by pre_tool_use")}
            else:
                tool_result = execute_tool(call["name"], call["arguments"])

            messages.append({
                "role": "tool",
                "tool_name": call["name"],
                "content": json.dumps(tool_result),
            })

            post_ctx = hooks.post_tool_use(call["name"], tool_result, world_state)
            if post_ctx:
                messages.append({
                    "role": "user",
                    "content": f"Hook note: {post_ctx}",
                })

    return {
        "status": "max_iterations",
        "final_answer": "Stopped due to max iterations.",
        "rounds": rounds,
        "messages": messages,
    }


In [None]:
result = run_loop_with_hooks("We sneak into Moonfall Keep at night and scout the west tower.")

print("status:", result["status"])
print("rounds:", len(result["rounds"]))
print("\nFINAL ANSWER\n")
print(result["final_answer"])

print("\nTRACE\n")
for r in result["rounds"]:
    print(f"iter={r['iteration']} stop_hook_active={r['stop_hook_active']} tools={[c['name'] for c in r['tool_calls']]}")
