In [None]:
# %% [markdown]
# # Kaitiaki Planner - Budget-Aware RAG Evaluation
# **Capstone Project**: Evaluating fairness in computational budget allocation
# 
# Research Question: How does budget allocation strategy affect fairness 
# between English and Te Reo Māori queries?

# %% [markdown]
# ## 1. Setup and Imports

# %%
import os
import time
import numpy as np
import pandas as pd
import requests
from pathlib import Path
from dotenv import load_dotenv

# Load environment variables
load_dotenv()
api_key = os.environ.get("ANTHROPIC_API_KEY")

if not api_key:
    raise ValueError(
        "ANTHROPIC_API_KEY not found!\n"
        "   1. Create .env file in project root\n"
        "   2. Add: ANTHROPIC_API_KEY=sk-ant-your-key\n"
        "   3. Restart kernel"
    )

print("API key loaded successfully")

# %%
# Import custom modules
from claude_client import ClaudeClient
from eval_utils import (
    load_corpus,
    load_eval_tasks,
    grounded_correctness,
    fairness_gap,
    summarize_results,
    print_summary
)

# %% [markdown]
# ## 2. Initialize Claude Client

# %%
# Initialize with $5 budget (safety limit)
claude = ClaudeClient(
    api_key=api_key,
    max_spend_usd=5.0
)

print(f"Claude client initialized")
print(f"   Model: {claude.model}")
print(f"   Budget: ${claude.max_spend_usd:.2f} USD")
print(f"   Current spend: ${claude.total_cost_usd:.4f} USD")

# %% [markdown]
# ## 3. Load Data

# %%
# Load corpus
corpus = load_corpus("../data/corpus.json")
print(f"Loaded corpus: {len(corpus)} documents")
print(f"   English: {sum(1 for d in corpus if d['lang']=='en')}")
print(f"   Māori:   {sum(1 for d in corpus if d['lang']=='mi')}")

# %%
# Load evaluation tasks
eval_tasks = load_eval_tasks(corpus, "../eval/tasks_labeled.yaml")
print(f"\nLoaded evaluation tasks: {len(eval_tasks)} tasks")

# Show distribution
df_dist = pd.DataFrame(eval_tasks)
print("\nTask distribution:")
print(df_dist.groupby(['lang', 'complexity']).size().unstack(fill_value=0))

# %%
# Show example task
print("\n" + "="*70)
print("EXAMPLE TASK")
print("="*70)
example = eval_tasks[0]
print(f"ID:         {example['id']}")
print(f"Query:      {example['query']}")
print(f"Language:   {example['lang']}")
print(f"Complexity: {example['complexity']}")
print(f"Gold doc:   {example['gold_citations'][0]['doc_id']}")
print(f"Gold span:  chars {example['gold_citations'][0]['start']}-{example['gold_citations'][0]['end']}")
print("="*70)

# %% [markdown]
# ## 4. Test with Multiple Queries (Task 3)

# %%
def test_claude_multiple_queries(n_queries=5, use_retriever=False):
    """
    Test Claude with first N queries from evaluation set.
    
    Args:
        n_queries: Number of queries to test (default: 5)
        use_retriever: If True, call retriever service. If False, use gold docs directly.
    
    Returns:
        DataFrame with results
    
    Expected cost: ~$0.015-0.025 for 5 queries
    """
    print("\n" + "="*70)
    print(f"TESTING CLAUDE WITH {n_queries} QUERIES")
    print("="*70)
    print(f"Mode: {'Retriever service' if use_retriever else 'Gold documents (testing)'}")
    print(f"Expected cost: ~${n_queries * 0.004:.3f} USD")
    print("="*70)
    
    test_results = []
    start_time = time.time()
    
    # Use first n_queries from eval_tasks
    for i, task in enumerate(eval_tasks[:n_queries], 1):
        task_id = task['id']
        query = task['query']
        lang = task['lang']
        complexity = task['complexity']
        
        print(f"\n[{i}/{n_queries}] {task_id}")
        print(f"  Query: {query[:60]}...")
        print(f"  Lang: {lang}, Complexity: {complexity}")
        
        try:
            # Get gold document
            gold_doc_id = task['gold_citations'][0]['doc_id']
            gold_doc = next((d for d in corpus if d['id'] == gold_doc_id), None)
            
            if not gold_doc:
                print(f"  Warning: Gold document not found: {gold_doc_id}")
                continue
            
            # Create passage from gold document
            passages = [{
                "doc_id": gold_doc["id"],
                "text": gold_doc["text"][:1000],
                "char_start": 0,
                "char_end": min(1000, len(gold_doc["text"]))
            }]
            
            # Generate answer with Claude
            t0 = time.time()
            response = claude.generate_answer(
                query=query,
                passages=passages[:3],
                max_tokens=150
            )
            lat_ms = (time.time() - t0) * 1000
            
            # Calculate grounded correctness
            gc = grounded_correctness(
                response['citations'],
                task['gold_citations'][0]
            )
            
            # Store result
            result = {
                "id": task_id,
                "lang": lang,
                "complexity": complexity,
                "gc": gc,
                "cost": response["cost_usd"],
                "refusal": response["refusal"],
                "lat_ms": lat_ms,
                "input_tokens": response["usage"]["input_tokens"],
                "output_tokens": response["usage"]["output_tokens"],
                "answer_preview": response["answer"][:80]
            }
            
            test_results.append(result)
            
            # Print status
            status = "PASS" if gc > 0 else "FAIL"
            refusal_str = " [REFUSAL]" if response["refusal"] else ""
            print(f"  {status}: GC={gc:.2f}, Cost=${response['cost_usd']:.6f}, "
                  f"Latency={lat_ms:.0f}ms{refusal_str}")
            
        except Exception as e:
            print(f"  ERROR: {str(e)[:60]}")
            continue
    
    elapsed = time.time() - start_time
    
    print("\n" + "="*70)
    print("TEST SUMMARY")
    print("="*70)
    
    if not test_results:
        print("No queries completed successfully")
        return None
    
    # Create DataFrame
    df_test = pd.DataFrame(test_results)
    
    # Print summary statistics
    print(f"\nCompleted: {len(test_results)}/{n_queries} queries")
    print(f"Time: {elapsed:.1f} seconds ({elapsed/len(test_results):.1f}s per query)")
    print(f"\nPerformance:")
    print(f"  Mean GC:         {df_test['gc'].mean():.3f}")
    print(f"  Total cost:      ${df_test['cost'].sum():.4f} USD")
    print(f"  Avg cost/query:  ${df_test['cost'].mean():.6f} USD")
    print(f"  Refusal rate:    {df_test['refusal'].sum()}/{len(df_test)} ({df_test['refusal'].mean():.1%})")
    print(f"  Avg latency:     {df_test['lat_ms'].mean():.0f}ms")
    
    # By language
    print("\nBy language:")
    lang_stats = df_test.groupby('lang').agg({
        'gc': 'mean',
        'cost': ['sum', 'mean'],
        'refusal': 'sum'
    }).round(4)
    print(lang_stats)
    
    # By complexity
    if 'complexity' in df_test.columns and df_test['complexity'].notna().all():
        print("\nBy complexity:")
        comp_stats = df_test.groupby('complexity').agg({
            'gc': 'mean',
            'cost': 'mean',
            'refusal': 'sum'
        }).round(4)
        print(comp_stats)
    
    print("\n" + "="*70)
    print("CLAUDE API STATISTICS")
    print("="*70)
    claude.print_stats()
    
    return df_test

# %%
# Run test with 5 queries
print("Starting test with 5 queries...")
print("This will cost approximately $0.02 USD")
print("\nNote: Using gold documents (not retriever service)")

df_test = test_claude_multiple_queries(n_queries=5, use_retriever=False)

# %% [markdown]
# ## 5. Verify Results

# %%
# Check if test passed
if df_test is not None:
    print("\n" + "="*70)
    print("MULTI-QUERY TEST PASSED!")
    print("="*70)
    
    print(f"\nTotal spend so far: ${claude.total_cost_usd:.4f} USD")
    print(f"Budget remaining: ${claude.max_spend_usd - claude.total_cost_usd:.2f} USD")
    
    # Check for potential issues
    issues = []
    warnings = []
    
    if df_test['gc'].mean() < 0.3:
        issues.append("Low mean GC (<0.3) - check citation extraction")
    elif df_test['gc'].mean() < 0.5:
        warnings.append("Moderate GC (0.3-0.5) - could be better")
    
    if df_test['refusal'].sum() > len(df_test) * 0.5:
        issues.append("High refusal rate (>50%) - check passage quality")
    
    if df_test['cost'].mean() > 0.01:
        issues.append("High cost per query (>$0.01) - passages may be too long")
    
    if issues:
        print("\nIssues to investigate:")
        for issue in issues:
            print(f"  - {issue}")
    
    if warnings:
        print("\nWarnings:")
        for warning in warnings:
            print(f"  - {warning}")
    
    if not issues and not warnings:
        print("\nAll checks passed - system is working well!")
        print("Ready for full evaluation!")
    
else:
    print("\n" + "="*70)
    print("MULTI-QUERY TEST FAILED")
    print("="*70)
    print("Check errors above and debug before proceeding.")

# %%
# View detailed results
if df_test is not None:
    print("\n" + "="*70)
    print("DETAILED RESULTS")
    print("="*70)
    
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', None)
    pd.set_option('display.max_colwidth', 50)
    
    display(df_test[['id', 'lang', 'complexity', 'gc', 'cost', 'refusal', 'answer_preview']])
    
    # Show failures
    failed = df_test[df_test['gc'] == 0.0]
    if len(failed) > 0:
        print(f"\n{len(failed)} queries with GC=0 (no correct citations):")
        for _, row in failed.iterrows():
            print(f"  - {row['id']} ({row['lang']}, {row['complexity']})")

# %%
# Save test results
if df_test is not None:
    output_dir = Path("../outputs")
    output_dir.mkdir(exist_ok=True)
    
    output_path = output_dir / "test_results_day1.csv"
    df_test.to_csv(output_path, index=False)
    print(f"\nSaved test results to: {output_path}")

API key loaded successfully
Claude client initialized
   Model: claude-sonnet-4-5-20250929
   Budget: $5.00 USD
   Current spend: $0.0000 USD
Loaded corpus: 30 documents
   English: 15
   Māori:   15

Loaded evaluation tasks: 30 tasks

Task distribution:
complexity  complex  simple
lang                       
en                6       9
mi                3      12

EXAMPLE TASK
ID:         en_maori_language_q1
Query:      [EN] Which Act recognised Māori as an official language of New Zealand?
Language:   en
Complexity: complex
Gold doc:   en_maori_language
Gold span:  chars 305-328
Starting test with 5 queries...
This will cost approximately $0.02 USD

Note: Using gold documents (not retriever service)

TESTING CLAUDE WITH 5 QUERIES
Mode: Gold documents (testing)
Expected cost: ~$0.020 USD

[1/5] en_maori_language_q1
  Query: [EN] Which Act recognised Māori as an official language of N...
  Lang: en, Complexity: complex
  FAIL: GC=0.00, Cost=$0.001539, Latency=4219ms

[2/5] en_matariki

Unnamed: 0,id,lang,complexity,gc,cost,refusal,answer_preview
0,en_maori_language_q1,en,complex,0.0,0.001539,False,The Māori Language Act 1987 recognised Māori a...
1,en_matariki_q1,en,simple,0.0,0.002094,False,The Pleiades star cluster is known as Matariki...
2,en_marae_q1,en,simple,0.0,0.002337,False,Based on the passages provided:\n\nA marae is ...
3,en_mount_tongariro_q1,en,simple,0.0,0.001866,False,Mount Tongariro lies within the Taupō Volcanic...
4,en_moa_q1,en,simple,0.0,0.002112,False,"Based on the passage provided, moa were an ext..."



5 queries with GC=0 (no correct citations):
  - en_maori_language_q1 (en, complex)
  - en_matariki_q1 (en, simple)
  - en_marae_q1 (en, simple)
  - en_mount_tongariro_q1 (en, simple)
  - en_moa_q1 (en, simple)

Saved test results to: ../outputs/test_results_day1.csv
