# WIQA CDCR-SFT Test Evaluation
## Using WIQACausalBuilder Method

This notebook evaluates the CDCR-SFT wiqa_test.csv dataset using your WIQA Causal method and breaks down accuracy by question type (EXOGENOUS_EFFECT vs INPARA_EFFECT).

In [None]:
import os
import json
import pandas as pd
from WIQACausalBuilder import WIQACausalBuilder
from tqdm import tqdm

## 1. Load Data

In [None]:
# Load the CSV file
csv_path = r'E:\PHD\01\other_code\CDCR-SFT\data\wiqa_test.csv'
df = pd.read_csv(csv_path)

print(f"Total datapoints in CSV: {len(df)}")
print(f"\nQuestion types distribution:")
print(df['question_type'].value_counts())
print(f"\nColumn names: {list(df.columns)}")
print(f"\nFirst few rows:")
df.head(3)

## 2. Process All Datapoints

In [None]:
# Store results
results = []

# You can limit the number of samples for testing
# Uncomment the line below to process only first N samples
# df = df.head(10)

# -----------------------------
# Multithread evaluation runner
# -----------------------------
from concurrent.futures import ThreadPoolExecutor, as_completed
import contextlib
import time

try:
    MAX_WORKERS = int(os.environ.get('WIQA_MAX_WORKERS', '4'))
except Exception:
    MAX_WORKERS = 4
MAX_WORKERS = 4

# Set to False if you want to see per-question pipeline prints (not recommended in multi-thread).
SUPPRESS_PIPELINE_OUTPUT = True

class _NullWriter:
    def write(self, s):
        return len(s)
    def flush(self):
        pass

_NULL = _NullWriter()

def _process_record(record):
    try:
        # Convert CSV record to the format expected by WIQACausalBuilder
        datapoint = {
            'question_stem': record['question_stem'],
            'answer_label': record['answer_label'],
            'answer_label_as_choice': record['answer_label_as_choice'],
            'choices': {
                'text': ['more', 'less', 'no_effect'],
                'label': ['A', 'B', 'C']
            }
        }

        wiqa = WIQACausalBuilder(datapoint, model_name="mistral:7b")
        if SUPPRESS_PIPELINE_OUTPUT:
            with contextlib.redirect_stdout(_NULL), contextlib.redirect_stderr(_NULL):
                is_correct = wiqa.run_wiqa_pipeline()
        else:
            is_correct = wiqa.run_wiqa_pipeline()

        return {
            'csv_id': record.get('id', ''),
            'question': record.get('question_stem', ''),
            'question_type': record.get('question_type', ''),
            'improved_question': record.get('improved_question', ''),
            'gold_answer': record.get('answer_label', ''),
            'gold_choice': record.get('answer_label_as_choice', ''),
            'is_correct': bool(is_correct),
            'cause_event': getattr(wiqa, 'cause_event', ''),
            'outcome_base': getattr(wiqa, 'outcome_base', ''),
        }

    except Exception as e:
        return {
            'csv_id': record.get('id', ''),
            'question': record.get('question_stem', ''),
            'question_type': record.get('question_type', ''),
            'improved_question': record.get('improved_question', ''),
            'gold_answer': record.get('answer_label', ''),
            'gold_choice': record.get('answer_label_as_choice', ''),
            'predicted_answer': 'ERROR',
            'predicted_choice': '',
            'is_correct': False,
            'cause_event': '',
            'outcome_base': '',
            'error': str(e),
        }

records = df.to_dict('records')
results = [None] * len(records)
t0 = time.time()

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    futures = {executor.submit(_process_record, r): i for i, r in enumerate(records)}
    for fut in tqdm(as_completed(futures), total=len(futures), desc=f'Processing WIQA test ({MAX_WORKERS} threads)'):
        idx = futures[fut]
        results[idx] = fut.result()

results = [r for r in results if r is not None]

elapsed = time.time() - t0
print()
print(f'Processing complete! Total results: {len(results)} (elapsed={elapsed:.1f}s, workers={MAX_WORKERS})')

## 3. Save Results

In [14]:
# Save to JSON
output_json = 'wiqa_test_results_by_type_mistral.json'
with open(output_json, 'w', encoding='utf-8') as f:
    json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Full results saved to: {output_json}")

# Save to CSV
results_df = pd.DataFrame(results)
output_csv = 'wiqa_test_results_by_type_mistral.csv'
results_df.to_csv(output_csv, index=False, encoding='utf-8')
print(f"CSV results saved to: {output_csv}")

## 4. Overall Statistics

In [15]:
# Calculate overall statistics
total_count = len(results)
correct_count = sum(1 for r in results if r['is_correct'])
error_count = sum(1 for r in results if r.get('predicted_answer') == 'ERROR')
accuracy = correct_count / total_count if total_count > 0 else 0

print("="*80)
print("OVERALL STATISTICS")
print("="*80)
print(f"Total processed: {total_count}")
print(f"Correct: {correct_count}")
print(f"Wrong: {total_count - correct_count - error_count}")
print(f"Errors: {error_count}")
print(f"Accuracy: {accuracy:.2%}")

## 5. Statistics by Question Type (EXOGENOUS vs INPARA)

In [None]:
print("="*80)
print("STATISTICS BY QUESTION TYPE")
print("="*80)

# Statistics by question type
for qtype in ['EXOGENOUS_EFFECT', 'INPARA_EFFECT']:
    type_results = [r for r in results if r['question_type'] == qtype]
    if type_results:
        type_total = len(type_results)
        type_correct = sum(1 for r in type_results if r['is_correct'])
        type_errors = sum(1 for r in type_results if r.get('predicted_answer') == 'ERROR')
        type_accuracy = type_correct / type_total if type_total > 0 else 0

        print(f"\n{qtype}:")
        print(f"  Total: {type_total}")
        print(f"  Correct: {type_correct}")
        print(f"  Wrong: {type_total - type_correct - type_errors}")
        print(f"  Errors: {type_errors}")
        print(f"  Accuracy: {type_accuracy:.2%}")

## 6. Visualize Results

In [None]:
import matplotlib.pyplot as plt

# Prepare data for visualization
stats_by_type = {}
for qtype in ['EXOGENOUS_EFFECT', 'INPARA_EFFECT']:
    type_results = [r for r in results if r['question_type'] == qtype]
    if type_results:
        type_total = len(type_results)
        type_correct = sum(1 for r in type_results if r['is_correct'])
        type_accuracy = type_correct / type_total if type_total > 0 else 0
        stats_by_type[qtype] = {
            'total': type_total,
            'correct': type_correct,
            'accuracy': type_accuracy
        }

# Create bar plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Accuracy by Type
types = list(stats_by_type.keys())
accuracies = [stats_by_type[t]['accuracy'] * 100 for t in types]
colors = ['#FF6B6B', '#4ECDC4']

bars1 = ax1.bar(types, accuracies, color=colors, alpha=0.7, edgecolor='black')
ax1.set_ylabel('Accuracy (%)', fontsize=12)
ax1.set_title('Accuracy by Question Type', fontsize=14, fontweight='bold')
ax1.set_ylim(0, 100)
ax1.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar in bars1:
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.1f}%',
             ha='center', va='bottom', fontweight='bold')

# Plot 2: Sample counts
totals = [stats_by_type[t]['total'] for t in types]
corrects = [stats_by_type[t]['correct'] for t in types]
wrongs = [stats_by_type[t]['total'] - stats_by_type[t]['correct'] for t in types]

x = range(len(types))
width = 0.35

bars2 = ax2.bar([i - width/2 for i in x], corrects, width, label='Correct', color='#2ECC71', alpha=0.7, edgecolor='black')
bars3 = ax2.bar([i + width/2 for i in x], wrongs, width, label='Wrong', color='#E74C3C', alpha=0.7, edgecolor='black')

ax2.set_ylabel('Count', fontsize=12)
ax2.set_title('Correct vs Wrong by Question Type', fontsize=14, fontweight='bold')
ax2.set_xticks(x)
ax2.set_xticklabels(types)
ax2.legend()
ax2.grid(axis='y', alpha=0.3)

# Add value labels
for bar in bars2:
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height,
             f'{int(height)}',
             ha='center', va='bottom', fontsize=10)

for bar in bars3:
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height,
             f'{int(height)}',
             ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.savefig('wiqa_test_accuracy_by_type.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nVisualization saved to: wiqa_test_accuracy_by_type_.png")

## 7. Detailed Results Table

In [None]:
# Display results as a DataFrame
display_df = results_df[['csv_id', 'question_type', 'gold_answer', 'is_correct']]
display_df

## 8. Error Analysis

In [None]:
# Show wrong predictions by type
print("="*80)
print("ERROR ANALYSIS")
print("="*80)

for qtype in ['EXOGENOUS_EFFECT', 'INPARA_EFFECT']:
    wrong_results = [r for r in results if r['question_type'] == qtype and not r['is_correct'] and r.get('predicted_answer') != 'ERROR']
    
    print(f"\n{qtype} - Wrong Predictions: {len(wrong_results)}")
    print("-" * 80)
    
    for r in wrong_results[:5]:  # Show first 5 errors
        print(f"ID {r['csv_id']}: Gold={r['gold_answer']}")
        print(f"  Question: {r['question']}")
        print()