In [6]:
!python -m agents.answer_agent \
    --input_file "dataset/truth_teller_liar_questions.json" \
    --output_file "outputs/answers_test.json" \
    --batch_size 20 \
    --verbose

Loading checkpoint shards: 100%|██████████████████| 3/3 [00:03<00:00,  1.08s/it]
STEPS: : 35batch [05:18,  9.10s/batch]                                          

=== Question 1 ===
Question: A farmer has three sons, Tom, Dick, and Harry. Each son has a favorite animal, a horse, a cow, and a chicken respectively. One day, the farmer asks each of his sons to tell him the truth about their favorite animal. Tom says that his horse loves to eat hay, Dick says that his cow loves to eat grass, and Harry says that his chicken loves to eat seeds. However, the farmer knows that one of his sons is a liar and the other two always tell the truth. Determine the favorite animal of the liar son and the food it eats. The choices are: A) Tom's horse - hay, B) Dick's cow - grass, C) Harry's chicken - seeds.
Expected: C) Harry's chicken - seeds
Model Answer:
{"answer": "D", "reasoning": "If A is true, Tom tells the truth, and the liar must be Dick or Harry. If B is true, Dick tells the truth, and the lia

In [12]:
import json
import re
from typing import Dict, List

def load_json_file(file_path: str) -> List[Dict]:
    """Load JSON data from file"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"Error: File {file_path} not found")
        return []
    except json.JSONDecodeError:
        print(f"Error: Invalid JSON in {file_path}")
        return []

def extract_choice_letter(text: str) -> str:
    """Extract choice letter A, B, C, or D from text"""
    if not text:
        return ''
    # Look for pattern like "A)", "B)", etc.
    pattern = re.search(r'^([ABCD])\)', text.strip())
    if pattern:
        return pattern.group(1)
    # Look for isolated letters
    pattern = re.search(r'\b([ABCD])\b', text.upper())
    return pattern.group(1) if pattern else ''

def diagnose_data_structure(ground_truth_file: str, predictions_file: str):
    """Diagnose the structure and compare first letters of answers"""
    print("=" * 80)
    print("DIAGNOSTIC ANALYSIS (First Letter Comparison)")
    print("=" * 80)
    
    # Load data
    gt_data = load_json_file(ground_truth_file)
    pred_data = load_json_file(predictions_file)
    
    if not gt_data or not pred_data:
        return
    
    # Filter None values
    gt_data = [item for item in gt_data if item is not None]
    pred_data = [item for item in pred_data if item is not None]
    
    print(f"Ground Truth: {len(gt_data)} items")
    print(f"Predictions: {len(pred_data)} items")
    
    # Show first 3 comparisons
    print("\n" + "-" * 50)
    print("INDEX-BASED FIRST LETTER COMPARISON (First 3 items)")
    print("-" * 50)
    
    min_len = min(len(gt_data), len(pred_data), 3)
    for i in range(min_len):
        gt_item = gt_data[i] if isinstance(gt_data[i], dict) else {}
        pred_item = pred_data[i] if isinstance(pred_data[i], dict) else {}
        
        gt_answer = gt_item.get('answer', '')
        pred_answer = pred_item.get('answer', '')
        
        gt_choice = extract_choice_letter(gt_answer)
        pred_choice = extract_choice_letter(pred_answer)
        
        print(f"\nIndex {i}:")
        print(f"  GT Answer: {repr(gt_answer)} (Choice: {gt_choice})")
        print(f"  Pred Answer: {repr(pred_answer)} (Choice: {pred_choice})")
        print(f"  Match: {gt_choice == pred_choice and gt_choice != ''}")

def calculate_first_letter_accuracy(ground_truth_file: str, predictions_file: str):
    """Calculate accuracy based on first letter (choice letter) matching"""
    print("\n" + "=" * 80)
    print("FIRST LETTER ACCURACY CALCULATION")
    print("=" * 80)
    
    # Load data
    gt_data = load_json_file(ground_truth_file)
    pred_data = load_json_file(predictions_file)
    
    if not gt_data or not pred_data:
        return {"error": "Failed to load data"}
    
    # Filter None values
    gt_data = [item for item in gt_data if item is not None]
    pred_data = [item for item in pred_data if item is not None]
    
    min_len = min(len(gt_data), len(pred_data))
    print(f"Comparing {min_len} items")
    
    correct = 0
    total = 0
    detailed_results = []
    
    for i in range(min_len):
        gt_item = gt_data[i] if isinstance(gt_data[i], dict) else {}
        pred_item = pred_data[i] if isinstance(pred_data[i], dict) else {}
        
        gt_answer = gt_item.get('answer', '')
        pred_answer = pred_item.get('answer', '')
        
        if gt_answer and pred_answer:  # Only count if both have answers
            gt_choice = extract_choice_letter(gt_answer)
            pred_choice = extract_choice_letter(pred_answer)
            
            if gt_choice and pred_choice:  # Only count if both have valid choice letters
                is_correct = gt_choice == pred_choice
                if is_correct:
                    correct += 1
                total += 1
                
                detailed_results.append({
                    'index': i,
                    'gt_answer': gt_answer,
                    'pred_answer': pred_answer,
                    'gt_choice': gt_choice,
                    'pred_choice': pred_choice,
                    'correct': is_correct
                })
    
    accuracy = correct / total if total > 0 else 0
    
    print(f"Total compared: {total}")
    print(f"Correct: {correct}")
    print(f"Accuracy: {accuracy:.2%}")
    
    # Show first 5 comparisons
    print("\nFirst 5 comparisons:")
    for result in detailed_results[:5]:
        print(f"Index {result['index']}: {'✓' if result['correct'] else '✗'}")
        print(f"  GT: {result['gt_choice']} - {result['gt_answer'][:50]}...")
        print(f"  Pred: {result['pred_choice']} - {result['pred_answer'][:50]}...")
        print()
    
    return {
        'accuracy': accuracy,
        'correct': correct,
        'total': total,
        'detailed_results': detailed_results
    }

if __name__ == "__main__":
    ground_truth_file = "dataset/truth_teller_liar_questions.json"
    predictions_file = "outputs/answers_test.json"
    
    diagnose_data_structure(ground_truth_file, predictions_file)
    results = calculate_first_letter_accuracy(ground_truth_file, predictions_file)

DIAGNOSTIC ANALYSIS (First Letter Comparison)
Ground Truth: 668 items
Predictions: 647 items

--------------------------------------------------
INDEX-BASED FIRST LETTER COMPARISON (First 3 items)
--------------------------------------------------

Index 0:
  GT Answer: "C) Harry's chicken - seeds" (Choice: C)
  Pred Answer: 'D' (Choice: D)
  Match: False

Index 1:
  GT Answer: 'B) Y' (Choice: B)
  Pred Answer: 'B' (Choice: B)
  Match: True

Index 2:
  GT Answer: 'A) A is telling the truth' (Choice: A)
  Pred Answer: 'A' (Choice: A)
  Match: True

FIRST LETTER ACCURACY CALCULATION
Comparing 647 items
Total compared: 647
Correct: 177
Accuracy: 27.36%

First 5 comparisons:
Index 0: ✗
  GT: C - C) Harry's chicken - seeds...
  Pred: D - D...

Index 1: ✓
  GT: B - B) Y...
  Pred: B - B...

Index 2: ✓
  GT: A - A) A is telling the truth...
  Pred: A - A...

Index 3: ✗
  GT: C - C) C...
  Pred: B - B...

Index 4: ✓
  GT: B - B) Mary is telling the truth and John is lying...
  Pred: B - B...



In [16]:
import json
import re
from typing import Dict, List

def load_json_file(file_path: str) -> List[Dict]:
    """Load JSON data from file."""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"Error: File {file_path} not found")
        return []
    except json.JSONDecodeError:
        print(f"Error: Invalid JSON in {file_path}")
        return []

def extract_choice_letter(text: str) -> str:
    """Extract choice letter A, B, C, or D from text."""
    if not text:
        return ''
    # Look for pattern like "A)", "B)", etc.
    pattern = re.search(r'^([ABCD])\)', text.strip())
    if pattern:
        return pattern.group(1)
    # Look for isolated letters
    pattern = re.search(r'\b([ABCD])\b', text.upper())
    return pattern.group(1) if pattern else ''

def find_discrepancies(ground_truth_file: str, predictions_file: str):
    """Find discrepancies between ground truth and predictions files."""
    print("=" * 80)
    print("DISCREPANCY ANALYSIS")
    print("=" * 80)
    
    # Load data
    gt_data = load_json_file(ground_truth_file)
    pred_data = load_json_file(predictions_file)
    
    if not gt_data:
        print("Error: Failed to load ground truth file")
        return
    if not pred_data:
        print("Error: Failed to load predictions file")
        return
    
    # Filter None values
    gt_data = [item for item in gt_data if item is not None]
    pred_data = [item for item in pred_data if item is not None]
    
    # Report counts
    expected_len = 668
    print(f"Ground Truth: {len(gt_data)} items")
    print(f"Predictions: {len(pred_data)} items")
    
    # Check for count mismatch
    discrepancies = []
    if len(gt_data) != expected_len:
        discrepancies.append(f"Ground Truth has {len(gt_data)} items, expected {expected_len}")
    
    # Identify missing indices
    missing_indices = []
    if len(pred_data) < len(gt_data):
        missing_indices = list(range(len(pred_data), len(gt_data)))
        discrepancies.append(f"Predictions missing {len(gt_data) - len(pred_data)} items. Missing indices: {missing_indices}")
    elif len(pred_data) > len(gt_data):
        discrepancies.append(f"Predictions has {len(pred_data)} items, expected {len(gt_data)}. Extra indices: {list(range(len(gt_data), len(pred_data)))}")
        pred_data = pred_data[:len(gt_data)]  # Truncate to ground truth length
    
    # Pad predictions with empty dicts to match ground truth length for comparison
    pred_data.extend([{}] * (len(gt_data) - len(pred_data)))
    
    # Compare answers at each index
    print("\n" + "-" * 50)
    print("DETAILED DISCREPANCIES")
    print("-" * 50)
    
    for i in range(len(gt_data)):
        gt_item = gt_data[i] if isinstance(gt_data[i], dict) else {}
        pred_item = pred_data[i] if isinstance(pred_data[i], dict) else {}
        
        gt_answer = gt_item.get('answer', '')
        pred_answer = pred_item.get('answer', '')
        
        gt_choice = extract_choice_letter(gt_answer)
        pred_choice = extract_choice_letter(pred_answer)
        
        # Check for discrepancies
        is_discrepant = False
        discrepancy_detail = {
            'index': i,
            'gt_answer': gt_answer,
            'pred_answer': pred_answer,
            'gt_choice': gt_choice,
            'pred_choice': pred_choice,
            'issue': []
        }
        
        # Missing prediction
        if i in missing_indices:
            discrepancy_detail['issue'].append("Missing prediction (empty entry)")
            is_discrepant = True
        
        # Invalid ground truth choice
        if not gt_choice and gt_answer:
            discrepancy_detail['issue'].append(f"Invalid ground truth answer: {repr(gt_answer)}")
            is_discrepant = True
        
        # Invalid prediction choice
        if not pred_choice and pred_answer:
            discrepancy_detail['issue'].append(f"Invalid prediction answer: {repr(pred_answer)}")
            is_discrepant = True
        
        # Mismatched choices
        if gt_choice and pred_choice and gt_choice != pred_choice:
            discrepancy_detail['issue'].append(f"Mismatched choices: GT={gt_choice}, Pred={pred_choice}")
            is_discrepant = True
        
        # No prediction choice (empty or missing)
        if gt_choice and not pred_choice:
            discrepancy_detail['issue'].append(f"No valid prediction choice: GT={gt_choice}, Pred={pred_choice}")
            is_discrepant = True
        
        if is_discrepant:
            discrepancies.append(discrepancy_detail)
    
    # Output discrepancies
    if discrepancies:
        for discrepancy in discrepancies:
            if isinstance(discrepancy, str):
                print(discrepancy)
            else:
                print(f"Index {discrepancy['index']}:")
                print(f"  Ground Truth: {repr(discrepancy['gt_answer'])} (Choice: {discrepancy['gt_choice']})")
                print(f"  Prediction: {repr(discrepancy['pred_answer'])} (Choice: {discrepancy['pred_choice']})")
                print(f"  Issues: {', '.join(discrepancy['issue'])}")
                print()
    else:
        print("No discrepancies found.")
    
    # Summary
    print("\n" + "=" * 80)
    print("SUMMARY")
    print("=" * 80)
    print(f"Total discrepancies: {len([d for d in discrepancies if isinstance(d, dict)])}")
    print(f"Missing indices: {missing_indices if missing_indices else 'None'}")
    
    return discrepancies

if __name__ == "__main__":
    ground_truth_file = "dataset/truth_teller_liar_questions.json"
    predictions_file = "outputs/answers_test.json"
    
    # Find and report discrepancies
    discrepancies = find_discrepancies(ground_truth_file, predictions_file)

DISCREPANCY ANALYSIS
Ground Truth: 668 items
Predictions: 647 items

--------------------------------------------------
DETAILED DISCREPANCIES
--------------------------------------------------
Predictions missing 21 items. Missing indices: [647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667]
Index 0:
  Ground Truth: "C) Harry's chicken - seeds" (Choice: C)
  Prediction: 'D' (Choice: D)
  Issues: Mismatched choices: GT=C, Pred=D

Index 3:
  Ground Truth: 'C) C' (Choice: C)
  Prediction: 'B' (Choice: B)
  Issues: Mismatched choices: GT=C, Pred=B

Index 8:
  Ground Truth: 'B) B and C' (Choice: B)
  Prediction: 'C' (Choice: C)
  Issues: Mismatched choices: GT=B, Pred=C

Index 9:
  Ground Truth: 'D) B' (Choice: D)
  Prediction: 'B' (Choice: B)
  Issues: Mismatched choices: GT=D, Pred=B

Index 11:
  Ground Truth: 'B) 𝑆3 must be true' (Choice: B)
  Prediction: 'C' (Choice: C)
  Issues: Mismatched choices: GT=B, Pred=C

Index 12:
  Ground T

In [18]:
import json
import re
from typing import Dict, List, Optional, Set

def parse_malformed_json(file_path: str) -> List[Dict]:
    """Parse malformed JSON file with missing structure"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # First try normal JSON parsing
        try:
            data = json.loads(content)
            if isinstance(data, list):
                return data
        except json.JSONDecodeError:
            pass
        
        # Custom parsing for malformed JSON
        lines = content.strip().split('\n')
        parsed_data = []
        current_item = {}
        
        for line in lines:
            line = line.strip()
            if not line:
                continue
                
            # Look for answer pattern
            if line.startswith('answer"') or line.startswith('"answer"'):
                if current_item:  # Save previous item
                    parsed_data.append(current_item)
                    current_item = {}
                
                # Extract answer
                answer_match = re.search(r'answer"([^"]*)"', line)
                if answer_match:
                    current_item['answer'] = answer_match.group(1)
                    
            # Look for reasoning pattern  
            elif line.startswith('reasoning"') or line.startswith('"reasoning"'):
                reasoning_match = re.search(r'reasoning"([^"]*)"', line)
                if reasoning_match:
                    current_item['reasoning'] = reasoning_match.group(1)
        
        # Don't forget the last item
        if current_item:
            parsed_data.append(current_item)
            
        return parsed_data
        
    except FileNotFoundError:
        print(f"Error: File {file_path} not found")
        return []

def extract_choice_letter(text: str) -> str:
    """Extract choice letter A, B, C, or D from text"""
    if not text:
        return ''
    
    text = text.strip()
    
    # Single letter (most common)
    if len(text) == 1 and text.upper() in 'ABCD':
        return text.upper()
    
    # Letter followed by parenthesis
    pattern = re.search(r'^([ABCD])\)', text.upper())
    if pattern:
        return pattern.group(1)
    
    # Letter at the beginning
    pattern = re.search(r'^([ABCD])\b', text.upper())
    if pattern:
        return pattern.group(1)
    
    # Isolated letter anywhere
    pattern = re.search(r'\b([ABCD])\b', text.upper())
    if pattern:
        return pattern.group(1)
    
    return ''

def find_missing_prediction_indices(ground_truth_file: str, predictions_file: str):
    """Find exact indices of missing predictions"""
    print("="*80)
    print("MISSING PREDICTION INDICES ANALYSIS")
    print("="*80)
    
    # Load data
    try:
        with open(ground_truth_file, 'r', encoding='utf-8') as f:
            gt_data = json.load(f)
            gt_data = [item for item in gt_data if item is not None]
    except Exception as e:
        print(f"Error loading ground truth: {e}")
        return
    
    pred_data = parse_malformed_json(predictions_file)
    
    print(f"Ground Truth items: {len(gt_data)}")
    print(f"Prediction items: {len(pred_data)}")
    
    # Expected length
    expected_len = 668
    
    # Find missing indices
    missing_indices = []
    empty_prediction_indices = []
    no_choice_indices = []
    
    # Pad predictions to match ground truth length
    if len(pred_data) < expected_len:
        pred_data.extend([{}] * (expected_len - len(pred_data)))
    
    print(f"\nAnalyzing {expected_len} positions...")
    
    for i in range(expected_len):
        gt_item = gt_data[i] if i < len(gt_data) else {}
        pred_item = pred_data[i] if i < len(pred_data) else {}
        
        # Get answers
        gt_answer = gt_item.get('answer', '') if isinstance(gt_item, dict) else ''
        pred_answer = pred_item.get('answer', '') if isinstance(pred_item, dict) else ''
        
        # Check if ground truth has valid choice
        gt_choice = extract_choice_letter(gt_answer)
        pred_choice = extract_choice_letter(pred_answer)
        
        if gt_choice:  # Only check items with valid ground truth
            # Check for completely missing predictions
            if not pred_answer or not pred_answer.strip():
                missing_indices.append(i)
                empty_prediction_indices.append(i)
            
            # Check for predictions without valid choice
            elif not pred_choice:
                no_choice_indices.append(i)
                print(f"Index {i}: Has prediction '{pred_answer}' but no valid choice extracted")
    
    # Results
    print("\n" + "="*80)
    print("RESULTS")
    print("="*80)
    
    print(f"Total missing/empty predictions: {len(empty_prediction_indices)}")
    print(f"Missing indices: {empty_prediction_indices}")
    
    print(f"\nPredictions with no valid choice: {len(no_choice_indices)}")
    if len(no_choice_indices) <= 20:
        print(f"No choice indices: {no_choice_indices}")
    else:
        print(f"No choice indices (first 20): {no_choice_indices[:20]}")
        print(f"... and {len(no_choice_indices)-20} more")
    
    print(f"\nAll problematic indices: {len(missing_indices + no_choice_indices)}")
    all_problematic = sorted(list(set(missing_indices + no_choice_indices)))
    
    # Show ranges for easier reading
    def format_indices_as_ranges(indices):
        if not indices:
            return "None"
        
        ranges = []
        start = indices[0]
        end = indices[0]
        
        for i in range(1, len(indices)):
            if indices[i] == end + 1:
                end = indices[i]
            else:
                if start == end:
                    ranges.append(str(start))
                else:
                    ranges.append(f"{start}-{end}")
                start = end = indices[i]
        
        # Add the last range
        if start == end:
            ranges.append(str(start))
        else:
            ranges.append(f"{start}-{end}")
        
        return ", ".join(ranges)
    
    print(f"\nEmpty prediction ranges: {format_indices_as_ranges(empty_prediction_indices)}")
    print(f"No choice ranges: {format_indices_as_ranges(no_choice_indices)}")
    print(f"All problematic ranges: {format_indices_as_ranges(all_problematic)}")
    
    return {
        'missing_indices': empty_prediction_indices,
        'no_choice_indices': no_choice_indices,
        'all_problematic_indices': all_problematic
    }

def analyze_prediction_patterns(predictions_file: str):
    """Analyze patterns in predictions to understand the data structure"""
    print("\n" + "="*80)
    print("PREDICTION PATTERN ANALYSIS")
    print("="*80)
    
    pred_data = parse_malformed_json(predictions_file)
    
    # Count different answer patterns
    patterns = {
        'single_letter': 0,
        'letter_with_parenthesis': 0,
        'full_text': 0,
        'empty': 0,
        'other': 0
    }
    
    sample_answers = {
        'single_letter': [],
        'letter_with_parenthesis': [],
        'full_text': [],
        'empty': [],
        'other': []
    }
    
    for i, item in enumerate(pred_data):
        if isinstance(item, dict):
            answer = item.get('answer', '')
        else:
            answer = ''
        
        if not answer or not answer.strip():
            patterns['empty'] += 1
            if len(sample_answers['empty']) < 5:
                sample_answers['empty'].append(f"Index {i}: {repr(answer)}")
        elif len(answer.strip()) == 1 and answer.upper() in 'ABCD':
            patterns['single_letter'] += 1
            if len(sample_answers['single_letter']) < 5:
                sample_answers['single_letter'].append(f"Index {i}: {repr(answer)}")
        elif re.match(r'^[ABCD]\)', answer.strip().upper()):
            patterns['letter_with_parenthesis'] += 1
            if len(sample_answers['letter_with_parenthesis']) < 5:
                sample_answers['letter_with_parenthesis'].append(f"Index {i}: {repr(answer)}")
        elif len(answer.strip()) > 10:
            patterns['full_text'] += 1
            if len(sample_answers['full_text']) < 5:
                sample_answers['full_text'].append(f"Index {i}: {repr(answer[:50])}...")
        else:
            patterns['other'] += 1
            if len(sample_answers['other']) < 5:
                sample_answers['other'].append(f"Index {i}: {repr(answer)}")
    
    print("Answer Pattern Distribution:")
    for pattern, count in patterns.items():
        print(f"  {pattern}: {count}")
        
    print("\nSample answers by pattern:")
    for pattern, samples in sample_answers.items():
        if samples:
            print(f"\n{pattern.upper()}:")
            for sample in samples:
                print(f"    {sample}")

def create_missing_indices_file(missing_indices: List[int], output_file: str = "missing_indices.txt"):
    """Create a file with missing indices for easy reference"""
    with open(output_file, 'w') as f:
        f.write("Missing Prediction Indices:\n")
        f.write("=" * 30 + "\n\n")
        
        f.write(f"Total missing: {len(missing_indices)}\n\n")
        
        f.write("Individual indices:\n")
        for idx in missing_indices:
            f.write(f"{idx}\n")
        
        f.write(f"\nAs comma-separated list:\n")
        f.write(f"{', '.join(map(str, missing_indices))}\n")
        
        f.write(f"\nAs Python list:\n")
        f.write(f"{missing_indices}\n")
    
    print(f"✓ Missing indices saved to {output_file}")

def generate_rerun_script(ground_truth_file: str, missing_indices: List[int], output_script: str = "rerun_missing.py"):
    """Generate a script to rerun predictions for missing indices"""
    script_content = f'''#!/usr/bin/env python3
"""
Script to regenerate missing predictions
Generated automatically from missing indices analysis

Missing indices: {len(missing_indices)} total
Indices: {missing_indices[:20]}{"..." if len(missing_indices) > 20 else ""}
"""

import json

def load_missing_questions(ground_truth_file: str, missing_indices: list):
    """Load questions that need predictions"""
    with open(ground_truth_file, 'r', encoding='utf-8') as f:
        gt_data = json.load(f)
    
    missing_questions = []
    for idx in missing_indices:
        if idx < len(gt_data) and gt_data[idx]:
            item = gt_data[idx]
            missing_questions.append({{
                'index': idx,
                'question': item.get('question', item.get('prompt', '')),
                'choices': item.get('choices', []),
                'ground_truth': item.get('answer', '')
            }})
    
    return missing_questions

def main():
    ground_truth_file = "{ground_truth_file}"
    missing_indices = {missing_indices}
    
    print(f"Loading {{len(missing_indices)}} missing questions...")
    questions = load_missing_questions(ground_truth_file, missing_indices)
    
    print(f"Found {{len(questions)}} questions to process")
    
    # TODO: Add your model inference code here
    # Process each question in the questions list
    # Generate predictions and save them
    
    # Example structure for results:
    results = []
    for q in questions:
        # Your model prediction code here
        prediction = "A"  # Replace with actual model call
        
        results.append({{
            'index': q['index'],
            'answer': prediction,
            'reasoning': 'Generated by rerun script'
        }})
    
    # Save results
    with open('missing_predictions.json', 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    print(f"Saved {{len(results)}} predictions to missing_predictions.json")

if __name__ == "__main__":
    main()
'''
    
    with open(output_script, 'w', encoding='utf-8') as f:
        f.write(script_content)
    
    print(f"✓ Rerun script saved to {output_script}")

if __name__ == "__main__":
    ground_truth_file = "dataset/truth_teller_liar_questions.json"
    predictions_file = "outputs/answers_test.json"
    
    # Find missing indices
    results = find_missing_prediction_indices(ground_truth_file, predictions_file)
    
    if results:
        # Analyze prediction patterns
        analyze_prediction_patterns(predictions_file)
        
        # Create output files
        if results['missing_indices']:
            create_missing_indices_file(results['missing_indices'])
            generate_rerun_script(ground_truth_file, results['missing_indices'])
        
        print("\n" + "="*80)
        print("SUMMARY")
        print("="*80)
        print(f"Empty predictions: {len(results['missing_indices'])}")
        print(f"No valid choice: {len(results['no_choice_indices'])}")
        print(f"Total problematic: {len(results['all_problematic_indices'])}")
        
        if results['missing_indices']:
            print(f"\nEmpty prediction indices: {results['missing_indices']}")
        
        if results['no_choice_indices'][:10]:  # Show first 10
            print(f"No choice indices (sample): {results['no_choice_indices'][:10]}")

MISSING PREDICTION INDICES ANALYSIS
Ground Truth items: 668
Prediction items: 668

Analyzing 668 positions...

RESULTS
Total missing/empty predictions: 21
Missing indices: [44, 52, 186, 221, 227, 253, 316, 325, 349, 376, 389, 390, 415, 423, 436, 478, 501, 506, 518, 571, 643]

Predictions with no valid choice: 0
No choice indices: []

All problematic indices: 21

Empty prediction ranges: 44, 52, 186, 221, 227, 253, 316, 325, 349, 376, 389-390, 415, 423, 436, 478, 501, 506, 518, 571, 643
No choice ranges: None
All problematic ranges: 44, 52, 186, 221, 227, 253, 316, 325, 349, 376, 389-390, 415, 423, 436, 478, 501, 506, 518, 571, 643

PREDICTION PATTERN ANALYSIS
Answer Pattern Distribution:
  single_letter: 647
  letter_with_parenthesis: 0
  full_text: 0
  empty: 21
  other: 0

Sample answers by pattern:

SINGLE_LETTER:
    Index 0: 'D'
    Index 1: 'B'
    Index 2: 'A'
    Index 3: 'B'
    Index 4: 'B'

EMPTY:
    Index 44: ''
    Index 52: ''
    Index 186: ''
    Index 221: ''
    Inde

In [19]:
import json
import re
from typing import Dict, List, Optional, Set

def parse_malformed_json(file_path: str) -> List[Dict]:
    """Parse malformed JSON file with missing structure"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # First try normal JSON parsing
        try:
            data = json.loads(content)
            if isinstance(data, list):
                return data
        except json.JSONDecodeError:
            pass
        
        # Custom parsing for malformed JSON
        lines = content.strip().split('\n')
        parsed_data = []
        current_item = {}
        
        for line in lines:
            line = line.strip()
            if not line:
                continue
                
            # Look for answer pattern
            if line.startswith('answer"') or line.startswith('"answer"'):
                if current_item:  # Save previous item
                    parsed_data.append(current_item)
                    current_item = {}
                
                # Extract answer
                answer_match = re.search(r'answer"([^"]*)"', line)
                if answer_match:
                    current_item['answer'] = answer_match.group(1)
                    
            # Look for reasoning pattern  
            elif line.startswith('reasoning"') or line.startswith('"reasoning"'):
                reasoning_match = re.search(r'reasoning"([^"]*)"', line)
                if reasoning_match:
                    current_item['reasoning'] = reasoning_match.group(1)
        
        # Don't forget the last item
        if current_item:
            parsed_data.append(current_item)
            
        return parsed_data
        
    except FileNotFoundError:
        print(f"Error: File {file_path} not found")
        return []

def extract_choice_letter(text: str) -> str:
    """Extract choice letter A, B, C, or D from text"""
    if not text:
        return ''
    
    text = text.strip()
    
    # Single letter (most common)
    if len(text) == 1 and text.upper() in 'ABCD':
        return text.upper()
    
    # Letter followed by parenthesis
    pattern = re.search(r'^([ABCD])\)', text.upper())
    if pattern:
        return pattern.group(1)
    
    # Letter at the beginning
    pattern = re.search(r'^([ABCD])\b', text.upper())
    if pattern:
        return pattern.group(1)
    
    # Isolated letter anywhere
    pattern = re.search(r'\b([ABCD])\b', text.upper())
    if pattern:
        return pattern.group(1)
    
    return ''

def calculate_accuracy(ground_truth_file: str, predictions_file: str):
    """Calculate accuracy for available predictions"""
    print("="*80)
    print("ACCURACY CALCULATION")
    print("="*80)
    
    # Load data
    try:
        with open(ground_truth_file, 'r', encoding='utf-8') as f:
            gt_data = json.load(f)
            gt_data = [item for item in gt_data if item is not None]
    except Exception as e:
        print(f"Error loading ground truth: {e}")
        return None
    
    pred_data = parse_malformed_json(predictions_file)
    
    print(f"Ground Truth items: {len(gt_data)}")
    print(f"Prediction items: {len(pred_data)}")
    
    # Expected length
    expected_len = min(len(gt_data), 668)
    
    # Pad predictions to match ground truth length if needed
    if len(pred_data) < expected_len:
        pred_data.extend([{}] * (expected_len - len(pred_data)))
    
    # Accuracy calculation variables
    total_with_gt = 0  # Total items with valid ground truth
    total_with_predictions = 0  # Total items with valid predictions
    correct_predictions = 0  # Correct predictions
    available_predictions = 0  # Predictions that could be evaluated
    
    # Detailed tracking
    correct_indices = []
    incorrect_indices = []
    missing_prediction_indices = []
    no_choice_indices = []
    invalid_gt_indices = []
    
    print(f"\nAnalyzing {expected_len} positions...")
    
    for i in range(expected_len):
        gt_item = gt_data[i] if i < len(gt_data) else {}
        pred_item = pred_data[i] if i < len(pred_data) else {}
        
        # Get answers
        gt_answer = gt_item.get('answer', '') if isinstance(gt_item, dict) else ''
        pred_answer = pred_item.get('answer', '') if isinstance(pred_item, dict) else ''
        
        # Extract choices
        gt_choice = extract_choice_letter(gt_answer)
        pred_choice = extract_choice_letter(pred_answer)
        
        # Count items with valid ground truth
        if gt_choice:
            total_with_gt += 1
            
            # Check if we have a valid prediction
            if pred_answer and pred_answer.strip():
                total_with_predictions += 1
                
                if pred_choice:
                    available_predictions += 1
                    
                    # Check if prediction is correct
                    if gt_choice == pred_choice:
                        correct_predictions += 1
                        correct_indices.append(i)
                    else:
                        incorrect_indices.append(i)
                else:
                    # Has prediction text but no valid choice
                    no_choice_indices.append(i)
            else:
                # Missing prediction
                missing_prediction_indices.append(i)
        else:
            # Invalid ground truth
            invalid_gt_indices.append(i)
    
    # Calculate accuracies
    accuracy_of_available = correct_predictions / available_predictions if available_predictions > 0 else 0
    accuracy_of_total = correct_predictions / total_with_gt if total_with_gt > 0 else 0
    coverage = available_predictions / total_with_gt if total_with_gt > 0 else 0
    
    # Results
    print("\n" + "="*80)
    print("ACCURACY RESULTS")
    print("="*80)
    
    print(f"Total items with valid ground truth: {total_with_gt}")
    print(f"Items with predictions available: {total_with_predictions}")
    print(f"Items with valid choice predictions: {available_predictions}")
    print(f"Correct predictions: {correct_predictions}")
    print(f"Incorrect predictions: {len(incorrect_indices)}")
    print(f"Missing predictions: {len(missing_prediction_indices)}")
    print(f"Predictions with no valid choice: {len(no_choice_indices)}")
    print(f"Items with invalid ground truth: {len(invalid_gt_indices)}")
    
    print(f"\n" + "-"*50)
    print(f"ACCURACY METRICS")
    print(f"-"*50)
    print(f"Accuracy (of available predictions): {accuracy_of_available:.4f} ({accuracy_of_available*100:.2f}%)")
    print(f"Accuracy (of total with GT): {accuracy_of_total:.4f} ({accuracy_of_total*100:.2f}%)")
    print(f"Coverage: {coverage:.4f} ({coverage*100:.2f}%)")
    
    # Show sample correct/incorrect predictions
    print(f"\n" + "-"*50)
    print(f"SAMPLE ANALYSIS")
    print(f"-"*50)
    
    if correct_indices:
        print(f"\nSample correct predictions (first 5):")
        for i, idx in enumerate(correct_indices[:5]):
            if idx < len(gt_data) and idx < len(pred_data):
                gt_item = gt_data[idx]
                pred_item = pred_data[idx]
                gt_choice = extract_choice_letter(gt_item.get('answer', ''))
                pred_choice = extract_choice_letter(pred_item.get('answer', ''))
                print(f"  Index {idx}: GT={gt_choice}, Pred={pred_choice} ✓")
    
    if incorrect_indices:
        print(f"\nSample incorrect predictions (first 5):")
        for i, idx in enumerate(incorrect_indices[:5]):
            if idx < len(gt_data) and idx < len(pred_data):
                gt_item = gt_data[idx]
                pred_item = pred_data[idx]
                gt_choice = extract_choice_letter(gt_item.get('answer', ''))
                pred_choice = extract_choice_letter(pred_item.get('answer', ''))
                print(f"  Index {idx}: GT={gt_choice}, Pred={pred_choice} ✗")
    
    if no_choice_indices:
        print(f"\nSample predictions with no valid choice (first 3):")
        for i, idx in enumerate(no_choice_indices[:3]):
            if idx < len(pred_data):
                pred_item = pred_data[idx]
                pred_answer = pred_item.get('answer', '')
                print(f"  Index {idx}: '{pred_answer}'")
    
    return {
        'total_with_gt': total_with_gt,
        'total_with_predictions': total_with_predictions,
        'available_predictions': available_predictions,
        'correct_predictions': correct_predictions,
        'accuracy_of_available': accuracy_of_available,
        'accuracy_of_total': accuracy_of_total,
        'coverage': coverage,
        'correct_indices': correct_indices,
        'incorrect_indices': incorrect_indices,
        'missing_prediction_indices': missing_prediction_indices,
        'no_choice_indices': no_choice_indices,
        'invalid_gt_indices': invalid_gt_indices
    }

def find_missing_prediction_indices(ground_truth_file: str, predictions_file: str):
    """Find exact indices of missing predictions"""
    print("="*80)
    print("MISSING PREDICTION INDICES ANALYSIS")
    print("="*80)
    
    # Load data
    try:
        with open(ground_truth_file, 'r', encoding='utf-8') as f:
            gt_data = json.load(f)
            gt_data = [item for item in gt_data if item is not None]
    except Exception as e:
        print(f"Error loading ground truth: {e}")
        return
    
    pred_data = parse_malformed_json(predictions_file)
    
    print(f"Ground Truth items: {len(gt_data)}")
    print(f"Prediction items: {len(pred_data)}")
    
    # Expected length
    expected_len = 668
    
    # Find missing indices
    missing_indices = []
    empty_prediction_indices = []
    no_choice_indices = []
    
    # Pad predictions to match ground truth length
    if len(pred_data) < expected_len:
        pred_data.extend([{}] * (expected_len - len(pred_data)))
    
    print(f"\nAnalyzing {expected_len} positions...")
    
    for i in range(expected_len):
        gt_item = gt_data[i] if i < len(gt_data) else {}
        pred_item = pred_data[i] if i < len(pred_data) else {}
        
        # Get answers
        gt_answer = gt_item.get('answer', '') if isinstance(gt_item, dict) else ''
        pred_answer = pred_item.get('answer', '') if isinstance(pred_item, dict) else ''
        
        # Check if ground truth has valid choice
        gt_choice = extract_choice_letter(gt_answer)
        pred_choice = extract_choice_letter(pred_answer)
        
        if gt_choice:  # Only check items with valid ground truth
            # Check for completely missing predictions
            if not pred_answer or not pred_answer.strip():
                missing_indices.append(i)
                empty_prediction_indices.append(i)
            
            # Check for predictions without valid choice
            elif not pred_choice:
                no_choice_indices.append(i)
                print(f"Index {i}: Has prediction '{pred_answer}' but no valid choice extracted")
    
    # Results
    print("\n" + "="*80)
    print("RESULTS")
    print("="*80)
    
    print(f"Total missing/empty predictions: {len(empty_prediction_indices)}")
    print(f"Missing indices: {empty_prediction_indices}")
    
    print(f"\nPredictions with no valid choice: {len(no_choice_indices)}")
    if len(no_choice_indices) <= 20:
        print(f"No choice indices: {no_choice_indices}")
    else:
        print(f"No choice indices (first 20): {no_choice_indices[:20]}")
        print(f"... and {len(no_choice_indices)-20} more")
    
    print(f"\nAll problematic indices: {len(missing_indices + no_choice_indices)}")
    all_problematic = sorted(list(set(missing_indices + no_choice_indices)))
    
    # Show ranges for easier reading
    def format_indices_as_ranges(indices):
        if not indices:
            return "None"
        
        ranges = []
        start = indices[0]
        end = indices[0]
        
        for i in range(1, len(indices)):
            if indices[i] == end + 1:
                end = indices[i]
            else:
                if start == end:
                    ranges.append(str(start))
                else:
                    ranges.append(f"{start}-{end}")
                start = end = indices[i]
        
        # Add the last range
        if start == end:
            ranges.append(str(start))
        else:
            ranges.append(f"{start}-{end}")
        
        return ", ".join(ranges)
    
    print(f"\nEmpty prediction ranges: {format_indices_as_ranges(empty_prediction_indices)}")
    print(f"No choice ranges: {format_indices_as_ranges(no_choice_indices)}")
    print(f"All problematic ranges: {format_indices_as_ranges(all_problematic)}")
    
    return {
        'missing_indices': empty_prediction_indices,
        'no_choice_indices': no_choice_indices,
        'all_problematic_indices': all_problematic
    }

def analyze_prediction_patterns(predictions_file: str):
    """Analyze patterns in predictions to understand the data structure"""
    print("\n" + "="*80)
    print("PREDICTION PATTERN ANALYSIS")
    print("="*80)
    
    pred_data = parse_malformed_json(predictions_file)
    
    # Count different answer patterns
    patterns = {
        'single_letter': 0,
        'letter_with_parenthesis': 0,
        'full_text': 0,
        'empty': 0,
        'other': 0
    }
    
    sample_answers = {
        'single_letter': [],
        'letter_with_parenthesis': [],
        'full_text': [],
        'empty': [],
        'other': []
    }
    
    for i, item in enumerate(pred_data):
        if isinstance(item, dict):
            answer = item.get('answer', '')
        else:
            answer = ''
        
        if not answer or not answer.strip():
            patterns['empty'] += 1
            if len(sample_answers['empty']) < 5:
                sample_answers['empty'].append(f"Index {i}: {repr(answer)}")
        elif len(answer.strip()) == 1 and answer.upper() in 'ABCD':
            patterns['single_letter'] += 1
            if len(sample_answers['single_letter']) < 5:
                sample_answers['single_letter'].append(f"Index {i}: {repr(answer)}")
        elif re.match(r'^[ABCD]\)', answer.strip().upper()):
            patterns['letter_with_parenthesis'] += 1
            if len(sample_answers['letter_with_parenthesis']) < 5:
                sample_answers['letter_with_parenthesis'].append(f"Index {i}: {repr(answer)}")
        elif len(answer.strip()) > 10:
            patterns['full_text'] += 1
            if len(sample_answers['full_text']) < 5:
                sample_answers['full_text'].append(f"Index {i}: {repr(answer[:50])}...")
        else:
            patterns['other'] += 1
            if len(sample_answers['other']) < 5:
                sample_answers['other'].append(f"Index {i}: {repr(answer)}")
    
    print("Answer Pattern Distribution:")
    for pattern, count in patterns.items():
        print(f"  {pattern}: {count}")
        
    print("\nSample answers by pattern:")
    for pattern, samples in sample_answers.items():
        if samples:
            print(f"\n{pattern.upper()}:")
            for sample in samples:
                print(f"    {sample}")

def create_missing_indices_file(missing_indices: List[int], output_file: str = "missing_indices.txt"):
    """Create a file with missing indices for easy reference"""
    with open(output_file, 'w') as f:
        f.write("Missing Prediction Indices:\n")
        f.write("=" * 30 + "\n\n")
        
        f.write(f"Total missing: {len(missing_indices)}\n\n")
        
        f.write("Individual indices:\n")
        for idx in missing_indices:
            f.write(f"{idx}\n")
        
        f.write(f"\nAs comma-separated list:\n")
        f.write(f"{', '.join(map(str, missing_indices))}\n")
        
        f.write(f"\nAs Python list:\n")
        f.write(f"{missing_indices}\n")
    
    print(f"✓ Missing indices saved to {output_file}")

def generate_rerun_script(ground_truth_file: str, missing_indices: List[int], output_script: str = "rerun_missing.py"):
    """Generate a script to rerun predictions for missing indices"""
    script_content = f'''#!/usr/bin/env python3
"""
Script to regenerate missing predictions
Generated automatically from missing indices analysis

Missing indices: {len(missing_indices)} total
Indices: {missing_indices[:20]}{"..." if len(missing_indices) > 20 else ""}
"""

import json

def load_missing_questions(ground_truth_file: str, missing_indices: list):
    """Load questions that need predictions"""
    with open(ground_truth_file, 'r', encoding='utf-8') as f:
        gt_data = json.load(f)
    
    missing_questions = []
    for idx in missing_indices:
        if idx < len(gt_data) and gt_data[idx]:
            item = gt_data[idx]
            missing_questions.append({{
                'index': idx,
                'question': item.get('question', item.get('prompt', '')),
                'choices': item.get('choices', []),
                'ground_truth': item.get('answer', '')
            }})
    
    return missing_questions

def main():
    ground_truth_file = "{ground_truth_file}"
    missing_indices = {missing_indices}
    
    print(f"Loading {{len(missing_indices)}} missing questions...")
    questions = load_missing_questions(ground_truth_file, missing_indices)
    
    print(f"Found {{len(questions)}} questions to process")
    
    # TODO: Add your model inference code here
    # Process each question in the questions list
    # Generate predictions and save them
    
    # Example structure for results:
    results = []
    for q in questions:
        # Your model prediction code here
        prediction = "A"  # Replace with actual model call
        
        results.append({{
            'index': q['index'],
            'answer': prediction,
            'reasoning': 'Generated by rerun script'
        }})
    
    # Save results
    with open('missing_predictions.json', 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    print(f"Saved {{len(results)}} predictions to missing_predictions.json")

if __name__ == "__main__":
    main()
'''
    
    with open(output_script, 'w', encoding='utf-8') as f:
        f.write(script_content)
    
    print(f"✓ Rerun script saved to {output_script}")

if __name__ == "__main__":
    ground_truth_file = "dataset/truth_teller_liar_questions.json"
    predictions_file = "outputs/answers_test.json"
    
    # First calculate accuracy for available predictions
    accuracy_results = calculate_accuracy(ground_truth_file, predictions_file)
    
    # Then find missing indices
    missing_results = find_missing_prediction_indices(ground_truth_file, predictions_file)
    
    if missing_results:
        # Analyze prediction patterns
        analyze_prediction_patterns(predictions_file)
        
        # Create output files
        if missing_results['missing_indices']:
            create_missing_indices_file(missing_results['missing_indices'])
            generate_rerun_script(ground_truth_file, missing_results['missing_indices'])
        
        print("\n" + "="*80)
        print("FINAL SUMMARY")
        print("="*80)
        
        if accuracy_results:
            print(f"ACCURACY METRICS:")
            print(f"  Accuracy (available): {accuracy_results['accuracy_of_available']:.4f} ({accuracy_results['accuracy_of_available']*100:.2f}%)")
            print(f"  Accuracy (total): {accuracy_results['accuracy_of_total']:.4f} ({accuracy_results['accuracy_of_total']*100:.2f}%)")
            print(f"  Coverage: {accuracy_results['coverage']:.4f} ({accuracy_results['coverage']*100:.2f}%)")
            print(f"  Correct: {accuracy_results['correct_predictions']}")
            print(f"  Available: {accuracy_results['available_predictions']}")
            print(f"  Total with GT: {accuracy_results['total_with_gt']}")
        
        print(f"\nMISSING DATA:")
        print(f"  Empty predictions: {len(missing_results['missing_indices'])}")
        print(f"  No valid choice: {len(missing_results['no_choice_indices'])}")
        print(f"  Total problematic: {len(missing_results['all_problematic_indices'])}")

ACCURACY CALCULATION
Ground Truth items: 668
Prediction items: 668

Analyzing 668 positions...

ACCURACY RESULTS
Total items with valid ground truth: 668
Items with predictions available: 647
Items with valid choice predictions: 647
Correct predictions: 238
Incorrect predictions: 409
Missing predictions: 21
Predictions with no valid choice: 0
Items with invalid ground truth: 0

--------------------------------------------------
ACCURACY METRICS
--------------------------------------------------
Accuracy (of available predictions): 0.3679 (36.79%)
Accuracy (of total with GT): 0.3563 (35.63%)
Coverage: 0.9686 (96.86%)

--------------------------------------------------
SAMPLE ANALYSIS
--------------------------------------------------

Sample correct predictions (first 5):
  Index 1: GT=B, Pred=B ✓
  Index 2: GT=A, Pred=A ✓
  Index 4: GT=B, Pred=B ✓
  Index 5: GT=D, Pred=D ✓
  Index 6: GT=B, Pred=B ✓

Sample incorrect predictions (first 5):
  Index 0: GT=C, Pred=D ✗
  Index 3: GT=C, Pred