In [None]:
# EVALUATION.ipynb
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import json
import seaborn as sns # Import seaborn

# 1. Evaluasi Retrieval
def load_retrieval_metrics():
    metrics_df = pd.read_csv('data/eval/retrieval_metrics.csv', index_col=0)

    # Plot perbandingan metrik
    metrics_df.plot(kind='bar', figsize=(10, 6))
    plt.title('Perbandingan Metrik Retrieval')
    plt.ylabel('Skor')
    plt.xticks(rotation=0)
    plt.savefig('figures/retrieval_metrics.png')
    plt.show()

    return metrics_df

# 2. Evaluasi Prediksi (Solution Reuse)
def evaluate_predictions():
    # Load ground truth (dalam implementasi nyata, ini harus disiapkan)
    # Contoh sederhana:
    ground_truth = {
        "pelaku ditangkap dengan barang bukti sabu 1 gram": "Pidana Penjara",
        "terdakwa mengedarkan ganja seberat 500 gram": "Pidana Penjara",
        "pemakai narkotika jenis ekstasi": "Rehabilitasi"
    }

    # Load predictions
    pred_df = pd.read_csv('data/results/predictions.csv')

    # Hitung metrik
    y_true = []
    y_pred = []
    # Filter predictions to only include queries present in ground_truth
    pred_df_filtered = pred_df[pred_df['query'].isin(ground_truth.keys())].copy()

    for _, row in pred_df_filtered.iterrows():
        y_true.append(ground_truth[row['query']])
        y_pred.append(row['predicted_solution'])

    # Handle cases where y_true or y_pred are empty to avoid errors in metrics calculation
    if not y_true:
        print("No predictions match ground truth queries for evaluation.")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
        cm = [] # Empty confusion matrix
    else:
        metrics = {
            'accuracy': accuracy_score(y_true, y_pred),
            'precision': precision_score(y_true, y_pred, average='weighted', zero_division=0), # Add zero_division
            'recall': recall_score(y_true, y_pred, average='weighted', zero_division=0),     # Add zero_division
            'f1': f1_score(y_true, y_pred, average='weighted', zero_division=0)             # Add zero_division
        }

        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred, labels=list(set(y_true + y_pred))) # Ensure all labels are included
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=list(set(y_true + y_pred)), yticklabels=list(set(y_true + y_pred)))
        plt.title('Confusion Matrix Prediksi')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.savefig('figures/confusion_matrix.png')
        plt.show()

    # Simpan metrik
    pd.DataFrame.from_dict(metrics, orient='index', columns=['Value']).to_csv('data/eval/prediction_metrics.csv')


    return metrics, ground_truth # Return ground_truth as well


# 3. Analisis kesalahan
def error_analysis(ground_truth):
    """Contoh analisis kasus yang salah prediksi"""
    print("\nContoh Kasus yang Salah Prediksi:")

    # Load data untuk contoh
    df = pd.read_csv('data/results/predictions.csv')

    for _, row in df.iterrows():
        if row['query'] in ground_truth and row['predicted_solution'] != ground_truth[row['query']]:
            print(f"\nQuery: {row['query']}")
            print(f"Prediksi: {row['predicted_solution']}")
            print(f"Sebenarnya: {ground_truth[row['query']]}")
        # else: # Optional: print correctly predicted cases
        #     if row['query'] in ground_truth and row['predicted_solution'] == ground_truth[row['query']]:
        #          print(f"\nCorrectly Predicted: {row['query']}")
        #          print(f"Prediction/Actual: {row['predicted_solution']}")


# Jalankan evaluasi
retrieval_metrics = load_retrieval_metrics()
prediction_metrics, ground_truth = evaluate_predictions() # Get ground_truth here

print("Metrik Retrieval:")
print(retrieval_metrics)

print("\nMetrik Prediksi:")
print(prediction_metrics)

error_analysis(ground_truth) # Pass ground_truth to error_analysis