In [9]:
import pandas as pd
import numpy as np
import os
import pandas as pd
import warnings
import sys
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import pipeline
import torch
from datasets import load_dataset
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
from tqdm.auto import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import torch


pd.set_option('display.max_colwidth', None)
warnings.filterwarnings("ignore")

model_name = "roberta-large-mnli"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
ds = load_dataset("wangrongsheng/ag_news")


if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cuda


In [None]:
classifier = pipeline('zero-shot-classification', model=model_name, device=device)                              
sequence_to_classify = "This is a test"
candidate_labels = ['travel', 'cooking', 'dancing', 'technician', 'teacher']
classifier(sequence_to_classify, candidate_labels)

In [None]:
def perform_error_analysis(dataset, classifier, num_samples=100):
    """
    Perform comprehensive error analysis on the zero-shot classifier.
    
    Args:
        dataset: The AG News dataset
        classifier: The zero-shot classification pipeline
        num_samples: Number of samples to analyze (use smaller number for testing)
    """
    # AG News labels mapping
    # label_map = {
    #     0: "This document is about politics",
    #     1: "This document is about sports",
    #     2: "This document is about economics",
    #     3: "This document is about science and technology"
    # }
    
    label_map = {
        0: "politics",
        1: "sports",
        2: "economics",
        3: "science and technology"
    }
    
    # Prepare candidate labels for zero-shot classification
    candidate_labels = list(label_map.values())
    
    # Store results
    results = []
    
    # Process test samples
    for i, item in tqdm(enumerate(dataset['test']), total=num_samples):
        if i >= num_samples:
            break
            
        text = item['text']
        true_label = label_map[item['label']]
        
        # Get model prediction
        prediction = classifier(text, candidate_labels)
        predicted_label = prediction['labels'][0]
        confidence = prediction['scores'][0]
        
        
        
        results.append({
            'text': text,
            'true_label': true_label,
            'predicted_label': predicted_label,
            'confidence': confidence,
            'correct': true_label == predicted_label
        })
    
    # Convert to DataFrame
    df_results = pd.DataFrame(results)
    
    # 1. Overall Accuracy
    accuracy = (df_results['correct'].sum() / len(df_results)) * 100
    print(f"\nOverall Accuracy: {accuracy:.2f}%")
    
    # 2. Per-class Performance
    print("\nPer-class Performance:")
    print(classification_report(df_results['true_label'], df_results['predicted_label']))
    
    # 3. Confusion Matrix
    plt.figure(figsize=(10, 8))
    cm = confusion_matrix(df_results['true_label'], df_results['predicted_label'], labels=candidate_labels)
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=candidate_labels, yticklabels=candidate_labels)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    # 4. Confidence Analysis
    plt.figure(figsize=(10, 6))
    sns.boxplot(x='true_label', y='confidence', hue='correct', data=df_results)
    plt.title('Confidence Distribution by Class and Correctness')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    # 5. Error Examples Analysis
    print("\nMost Confident Mistakes:")
    mistakes = df_results[~df_results['correct']].sort_values('confidence', ascending=False)
    print(mistakes[['text', 'true_label', 'predicted_label', 'confidence']].head())
    
    # 6. Length Analysis
    df_results['text_length'] = df_results['text'].str.len()
    plt.figure(figsize=(10, 6))
    sns.boxplot(x='true_label', y='text_length', hue='correct', data=df_results)
    plt.title('Text Length Distribution by Class and Correctness')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    return df_results

In [13]:
def analyze_error_patterns(df_results):
    """
    Analyze specific patterns in the errors.
    """
    # Common misclassification patterns
    error_patterns = defaultdict(int)
    for _, row in df_results[~df_results['correct']].iterrows():
        pattern = f"{row['true_label']} → {row['predicted_label']}"
        error_patterns[pattern] += 1
    
    print("\nCommon Error Patterns:")
    for pattern, count in sorted(error_patterns.items(), key=lambda x: x[1], reverse=True):
        print(f"{pattern}: {count}")
        
    # Confidence threshold analysis
    thresholds = np.arange(0.1, 1.0, 0.1)
    accuracies = []
    coverage = []
    
    for threshold in thresholds:
        filtered_preds = df_results[df_results['confidence'] >= threshold]
        if len(filtered_preds) > 0:
            acc = (filtered_preds['correct'].sum() / len(filtered_preds)) * 100
            cov = (len(filtered_preds) / len(df_results)) * 100
            accuracies.append(acc)
            coverage.append(cov)
    
    # Ensure arrays are the same length before plotting
    min_len = min(len(accuracies), len(coverage))
    thresholds = thresholds[:min_len]
    accuracies = accuracies[:min_len]
    coverage = coverage[:min_len]
    
    if min_len > 0:  # Only plot if we have data
        plt.figure(figsize=(10, 6))
        plt.plot(thresholds, accuracies, 'b-', label='Accuracy')
        plt.plot(thresholds, coverage, 'r-', label='Coverage')
        plt.xlabel('Confidence Threshold')
        plt.ylabel('Percentage')
        plt.title('Accuracy vs Coverage Trade-off')
        plt.legend()
        plt.grid(True)
        plt.show()
    else:
        print("Warning: Not enough data points to create accuracy-coverage plot")


In [None]:
# Example usage
if __name__ == "__main__":
    # Perform main error analysis
    results_df = perform_error_analysis(ds, classifier, num_samples=500)
    
    # Analyze error patterns
    analyze_error_patterns(results_df)