# Zero-Shot Text Classification with Transformers

**Author:** Olivier Robert-Duboille

## 1. Introduction
Zero-Shot Learning (ZSL) allows a model to classify data into classes it has never seen during training. This is particularly powerful in NLP, where we can leverage large pre-trained language models (LLMs) to understand the semantic relationship between a text sequence and candidate labels.

### Objectives:
- Implement a Zero-Shot Classifier pipeline using Hugging Face `transformers`.
- Create a custom dataset of "unseen" topics (e.g., specific legal or medical queries).
- Evaluate the model's ability to categorize these queries without any fine-tuning.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import pipeline
from sklearn.metrics import confusion_matrix, classification_report

# Visualization settings
sns.set_theme(style="darkgrid")
plt.rcParams['figure.figsize'] = (10, 6)

## 2. The Model
We will use the `bart-large-mnli` model, which is pre-trained on the Multi-Genre Natural Language Inference (MNLI) corpus. It treats classification as an entailment problem: Given a premise (text) and a hypothesis (This text is about {label}), does the premise entail the hypothesis?

In [None]:
# Initialize the zero-shot classification pipeline
# Note: This will download the model weights (~1.6GB) if not cached.
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

## 3. Creating a "Mock" Unseen Dataset
Let's simulate a scenario where we are analyzing customer support tickets for a Tech SaaS company. The model has not been explicitly trained on these categories.

In [None]:
data = [
    ("I cannot reset my password, the link is broken.", "access_issue"),
    ("How do I integrate the API with my Python backend?", "technical_support"),
    ("The billing dashboard is showing an incorrect amount for March.", "billing"),
    ("Is there a discount for non-profit organizations?", "sales"),
    ("The server returns a 500 error when I upload a large CSV.", "technical_support"),
    ("I want to cancel my subscription immediately.", "billing"),
    ("Can I add more seats to my team plan?", "sales"),
    ("My 2FA code is not arriving on my phone.", "access_issue")
]

df = pd.DataFrame(data, columns=['text', 'true_label'])
candidate_labels = ["access_issue", "billing", "sales", "technical_support", "feature_request"]

df.head()

## 4. Inference
We will now run the classifier on our dataset. The model outputs a probability distribution over the candidate labels.

In [None]:
predictions = []
confidence_scores = []

print(f"Classifying {len(df)} examples...")

for text in df['text']:
    result = classifier(text, candidate_labels)
    # The result is sorted by score, so the first element is the top prediction
    top_label = result['labels'][0]
    top_score = result['scores'][0]
    
    predictions.append(top_label)
    confidence_scores.append(top_score)

df['predicted_label'] = predictions
df['confidence'] = confidence_scores

df

## 5. Evaluation & Analysis
Let's see how well our zero-shot model performed without any training data.

In [None]:
# Confusion Matrix
cm = confusion_matrix(df['true_label'], df['predicted_label'], labels=candidate_labels)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=candidate_labels, yticklabels=candidate_labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Zero-Shot Classification Confusion Matrix')
plt.show()

# Classification Report
print(classification_report(df['true_label'], df['predicted_label']))

## 6. Confidence Analysis
It's useful to understand how confident the model is. Low confidence predictions might require human review in a production system.

In [None]:
plt.figure(figsize=(8, 4))
sns.histplot(df['confidence'], bins=10, kde=True, color='purple')
plt.title('Distribution of Model Confidence Scores')
plt.xlabel('Confidence Score')
plt.show()

# Show low confidence examples
threshold = 0.7
low_conf = df[df['confidence'] < threshold]
if not low_conf.empty:
    print(f"Examples with confidence < {threshold}:")
    print(low_conf[['text', 'predicted_label', 'confidence']])
else:
    print(f"All predictions have confidence >= {threshold}")