In [1]:
import os
import json
import random
from WIQACausalBuilder import WIQACausalBuilder

# Set random seed for reproducibility
random.seed(42)

# Load data
json_path = r'E:\PHD\01\wiqa_filtered_INPARA_EFFECT.json'
with open(json_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

print(f"Total datapoints in file: {len(data)}")
print(f"Processing first 10 datapoints...\n")

# Select first 10 datapoints
datapoints = data[:300]

# Store results
results = []


Total datapoints in file: 1655
Processing first 10 datapoints...



In [2]:

for idx, datapoint in enumerate(datapoints, 1):
    print("=" * 80)
    print(f"Processing datapoint {idx}/10")
    print("=" * 80)
    print(f"Question: {datapoint['question_stem']}")
    print(f"Gold answer: {datapoint['answer_label']} ({datapoint['answer_label_as_choice']})")
    print()
    
    try:
        # Initialize builder
        wiqa = WIQACausalBuilder(datapoint)
        
        # Step 1: Extract start and target entities
        print("Step 1: Extracting entities...")
        info = wiqa.extract_start_entity()
        start = info["cause_event"]
        target = info["outcome_base"]
        print(f"  Cause: '{start}'")
        print(f"  Outcome base: '{target}'")
        print()
        
        # Step 2: BFS expansion
        print("Step 2: BFS expansion...")
        bfs = wiqa.expand_toward_target(
            start_X=start, 
            target_Y=target, 
            max_depth=5, 
            max_relations_per_node=5
        )
        print(f"  Triples found: {len(bfs['triples'])}")
        print(f"  Nodes visited: {len(bfs['visited'])}")
        print(f"  Close hits: {len(bfs['close_hits'])}")
        print()
        
        # Step 3: Bridge close hits and extract causal chain
        print("Step 3: Bridging and chain extraction...")
        if bfs["close_hits"]:
            triples_with_bridges = wiqa.bridge_close_hits(
                triples=bfs["triples"],
                close_hits=bfs["close_hits"],
                Y=target,
                max_bridge_nodes=3,
            )
        else:
            triples_with_bridges = bfs["triples"]
        
        chain_result = wiqa.get_causal_chain(triples_with_bridges, start_X=start, target_Y=target)
        print(f"  Paths found: {chain_result['num_paths']}")
        print()
        
        # Step 4: Generate description
        print("Step 4: Generating description...")
        description = wiqa.causal_chain_to_text(chain_result, bfs)
        print()
        
        # Step 5: LLM reasoning
        print("Step 5: Final reasoning...")
        reasoning_result = wiqa.reason_with_description(description, chain_result=chain_result)
        
        # Check result
        gold_label = datapoint['answer_label']
        pred_label = reasoning_result['predicted_answer']
        gold_norm = gold_label.strip().lower().replace(" ", "_")
        pred_norm = pred_label.strip().lower().replace(" ", "_")
        is_correct = (pred_norm == gold_norm)
        
        print(f"\nPrediction: {pred_label}")
        print(f"Gold: {gold_label}")
        print(f"Result: {'✓ CORRECT' if is_correct else '✗ WRONG'}")
        
        # Store result
        result_entry = {
            'index': idx,
            'question': datapoint['question_stem'],
            'gold_answer': gold_label,
            'gold_choice': datapoint['answer_label_as_choice'],
            'predicted_answer': pred_label,
            'predicted_choice': reasoning_result['predicted_choice'],
            'is_correct': is_correct,
            'confidence': reasoning_result['confidence'],
            'effect_on_base': reasoning_result.get('effect_on_base', 'N/A'),
            'reasoning': reasoning_result['reasoning'],
            'num_paths': chain_result['num_paths'],
            'num_triples': len(bfs['triples']),
        }
        results.append(result_entry)
        
    except Exception as e:
        print(f"\n✗ ERROR processing datapoint {idx}: {str(e)}")
        import traceback
        traceback.print_exc()
        
        # Store error result
        result_entry = {
            'index': idx,
            'question': datapoint['question_stem'],
            'gold_answer': datapoint['answer_label'],
            'gold_choice': datapoint['answer_label_as_choice'],
            'predicted_answer': 'ERROR',
            'predicted_choice': 'ERROR',
            'is_correct': False,
            'confidence': 'N/A',
            'effect_on_base': 'N/A',
            'reasoning': f'Error: {str(e)}',
            'num_paths': 0,
            'num_triples': 0,
        }
        results.append(result_entry)
    
    print()

# Summary
print("=" * 80)
print("SUMMARY")
print("=" * 80)
correct_count = sum(1 for r in results if r['is_correct'])
total_count = len(results)
accuracy = correct_count / total_count if total_count > 0 else 0

print(f"Total processed: {total_count}")
print(f"Correct: {correct_count}")
print(f"Wrong: {total_count - correct_count}")
print(f"Accuracy: {accuracy:.2%}")
print()

# Display results table
print("Detailed Results:")
print("-" * 80)
for r in results:
    status = "✓" if r['is_correct'] else "✗"
    print(f"{status} [{r['index']}] Gold: {r['gold_answer']:10s} | Pred: {r['predicted_answer']:10s} | Conf: {r['confidence']}")
print("-" * 80)

# Save results to JSON
output_path = 'single_question_demo_results.json'
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(results, f, indent=2, ensure_ascii=False)
print(f"\nResults saved to: {output_path}")

Processing datapoint 1/10
Question: suppose the female is sterile happens, how will it affect LESS rabbits.
Gold answer: more (A)

Step 1: Extracting entities...
  Cause: 'the female is sterile happens'
  Outcome base: 'rabbits'

Step 2: BFS expansion...
find_causal_relations: X = the female is sterile happens , Y = rabbits
is_same_variable: LLM 判定 rabbit offspring <-> rabbits = part_of
find_causal_relations: X = rabbit offspring , Y = rabbits
  Triples found: 2
  Nodes visited: 3
  Close hits: 1

Step 3: Bridging and chain extraction...
  Paths found: 1

Step 4: Generating description...

Step 5: Final reasoning...
[_final_llm_decision] JSON parse failed after regex extraction. Raw output snippet:
{
  "effect_on_base": "less",
  "confidence": "high",
  "reasoning": "The only causal path shows that the female being sterile suppresses rabbit offspring, which in turn promotes rabbits. This strongly suggests a decrease in the number of rabbits.",
  "scores": {
    "more": 0.0,
    "less":