# Case Study: CoT vs DDXPlusCausalBuilder

使用数据集 `DDXPlus_CausalQA_multistep_meta.jsonl`，对比两种方法在同一批样本上的表现：

- **CoT baseline**：单次调用 LLM，输出 A/B/C。
- **DDXPlusCausalBuilder**：多步构建因果图（BFS + bridge/seed）→ 抽取路径 → 让 LLM 根据图和路径做最终判断。

说明：本 notebook 默认只跑少量样本做 case study（避免耗时太长）。你可以在配置区修改 `CASE_INDICES` 或 `NUM_CASES`。


In [None]:
from __future__ import annotations

import contextlib
import io
import json
import os
import re
from pathlib import Path
from typing import Any, Dict, List, Optional

import pandas as pd

from DDXPlusCausalBuilder import DDXPlusCausalBuilder

DATA_PATH = Path("DDXPlus_CausalQA_multistep_meta.jsonl")
MODEL_NAME = os.environ.get("WIQA_MODEL_NAME", os.environ.get("OLLAMA_MODEL", "llama3.1:8b"))

# 这些参数会显著影响 DDXPlusCausalBuilder 的速度/效果；case study 建议先用偏小配置。
BUILDER_PARAMS = {
    "bfs_max_depth": 4,
    "bfs_max_relations_per_node": 5,
    "bfs_max_nodes": 50,
    "bfs_beam_width": 8,
    "bridge_max_bridge_nodes": 3,
    "seed_max_parents": 6,
    "chain_max_path_length": 5,
}

# case 选择：如果 CASE_INDICES 为空，则按规则自动挑 NUM_CASES 个样本。
NUM_CASES = 5
CASE_INDICES: List[int] = []  # e.g. [0, 1, 2, 10, 42]

VERBOSE_BUILDER = False  # True 会把 Builder 的 debug print 直接输出到 notebook


In [None]:
def load_jsonl(path: Path) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows


if not DATA_PATH.exists():
    raise FileNotFoundError(f"Dataset not found: {DATA_PATH.resolve()}")

rows = load_jsonl(DATA_PATH)
print("Loaded rows:", len(rows))
rows[0]


In [None]:
# Ollama 预检查（可选）
try:
    import ollama

    info = ollama.list()
    print("Ollama OK")
    # 不同版本的 ollama-python 返回结构可能略有差异，这里直接打印一小段。
    print(str(info)[:500])
except Exception as e:
    print("Ollama not reachable (skip if you only want to read the notebook):", e)


In [None]:
import ollama

ANSWER_LABEL_TO_CHOICE = {"more": "A", "less": "B", "no_effect": "C", "no effect": "C", "no_change": "C"}
CHOICE_TO_ANSWER_LABEL = {"A": "more", "B": "less", "C": "no_effect"}


def normalize_label(label: str) -> str:
    return (label or "").strip().lower().replace(" ", "_")


def extract_choice_from_text(text: str) -> Optional[str]:
    if not text:
        return None

    m = re.search(r"(?:final answer|answer)\s*[:\-]?\s*([ABC])\b", text, flags=re.IGNORECASE)
    if m:
        return m.group(1).upper()

    last_line = text.strip().splitlines()[-1].strip()
    if re.fullmatch(r"[ABC]", last_line, flags=re.IGNORECASE):
        return last_line.upper()

    m2 = re.search(
        r"(?:final answer|answer)\s*[:\-]?\s*(more|less|no[_ ]effect|no[_ ]change|no effect)\b",
        text,
        flags=re.IGNORECASE,
    )
    if m2:
        return ANSWER_LABEL_TO_CHOICE.get(normalize_label(m2.group(1)))

    return None


def build_ddxplus_prompt(question_stem: str) -> str:
    return (
        f"Question: {question_stem}\n"
        "Choice A: more\n"
        "Choice B: less\n"
        "Choice C: no effect\n"
    )


def _ollama_chat(model: str, prompt: str, *, temperature: float, num_predict: int, seed: int) -> str:
    resp = ollama.chat(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        options={"temperature": float(temperature), "num_predict": int(num_predict), "seed": int(seed)},
    )
    return str((resp.get("message") or {}).get("content", ""))


def _force_extract_choice(model: str, question_prompt: str, reasoning_text: str, *, seed: int) -> str:
    extractor_prompt = (
        "You are an answer extractor.\n"
        "Given the question and a model's reasoning, output ONLY one letter: A, B, or C.\n\n"
        f"{question_prompt}\n"
        "Reasoning:\n"
        f"{reasoning_text}\n\n"
        "Output:"
    )
    out = _ollama_chat(model, extractor_prompt, temperature=0.0, num_predict=8, seed=seed)
    return extract_choice_from_text(out) or ""


def run_cot_case(datapoint: Dict[str, Any], *, model_name: str, seed: int) -> Dict[str, Any]:
    question_prompt = build_ddxplus_prompt(str(datapoint.get("question_stem", "")))
    prompt = (
        "[CoT]\n"
        "Guidance: Use chain-of-thought with a minimal causal graph.\n"
        "1) Construct a minimal causal graph.\n"
        "2) Reason briefly how changing the cause affects the outcome.\n"
        "3) Choose the best option.\n\n"
        f"{question_prompt}\n\n"
        "Output format:\n"
        "Causal graph: <comma-separated edges>\n"
        "Reasoning: <1-4 sentences>\n"
        "Final answer: <A|B|C>\n"
    )
    out = _ollama_chat(model_name, prompt, temperature=0.0, num_predict=512, seed=seed)
    choice = extract_choice_from_text(out)
    if choice is None:
        choice = _force_extract_choice(model_name, question_prompt, out, seed=seed + 10_000)
    pred_label = CHOICE_TO_ANSWER_LABEL.get(choice or "", "")
    gold_label = normalize_label(str(datapoint.get("answer_label", "")))
    pred_norm = normalize_label(pred_label)
    is_correct = bool(pred_norm) and pred_norm == gold_label
    return {
        "method": "CoT",
        "choice": choice or "",
        "pred_label": pred_label,
        "is_correct": bool(is_correct),
        "raw_output": out,
    }


def run_builder_case(datapoint: Dict[str, Any], *, model_name: str, params: Dict[str, Any], verbose: bool) -> Dict[str, Any]:
    builder = DDXPlusCausalBuilder(datapoint, model_name=model_name)

    buf = io.StringIO()
    ctx = contextlib.nullcontext() if verbose else contextlib.redirect_stdout(buf)
    try:
        with ctx:
            info = builder.extract_start_entity()
            start = str(info.get("cause_event") or datapoint.get("cause_event") or "").strip()
            target = str(info.get("outcome_base") or datapoint.get("outcome_base") or "").strip()

            bfs = builder.expand_toward_target(
                start_X=start,
                target_Y=target,
                max_depth=int(params.get("bfs_max_depth", 4)),
                max_relations_per_node=int(params.get("bfs_max_relations_per_node", 5)),
                max_nodes=int(params.get("bfs_max_nodes", 50)),
                beam_width=int(params.get("bfs_beam_width", 8)),
            )
            try:
                builder.last_node_rel_to_target = bfs.get("node_rel_to_target", {}) if isinstance(bfs, dict) else {}
            except Exception:
                builder.last_node_rel_to_target = {}

            triples = list((bfs or {}).get("triples", []) or [])
            close_hits = list((bfs or {}).get("close_hits", []) or [])

            # Step 3: bridge close hits (optional)
            max_bridge = int(params.get("bridge_max_bridge_nodes", 0) or 0)
            if close_hits and max_bridge > 0:
                triples = builder.bridge_close_hits(triples=triples, close_hits=close_hits, Y=target, max_bridge_nodes=max_bridge)

            # Step 3: seed target parents (optional)
            max_parents = int(params.get("seed_max_parents", 0) or 0)
            if max_parents > 0:
                seed_edges = builder.find_target_parents(target, max_parents=max_parents)
                existing = set()
                for e in triples:
                    if isinstance(e, dict):
                        existing.add((e.get("head", ""), e.get("relation", ""), e.get("tail", "")))
                    elif isinstance(e, (list, tuple)) and len(e) >= 3:
                        existing.add((e[0], e[1], e[2]))
                for e in seed_edges or []:
                    key = (e.get("head", ""), e.get("relation", ""), e.get("tail", ""))
                    if key in existing:
                        continue
                    triples.append(e)
                    existing.add(key)

            # Step 3: extract causal chain
            chain_result = builder.get_causal_chain(
                triples,
                start_X=start,
                target_Y=target,
                max_path_length=int(params.get("chain_max_path_length", 5)),
            )

            # Fallback: if no path, try using exact_target close-hit node as alternative target (same as pipeline).
            if (chain_result.get("num_paths", 0) or 0) == 0 and close_hits:
                best_alt = None
                for hit in close_hits:
                    bfs_eq = str(hit.get("bfs_equivalence") or "").strip().lower()
                    if bfs_eq != "exact_target":
                        continue
                    alt_target = str(hit.get("node") or "").strip()
                    if not alt_target:
                        continue
                    alt_chain = builder.get_causal_chain(
                        triples,
                        start_X=start,
                        target_Y=alt_target,
                        max_path_length=int(params.get("chain_max_path_length", 5)),
                    )
                    if (alt_chain.get("num_paths", 0) or 0) <= 0:
                        continue
                    alt_chain["mapped_target"] = alt_target
                    alt_chain["original_target"] = target
                    if best_alt is None:
                        best_alt = alt_chain
                        continue
                    cur_len = best_alt.get("shortest_path_length")
                    new_len = alt_chain.get("shortest_path_length")
                    if cur_len is None or (new_len is not None and new_len < cur_len):
                        best_alt = alt_chain
                if best_alt is not None:
                    chain_result = best_alt

            description = builder.causal_chain_to_text(chain_result, bfs)
            reasoning = builder.reason_with_description(description, chain_result=chain_result)

        pred_label = str(reasoning.get("predicted_answer", "") or "")
        pred_choice = str(reasoning.get("predicted_choice", "") or "")
        gold_label = normalize_label(str(datapoint.get("answer_label", "")))
        pred_norm = normalize_label(pred_label)
        is_correct = bool(pred_norm) and pred_norm == gold_label

        return {
            "method": "DDXPlusCausalBuilder",
            "choice": pred_choice,
            "pred_label": pred_label,
            "is_correct": bool(is_correct),
            "description": description,
            "reasoning": str(reasoning.get("reasoning", "") or ""),
            "num_triples": int(len((bfs or {}).get("triples", []) or [])),
            "num_visited": int(len((bfs or {}).get("visited", []) or [])),
            "num_paths": int(chain_result.get("num_paths", 0) or 0),
            "debug_log": buf.getvalue(),
        }
    except Exception as e:
        return {
            "method": "DDXPlusCausalBuilder",
            "choice": "",
            "pred_label": "",
            "is_correct": False,
            "description": "",
            "reasoning": "",
            "num_triples": 0,
            "num_visited": 0,
            "num_paths": 0,
            "debug_log": buf.getvalue() + f"\nERROR: {e}",
            "error": str(e),
        }


In [None]:
def auto_pick_cases(rows: List[Dict[str, Any]], k: int) -> List[int]:
    # 优先挑一些更“有难度/有反直觉”的 case：包含 not / less probability 等关键词。
    idxs: List[int] = []
    for i, r in enumerate(rows):
        q = str(r.get("question_stem", "")).lower()
        if " not " in q or "less probability" in q:
            idxs.append(i)
    if len(idxs) < k:
        idxs = list(range(min(k, len(rows))))
    return idxs[:k]


case_indices = CASE_INDICES or auto_pick_cases(rows, NUM_CASES)
case_indices


In [None]:
runs: List[Dict[str, Any]] = []

for idx in case_indices:
    dp = rows[idx]
    gold_label = normalize_label(str(dp.get("answer_label", "")))
    gold_choice = str(dp.get("answer_label_as_choice", ""))

    cot = run_cot_case(dp, model_name=MODEL_NAME, seed=42 + idx)
    builder = run_builder_case(dp, model_name=MODEL_NAME, params=BUILDER_PARAMS, verbose=VERBOSE_BUILDER)

    runs.append(
        {
            "idx": idx,
            "question": dp.get("question_stem", ""),
            "gold_label": gold_label,
            "gold_choice": gold_choice,
            "cot_choice": cot.get("choice", ""),
            "cot_pred": normalize_label(str(cot.get("pred_label", ""))),
            "cot_correct": bool(cot.get("is_correct", False)),
            "cot_raw_output": cot.get("raw_output", ""),
            "builder_choice": builder.get("choice", ""),
            "builder_pred": normalize_label(str(builder.get("pred_label", ""))),
            "builder_correct": bool(builder.get("is_correct", False)),
            "builder_num_triples": int(builder.get("num_triples", 0)),
            "builder_num_paths": int(builder.get("num_paths", 0)),
            "builder_description": builder.get("description", ""),
            "builder_reasoning": builder.get("reasoning", ""),
            "builder_debug_log": builder.get("debug_log", ""),
            "builder_error": builder.get("error", ""),
        }
    )

df = pd.DataFrame(runs)
summary_cols = [
    "idx",
    "gold_label",
    "gold_choice",
    "cot_choice",
    "cot_pred",
    "cot_correct",
    "builder_choice",
    "builder_pred",
    "builder_correct",
    "builder_num_triples",
    "builder_num_paths",
]
df[summary_cols]


In [None]:
def _acc(flag_series: pd.Series) -> float:
    if flag_series.empty:
        return 0.0
    return float(flag_series.mean())


print("CoT accuracy (case subset):", _acc(df["cot_correct"]))
print("DDXPlusCausalBuilder accuracy (case subset):", _acc(df["builder_correct"]))


## 逐条对比（详细输出）

下面会把每条样本的：题目、gold、CoT 原始输出、Builder 的 description + reasoning 都打印出来，方便做定性分析。


In [None]:
for r in runs:
    idx = int(r.get("idx", -1))
    print("=" * 100)
    print(f"IDX={idx} | gold={r.get('gold_label', '')} ({r.get('gold_choice', '')})")
    print("Question:")
    print(r.get("question", ""))

    print("\n[CoT] pred=", r.get("cot_pred", ""), "choice=", r.get("cot_choice", ""), "correct=", r.get("cot_correct", False))
    print("-" * 100)
    print(r.get("cot_raw_output", ""))

    print("\n[DDXPlusCausalBuilder] pred=", r.get("builder_pred", ""), "choice=", r.get("builder_choice", ""), "correct=", r.get("builder_correct", False))
    print(f"triples={r.get('builder_num_triples', 0)} paths={r.get('builder_num_paths', 0)}")
    print("-" * 100)
    print("[Description]")
    print(r.get("builder_description", ""))
    print("\n[Reasoning]")
    print(r.get("builder_reasoning", ""))

    if (not VERBOSE_BUILDER) and r.get("builder_debug_log"):
        print("\n[Builder debug log] (captured)")
        print(str(r.get("builder_debug_log"))[:2000])

    if r.get("builder_error"):
        print("\n[Builder error]")
        print(r.get("builder_error"))
