In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

def teacher_bert_ag_news():
    """
    Evaluate the fabriceyhc/bert-base-uncased-ag_news model
    and generate comprehensive classification report
    """
    
    # Load model and tokenizer
    print("Loading model and tokenizer...")
    model_name = "fabriceyhc/bert-base-uncased-ag_news"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    
    # Load AG News test dataset
    print("Loading AG News test dataset...")
    dataset = load_dataset("ag_news")
    test_data = dataset["test"]
    
    # AG News class labels
    class_names = ["World", "Sports", "Business", "Science/Technology"]
    
    # Prepare predictions and true labels
    predictions = []
    true_labels = []
    
    print("Running inference on test set...")
    batch_size = 32
    
    for i in tqdm(range(0, len(test_data), batch_size)):
        batch = test_data[i:i+batch_size]
        texts = batch["text"]
        labels = batch["label"]
        
        # Tokenize batch
        inputs = tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )
        
        # Move to device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get predictions
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            batch_predictions = torch.argmax(logits, dim=-1).cpu().numpy()
        
        predictions.extend(batch_predictions)
        true_labels.extend(labels)
    
    # Convert to numpy arrays
    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    
    # Generate classification report
    print("\n" + "="*60)
    print("CLASSIFICATION REPORT")
    print("="*60)
    
    report = classification_report(
        true_labels, 
        predictions, 
        target_names=class_names,
        digits=4
    )
    print(report)
    
    # Generate detailed metrics
    # from sklearn.metrics import accuracy_score, precision_recall_fscore_support
    
    # accuracy = accuracy_score(true_labels, predictions)
    # precision, recall, f1, support = precision_recall_fscore_support(
    #     true_labels, predictions, average=None
    # )
    
    # # Create detailed results DataFrame
    # results_df = pd.DataFrame({
    #     'Class': class_names,
    #     'Precision': precision,
    #     'Recall': recall,
    #     'F1-Score': f1,
    #     'Support': support
    # })
    
    # print("\nDETAILED METRICS:")
    # print(results_df.round(4))
    
    # Overall metrics
    # macro_avg_precision = np.mean(precision)
    # macro_avg_recall = np.mean(recall)
    # macro_avg_f1 = np.mean(f1)
    # weighted_avg_f1 = np.average(f1, weights=support)
    
    # print(f"\nOVERALL PERFORMANCE:")
    # print(f"Accuracy: {accuracy:.4f}")
    # print(f"Macro Avg Precision: {macro_avg_precision:.4f}")
    # print(f"Macro Avg Recall: {macro_avg_recall:.4f}")
    # print(f"Macro Avg F1-Score: {macro_avg_f1:.4f}")
    # print(f"Weighted Avg F1-Score: {weighted_avg_f1:.4f}")
    
    # # Confusion Matrix
    # cm = confusion_matrix(true_labels, predictions)
    
    # plt.figure(figsize=(10, 8))
    # sns.heatmap(
    #     cm, 
    #     annot=True, 
    #     fmt='d', 
    #     cmap='Blues',
    #     xticklabels=class_names,
    #     yticklabels=class_names
    # )
    # plt.title('Confusion Matrix - BERT AG News Classification')
    # plt.xlabel('Predicted Label')
    # plt.ylabel('True Label')
    # plt.tight_layout()
    # plt.show()
    
    # # Per-class analysis
    # print("\nPER-CLASS ANALYSIS:")
    # for i, class_name in enumerate(class_names):
    #     print(f"\n{class_name}:")
    #     print(f"  Precision: {precision[i]:.4f}")
    #     print(f"  Recall: {recall[i]:.4f}")
    #     print(f"  F1-Score: {f1[i]:.4f}")
    #     print(f"  Support: {support[i]}")
        
    #     # Calculate class-specific metrics
    #     tp = cm[i, i]
    #     fp = cm[:, i].sum() - tp
    #     fn = cm[i, :].sum() - tp
    #     tn = cm.sum() - tp - fp - fn
        
    #     specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    #     print(f"  Specificity: {specificity:.4f}")
    
    # # Error analysis - find misclassified examples
    # print("\nERROR ANALYSIS - Sample Misclassifications:")
    # misclassified_indices = np.where(predictions != true_labels)[0]
    
    # # Show first 5 misclassifications
    # for i, idx in enumerate(misclassified_indices[:5]):
    #     print(f"\nExample {i+1}:")
    #     print(f"Text: {test_data[int(idx)]['text'][:200]}...")
    #     print(f"True Label: {class_names[true_labels[idx]]}")
    #     print(f"Predicted Label: {class_names[predictions[idx]]}")
    
    # # Save results to file
    # results_summary = {
    #     'model_name': model_name,
    #     'accuracy': accuracy,
    #     'macro_avg_precision': macro_avg_precision,
    #     'macro_avg_recall': macro_avg_recall,
    #     'macro_avg_f1': macro_avg_f1,
    #     'weighted_avg_f1': weighted_avg_f1,
    #     'per_class_metrics': results_df.to_dict('records'),
    #     'confusion_matrix': cm.tolist(),
    #     'total_samples': len(true_labels),
    #     'misclassification_rate': 1 - accuracy
    # }
    
    # return results_summary, predictions, true_labels

# Run the evaluation
if __name__ == "__main__":
    results = evaluate_bert_ag_news()

Loading model and tokenizer...
Loading AG News test dataset...
Running inference on test set...


100%|███████████████████████████████████████████████| 238/238 [02:13<00:00,  1.78it/s]


CLASSIFICATION REPORT
                    precision    recall  f1-score   support

             World     0.9471    0.9526    0.9499      1900
            Sports     0.9736    0.9916    0.9825      1900
          Business     0.9091    0.9000    0.9045      1900
Science/Technology     0.9188    0.9058    0.9123      1900

          accuracy                         0.9375      7600
         macro avg     0.9372    0.9375    0.9373      7600
      weighted avg     0.9372    0.9375    0.9373      7600


DETAILED METRICS:
                Class  Precision  Recall  F1-Score  Support
0               World     0.9471  0.9526    0.9499     1900
1              Sports     0.9736  0.9916    0.9825     1900
2            Business     0.9091  0.9000    0.9045     1900
3  Science/Technology     0.9188  0.9058    0.9123     1900



