# Batch Evaluation: WIQA Dataset Comparison

This notebook compares three approaches on the WIQA validation dataset:
1. **Gold Label** - Ground truth from dataset
2. **Baseline** - Direct LLM prediction without causal triples
3. **Pipeline** - Full causal triple generation → ranking → selection → decision

Outputs accuracy and detailed results for comparison.

In [1]:
import sys, os, importlib, json
from tqdm import tqdm
import pandas as pd
from datetime import datetime

sys.path.append(os.path.abspath('01'))

from datasets import load_dataset
import ollama

import semantic_ranker, triple_ranker, triple_selector, effect_decider, ego_expansion_builder
importlib.reload(semantic_ranker)
importlib.reload(triple_ranker)
importlib.reload(triple_selector)
importlib.reload(effect_decider)
importlib.reload(ego_expansion_builder)
from ego_expansion_builder import EgoExpansionCausalBuilder

# Configuration
MODEL = 'gemma2:27b'
CONFIDENCE_THRESHOLD = 0.7
SPLIT = 'validation'
NUM_VARIATIONS = 10
TOP_M = 3
KEEP_FRACTION = 0.5
BACKEND = 'auto'

# EgoExpansionCausalBuilder settings
MAX_EXPANSION_DEPTH = 2
MAX_NEIGHBORS_PER_SEED = 3
MAX_RELATIONS_PER_ENTITY = 3

# Instantiate EgoExpansionCausalBuilder
BUILDER = EgoExpansionCausalBuilder(
    model_name=MODEL,
    max_neighbors_per_seed=MAX_NEIGHBORS_PER_SEED,
    max_expansion_depth=MAX_EXPANSION_DEPTH,
    max_relations_per_entity=MAX_RELATIONS_PER_ENTITY,
)

# Batch settings
NUM_SAMPLES = 100  # Set to None to process all samples
START_INDEX = 0
SAVE_INTERVAL = 10  # Save results every N samples

print(f"Configuration loaded. Will process {NUM_SAMPLES if NUM_SAMPLES else 'ALL'} samples from {SPLIT} split.")

  from .autonotebook import tqdm as notebook_tqdm


Configuration loaded. Will process 100 samples from validation split.


In [2]:
# Load dataset
ds = load_dataset('allenai/wiqa', split=SPLIT, trust_remote_code=True)
print(f"Total samples in {SPLIT} split: {len(ds)}")

# Determine sample range
end_index = START_INDEX + NUM_SAMPLES if NUM_SAMPLES else len(ds)
end_index = min(end_index, len(ds))
print(f"Will process samples {START_INDEX} to {end_index-1} (total: {end_index-START_INDEX} samples)")

Total samples in validation split: 6894
Will process samples 0 to 99 (total: 100 samples)


In [3]:
# Helper functions
def get_question(ex):
    for key in ['question', 'question_stem', 'query', 'what_if', 'question_text']:
        if key in ex and ex[key]:
            q = ex[key]
            if isinstance(q, dict) and 'stem' in q:
                q = q['stem']
            return str(q)
    return ''

def get_label(ex):
    for key in ['answer_label', 'label', 'effect_label']:
        if key in ex and ex[key] is not None:
            return str(ex[key]).strip().lower()
    return None

def normalize_label(lbl):
    if lbl is None:
        return None
    mapping = {
        'no effect': 'no_effect',
        'no_effect': 'no_effect',
        'more': 'more',
        'less': 'less'
    }
    return mapping.get(str(lbl).strip().lower(), None)

print("Helper functions defined.")

Helper functions defined.


In [4]:
# Method 1: Baseline - Direct LLM prediction
def predict_baseline(question, model=MODEL):
    """Direct LLM prediction without causal triple pipeline."""
    try:
        baseline_prompt = f"""Based on the following question, directly predict whether the effect is MORE, LESS, or NO_EFFECT.

Question: {question}

Consider:
1. What is the initial state/change mentioned?
2. What is the outcome/effect being asked about?
3. Does logic or common sense suggest the outcome increases (MORE), decreases (LESS), or stays the same (NO_EFFECT)?

Return ONLY the label in one of these formats:
- more
- less
- no_effect

Do NOT provide explanation, just the label:"""
        
        response = ollama.generate(model=model, prompt=baseline_prompt)
        baseline_response = response.get("response", "").strip().lower()
        
        # Parse and normalize
        baseline_label = normalize_label(baseline_response)
        if baseline_label is None:
            # Fallback pattern matching
            if "more" in baseline_response:
                baseline_label = "more"
            elif "less" in baseline_response:
                baseline_label = "less"
            elif "no" in baseline_response and "effect" in baseline_response:
                baseline_label = "no_effect"
            else:
                baseline_label = "uncertain"
        
        return baseline_label
    except Exception as e:
        print(f"Baseline prediction error: {e}")
        return "error"

print("Baseline prediction function defined.")

Baseline prediction function defined.


In [5]:
# Method 2: Pipeline - Full causal triple approach
def predict_pipeline(question, model=MODEL):
    """Full pipeline: triple generation → ranking → selection → decision."""
    try:
        # Step 1: Build causal graph via EgoExpansionCausalBuilder
        builder_result = BUILDER.build_causal_chain(question)
        
        # If no edges generated, return no_effect
        if not builder_result.get('edges'):
            return "no_effect"
        
        # Step 2: Rank triples
        ranked = triple_ranker.rank_triples(
            builder_result,
            question,
            num_variations=NUM_VARIATIONS,
            backend=BACKEND,
            top_m=TOP_M
        )
        
        # Step 3: Decide effect
        decision = effect_decider.decide_effect(
            question,
            ranked,
            target=None,
            weight_avg=0.7,
            weight_confidence=0.3
        )
        
        return decision.get('decision', 'uncertain')
    except Exception as e:
        print(f"Pipeline prediction error: {e}")
        return "error"

print("Pipeline prediction function defined.")

Pipeline prediction function defined.


In [None]:
from question_parser import QuestionParser
from llm_predictors import predict_meta_informed_llm, predict_combined_context_llm
PARSER = QuestionParser(model_name=MODEL)
print("Parser and LLM predictor functions ready.")


In [None]:
# Quick demo: run both new predictors on a few samples
SAMPLE_N = 3
print("\n" + "="*80)
print(f"Quick demo on first {SAMPLE_N} samples")
print("="*80)
for i in range(START_INDEX, min(START_INDEX + SAMPLE_N, end_index)):
    ex = ds[i]
    question = get_question(ex)
    gold = normalize_label(get_label(ex))
    print("\n" + "-"*40)
    print(f"Index: {i}")
    print(f"Q: {question}")
    print(f"Gold: {gold}")
    try:
        res1 = predict_meta_informed_llm(question, PARSER, BUILDER, MODEL)
        res2 = predict_combined_context_llm(question, PARSER, BUILDER, MODEL)
        print(f"\nMeta-Informed LLM => {res1.get('final_answer')}")
        print(f"Combined-Context LLM => {res2.get('final_answer')}")
    except Exception as e:
        print(f"Error running demo predictors: {e}")


In [None]:
# Main evaluation loop
results = []
baseline_correct = 0
meta_correct = 0
combined_correct = 0
total_processed = 0

print(f"\nStarting batch evaluation at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("="*80)

for idx in tqdm(range(START_INDEX, end_index), desc="Processing samples"):
    ex = ds[idx]
    question = get_question(ex)
    gold = normalize_label(get_label(ex))
    
    # Skip if no valid question or gold label
    if not question or gold is None:
        continue
    
    # Get predictions
    baseline_pred = predict_baseline(question)
    meta_res = predict_meta_informed_llm(question, PARSER, BUILDER, MODEL)
    combined_res = predict_combined_context_llm(question, PARSER, BUILDER, MODEL)
    meta_pred = normalize_label(meta_res.get('final_answer'))
    combined_pred = normalize_label(combined_res.get('final_answer'))
    
    # Check correctness
    baseline_match = (baseline_pred == gold)
    meta_match = (meta_pred == gold)
    combined_match = (combined_pred == gold)
    
    if baseline_match:
        baseline_correct += 1
    if meta_match:
        meta_correct += 1
    if combined_match:
        combined_correct += 1
    total_processed += 1
    
    # Store result
    results.append({
        'index': idx,
        'question': question,
        'gold': gold,
        'baseline_pred': baseline_pred,
        'meta_pred': meta_pred,
        'combined_pred': combined_pred,
        'baseline_correct': baseline_match,
        'meta_correct': meta_match,
        'combined_correct': combined_match
    })
    
    # Periodic save and progress update
    if (total_processed % SAVE_INTERVAL == 0):
        baseline_acc = (baseline_correct / total_processed) * 100
        meta_acc = (meta_correct / total_processed) * 100
        combined_acc = (combined_correct / total_processed) * 100
        print(f"\n[Progress] {total_processed} samples | Baseline: {baseline_acc:.2f}% | Meta: {meta_acc:.2f}% | Combined: {combined_acc:.2f}%")
        
        # Save intermediate results
        df_temp = pd.DataFrame(results)
        df_temp.to_csv(f'batch_results_temp_{total_processed}.csv', index=False)

print("\n" + "="*80)
print(f"Evaluation completed at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Total samples processed: {total_processed}")

In [7]:
# Calculate final accuracy
if total_processed > 0:
    baseline_accuracy = (baseline_correct / total_processed) * 100
    meta_accuracy = (meta_correct / total_processed) * 100
    combined_accuracy = (combined_correct / total_processed) * 100
    
    print("\n" + "="*80)
    print("FINAL RESULTS")
    print("="*80)
    print(f"Total samples:      {total_processed}")
    print(f"\nBaseline (Direct LLM):")
    print(f"  Correct:          {baseline_correct}")
    print(f"  Accuracy:         {baseline_accuracy:.2f}%")
    print(f"\nMeta-Informed LLM:")
    print(f"  Correct:          {meta_correct}")
    print(f"  Accuracy:         {meta_accuracy:.2f}%")
    print(f"  vs Baseline:       {meta_accuracy - baseline_accuracy:+.2f}%")
    print(f"\nCombined-Context LLM:")
    print(f"  Correct:          {combined_correct}")
    print(f"  Accuracy:         {combined_accuracy:.2f}%")
    print(f"  vs Baseline:       {combined_accuracy - baseline_accuracy:+.2f}%")
    print("="*80)
else:
    print("No samples were processed.")


FINAL RESULTS
Total samples:      100

Baseline (Direct LLM):
  Correct:          53
  Accuracy:         53.00%

Pipeline (Causal Triples):
  Correct:          31
  Accuracy:         31.00%

Improvement:        -22.00%


In [8]:
# Save detailed results
df_results = pd.DataFrame(results)
output_filename = f'batch_results_{SPLIT}_{START_INDEX}_{end_index-1}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.csv'
df_results.to_csv(output_filename, index=False)
print(f"\nDetailed results saved to: {output_filename}")

# Display first few results
print("\nSample results (first 5):")
print(df_results.head())


Detailed results saved to: batch_results_validation_0_99_20251103_150625.csv

Sample results (first 5):
   index                                           question       gold  \
0      0  suppose squirrels get sick happens, how will i...       more   
1      1  suppose the female is sterile happens, how wil...       more   
2      2  suppose there is no sunlight for the tree to g...  no_effect   
3      3  suppose less water vapor forms happens, how wi...       less   
4      4  suppose less gasoline is loaded onto tank truc...  no_effect   

  baseline_pred pipeline_pred  baseline_correct  pipeline_correct  
0          more     no_effect              True             False  
1          less          more             False              True  
2     no_effect     no_effect              True              True  
3          less          more              True             False  
4          less          more             False             False  


In [9]:
# Breakdown by label
print("\n" + "="*80)
print("ACCURACY BREAKDOWN BY LABEL")
print("="*80)

for label in ['more', 'less', 'no_effect']:
    label_subset = df_results[df_results['gold'] == label]
    if len(label_subset) > 0:
        baseline_acc = (label_subset['baseline_correct'].sum() / len(label_subset)) * 100
        meta_acc = (label_subset['meta_correct'].sum() / len(label_subset)) * 100
        combined_acc = (label_subset['combined_correct'].sum() / len(label_subset)) * 100
        print(f"\nLabel: {label.upper()}")
        print(f"  Count:            {len(label_subset)}")
        print(f"  Baseline Acc:     {baseline_acc:.2f}%")
        print(f"  Meta Acc:         {meta_acc:.2f}%")
        print(f"  Combined Acc:     {combined_acc:.2f}%")
        print(f"  Meta vs Base:     {meta_acc - baseline_acc:+.2f}%")
        print(f"  Comb vs Base:     {combined_acc - baseline_acc:+.2f}%")


ACCURACY BREAKDOWN BY LABEL

Label: MORE
  Count:            38
  Baseline Acc:     42.11%
  Pipeline Acc:     42.11%
  Improvement:      +0.00%

Label: LESS
  Count:            28
  Baseline Acc:     71.43%
  Pipeline Acc:     0.00%
  Improvement:      -71.43%

Label: NO_EFFECT
  Count:            34
  Baseline Acc:     50.00%
  Pipeline Acc:     44.12%
  Improvement:      -5.88%


In [10]:
# Confusion analysis
print("\n" + "="*80)
print("ERROR ANALYSIS")
print("="*80)

# Baseline correct but Meta wrong
base_meta = df_results[(df_results['baseline_correct'] == True) & (df_results['meta_correct'] == False)]
print(f"\nBaseline correct, Meta wrong: {len(base_meta)} cases")
if len(base_meta) > 0:
    print("Examples:")
    for _, row in base_meta.head(3).iterrows():
        print(f"  Q: {row['question'][:80]}...")
        print(f"     Gold: {row['gold']} | Baseline: {row['baseline_pred']} | Meta: {row.get('meta_pred')}\n")

# Baseline correct but Combined wrong
base_comb = df_results[(df_results['baseline_correct'] == True) & (df_results['combined_correct'] == False)]
print(f"\nBaseline correct, Combined wrong: {len(base_comb)} cases")
if len(base_comb) > 0:
    print("Examples:")
    for _, row in base_comb.head(3).iterrows():
        print(f"  Q: {row['question'][:80]}...")
        print(f"     Gold: {row['gold']} | Baseline: {row['baseline_pred']} | Combined: {row.get('combined_pred')}\n")

# Combined correct but Baseline wrong
comb_only = df_results[(df_results['baseline_correct'] == False) & (df_results['combined_correct'] == True)]
print(f"\nCombined correct, Baseline wrong: {len(comb_only)} cases")
if len(comb_only) > 0:
    print("Examples:")
    for _, row in comb_only.head(3).iterrows():
        print(f"  Q: {row['question'][:80]}...")
        print(f"     Gold: {row['gold']} | Baseline: {row['baseline_pred']} | Combined: {row.get('combined_pred')}\n")

# Meta correct but Baseline wrong
meta_only = df_results[(df_results['baseline_correct'] == False) & (df_results['meta_correct'] == True)]
print(f"\nMeta correct, Baseline wrong: {len(meta_only)} cases")
if len(meta_only) > 0:
    print("Examples:")
    for _, row in meta_only.head(3).iterrows():
        print(f"  Q: {row['question'][:80]}...")
        print(f"     Gold: {row['gold']} | Baseline: {row['baseline_pred']} | Meta: {row.get('meta_pred')}\n")

# Both LLM-direct methods wrong
both_wrong = df_results[(df_results['meta_correct'] == False) & (df_results['combined_correct'] == False)]
print(f"\nBoth wrong (Meta & Combined): {len(both_wrong)} cases")


ERROR ANALYSIS

Baseline correct, Pipeline wrong: 39 cases
Examples:
  Q: suppose squirrels get sick happens, how will it affect squirrels need more food....
     Gold: more | Baseline: more | Pipeline: no_effect

  Q: suppose less water vapor forms happens, how will it affect MORE clouds forming....
     Gold: less | Baseline: less | Pipeline: more

  Q: suppose more oil is processed happens, how will it affect MORE oil arriving at g...
     Gold: more | Baseline: more | Pipeline: no_effect


Pipeline correct, Baseline wrong: 17 cases
Examples:
  Q: suppose the female is sterile happens, how will it affect LESS rabbits....
     Gold: more | Baseline: less | Pipeline: more

  Q: suppose the volcano has become inactive happens, how will it affect tree rings w...
     Gold: no_effect | Baseline: more | Pipeline: no_effect

  Q: suppose less oil delivered happens, how will it affect more paper available....
     Gold: no_effect | Baseline: less | Pipeline: no_effect


Both wrong: 30 case