In [None]:
import json
import numpy as np
import torch
import os
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from collections import Counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üì± Using device: {device}")

model_classify_path = "/content/results_biobert_finetuned"

print(" Loading tokenizer and model...")
tokenizer_classify = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
model_classify = AutoModelForSequenceClassification.from_pretrained(model_classify_path)
model_classify.to(device)
model_classify.eval()
print(" Model loaded successfully!")

def classify_non_rag_answer(question):
    """Ph√¢n lo·∫°i c√¢u tr·∫£ l·ªùi ch·ªâ d·ª±a tr√™n c√¢u h·ªèi, kh√¥ng d√πng RAG"""
    try:
        inputs = tokenizer_classify(
            question,
            truncation=True,
            padding=True,
            max_length=512,
            return_tensors="pt"
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model_classify(**inputs)
            pred = torch.argmax(outputs.logits, dim=1).item()

        return ["yes", "no", "maybe"][pred]

    except Exception as e:
        print(f" Classification error: {e}")
        return "maybe"  

def load_pqa_questions(pqa_file_path):
    """Load c√¢u h·ªèi t·ª´ file PQA g·ªëc"""
    try:
        with open(pqa_file_path, "r", encoding="utf-8") as f:
            pqa_data = json.load(f)
        print(f" Loaded PQA data: {len(pqa_data)} items")
        return pqa_data
    except Exception as e:
        print(f" Error loading PQA data: {e}")
        return {}

def load_test_ground_truth(ground_truth_file_path):
    """Load ground truth t·ª´ file test"""
    try:
        with open(ground_truth_file_path, "r", encoding="utf-8") as f:
            ground_truth = json.load(f)
        print(f" Loaded ground truth: {len(ground_truth)} items")
        return ground_truth
    except Exception as e:
        print(f" Error loading ground truth: {e}")
        return {}

def map_questions_to_ground_truth(pqa_data, ground_truth):
    """Map ID t·ª´ ground truth v·ªõi c√¢u h·ªèi t·ª´ PQA data"""
    mapped_data = []
    not_found_count = 0

    for id, true_label in ground_truth.items():
        if id in pqa_data:
            question = pqa_data[id].get("QUESTION", "")
            if question:  
                mapped_data.append({
                    "id": id,
                    "question": question,
                    "true_label": true_label
                })
        else:
            not_found_count += 1

    if not_found_count > 0:
        print(f" {not_found_count} IDs not found in PQA data")

    return mapped_data

def evaluate_classification_model(test_data, sample_size=None):
    """ƒê√°nh gi√° classification model v·ªõi c√¢u h·ªèi th·∫≠t (kh√¥ng RAG)"""

    if sample_size and sample_size < len(test_data):
        import random
        random.seed(42)  
        test_data = random.sample(test_data, sample_size)

    print(f" Evaluating on {len(test_data)} samples (without RAG)...")

    true_labels = []
    pred_labels = []

    for i, item in enumerate(test_data):
        try:
            question = item["question"]
            true_label = item["true_label"]

            pred_label = classify_non_rag_answer(question)

            true_labels.append(true_label)
            pred_labels.append(pred_label)

            if (i + 1) % 20 == 0:
                print(f" Processed {i+1}/{len(test_data)} samples")
                print(f"   Sample: Q: '{question[:60]}...'")
                print(f"   True: {true_label} | Pred: {pred_label}")

            if (i + 1) % 50 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()

        except Exception as e:
            print(f" Error on sample {i}: {e}")
            continue

    print(f"üìã Successfully processed {len(true_labels)} samples")
    return true_labels, pred_labels

def calculate_metrics(true_labels, pred_labels):
    """T√≠nh c√°c metrics ƒë√°nh gi√°"""

    if len(true_labels) == 0:
        print("‚ùå No data to evaluate")
        return 0, None

    print("\n" + "="*50)
    print("üìä CLASSIFICATION EVALUATION RESULTS (NON-RAG)")
    print("="*50)

    accuracy = accuracy_score(true_labels, pred_labels)
    print(f"üéØ Overall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

    print("\nüìà Detailed Classification Report:")
    print(classification_report(true_labels, pred_labels, target_names=["yes", "no", "maybe"]))

    print("\nüîÑ Confusion Matrix:")
    cm = confusion_matrix(true_labels, pred_labels, labels=["yes", "no", "maybe"])
    print("True \\ Pred |   yes  |   no   | maybe ")
    print("-" * 45)
    for i, true_label in enumerate(["yes", "no", "maybe"]):
        row = "   {:5s}   |".format(true_label)
        for j in range(3):
            row += "  {:4d}  |".format(cm[i][j])
        print(row)

    errors = sum(1 for true, pred in zip(true_labels, pred_labels) if true != pred)
    print(f"\n Total Errors: {errors}/{len(true_labels)}")
    print(f" Correct Predictions: {len(true_labels)-errors}/{len(true_labels)}")

    true_dist = Counter(true_labels)
    pred_dist = Counter(pred_labels)
    print(f"\nüìä Class Distribution - True: {dict(true_dist)}")
    print(f"üìä Class Distribution - Pred: {dict(pred_dist)}")

    return accuracy, cm

def main_classification_evaluation():
    """H√†m ch√≠nh ƒë·ªÉ ch·∫°y ƒë√°nh gi√° classification kh√¥ng d√πng RAG"""

    print("üìÅ Loading data...")

    pqa_data = load_pqa_questions("/content/ori_pqal.json")

    ground_truth = load_test_ground_truth("/content/test_ground_truth.json")

    if not pqa_data or not ground_truth:
        print(" Cannot load data files")
        return None

    print(f"üìä PQA data: {len(pqa_data)} items")
    print(f"üìä Ground truth: {len(ground_truth)} items")

    print("üîó Mapping questions to ground truth...")
    test_data = map_questions_to_ground_truth(pqa_data, ground_truth)

    print(f"üìã Mapped {len(test_data)} samples")

    if len(test_data) == 0:
        print(" No samples mapped! Check file paths and data structure.")
        return None

    print("\nüîç Sample mapped data:")
    for i in range(min(3, len(test_data))):
        sample = test_data[i]
        print(f"  {i+1}. ID: {sample['id']}")
        print(f"     Q: {sample['question'][:80]}...")
        print(f"     A: {sample['true_label']}")

    sample_size = min(200, len(test_data))
    print(f"\nüî¨ Using sample size: {sample_size}")

    true_labels, pred_labels = evaluate_classification_model(test_data, sample_size)

    if len(true_labels) == 0:
        print(" No valid samples to evaluate!")
        return None

    accuracy, cm = calculate_metrics(true_labels, pred_labels)

    results = {
        "sample_size": len(true_labels),
        "accuracy": accuracy,
        "confusion_matrix": cm.tolist(),
        "class_names": ["yes", "no", "maybe"],
        "data_source": "PQA mapped questions",
        "model_type": "Non-RAG Classification",
        "model_config": {
            "base_model": "dmis-lab/biobert-base-cased-v1.1",
            "fine_tuned_path": model_classify_path,
            "device": str(device)
        }
    }

    with open("non_rag_classification_results.json", "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2)

    print(f"\nüíæ Results saved to non_rag_classification_results.json")

    return results

if __name__ == "__main__":
    print("üöÄ STARTING NON-RAG CLASSIFICATION EVALUATION")
    results = main_classification_evaluation()

    if results:
        print(" Evaluation completed successfully!")
    else:
        print(" Evaluation failed!")