# NLI and Zero-Shot Classification

Learn how to use Natural Language Inference for zero-shot text classification.

In [None]:
from transformers import pipeline
from sentence_transformers import CrossEncoder
import torch

## What is NLI?

**Natural Language Inference** predicts the relationship between two texts:

- **Entailment**: Hypothesis logically follows from premise
- **Contradiction**: Hypothesis contradicts premise
- **Neutral**: No clear relationship

In [None]:
# Load NLI cross-encoder
nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-small')

def predict_nli(premise, hypothesis):
    """Predict NLI relationship"""
    scores = nli_model.predict([(premise, hypothesis)])
    labels = ['contradiction', 'entailment', 'neutral']
    # scores is array [contradiction_score, entailment_score, neutral_score]
    max_idx = scores[0].argmax()
    return labels[max_idx], scores[0][max_idx]

# Example 1: Entailment
premise = "I have a dog named Max"
hypothesis = "I have a pet"
label, score = predict_nli(premise, hypothesis)
print(f"Premise: {premise}")
print(f"Hypothesis: {hypothesis}")
print(f"Prediction: {label} ({score:.3f})\n")

# Example 2: Contradiction
premise = "The weather is sunny"
hypothesis = "It's raining outside"
label, score = predict_nli(premise, hypothesis)
print(f"Premise: {premise}")
print(f"Hypothesis: {hypothesis}")
print(f"Prediction: {label} ({score:.3f})\n")

# Example 3: Neutral
premise = "I went to the store"
hypothesis = "I bought milk"
label, score = predict_nli(premise, hypothesis)
print(f"Premise: {premise}")
print(f"Hypothesis: {hypothesis}")
print(f"Prediction: {label} ({score:.3f})")

## Zero-Shot Classification

Convert classification to NLI:
- **Premise** = Text to classify
- **Hypothesis** = "This text is about {label}"
- Check which hypothesis has highest **entailment** score

In [None]:
# Use HuggingFace's zero-shot pipeline
classifier = pipeline(
    "zero-shot-classification",
    model="facebook/bart-large-mnli",
    device=0 if torch.cuda.is_available() else -1
)

# Example: Sentiment classification
text = "I absolutely love this product! It exceeded all my expectations."
labels = ["positive", "negative", "neutral"]

result = classifier(text, candidate_labels=labels)

print("Text:", text)
print("\nClassification:")
for label, score in zip(result['labels'], result['scores']):
    print(f"  {label}: {score:.3f}")

## How It Works Behind the Scenes

In [None]:
# Manual zero-shot (what the pipeline does internally)
text = "The system crashed and all data was lost."
labels = ["positive", "negative", "neutral"]

# Create hypothesis for each label
template = "This text expresses {} sentiment."

scores = []
for label in labels:
    hypothesis = template.format(label)
    # Predict entailment
    result = nli_model.predict([(text, hypothesis)])
    # Get entailment score (index 1)
    entailment_score = result[0][1]
    scores.append(entailment_score)
    print(f"{label}: {entailment_score:.3f}")

# Best label
best_idx = np.argmax(scores)
print(f"\nPredicted: {labels[best_idx]}")

## Multi-Label Classification

Allow text to belong to multiple categories.

In [None]:
text = "URGENT: Payment system is down. All customer transactions failing."
labels = ["urgent", "technical", "payment", "customer_facing", "feature_request"]

# Multi-label classification
result = classifier(
    text,
    candidate_labels=labels,
    multi_label=True  # Allow multiple labels
)

print("Text:", text)
print("\nLabels (threshold > 0.5):")
for label, score in zip(result['labels'], result['scores']):
    if score > 0.5:
        print(f"  ✓ {label}: {score:.3f}")
    else:
        print(f"  ✗ {label}: {score:.3f}")

## Custom Hypothesis Templates

Improve accuracy with better templates.

In [None]:
text = "The meeting is scheduled for tomorrow at 2pm"

# Poor template (generic)
result1 = classifier(
    text,
    candidate_labels=["scheduling", "reminder", "general"],
    hypothesis_template="This text is about {}."
)

print("Generic template:")
for label, score in zip(result1['labels'][:3], result1['scores'][:3]):
    print(f"  {label}: {score:.3f}")

# Better template (specific)
result2 = classifier(
    text,
    candidate_labels=["scheduling", "reminder", "general"],
    hypothesis_template="This message is intended to {} something."
)

print("\nSpecific template:")
for label, score in zip(result2['labels'][:3], result2['scores'][:3]):
    print(f"  {label}: {score:.3f}")

## Use Cases

### 1. Content Moderation

In [None]:
comments = [
    "Great product, highly recommend!",
    "This is spam! Buy now at spamsite.com",
    "You're an idiot and I hate you",
    "How do I reset my password?"
]

moderation_labels = ["appropriate", "spam", "toxic", "helpful"]

for comment in comments:
    result = classifier(comment, candidate_labels=moderation_labels)
    top_label = result['labels'][0]
    top_score = result['scores'][0]
    print(f"[{top_label:12s} {top_score:.2f}] {comment[:40]}...")

### 2. Ticket Routing

In [None]:
tickets = [
    "My credit card was charged twice for the same order",
    "The app crashes when I try to upload images",
    "Can you add dark mode to the settings?",
    "I forgot my password and can't log in"
]

teams = ["billing", "technical", "product", "support"]

print("Ticket Routing:")
for ticket in tickets:
    result = classifier(ticket, candidate_labels=teams)
    assigned_team = result['labels'][0]
    confidence = result['scores'][0]
    print(f"→ {assigned_team:10s} ({confidence:.2f}) | {ticket[:40]}...")

### 3. Intent Detection

In [None]:
messages = [
    "What's the weather like today?",
    "Order a pizza with extra cheese",
    "Cancel my subscription",
    "Tell me a joke"
]

intents = ["query", "command", "cancel", "entertainment"]

for msg in messages:
    result = classifier(msg, candidate_labels=intents)
    intent = result['labels'][0]
    print(f"[{intent:15s}] {msg}")

## Limitations

### 1. Class Imbalance

In [None]:
# Problem: Specific labels dominate generic ones
text = "The product is okay"
labels = ["positive", "negative", "neutral", "product_review"]

result = classifier(text, candidate_labels=labels)
print("⚠️  Notice how 'product_review' dominates:")
for label, score in zip(result['labels'], result['scores']):
    print(f"  {label}: {score:.3f}")

# Solution: Use hierarchical classification
# First: Classify as product_review or not
# Then: If product_review, classify sentiment

### 2. Ambiguous Categories

In [None]:
# Ambiguous text
text = "The system is running"
labels = ["positive", "neutral", "status_update"]

result = classifier(text, candidate_labels=labels)
print("Ambiguous classification:")
for label, score in zip(result['labels'], result['scores']):
    print(f"  {label}: {score:.3f}")
print("\n⚠️  Close scores indicate ambiguity")

### 3. Sarcasm and Negation

In [None]:
# Sarcasm (hard for NLI)
text = "Oh great, another system outage. Just perfect!"
labels = ["positive", "negative"]

result = classifier(text, candidate_labels=labels)
print("Sarcasm test:")
for label, score in zip(result['labels'], result['scores']):
    print(f"  {label}: {score:.3f}")

# Most models struggle with sarcasm
if result['labels'][0] == 'positive':
    print("\n⚠️  Model missed sarcasm (classified as positive)")
else:
    print("\n✓ Model caught sarcasm")

## Best Practices

1. **Use specific hypothesis templates**
   - "This is a {} message" > "This is about {}"

2. **Start with 3-5 labels**
   - Too many labels confuse the model
   - Use hierarchical classification for more

3. **Set thresholds for multi-label**
   - Don't trust scores below 0.5-0.6

4. **Validate on edge cases**
   - Sarcasm, negation, ambiguity

5. **Consider fine-tuning if you have data**
   - Zero-shot is great for bootstrapping
   - Fine-tuned models are more accurate

## Model Comparison

In [None]:
# Compare different models
models = [
    "facebook/bart-large-mnli",
    "cross-encoder/nli-deberta-v3-small",
]

text = "This movie was absolutely terrible"
labels = ["positive", "negative"]

for model_name in models:
    clf = pipeline("zero-shot-classification", model=model_name)
    result = clf(text, candidate_labels=labels)
    print(f"{model_name}:")
    print(f"  Predicted: {result['labels'][0]} ({result['scores'][0]:.3f})")
    print()

## Summary

✅ NLI models predict entailment/contradiction/neutral  
✅ Zero-shot classification converts labels to hypotheses  
✅ No training data needed  
✅ Works across many domains  
✅ Great for prototyping and bootstrapping  
⚠️  May struggle with sarcasm and ambiguity  
⚠️  Specific labels can dominate generic ones

**Next:** Combine cross-encoders and NLI in production tasks!