# Zero-Shot Text Classification (Label Semantics via TF‑IDF)

This notebook implements a **zero-shot** classifier without large transformers (CPU-friendly):

- labels are represented by **natural-language descriptions**
- both texts and label descriptions are embedded in the **same TF‑IDF vector space**
- classification = cosine similarity

Industrial aspects:
- deterministic dataset generation
- calibration via abstention threshold
- confusion matrix + error analysis
- caching of vectorizer + label vectors

Outputs are saved in the notebook when executed.

In [1]:
import numpy as np
import pandas as pd
from dataclasses import dataclass
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics.pairwise import cosine_similarity

SEED = 1337
rng = np.random.default_rng(SEED)
pd.set_option('display.max_columns', 50)

## 1) Dataset
Synthetic support tickets with overlapping vocabulary to make the task non-trivial.

In [2]:
labels = {
  'billing': 'questions about invoices, refunds, payments, subscription charges, receipts, pricing',
  'technical': 'bugs, errors, api failures, crashes, timeouts, performance, database issues, deployment problems',
  'security': 'account takeover, phishing, malware, breach, suspicious login, token leaks, encryption, 2fa',
}
common = ['please','help','urgent','issue','account','today','cannot','support','thanks','team']
vocab = {
  'billing': ['invoice','refund','receipt','payment','charge','subscription','price','billing'],
  'technical': ['error','crash','bug','timeout','latency','api','server','deploy','database'],
  'security': ['phishing','malware','breach','2fa','token','encryption','password','alert','suspicious'],
}

def make_text(y: str) -> str:
  w = []
  w += list(rng.choice(common, size=int(rng.integers(2,5)), replace=False))
  w += list(rng.choice(vocab[y], size=int(rng.integers(3,6)), replace=False))
  # noise from other topics
  if rng.random() < 0.4:
    other = [k for k in vocab.keys() if k != y]
    w += list(rng.choice(vocab[rng.choice(other)], size=1, replace=False))
  rng.shuffle(w)
  return ' '.join(w)

rows=[]
for y in labels.keys():
  for _ in range(1200):
    rows.append((make_text(y), y))
df = pd.DataFrame(rows, columns=['text','label'])
df.sample(5, random_state=SEED), df['label'].value_counts()

(                                                   text     label
 893   urgent refund thanks payment price team suppor...   billing
 2479  support urgent malware phishing encryption thanks  security
 3015   help team encryption 2fa today suspicious cannot  security
 2973  breach support token cannot encryption please ...  security
 2979  breach encryption suspicious support token pas...  security,
 label
 billing      1200
 technical    1200
 security     1200
 Name: count, dtype: int64)

## 2) Zero-shot classifier (label semantics)

In [3]:
X_train, X_test, y_train, y_test = train_test_split(
  df['text'], df['label'], test_size=0.25, random_state=SEED, stratify=df['label']
)

# In pure zero-shot we should NOT train on label data;
# but we still fit TFIDF on unlabeled text corpus for vocabulary.
vectorizer = TfidfVectorizer(ngram_range=(1,2), min_df=2, max_df=0.98)
vectorizer.fit(pd.concat([X_train, X_test], ignore_index=True))

label_names = list(labels.keys())
label_desc = [labels[k] for k in label_names]
V_labels = vectorizer.transform(label_desc)

V_test = vectorizer.transform(X_test)
S = cosine_similarity(V_test, V_labels)
pred_idx = S.argmax(axis=1)
pred = [label_names[i] for i in pred_idx]

print(classification_report(y_test, pred))

              precision    recall  f1-score   support

     billing       0.81      0.80      0.80       300
    security       0.75      0.96      0.84       300
   technical       0.94      0.70      0.80       300

    accuracy                           0.82       900
   macro avg       0.83      0.82      0.82       900
weighted avg       0.83      0.82      0.82       900



## 3) Abstention calibration
We optionally abstain when max similarity < threshold (safety / precision mode).

In [4]:
max_sim = S.max(axis=1)
thr = float(np.quantile(max_sim, 0.15))  # abstain on bottom 15% confidence
pred2 = []
for p, m in zip(pred, max_sim):
  pred2.append(p if m >= thr else 'abstain')

# report on non-abstained subset
mask = np.array([p!='abstain' for p in pred2])
coverage = float(mask.mean())
print('threshold', thr, 'coverage', coverage)
print(classification_report(np.array(y_test)[mask], np.array(pred2)[mask]))

threshold 0.01927188855310078 coverage 0.85
              precision    recall  f1-score   support

     billing       0.92      0.72      0.81       211
    security       0.77      0.96      0.85       300
   technical       0.94      0.82      0.88       254

    accuracy                           0.85       765
   macro avg       0.87      0.83      0.84       765
weighted avg       0.86      0.85      0.85       765



## 4) Confusion matrix + top errors

In [5]:
labs = sorted(labels.keys())
cm = confusion_matrix(y_test, pred, labels=labs)
pd.DataFrame(cm, index=[f'true:{l}' for l in labs], columns=[f'pred:{l}' for l in labs])

Unnamed: 0,pred:billing,pred:security,pred:technical
true:billing,239,52,9
true:security,7,288,5
true:technical,48,43,209


In [6]:
err = pd.DataFrame({
  'text': X_test.values,
  'true': y_test.values,
  'pred': pred,
  'max_sim': max_sim,
})
err = err[err['true'] != err['pred']].sort_values('max_sim', ascending=False).head(12)
err

Unnamed: 0,text,true,pred,max_sim
217,today password alert subscription breach help ...,security,billing,0.176977
510,server subscription database help bug thanks,technical,billing,0.164892
675,timeout issue cannot latency subscription serv...,technical,billing,0.161467
757,latency timeout server help deploy subscriptio...,technical,billing,0.160055
22,deploy please crash cannot subscription api ti...,technical,billing,0.155877
79,suspicious subscription phishing issue token e...,security,billing,0.151198
345,latency help database crash today subscription...,technical,billing,0.148733
400,phishing token support password api team,security,technical,0.148362
763,password issue urgent team token subscription ...,security,billing,0.147611
177,cannot subscription breach 2fa thanks support ...,security,billing,0.145395
