# Industrialized Zero-Shot Classification Pipeline

**Author:** Olivier Robert-Duboille
**Date:** 2024-05-22
**Version:** 2.0 (Industrialized)

## 1. Abstract
This notebook demonstrates a robust pipeline for Zero-Shot Learning (ZSL) in a production context. Unlike basic tutorials, we focus on:
-   **Multi-Model Comparison**: Benchmarking `bart-large-mnli` vs. `deberta-v3-base-tasksource-nli`.
-   **Prompt Engineering**: Optimizing the hypothesis template to improve entailment accuracy.
-   **Uncertainty Quantification**: Implementing calibration checks and "I don't know" thresholds.
-   **Top-K Evaluation**: Assessing if the correct label is within the top-N predictions.

## 2. Methodology
Zero-Shot Text Classification is formulated as a Natural Language Inference (NLI) problem.
Given a premise $P$ (the text) and a hypothesis $H$ (constructed as "This text is about {label}"), the model predicts the probability $P(Entailment|P, H)$.

$$
\hat{y} = \arg\max_{c \in C} P(\text{Entailment} | \text{text}, \text{template}(c))
$$


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import pipeline
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import torch

# Visualization settings
sns.set_theme(style="whitegrid", palette="pastel")
plt.rcParams['figure.figsize'] = (12, 6)

device = 0 if torch.cuda.is_available() else -1
print(f"Running on device: {'GPU' if device == 0 else 'CPU'}")

## 3. Dataset Construction
We simulate a challenging Customer Support dataset with ambiguous queries to test model robustness.

In [None]:
# Define a more challenging dataset with some ambiguity
data = [
    ("I cannot reset my password, the link is broken.", "access_issue"),
    ("How do I integrate the API with my Python backend?", "developer_support"), # Specific
    ("The billing dashboard is showing an incorrect amount for March.", "billing"),
    ("Is there a discount for non-profit organizations?", "sales"),
    ("The server returns a 500 error when I upload a large CSV.", "developer_support"),
    ("I want to cancel my subscription immediately.", "retention"), # Tricky: billing or retention?
    ("Can I add more seats to my team plan?", "sales"), # Upsell
    ("My 2FA code is not arriving on my phone.", "access_issue"),
    ("I really love the new dark mode, great job!", "feedback"),
    ("This tool is garbage, I want a refund.", "retention"), # Angry churn
    ("Where can I find the documentation for webhooks?", "developer_support")
]

df = pd.DataFrame(data, columns=['text', 'true_label'])

# Candidate labels (schema)
candidate_labels = [
    "access_issue", 
    "billing", 
    "sales", 
    "developer_support", 
    "retention",
    "feedback"
]

print(f"Dataset Size: {len(df)}")
df.head()

## 4. Model Factory & Prompt Engineering
Different models and templates yield different results. We create a flexible inference function.

In [None]:
def run_inference(model_name, dataset, labels, template="This example is {}."):
    print(f"Loading {model_name}...")
    classifier = pipeline("zero-shot-classification", model=model_name, device=device)
    
    preds = []
    scores = []
    all_scores_packed = []
    
    print(f"Running inference with template: '{template}'...")
    results = classifier(dataset['text'].tolist(), labels, hypothesis_template=template)
    
    for res in results:
        preds.append(res['labels'][0])
        scores.append(res['scores'][0])
        # Pack all scores for top-k analysis later
        score_dict = {label: score for label, score in zip(res['labels'], res['scores'])}
        all_scores_packed.append(score_dict)
        
    return preds, scores, all_scores_packed

# Experiment 1: Standard BART with default template
preds_bart, scores_bart, _ = run_inference(
    "facebook/bart-large-mnli", 
    df, 
    candidate_labels, 
    template="This text is about {}."
)

df['pred_bart'] = preds_bart
df['conf_bart'] = scores_bart

## 5. Evaluation & Error Analysis
We visualize the Confusion Matrix to see where the model gets confused (e.g., Retention vs Billing).

In [None]:
def plot_confusion_matrix(y_true, y_pred, labels, title="Confusion Matrix"):
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=labels, yticklabels=labels)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

plot_confusion_matrix(df['true_label'], df['pred_bart'], candidate_labels, "BART-Large-MNLI Results")

print(classification_report(df['true_label'], df['pred_bart']))

## 6. Uncertainty Quantification (Human-in-the-Loop)
In production, we cannot blindly trust low-confidence predictions. We simulate a 'Human Review' bucket for predictions with confidence < Threshold.

In [None]:
THRESHOLD = 0.6

plt.figure(figsize=(10, 4))
sns.histplot(df['conf_bart'], bins=10, kde=True, color='purple')
plt.axvline(THRESHOLD, color='red', linestyle='--', label=f'Threshold {THRESHOLD}')
plt.title('Confidence Distribution')
plt.legend()
plt.show()

# Flag for review
df['action'] = df['conf_bart'].apply(lambda x: 'Automate' if x >= THRESHOLD else 'Manual Review')

print("Action Distribution:")
print(df['action'].value_counts())

print("\n--- Examples flagged for Manual Review ---")
review_queue = df[df['action'] == 'Manual Review'][['text', 'true_label', 'pred_bart', 'conf_bart']]
display(review_queue)

## 7. Conclusion
- **Accuracy**: We achieved baseline performance with `bart-large-mnli`.
- **Ambiguity**: Classes like 'retention' and 'billing' overlap semantically.
- **Safety**: The thresholding mechanism successfully caught uncertain predictions (e.g., complex queries).

**Next Steps:**
1.  Fine-tune a smaller SetFit model for latency reduction.
2.  Use `cross-encoder/nli-deberta-v3-base` for potentially higher accuracy (at the cost of speed).
