In [None]:
"""
Use this file to define pytest tests that verify the outputs of the task.

This file will be copied to /tests/test_outputs.py and run by the /tests/test.sh file
from the working directory.
"""

import pytest
from pathlib import Path
import pandas as pd
import json

# --- CONFIGURATION ---
# Paths used in the container environment
VARS_PATH = Path("/logs/verifier/notebook_variables.json")
GROUND_TRUTH_PATH = Path("/tests/sabotaged_rows.csv")
AGENT_OUTPUT_PATH = Path("/workspace/sabotaged_rows.csv")

# Hidden Ground Truth Constants
GT_SABOTEUR_ID = 13

@pytest.fixture(scope="module")
def notebook_vars():
    """
    Loads the dictionary of variables captured from the agent's notebook.
    Returns an empty dict if the file is missing.
    """
    if not VARS_PATH.exists():
        pytest.fail(f"Variables file not found at {VARS_PATH}. Did the notebook execute?")
    
    with open(VARS_PATH, 'r') as f:
        try:
            data = json.load(f)
            return data
        except json.JSONDecodeError:
            pytest.fail("Failed to decode notebook_variables.json")

# ==============================================================================
# TESTS
# ==============================================================================

def test_saboteur_identification(notebook_vars):
    """
    Task 1: Tests that the agent identify the correct saboteur_id variable
    """
    var_name = "saboteur_id"

    # Check if variable exists
    if var_name not in notebook_vars:
        pytest.fail(f"Variable '{var_name}' not found in notebook variables.")
    
    # Check value (Allow for string or int representation)
    try:
        agent_val = int(notebook_vars[var_name])
    except ValueError:
        pytest.fail(f"Variable '{var_name}' is not an integer.")
    
    assert agent_val == GT_SABOTEUR_ID, (
        f"Incorrect saboteur identified. Agent found {agent_val}, expected {GT_SABOTEUR_ID}."
    )

def test_output_file_structure():
    """
    Task 2 (Structure): Tests that the agent generate the CSV file with the correct format
    """
    if not AGENT_OUTPUT_PATH.exists():
        pytest.fail(f"Output file '{AGENT_OUTPUT_PATH}' was not created.")
    
    try:
        df = pd.read_csv(AGENT_OUTPUT_PATH)
    except Exception as e:
        pytest.fail(f"Could not read output CSV: {e}")
    
    assert "row_id" in df.columns, "Output CSV is missing the required 'row_id' column."
    assert len(df) > 0, "Output CSV is empty."

def test_cleanup_performance_metrics():
    """
    Task 2 (Logic): Check Precision (critical) and Recall.
    """
    if not GROUND_TRUTH_PATH.exists():
        pytest.fail(f"Ground Truth file '{GROUND_TRUTH_PATH}' does not exists.")
    
    # Load ground truth
    try:
        ground_truth_df = pd.read_csv(GROUND_TRUTH_PATH)
        ground_truth_set = set(ground_truth_df['row_id'].unique())
    except Exception:
        pytest.fail("Could not ground truth data for metric calculation.")
        
    # Load Agent Output
    try:
        agent_df = pd.read_csv(AGENT_OUTPUT_PATH)
        agent_set = set(agent_df['row_id'].unique())
    except Exception:
        pytest.fail("Could not load agent output for metric calculation.")
        
    # Calculate Intersection (True Positives)
    true_positives = len(agent_set.intersection(ground_truth_set))
    false_positives = len(agent_set - ground_truth_set)
    false_negatives = len(ground_truth_set - agent_set)
    
    # Calculate Precision
    # Precision = TP / (TP + FP)
    if len(agent_set) == 0:
        precision = 0.0
    else:
        precision = true_positives / (true_positives + false_positives)
        
    # Calculate Recall
    # Recall = TP / (TP + FN)
    if len(ground_truth_set) == 0:
        recall = 0.0 # Should not happen with valid data
    else:
        recall = true_positives / (true_positives + false_negatives)
        
    print(f"\n[METRICS] Precision: {precision:.4f} | Recall: {recall:.4f}")
    print(f"[COUNTS] TP: {true_positives}, FP: {false_positives}, FN: {false_negatives}")
    
    # --- ASSERTIONS ---
    
    # Constraint 1: Precision must be >= 0.95 (Strict Requirement)
    assert precision >= 0.95, (
        f"Precision failed. Required >= 0.95, got {precision:.4f}. "
        "You flagged too many honest rows as sabotage."
    )
    
    # Constraint 2: Recall must be reasonable (e.g., > 0.5)
    assert recall >= 0.50, (
        f"Recall failed. Required >= 0.50, got {recall:.4f}. "
        "You missed too many sabotaged rows."
    )