In [None]:
# Week 3: RAG Publishability Classifier Testing
# Test the complete RAG-powered classification system

import sys
sys.path.append('../../')

import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from config.settings import *
from src.rag_engine.rag_classifier import RAGPublishabilityClassifier

print("Week 3: RAG Publishability Classifier Testing")
print("=" * 50)

# =============================================================================
# SECTION 1: INITIALIZE CLASSIFIER
# =============================================================================

print("\nSECTION 1: Initialize RAG Classifier")
print("-" * 40)

# Initialize the classifier
print("Loading RAG Publishability Classifier...")
classifier = RAGPublishabilityClassifier()
print("Classifier loaded successfully!")

# =============================================================================
# SECTION 2: TEST WITH SAMPLE ARTICLES
# =============================================================================

print("\nSECTION 2: Test with Sample Articles")
print("-" * 37)

# Test articles of different quality levels
test_articles = [
    {
        'label': 'High Quality',
        'expected': 'موافق',
        'text': """
        الرياض في 27 سبتمبر /واس/ أكد صاحب السمو الملكي الأمير محمد بن سلمان بن عبدالعزيز ولي العهد رئيس مجلس الوزراء، أهمية تعزيز التعاون الاستراتيجي بين المملكة العربية السعودية والولايات المتحدة الأمريكية في مختلف المجالات.
        
        جاء ذلك خلال استقبال سموه في قصر اليمامة بالرياض اليوم، وزير الخارجية الأمريكي أنتوني بلينكن والوفد المرافق له.
        
        وبحث الجانبان خلال اللقاء، أوجه التعاون الثنائي بين البلدين، والمستجدات الإقليمية والدولية ذات الاهتمام المشترك.
        
        وأشار ولي العهد إلى عمق العلاقات التاريخية بين المملكة والولايات المتحدة، مؤكداً حرص القيادة على تطوير هذه العلاقات في جميع المجالات.
        """
    },
    {
        'label': 'Low Quality',
        'expected': 'مرفوض',
        'text': """
        خبر مهم
        حدث شيء اليوم
        """
    },
    {
        'label': 'Medium Quality',
        'expected': 'مرفوض',
        'text': """
        أعلنت وزارة الصحة اليوم عن تسجيل حالات جديدة.
        وأوضحت الوزارة أن الأرقام في تزايد مستمر.
        """
    },
    {
        'label': 'Good with Attribution',
        'expected': 'موافق',
        'text': """
        جدة في 27 سبتمبر /واس/ قال المتحدث باسم وزارة التعليم إن الوزارة أطلقت برنامجاً جديداً لتطوير المناهج الدراسية.
        
        وأضاف المتحدث أن البرنامج يهدف إلى تحسين جودة التعليم ومواكبة التطورات العالمية في هذا المجال.
        
        وأوضح أن البرنامج سيشمل جميع المراحل التعليمية ويتضمن تدريب المعلمين على أحدث الطرق التدريسية.
        """
    }
]

# Test each article
test_results = []

for i, test_case in enumerate(test_articles, 1):
    print(f"\nTest Case {i}: {test_case['label']}")
    print(f"Expected: {test_case['expected']}")
    print(f"Text Preview: {test_case['text'].strip()[:100]}...")
    
    # Classify the article
    result = classifier.classify_article(test_case['text'])
    
    if result['status'] == 'success':
        decision = result['decision']
        confidence = result['confidence']
        reasoning = result['reasoning'][:150] + "..." if len(result['reasoning']) > 150 else result['reasoning']
        
        print(f"AI Decision: {decision}")
        print(f"Confidence: {confidence:.2f}")
        print(f"Reasoning: {reasoning}")
        
        # Check if correct
        correct = decision == test_case['expected']
        print(f"Correct: {'✓' if correct else '✗'}")
        
        test_results.append({
            'label': test_case['label'],
            'expected': test_case['expected'],
            'predicted': decision,
            'confidence': confidence,
            'correct': correct
        })
    else:
        print(f"Error: {result['error']}")
        test_results.append({
            'label': test_case['label'],
            'expected': test_case['expected'],
            'predicted': 'Error',
            'confidence': 0.0,
            'correct': False
        })

# =============================================================================
# SECTION 3: TEST ON REAL DATASET SAMPLE
# =============================================================================

print(f"\n\nSECTION 3: Test on Real Dataset Sample")
print("-" * 42)

# Load a small sample from the real dataset
print("Loading real dataset sample...")
df_sample = pd.read_excel(MAIN_DATASET).head(10)
df_sample['is_approved'] = df_sample['TrackId'].notna() & (df_sample['TrackId'] != 'NULL')

real_test_results = []

print(f"Testing {len(df_sample)} real articles...")

for idx, row in tqdm(df_sample.iterrows(), total=len(df_sample), desc="Testing real articles"):
    article_text = row['Story']
    actual_decision = 'موافق' if row['is_approved'] else 'مرفوض'
    
    # Classify with AI
    result = classifier.classify_article(article_text)
    
    if result['status'] == 'success':
        predicted_decision = result['decision']
        confidence = result['confidence']
        correct = predicted_decision == actual_decision
        
        real_test_results.append({
            'story_id': row['StoryId'],
            'actual': actual_decision,
            'predicted': predicted_decision,
            'confidence': confidence,
            'correct': correct,
            'text_length': len(article_text)
        })
    else:
        real_test_results.append({
            'story_id': row['StoryId'],
            'actual': actual_decision,
            'predicted': 'Error',
            'confidence': 0.0,
            'correct': False,
            'text_length': len(article_text)
        })

# =============================================================================
# SECTION 4: RESULTS ANALYSIS
# =============================================================================

print(f"\n\nSECTION 4: Results Analysis")
print("-" * 31)

# Sample test results
print("Sample Test Results:")
correct_count = sum(1 for r in test_results if r['correct'])
accuracy = correct_count / len(test_results) if test_results else 0
print(f"  Accuracy: {accuracy:.1%} ({correct_count}/{len(test_results)})")

avg_confidence = np.mean([r['confidence'] for r in test_results if r['confidence'] > 0])
print(f"  Average Confidence: {avg_confidence:.2f}")

# Real dataset results
if real_test_results:
    print(f"\nReal Dataset Results:")
    real_correct = sum(1 for r in real_test_results if r['correct'])
    real_accuracy = real_correct / len(real_test_results)
    print(f"  Accuracy: {real_accuracy:.1%} ({real_correct}/{len(real_test_results)})")
    
    real_avg_confidence = np.mean([r['confidence'] for r in real_test_results if r['confidence'] > 0])
    print(f"  Average Confidence: {real_avg_confidence:.2f}")
    
    # Breakdown by actual decision
    approved_articles = [r for r in real_test_results if r['actual'] == 'موافق']
    rejected_articles = [r for r in real_test_results if r['actual'] == 'مرفوض']
    
    if approved_articles:
        approved_accuracy = sum(1 for r in approved_articles if r['correct']) / len(approved_articles)
        print(f"  Approved Articles Accuracy: {approved_accuracy:.1%}")
    
    if rejected_articles:
        rejected_accuracy = sum(1 for r in rejected_articles if r['correct']) / len(rejected_articles)
        print(f"  Rejected Articles Accuracy: {rejected_accuracy:.1%}")

# =============================================================================
# SECTION 5: VISUALIZATION
# =============================================================================

print(f"\n\nSECTION 5: Visualization")
print("-" * 26)

if real_test_results:
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Accuracy by decision type
    actual_decisions = [r['actual'] for r in real_test_results]
    predicted_decisions = [r['predicted'] for r in real_test_results]
    
    decision_counts = pd.DataFrame({
        'Actual': actual_decisions,
        'Predicted': predicted_decisions
    })
    
    # Confusion matrix data
    approved_actual = len([r for r in real_test_results if r['actual'] == 'موافق'])
    rejected_actual = len([r for r in real_test_results if r['actual'] == 'مرفوض'])
    
    axes[0, 0].bar(['Approved Actual', 'Rejected Actual'], [approved_actual, rejected_actual], 
                   color=['green', 'red'], alpha=0.7)
    axes[0, 0].set_title('Actual Decisions Distribution')
    axes[0, 0].set_ylabel('Count')
    
    # Confidence distribution
    confidences = [r['confidence'] for r in real_test_results if r['confidence'] > 0]
    if confidences:
        axes[0, 1].hist(confidences, bins=10, alpha=0.7, color='blue')
        axes[0, 1].set_title('Confidence Score Distribution')
        axes[0, 1].set_xlabel('Confidence')
        axes[0, 1].set_ylabel('Frequency')
    
    # Accuracy vs confidence
    correct_confidences = [r['confidence'] for r in real_test_results if r['correct']]
    incorrect_confidences = [r['confidence'] for r in real_test_results if not r['correct']]
    
    if correct_confidences and incorrect_confidences:
        axes[1, 0].hist([correct_confidences, incorrect_confidences], 
                       bins=10, alpha=0.7, label=['Correct', 'Incorrect'], 
                       color=['green', 'red'])
        axes[1, 0].set_title('Confidence by Correctness')
        axes[1, 0].set_xlabel('Confidence')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].legend()
    
    # Text length vs accuracy
    text_lengths = [r['text_length'] for r in real_test_results]
    correctness = [1 if r['correct'] else 0 for r in real_test_results]
    
    axes[1, 1].scatter(text_lengths, correctness, alpha=0.6)
    axes[1, 1].set_title('Text Length vs Accuracy')
    axes[1, 1].set_xlabel('Text Length (characters)')
    axes[1, 1].set_ylabel('Correct (1) / Incorrect (0)')
    
    plt.tight_layout()
    plt.show()

# =============================================================================
# SECTION 6: SYSTEM PERFORMANCE
# =============================================================================

print(f"\n\nSECTION 6: System Performance Summary")
print("-" * 38)

performance_summary = {
    'Vector Database': f"{classifier.vector_store.get_collection_stats()['approved_articles']['count']} approved + {classifier.vector_store.get_collection_stats()['rejected_articles']['count']} rejected examples",
    'LLM Model': classifier.model,
    'Sample Test Accuracy': f"{accuracy:.1%}" if test_results else "N/A",
    'Real Dataset Accuracy': f"{real_accuracy:.1%}" if real_test_results else "N/A",
    'Average Confidence': f"{real_avg_confidence:.2f}" if real_test_results else "N/A"
}

print("Performance Summary:")
for key, value in performance_summary.items():
    print(f"  {key}: {value}")

print(f"\nWeek 3 Status: RAG Classifier Implementation Complete!")
print("Next: Optimize accuracy and build Stage 2 (Editorial Assistant)")

# =============================================================================
# SECTION 7: SAVE RESULTS
# =============================================================================

print(f"\n\nSECTION 7: Save Results")
print("-" * 25)

# Save test results
if real_test_results:
    results_df = pd.DataFrame(real_test_results)
    output_file = ANALYTICS_DIR / "week3_rag_classifier_results.csv"
    results_df.to_csv(output_file, index=False)
    print(f"Results saved to: {output_file}")

print("Week 3 testing complete!")