In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
import json
import torch
from rank_bm25 import BM25Okapi
import nltk
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Download NLTK n·∫øu c·∫ßn
try:
    nltk.data.find('tokenizers/punkt')
except nltk.downloader.DownloadError:
    nltk.download('punkt')

# Load d·ªØ li·ªáu t·ª´ pqa_l (thay v√¨ ori_pqau.json)
file_path_pqa_l = "/content/ori_pqal.json"  # Gi·∫£ ƒë·ªãnh t√™n file; thay n·∫øu kh√°c
try:
    with open(file_path_pqa_l, "r", encoding="utf-8") as f:
        pqa_l = json.load(f)  # D√πng pqa_l thay pqa_u
except FileNotFoundError:
    raise FileNotFoundError(f"File {file_path_pqa_l} not found. Please check the path.")
except json.JSONDecodeError:
    raise ValueError("Invalid JSON file.")

# T·∫°o corpus t·ª´ pqa_l
corpus = []
for key, value in pqa_l.items():
    if "CONTEXTS" in value and isinstance(value["CONTEXTS"], list):
        context = " ".join(value["CONTEXTS"])
        corpus.append(context)
print(f"Corpus size: {len(corpus)}")

# Tokenize cho BM25
def simple_tokenize(text):
    return text.lower().split()

tokenized_corpus = [simple_tokenize(doc) for doc in corpus]
bm25 = BM25Okapi(tokenized_corpus)

# Load models (gi·∫£ ƒë·ªãnh ƒë√£ c√≥ t·ª´ code g·ªëc)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_classify_path = "/content/results_biobert_finetuned"
tokenizer_classify = AutoTokenizer.from_pretrained(model_classify_path)  # T·ª´ th∆∞ m·ª•c fine-tuned
model_classify = AutoModelForSequenceClassification.from_pretrained(model_classify_path)
model_classify.to(device)

# H√†m retrieve_evidence
def retrieve_evidence(question, top_k=3):
    if not corpus:
        return []
    tokenized_query = simple_tokenize(question)
    top_indices = bm25.get_top_n(tokenized_query, range(len(corpus)), n=top_k)
    return [corpus[i] for i in top_indices]

# H√†m classify_answer
def classify_answer(question, evidence):
    if not evidence:
        return "maybe"
    combined_input = question + " [SEP] " + " ".join(evidence[:2])
    inputs = tokenizer_classify(
        combined_input,
        truncation=True,
        padding=True,
        max_length=512,
        return_tensors="pt"
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model_classify(**inputs)
        pred = torch.argmax(outputs.logits, dim=1).item()
    return ["yes", "no", "maybe"][pred]

# Load test_ground_truth
test_file_path = "/content/test_ground_truth.json"
try:
    with open(test_file_path, "r", encoding="utf-8") as f:
        test_ground_truth = json.load(f)
except FileNotFoundError:
    raise FileNotFoundError(f"File {test_file_path} not found.")

# Ki·ªÉm tra common keys gi·ªØa pqa_l v√† test_ground_truth
common_keys = set(pqa_l.keys()) & set(test_ground_truth.keys())
print(f"Number of common keys between pqa_l and test_ground_truth: {len(common_keys)}")
if len(common_keys) == 0:
    print("Still no common keys! Please check file names or structures.")
else:
    print("Common keys found! Proceeding with evaluation.")

# Danh s√°ch predictions v√† ground_truths
predictions = []
ground_truths = []
labels = ["yes", "no", "maybe"]

print("üîç Starting evaluation of classification on test set...")
for test_id, true_label in test_ground_truth.items():
    if test_id not in pqa_l:
        print(f"Warning: ID {test_id} not found in pqa_l. Skipping.")
        continue

    question = pqa_l[test_id].get("QUESTION", "")
    if not question:
        print(f"Warning: No question found for ID {test_id}. Skipping.")
        continue

    evidence_docs = retrieve_evidence(question, top_k=3)
    pred_label = classify_answer(question, evidence_docs)

    predictions.append(pred_label)
    ground_truths.append(true_label)

    print(f"ID: {test_id} | Predicted: {pred_label} | Ground Truth: {true_label}")

# T√≠nh metrics
accuracy = accuracy_score(ground_truths, predictions)
precision, recall, f1, support = precision_recall_fscore_support(ground_truths, predictions, labels=labels, average=None)
macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(ground_truths, predictions, labels=labels, average='macro')
micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(ground_truths, predictions, labels=labels, average='micro')

print("\nüìä CLASSIFICATION EVALUATION RESULTS:")
print(f"Accuracy: {accuracy:.4f}")
print(f"Macro Precision: {macro_precision:.4f}")
print(f"Macro Recall: {macro_recall:.4f}")
print(f"Macro F1-Score: {macro_f1:.4f}")
print(f"Micro Precision: {micro_precision:.4f}")
print(f"Micro Recall: {micro_recall:.4f}")
print(f"Micro F1-Score: {micro_f1:.4f}")
print("\nPer-Class Metrics:")
for i, label in enumerate(labels):
    print(f"{label}: Precision={precision[i]:.4f}, Recall={recall[i]:.4f}, F1={f1[i]:.4f}, Support={support[i]}")

print("\nDetailed Classification Report:")
print(classification_report(ground_truths, predictions, labels=labels, target_names=labels))


In [None]:
results_dir = "/content/results/"  
os.makedirs(results_dir, exist_ok=True)  
metrics_file = os.path.join(results_dir, "classification_metrics.txt")
with open(metrics_file, "w", encoding="utf-8") as f:
    f.write("üìä CLASSIFICATION EVALUATION RESULTS:\n")
    f.write(f"Accuracy: {accuracy:.4f}\n")
    f.write(f"Macro Precision: {macro_precision:.4f}\n")
    f.write(f"Macro Recall: {macro_recall:.4f}\n")
    f.write(f"Macro F1-Score: {macro_f1:.4f}\n")
    f.write(f"Micro Precision: {micro_precision:.4f}\n")
    f.write(f"Micro Recall: {micro_recall:.4f}\n")
    f.write(f"Micro F1-Score: {micro_f1:.4f}\n\n")
    f.write("Per-Class Metrics:\n")
    for i, label in enumerate(labels):
        f.write(f"{label}: Precision={precision[i]:.4f}, Recall={recall[i]:.4f}, F1={f1[i]:.4f}, Support={support[i]}\n")
    f.write("\nDetailed Classification Report:\n")
    f.write(classification_report(ground_truths, predictions, labels=labels, target_names=labels))
print(f" Classification metrics saved to: {metrics_file}")
predictions_file = os.path.join(results_dir, "predictions.json")
pred_data = {
    "predictions": predictions,
    "ground_truths": ground_truths,
    "test_ids": list(test_ground_truth.keys())  
}
with open(predictions_file, "w", encoding="utf-8") as f:
    json.dump(pred_data, f, indent=4)
print(f" Predictions saved to: {predictions_file}")
print(f" All results saved in directory: {results_dir}")
