# Choice Reasoning - LLM驱动的多选项推理实验

本notebook实现了一个完全由LLM驱动的推理流程：
1. 对每个选项（more/less/no_effect）分别调用LLM进行独立思考
2. LLM为每个选项提供支持理由和置信度
3. 结合build_causal_chain、反思机制和meta分析
4. 最终由LLM综合所有选项的分析做出决策

In [10]:
import os, json, random
import pandas as pd
from datetime import datetime
from typing import Dict, Any, List

import importlib
import ollama

import question_parser, ego_expansion_builder
importlib.reload(question_parser)
importlib.reload(ego_expansion_builder)
from question_parser import QuestionParser
from ego_expansion_builder import EgoExpansionCausalBuilder

print('✓ Modules loaded successfully')

✓ Modules loaded successfully


## 配置参数

In [11]:
# ========== 用户可配置参数 ==========

# 数据量选择
NUM_SAMPLES = 5  # 可以修改这个值来选择测试的数据量

# 是否打印详细信息
PRINT_DETAILS = True  # True: 打印所有推理细节, False: 只打印最终结果

# 置信度阈值 - 只使用高于此值的因果关系
CONFIDENCE_THRESHOLD = 0.3  # 可以调整（0.0-1.0），越高越严格

# 模型配置
MODEL = 'gemma2:27b'
MAX_EXPANSION_DEPTH = 2
MAX_NEIGHBORS_PER_SEED = 5
MAX_RELATIONS_PER_ENTITY = 5

print(f'Configuration:')
print(f'  - Samples: {NUM_SAMPLES}')
print(f'  - Print Details: {PRINT_DETAILS}')
print(f'  - Confidence Threshold: {CONFIDENCE_THRESHOLD}')
print(f'  - Model: {MODEL}')

Configuration:
  - Samples: 5
  - Print Details: True
  - Confidence Threshold: 0.3
  - Model: gemma2:27b


## 初始化组件

In [12]:
# 初始化Parser和Builder（根据PRINT_DETAILS决定是否显示详细日志）
PARSER = QuestionParser(model_name=MODEL, verbose=PRINT_DETAILS)
BUILDER = EgoExpansionCausalBuilder(
    model_name=MODEL,
    max_neighbors_per_seed=MAX_NEIGHBORS_PER_SEED,
    max_expansion_depth=MAX_EXPANSION_DEPTH,
    max_relations_per_entity=MAX_RELATIONS_PER_ENTITY,
    verbose=PRINT_DETAILS,
)

print('✓ Parser and Builder initialized')

✓ Parser and Builder initialized


## 加载数据

In [13]:
def load_wiqa_local(path='wiqa_train_data.json', limit=10, seed=42):
    """加载本地WIQA数据集"""
    items = []
    with open(path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            try:
                obj = json.loads(line)
            except Exception:
                continue
            q = obj.get('question_stem') or obj.get('question') or ''
            lbl = (obj.get('answer_label') or obj.get('label') or '').strip().lower()
            if q and lbl:
                items.append({'question': q, 'gold': lbl})
    random.Random(seed).shuffle(items)
    return items[:limit]

SAMPLES = load_wiqa_local(limit=NUM_SAMPLES)
print(f'✓ Loaded {len(SAMPLES)} samples')

✓ Loaded 5 samples


## 核心推理函数

### 1. 识别观察对象

In [14]:
def identify_observable_outcome(
    question: str,
    question_structure: Dict[str, Any],
    model: str
) -> Dict[str, Any]:
    """
    识别问题中的观察对象/可测量结果
    
    Args:
        question: 原始问题
        question_structure: 问题结构解析结果
        model: LLM模型
    
    Returns:
        观察对象的识别结果
    """
    prompt = f"""You are a causal reasoning expert. Your task is to identify the SPECIFIC OBSERVABLE OUTCOME being asked about in this question.

QUESTION:
{question}

QUESTION STRUCTURE:
- Intervention/Change: {question_structure.get('intervention', 'N/A')}
- Target Phrase: {question_structure.get('target_phrase', 'N/A')}
- Target Direction: {question_structure.get('target_direction', 'N/A')}
- Question Type: {question_structure.get('question_type', 'N/A')}

TASK:
Clarify what SPECIFIC, MEASURABLE outcome or behavior is being asked about.

Examples:
- "How will it affect MORE vegetables" → What does "MORE vegetables" specifically mean?
  * Eating more vegetables?
  * Growing more vegetables?
  * Buying more vegetables?
  * Having more vegetables available in stores?
  
- "How will it affect LESS water" → What does "LESS water" mean?
  * Using less water?
  * Having less water available?
  * Less rainfall?
  * Less water consumption?

IMPORTANT GUIDELINES:
1. Choose the MOST PLAUSIBLE interpretation based on the intervention context
2. The outcome should be OBSERVABLE and MEASURABLE
3. Consider what makes sense given the domain and intervention
4. Be specific and concrete

Return in JSON format:
{{
  "observable_outcome": "Specific description of what is being measured/observed",
  "reasoning": "Why this interpretation makes most sense given the intervention (2-3 sentences)",
  "alternative_interpretations": ["other possible interpretation 1", "other possible interpretation 2"]
}}

Return ONLY JSON, nothing else."""

    try:
        response = ollama.generate(model=model, prompt=prompt)
        llm_response = response.get("response", "").strip()
        
        import re
        json_match = re.search(r"\{[^}]+\}", llm_response, re.DOTALL)
        if json_match:
            result = json.loads(json_match.group())
        else:
            result = {
                "observable_outcome": question_structure.get('target_phrase', 'unclear'),
                "reasoning": "Could not parse response",
                "alternative_interpretations": []
            }
        
        return {
            "observable_outcome": result.get("observable_outcome", "unclear"),
            "reasoning": result.get("reasoning", ""),
            "alternative_interpretations": result.get("alternative_interpretations", []),
            "raw_response": llm_response
        }
    
    except Exception as e:
        return {
            "observable_outcome": question_structure.get('target_phrase', 'error'),
            "reasoning": f"Error: {str(e)}",
            "alternative_interpretations": [],
            "raw_response": ""
        }

print('✓ Observable outcome identification function defined')

✓ Observable outcome identification function defined


### 2. 构建因果链路径

In [15]:
### 2. 对每个choice进行正向推理和反事实推理

### 3. 综合评估并做出最终决策

In [16]:
def forward_reasoning(
    question: str,
    choice: str,
    observable_outcome: str,
    causal_chains_text: str,
    question_structure: Dict[str, Any],
    model: str
) -> Dict[str, Any]:
    """
    正向推理：假设这个choice是正确的，寻找支持证据
    
    Args:
        question: 原始问题
        choice: 当前选项
        observable_outcome: 已识别的观察对象
        causal_chains_text: 因果链的文本描述
        question_structure: 问题结构
        model: LLM模型
    
    Returns:
        正向推理结果
    """
    choice_desc = {
        "more": "MORE (increase/greater)",
        "less": "LESS (decrease/fewer)",
        "no_effect": "NO EFFECT (no change)"
    }
    
    prompt = f"""You are a causal reasoning expert. Perform FORWARD REASONING to evaluate if "{choice}" could be the correct answer.

QUESTION:
{question}

OBSERVABLE OUTCOME (what we are measuring):
{observable_outcome}

QUESTION STRUCTURE:
- Intervention: {question_structure.get('intervention', 'N/A')}
- Target: {question_structure.get('target_phrase', 'N/A')}
- Question Type: {question_structure.get('question_type', 'N/A')}

AVAILABLE CAUSAL CHAINS:
{causal_chains_text}

EVALUATING CHOICE: {choice_desc[choice]}

TASK - FORWARD REASONING:
Assume "{choice}" IS the correct answer. Find evidence that SUPPORTS this choice.

1. Look for causal chains that connect the intervention to the observable outcome
2. Evaluate if these chains support the "{choice}" outcome
3. Consider the strength and directness of the causal path
4. Assess your confidence that "{choice}" is correct based on the evidence

CRITICAL REQUIREMENTS:
- ONLY use the causal chains provided above
- REJECT indirect/weak causal chains (more than 2 hops or involving uncertain steps)
- AVOID speculative reasoning without evidence
- For "no_effect": HIGH confidence only if there's clearly NO plausible causal path

IMPORTANT FOR META-LEVEL QUESTIONS:
- If asking about "MORE X" or "LESS X", you're reasoning about the PHENOMENON, not the entity
- Example: "LESS rabbits" means the phenomenon of "having fewer rabbits"
  * If intervention causes rabbits to decrease → "LESS rabbits" phenomenon INCREASES → answer: MORE

Return in JSON format:
{{
  "confidence": 0.0-1.0 (How confident that "{choice}" is CORRECT),
  "supporting_evidence": ["evidence 1", "evidence 2", ...],
  "causal_path": "Description of the causal path from intervention to outcome",
  "path_strength": "strong|moderate|weak|none",
  "rationale": "Explain your reasoning (3-5 sentences)"
}}

Return ONLY JSON, nothing else."""

    try:
        response = ollama.generate(model=model, prompt=prompt)
        llm_response = response.get("response", "").strip()
        
        import re
        json_match = re.search(r"\{[^}]+\}", llm_response, re.DOTALL)
        if json_match:
            result = json.loads(json_match.group())
        else:
            result = {
                "confidence": 0.0,
                "supporting_evidence": [],
                "causal_path": "Could not parse",
                "path_strength": "none",
                "rationale": llm_response
            }
        
        return {
            "confidence": result.get("confidence", 0.0),
            "supporting_evidence": result.get("supporting_evidence", []),
            "causal_path": result.get("causal_path", ""),
            "path_strength": result.get("path_strength", "unknown"),
            "rationale": result.get("rationale", ""),
            "raw_response": llm_response
        }
    
    except Exception as e:
        return {
            "confidence": 0.0,
            "supporting_evidence": [],
            "causal_path": "",
            "path_strength": "error",
            "rationale": f"Error: {str(e)}",
            "raw_response": ""
        }


def counterfactual_reasoning(
    question: str,
    choice: str,
    observable_outcome: str,
    causal_chains_text: str,
    question_structure: Dict[str, Any],
    model: str
) -> Dict[str, Any]:
    """
    反事实推理：测试如果干预没有发生，结果会如何
    
    Args:
        question: 原始问题
        choice: 当前选项
        observable_outcome: 已识别的观察对象
        causal_chains_text: 因果链的文本描述
        question_structure: 问题结构
        model: LLM模型
    
    Returns:
        反事实推理结果
    """
    choice_desc = {
        "more": "MORE (increase/greater)",
        "less": "LESS (decrease/fewer)",
        "no_effect": "NO EFFECT (no change)"
    }
    
    prompt = f"""You are a causal reasoning expert. Perform COUNTERFACTUAL REASONING to test if "{choice}" is the correct answer.

QUESTION:
{question}

OBSERVABLE OUTCOME (what we are measuring):
{observable_outcome}

QUESTION STRUCTURE:
- Intervention: {question_structure.get('intervention', 'N/A')}
- Target: {question_structure.get('target_phrase', 'N/A')}

AVAILABLE CAUSAL CHAINS:
{causal_chains_text}

EVALUATING CHOICE: {choice_desc[choice]}

TASK - COUNTERFACTUAL REASONING:
Test the causal relationship by considering the COUNTERFACTUAL scenario.

COUNTERFACTUAL TEST:
Ask: "If the intervention ({question_structure.get('intervention')}) did NOT happen, would the outcome ({observable_outcome}) still change?"

For "{choice}" to be correct:
- If choice is "more" or "less": The intervention MUST cause the outcome to change
  * Counterfactual: No intervention → outcome should NOT change (or change differently)
  * Test PASSES if: intervention is NECESSARY for the outcome change
  
- If choice is "no_effect": The intervention should NOT affect the outcome
  * Counterfactual: No intervention → outcome stays the same
  * Test PASSES if: removing intervention doesn't change anything

EVALUATION CRITERIA:
1. Is there a necessary causal link between intervention and outcome?
2. Could the outcome occur without the intervention?
3. Does removing the intervention eliminate or significantly change the outcome?

Return in JSON format:
{{
  "test_result": "pass|fail|unclear",
  "counterfactual_scenario": "Description of what happens without the intervention",
  "outcome_without_intervention": "What happens to the observable outcome",
  "confidence_adjustment": -0.3 to +0.3 (adjustment to confidence based on this test),
  "rationale": "Explain the counterfactual reasoning (3-5 sentences)"
}}

Return ONLY JSON, nothing else."""

    try:
        response = ollama.generate(model=model, prompt=prompt)
        llm_response = response.get("response", "").strip()
        
        import re
        json_match = re.search(r"\{[^}]+\}", llm_response, re.DOTALL)
        if json_match:
            result = json.loads(json_match.group())
        else:
            result = {
                "test_result": "unclear",
                "counterfactual_scenario": "Could not parse",
                "outcome_without_intervention": "unclear",
                "confidence_adjustment": 0.0,
                "rationale": llm_response
            }
        
        return {
            "test_result": result.get("test_result", "unclear"),
            "counterfactual_scenario": result.get("counterfactual_scenario", ""),
            "outcome_without_intervention": result.get("outcome_without_intervention", ""),
            "confidence_adjustment": result.get("confidence_adjustment", 0.0),
            "rationale": result.get("rationale", ""),
            "raw_response": llm_response
        }
    
    except Exception as e:
        return {
            "test_result": "error",
            "counterfactual_scenario": "",
            "outcome_without_intervention": "",
            "confidence_adjustment": 0.0,
            "rationale": f"Error: {str(e)}",
            "raw_response": ""
        }

print('✓ Forward and counterfactual reasoning functions defined')

✓ Forward and counterfactual reasoning functions defined


### 4. 综合评估并做出最终决策

In [17]:
def make_final_decision(
    question: str,
    observable_outcome: str,
    choice_evaluations: List[Dict[str, Any]],
    model: str
) -> Dict[str, Any]:
    """
    基于所有choice的正向和反事实推理结果做出最终决策
    
    Args:
        question: 原始问题
        observable_outcome: 观察对象
        choice_evaluations: 所有choice的评估结果
        model: LLM模型
    
    Returns:
        最终决策
    """
    # 构建评估摘要
    eval_summary = f"Observable Outcome: {observable_outcome}\n\n"
    
    for eval_data in choice_evaluations:
        choice = eval_data['choice']
        forward = eval_data['forward_reasoning']
        counterfactual = eval_data['counterfactual_reasoning']
        
        final_conf = forward['confidence'] + counterfactual.get('confidence_adjustment', 0)
        final_conf = max(0.0, min(1.0, final_conf))  # Clamp to [0, 1]
        
        eval_summary += f"CHOICE: {choice.upper()}\n"
        eval_summary += f"  Forward Confidence: {forward['confidence']:.2f}\n"
        eval_summary += f"  Path Strength: {forward['path_strength']}\n"
        eval_summary += f"  Counterfactual Test: {counterfactual['test_result']}\n"
        eval_summary += f"  Confidence Adjustment: {counterfactual.get('confidence_adjustment', 0):+.2f}\n"
        eval_summary += f"  Final Confidence: {final_conf:.2f}\n"
        eval_summary += f"  Forward Rationale: {forward['rationale'][:150]}...\n"
        eval_summary += f"  Counterfactual Rationale: {counterfactual['rationale'][:150]}...\n\n"
    
    prompt = f"""You are a causal reasoning expert. Make a FINAL DECISION based on forward and counterfactual reasoning for all three choices.

QUESTION:
{question}

EVALUATION SUMMARY:
{eval_summary}

DECISION CRITERIA (in priority order):
1. **Counterfactual Test Result**: Choices that PASS the test are strongly preferred
   - "pass" means the causal relationship is validated
   - "fail" means the causal relationship is questionable
   - "unclear" means insufficient information

2. **Final Confidence Score**: After counterfactual adjustment
   - Choose the option with HIGHEST final confidence
   - But heavily weight counterfactual test results

3. **Path Strength**: Prefer "strong" > "moderate" > "weak" > "none"

SPECIAL RULES:
- If "no_effect" PASSES counterfactual test with confidence > 0.6 → STRONGLY prefer it
- If all choices have weak path strength and low confidence → default to "no_effect"
- If a choice FAILS counterfactual test → significantly lower its priority

Return in JSON format:
{{
  "final_answer": "more|less|no_effect",
  "confidence": 0.0-1.0,
  "reasoning": "Explain your decision considering both forward and counterfactual reasoning (3-5 sentences)",
  "decision_basis": "counterfactual_test|highest_confidence|path_strength|default_no_effect"
}}

Return ONLY JSON, nothing else."""

    try:
        response = ollama.generate(model=model, prompt=prompt)
        llm_response = response.get("response", "").strip()
        
        import re
        json_match = re.search(r"\{[^}]+\}", llm_response, re.DOTALL)
        if json_match:
            result = json.loads(json_match.group())
        else:
            # Fallback: 选择通过反事实测试且置信度最高的选项
            best_choice = None
            best_score = -1
            for eval_data in choice_evaluations:
                forward = eval_data['forward_reasoning']
                counterfactual = eval_data['counterfactual_reasoning']
                final_conf = forward['confidence'] + counterfactual.get('confidence_adjustment', 0)
                
                score = final_conf
                if counterfactual['test_result'] == 'pass':
                    score += 0.5  # Boost for passing counterfactual test
                
                if score > best_score:
                    best_score = score
                    best_choice = eval_data['choice']
            
            result = {
                "final_answer": best_choice or "no_effect",
                "confidence": best_score,
                "reasoning": "Fallback decision based on scores",
                "decision_basis": "highest_confidence"
            }
        
        final_answer = result.get("final_answer", "").strip().lower()
        if final_answer not in ["more", "less", "no_effect"]:
            # Fallback
            best_choice = max(choice_evaluations, 
                            key=lambda x: x['forward_reasoning']['confidence'] + 
                                        x['counterfactual_reasoning'].get('confidence_adjustment', 0))
            final_answer = best_choice['choice']
        
        return {
            "final_answer": final_answer,
            "confidence": result.get("confidence", 0.0),
            "reasoning": result.get("reasoning", ""),
            "decision_basis": result.get("decision_basis", "highest_confidence"),
            "raw_response": llm_response
        }
    
    except Exception as e:
        # Error fallback
        best_choice = max(choice_evaluations, 
                        key=lambda x: x['forward_reasoning']['confidence'] + 
                                    x['counterfactual_reasoning'].get('confidence_adjustment', 0))
        return {
            "final_answer": best_choice['choice'],
            "confidence": 0.5,
            "reasoning": f"Error fallback: {str(e)}",
            "decision_basis": "error_fallback",
            "raw_response": ""
        }

print('✓ Final decision function defined')

✓ Final decision function defined


In [18]:
def predict_with_choice_reasoning_v2(
    question: str,
    gold_answer: str,
    parser,
    builder,
    model: str,
    confidence_threshold: float = 0.3,
    print_details: bool = True
) -> Dict[str, Any]:
    """
    使用新的V2架构进行推理：observable outcome在seed extraction之前识别
    
    Args:
        question: 输入问题
        gold_answer: 正确答案
        parser: 问题解析器
        builder: 因果链构建器
        model: LLM模型
        confidence_threshold: 置信度阈值
        print_details: 是否打印详细信息
    
    Returns:
        推理结果
    """
    if print_details:
        print("\n" + "="*60)
        print("Question:", question)
        print("="*60)
    
    # Step 1: 解析问题结构
    if print_details:
        print("\n[Step 1] Parsing question structure...")
    
    question_structure = parser.parse_question_structure(question)
    intervention = question_structure.get('intervention')
    target_entity = question_structure.get('target_entity')
    
    if print_details:
        print(f"  Intervention: {intervention}")
        print(f"  Target: {target_entity}")
    
    # Step 2: 构建因果链（会自动识别observable outcome）
    if print_details:
        print("\n[Step 2] Building causal chain with observable outcome identification...")
    
    builder_result = builder.build_causal_chain(
        question=question,
        intervention=intervention,
        target=target_entity
    )
    
    # 获取observable outcome信息
    observable_info = builder_result.get('observable_outcome', {})
    observable_outcome = observable_info.get('observable_outcome', target_entity or 'unclear')
    
    if print_details:
        print(f"\n  Observable Outcome: {observable_outcome}")
        print(f"  Reasoning: {observable_info.get('reasoning', '')[:150]}...")
        print(f"  Seeds found: {len(builder_result.get('seeds', []))}")
        print(f"  Total edges: {len(builder_result.get('edges', []))}")
    
    # Step 3: 获取格式化的因果链
    if print_details:
        print("\n[Step 3] Formatting causal chains...")
    
    causal_chains_text = builder.get_formatted_causal_chains(
        result=builder_result,
        intervention=intervention,
        target=target_entity,
        min_confidence=confidence_threshold,
        max_paths=10
    )
    
    if print_details:
        print("  Causal Chains:")
        for line in causal_chains_text.split('\n')[:5]:  # 只显示前5条
            print(f"    {line}")
        if len(causal_chains_text.split('\n')) > 5:
            print(f"    ... and more")
    
    # 统计高置信度关系
    high_conf_relations = sum(
        1 for edge in builder_result.get('edges', [])
        if edge.get('confidence', 0) >= confidence_threshold
    )
    
    # Step 4: 对每个choice进行评估
    if print_details:
        print("\n[Step 4] Evaluating each choice...")
    
    choice_evaluations = []
    
    for choice in ['more', 'less', 'no_effect']:
        if print_details:
            print(f"\n  --- Evaluating: {choice.upper()} ---")
        
        # 4.1 正向推理（独立LLM调用）
        if print_details:
            print(f"    Forward reasoning...")
        
        forward_result = forward_reasoning(
            question=question,
            choice=choice,
            observable_outcome=observable_outcome,
            causal_chains_text=causal_chains_text,
            question_structure=question_structure,
            model=model
        )
        
        if print_details:
            print(f"      Confidence: {forward_result['confidence']:.2f}")
            print(f"      Path Strength: {forward_result['path_strength']}")
        
        # 4.2 反事实推理（独立LLM调用）
        if print_details:
            print(f"    Counterfactual reasoning...")
        
        counterfactual_result = counterfactual_reasoning(
            question=question,
            choice=choice,
            observable_outcome=observable_outcome,
            causal_chains_text=causal_chains_text,
            question_structure=question_structure,
            model=model
        )
        
        if print_details:
            print(f"      Test Result: {counterfactual_result['test_result']}")
            print(f"      Confidence Adjustment: {counterfactual_result.get('confidence_adjustment', 0):+.2f}")
        
        choice_evaluations.append({
            'choice': choice,
            'forward_reasoning': forward_result,
            'counterfactual_reasoning': counterfactual_result
        })
    
    # Step 5: 最终决策
    if print_details:
        print("\n[Step 5] Making final decision...")
    
    final_decision = make_final_decision(
        question=question,
        observable_outcome=observable_outcome,
        choice_evaluations=choice_evaluations,
        model=model
    )
    
    prediction = final_decision['final_answer']
    correct = (prediction == gold_answer)
    
    if print_details:
        print(f"\n  Final Answer: {prediction.upper()}")
        print(f"  Gold Answer: {gold_answer.upper()}")
        print(f"  Result: {'✓ CORRECT' if correct else '✗ WRONG'}")
        print(f"  Confidence: {final_decision['confidence']:.2f}")
    
    return {
        'question': question,
        'gold': gold_answer,
        'prediction': prediction,
        'correct': correct,
        'confidence': final_decision['confidence'],
        'decision_basis': final_decision.get('decision_basis', 'unknown'),
        'high_conf_relations': high_conf_relations,
        'observable_outcome': observable_outcome,
        'result': {
            'observable_info': observable_info,
            'choice_evaluations': choice_evaluations,
            'final_decision': final_decision,
            'builder_stats': {
                'num_seeds': len(builder_result.get('seeds', [])),
                'num_entities': len(builder_result.get('entities', [])),
                'num_edges': len(builder_result.get('edges', [])),
                'num_chains': len(builder_result.get('chains', []))
            }
        }
    }

print('✓ Main prediction function (V2) defined')

# 运行主推理循环
print("\n" + "="*60)
print("Running Choice Reasoning with Observable Outcome Integration")
print("="*60)

results = []

for i, sample in enumerate(SAMPLES, 1):
    print(f"\n{'='*60}")
    print(f"Processing Sample {i}/{len(SAMPLES)}")
    print(f"{'='*60}")
    
    try:
        result = predict_with_choice_reasoning_v2(
            question=sample['question'],
            gold_answer=sample['gold'],
            parser=PARSER,
            builder=BUILDER,
            model=MODEL,
            confidence_threshold=CONFIDENCE_THRESHOLD,
            print_details=PRINT_DETAILS
        )
        results.append(result)
        
    except Exception as e:
        print(f"\n⚠️ Error processing sample {i}: {e}")
        import traceback
        traceback.print_exc()
        
        # 添加错误结果
        results.append({
            'question': sample['question'],
            'gold': sample['gold'],
            'prediction': 'no_effect',  # 默认答案
            'correct': False,
            'confidence': 0.0,
            'decision_basis': 'error',
            'high_conf_relations': 0,
            'observable_outcome': 'error',
            'result': {'error': str(e)}
        })

print("\n" + "="*60)
print("All samples processed!")
print("="*60)

✓ Main prediction function (V2) defined

Running Choice Reasoning with Observable Outcome Integration

Processing Sample 1/5

Question: suppose getting a storm over the coast from the ocean happens, how will it affect MORE erosion by the ocean.

[Step 1] Parsing question structure...
  Intervention: getting a storm over the coast from the ocean happens
  Target: erosion by the ocean

[Step 2] Building causal chain with observable outcome identification...
[17:47:20] 
[17:47:20] STEP 0: Identify Observable Outcome
[17:47:22] 
LLM Response:
```json
{
  "observable_outcome": "increased rate of shoreline erosion",
  "reasoning": "Storms are a known factor contributing to coastal erosion. The question implies a causal relationship between the storm and a measurable increase in the physical process of erosion along the coastline."
}
```
[17:47:22] 
✅ Observable Outcome: increased rate of shoreline erosion
[17:47:22]    Reasoning: Storms are a known factor contributing to coastal erosion. The

## 结果汇总

In [19]:
# 计算准确率
total = len(results)
correct = sum(1 for r in results if r['correct'])
accuracy = correct / total if total > 0 else 0

print("\n" + "="*60)
print("Overall Results")
print("="*60)
print(f"Total Samples: {total}")
print(f"Correct: {correct}")
print(f"Accuracy: {accuracy:.2%}")

# 按标签统计
print("\nBy Label:")
for label in ['more', 'less', 'no_effect']:
    label_results = [r for r in results if r['gold'] == label]
    if label_results:
        label_correct = sum(1 for r in label_results if r['correct'])
        label_acc = label_correct / len(label_results)
        print(f"  {label}: {label_correct}/{len(label_results)} = {label_acc:.2%}")

# 按决策依据统计
print("\nBy Decision Basis:")
basis_counts = {}
for r in results:
    basis = r.get('decision_basis', 'unknown')
    if basis not in basis_counts:
        basis_counts[basis] = {'total': 0, 'correct': 0}
    basis_counts[basis]['total'] += 1
    if r['correct']:
        basis_counts[basis]['correct'] += 1

for basis, counts in basis_counts.items():
    acc = counts['correct'] / counts['total'] if counts['total'] > 0 else 0
    print(f"  {basis}: {counts['correct']}/{counts['total']} = {acc:.2%}")

# 平均置信度
avg_confidence = sum(r['confidence'] for r in results) / total if total > 0 else 0
avg_confidence_correct = sum(r['confidence'] for r in results if r['correct']) / correct if correct > 0 else 0
avg_confidence_wrong = sum(r['confidence'] for r in results if not r['correct']) / (total - correct) if (total - correct) > 0 else 0

print("\nConfidence Analysis:")
print(f"  Average Confidence: {avg_confidence:.2f}")
print(f"  Correct Predictions: {avg_confidence_correct:.2f}")
print(f"  Wrong Predictions: {avg_confidence_wrong:.2f}")

# 高置信度关系统计
avg_high_conf_rels = sum(r.get('high_conf_relations', 0) for r in results) / total if total > 0 else 0
print(f"\nAverage High-Confidence Relations: {avg_high_conf_rels:.1f}")

# Observable Outcome统计
print("\nObservable Outcomes:")
outcomes = {}
for r in results:
    outcome = r.get('observable_outcome', 'unclear')
    if outcome not in outcomes:
        outcomes[outcome] = {'total': 0, 'correct': 0}
    outcomes[outcome]['total'] += 1
    if r['correct']:
        outcomes[outcome]['correct'] += 1

for outcome, counts in sorted(outcomes.items(), key=lambda x: x[1]['total'], reverse=True)[:10]:
    acc = counts['correct'] / counts['total'] if counts['total'] > 0 else 0
    print(f"  {outcome[:50]}: {counts['correct']}/{counts['total']} = {acc:.2%}")


Overall Results
Total Samples: 5
Correct: 3
Accuracy: 60.00%

By Label:
  more: 1/1 = 100.00%
  less: 0/1 = 0.00%
  no_effect: 2/3 = 66.67%

By Decision Basis:
  counterfactual_test: 2/3 = 66.67%
  highest_confidence: 1/2 = 50.00%

Confidence Analysis:
  Average Confidence: 0.95
  Correct Predictions: 0.96
  Wrong Predictions: 0.95

Average High-Confidence Relations: 43.2

Observable Outcomes:
  increased rate of shoreline erosion: 1/1 = 100.00%
  Number of babies born: 1/1 = 100.00%
  Eating fewer vegetables: 0/1 = 0.00%
  the concentration of toxins in the liver: 1/1 = 100.00%
  The amount of recycling processed by the plant: 0/1 = 0.00%


## 详细结果查看

In [20]:
# 创建结果DataFrame
df_results = pd.DataFrame([{
    '问题': r['question'][:50] + '...' if len(r['question']) > 50 else r['question'],
    '正确答案': r['gold'],
    '预测': r['prediction'],
    '正确': '✓' if r['correct'] else '✗',
    '置信度': f"{r['confidence']:.2f}"
} for r in results])

print("\n详细结果表格:")
print(df_results.to_string(index=False))


详细结果表格:
                                                   问题      正确答案        预测 正确  置信度
suppose getting a storm over the coast from the oc...      more      more  ✓ 0.87
suppose use less batter per pancake happens, how w... no_effect no_effect  ✓ 1.00
suppose Having a stomach bug happens, how will it ... no_effect      less  ✗ 0.90
suppose more caterpillar eats happens, how will it... no_effect no_effect  ✓ 1.00
suppose The plant is given a new contract for recy...      less no_effect  ✗ 1.00


## 查看单个样本的完整推理过程

In [21]:
# 选择要查看的样本索引（0-based）
SAMPLE_INDEX = 0

if SAMPLE_INDEX < len(results):
    sample_result = results[SAMPLE_INDEX]
    
    print("\n" + "="*60)
    print(f"Sample {SAMPLE_INDEX + 1} - Complete Reasoning Process")
    print("="*60)
    
    print(f"\nQuestion: {sample_result['question']}")
    print(f"Gold Answer: {sample_result['gold']}")
    print(f"Prediction: {sample_result['prediction']}")
    print(f"Result: {'✓ Correct' if sample_result['correct'] else '✗ Wrong'}")
    
    # Observable Outcome
    observable_info = sample_result['result'].get('observable_info', {})
    print(f"\n{'='*60}")
    print("OBSERVABLE OUTCOME")
    print(f"{'='*60}")
    print(f"Identified: {observable_info.get('observable_outcome', 'unclear')}")
    print(f"Reasoning: {observable_info.get('reasoning', 'N/A')}")
    
    # Builder Stats
    builder_stats = sample_result['result'].get('builder_stats', {})
    print(f"\n{'='*60}")
    print("CAUSAL GRAPH STATISTICS")
    print(f"{'='*60}")
    print(f"Seeds: {builder_stats.get('num_seeds', 0)}")
    print(f"Entities: {builder_stats.get('num_entities', 0)}")
    print(f"Edges: {builder_stats.get('num_edges', 0)}")
    print(f"Chains: {builder_stats.get('num_chains', 0)}")
    print(f"High-Confidence Relations (≥{CONFIDENCE_THRESHOLD}): {sample_result.get('high_conf_relations', 0)}")
    
    # Choice Evaluations
    print("\n" + "-"*60)
    print("CHOICE EVALUATIONS")
    print("-"*60)
    
    for eval_data in sample_result['result']['choice_evaluations']:
        choice = eval_data['choice']
        forward = eval_data['forward_reasoning']
        counterfactual = eval_data['counterfactual_reasoning']
        
        final_conf = forward['confidence'] + counterfactual.get('confidence_adjustment', 0)
        final_conf = max(0.0, min(1.0, final_conf))
        
        print(f"\n{'='*40}")
        print(f"CHOICE: {choice.upper()}")
        print(f"{'='*40}")
        
        print(f"\n  [Forward Reasoning]")
        print(f"    Confidence: {forward['confidence']:.2f}")
        print(f"    Path Strength: {forward['path_strength']}")
        print(f"    Causal Path: {forward.get('causal_path', 'N/A')}")
        print(f"    Rationale: {forward['rationale'][:200]}...")
        
        if forward.get('supporting_evidence'):
            print(f"    Evidence:")
            for i, ev in enumerate(forward['supporting_evidence'][:3], 1):
                print(f"      {i}. {ev}")
        
        print(f"\n  [Counterfactual Reasoning]")
        print(f"    Test Result: {counterfactual['test_result']}")
        print(f"    Confidence Adjustment: {counterfactual.get('confidence_adjustment', 0):+.2f}")
        print(f"    Final Confidence: {final_conf:.2f}")
        print(f"    Scenario: {counterfactual.get('counterfactual_scenario', 'N/A')[:150]}...")
        print(f"    Rationale: {counterfactual['rationale'][:200]}...")
    
    # Final Decision
    print("\n" + "="*60)
    print("FINAL DECISION")
    print("="*60)
    
    final = sample_result['result']['final_decision']
    print(f"\nAnswer: {final['final_answer'].upper()}")
    print(f"Confidence: {final['confidence']:.2f}")
    print(f"Decision Basis: {final.get('decision_basis', 'unknown')}")
    print(f"\nReasoning:")
    print(f"{final['reasoning']}")
    
else:
    print(f"Error: Sample index {SAMPLE_INDEX} out of range (0-{len(results)-1})")


Sample 1 - Complete Reasoning Process

Question: suppose getting a storm over the coast from the ocean happens, how will it affect MORE erosion by the ocean.
Gold Answer: more
Prediction: more
Result: ✓ Correct

OBSERVABLE OUTCOME
Identified: increased rate of shoreline erosion
Reasoning: Storms are a known factor contributing to coastal erosion. The question implies a causal relationship between the storm and a measurable increase in the physical process of erosion along the coastline.

CAUSAL GRAPH STATISTICS
Seeds: 5
Entities: 27
Edges: 53
Chains: 10
High-Confidence Relations (≥0.3): 53

------------------------------------------------------------
CHOICE EVALUATIONS
------------------------------------------------------------

CHOICE: MORE

  [Forward Reasoning]
    Confidence: 0.00
    Path Strength: none
    Causal Path: 
    Rationale: There are no available causal chains provided to connect a storm over the coast to an increase in shoreline erosion. Without any causal evidence,

## 导出结果

In [22]:
# 保存结果到JSON文件
output_file = f'choice_reasoning_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'

export_data = {
    'config': {
        'num_samples': NUM_SAMPLES,
        'model': MODEL,
        'confidence_threshold': CONFIDENCE_THRESHOLD,
        'max_expansion_depth': MAX_EXPANSION_DEPTH,
        'max_neighbors_per_seed': MAX_NEIGHBORS_PER_SEED,
        'timestamp': datetime.now().isoformat()
    },
    'summary': {
        'total': total,
        'correct': correct,
        'accuracy': accuracy,
        'avg_confidence': avg_confidence,
        'avg_high_conf_relations': avg_high_conf_rels
    },
    'results': [{
        'question': r['question'],
        'gold': r['gold'],
        'prediction': r['prediction'],
        'correct': r['correct'],
        'confidence': r['confidence'],
        'decision_basis': r.get('decision_basis', 'unknown'),
        'observable_outcome': r.get('observable_outcome', 'unclear'),
        'high_conf_relations': r.get('high_conf_relations', 0),
        'observable_info': r['result'].get('observable_info', {}),
        'builder_stats': r['result'].get('builder_stats', {}),
        'choice_evaluations': r['result'].get('choice_evaluations', []),
        'final_decision': r['result'].get('final_decision', {})
    } for r in results]
}

with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(export_data, f, indent=2, ensure_ascii=False)

print(f"\n结果已保存到: {output_file}")
print(f"包含 {len(results)} 个样本的完整推理过程")


结果已保存到: choice_reasoning_results_20251107_175918.json
包含 5 个样本的完整推理过程
