In [1]:
"""
Reproduction of Section 3.5: Text Classification with SVMs.
Dataset: 20 Newsgroups (Atheism vs. Christianity).
Goal: Show that high accuracy is based on artifacts (Headers).
"""

import sys
import os
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn import metrics

# Add project root to python path so we can import src
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src.explainers.lime_text import LimeTextExplainer
from src.utils.visualization import Visualizer

def run_experiment():
    print("Loading 20 Newsgroups dataset (Atheism vs. Christianity)...")
    categories = ['alt.atheism', 'soc.religion.christian']
    newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
    newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
    
    class_names = ['Atheism', 'Christianity']
    
    # 1. Train the Black Box Model (SVM with RBF Kernel - per paper)
    # The paper mentions "SVM with RBF kernel... trained on unigrams".
    print("Training SVM (this might take a minute)...")
    vectorizer = TfidfVectorizer(lowercase=False) # lowercase=False to keep capitalized headers visible
    model = SVC(kernel='rbf', probability=True) # probability=True is needed for LIME
    
    # Create a pipeline
    c = make_pipeline(vectorizer, model)
    c.fit(newsgroups_train.data, newsgroups_train.target)
    
    # 2. Verify Accuracy
    pred_test = c.predict(newsgroups_test.data)
    f1 = metrics.f1_score(newsgroups_test.target, pred_test)
    print(f"Model F1 Score on Test Set: {f1:.4f}")
    print("Model seems trustworthy based on metrics... or is it?")

    # 3. LIME Explanation
    # We pick a specific instance that looks suspicious. 
    # In the paper, they show an instance where "Posting" and "Host" are key.
    # Let's pick a random instance from the test set.
    idx = 83 # Arbitrary index, or we can loop to find a good example
    text_instance = newsgroups_test.data[idx]
    true_label = newsgroups_test.target[idx]
    
    print("\n--- Explaining Instance #{} ---".format(idx))
    print(f"True Label: {class_names[true_label]}")
    print(f"Model Prediction: {class_names[c.predict([text_instance])[0]]}")
    print(f"Text Snippet: {text_instance[:300]}...\n")

    # Initialize our LIME Explainer
    explainer = LimeTextExplainer(kernel_width=25, random_state=42, verbose=False)
    
    # Explain the predicted class
    # The pipeline.predict_proba takes a list of strings
    print("Running LIME...")
    exp = explainer.explain_instance(
        text_instance, 
        c.predict_proba, 
        labels=(c.predict([text_instance])[0],), # Explain the predicted class
        num_features=6, 
        num_samples=2000
    )
    
    # 4. Visualize
    viz = Visualizer()
    predicted_idx = list(exp.keys())[0] # Get the class we explained
    viz.visualize_text(exp[predicted_idx])
    
    print("\nANALYSIS:")
    print("If you see words like 'Subject', 'From', 'Organization', or 'Re' with high bars,")
    print("you have successfully reproduced the paper's finding: the model is overfitting to headers!")


    

In [2]:
run_experiment()

Loading 20 Newsgroups dataset (Atheism vs. Christianity)...
Training SVM (this might take a minute)...
Model F1 Score on Test Set: 0.9303
Model seems trustworthy based on metrics... or is it?

--- Explaining Instance #83 ---
True Label: Atheism
Model Prediction: Atheism
Text Snippet: From: johnchad@triton.unm.edu (jchadwic)
Subject: Another request for Darwin Fish
Organization: University of New Mexico, Albuquerque
Lines: 11
NNTP-Posting-Host: triton.unm.edu

Hello Gang,

There have been some notes recently asking where to obtain the DARWIN fish.
This is the same question I have...

Running LIME...

=== LIME Explanation ===
Target Class: 0
Local Linear Prediction: 1.0390
Features:
[92m            unm | 0.3053 ███████████████[0m
[92m            edu | 0.1125 █████[0m
[92m           Host | 0.1083 █████[0m
[92m        Posting | 0.1047 █████[0m
[92m           NNTP | 0.0904 ████[0m
[91m             or | -0.0053 [0m

ANALYSIS:
If you see words like 'Subject', 'From', 'Organization