In [None]:
"""
SMS Spam Classifier using DistilBERT
A complete implementation for fine-tuning a transformer model on spam detection
"""

# Disable wandb to avoid login prompts
import os
os.environ["WANDB_DISABLED"] = "true"

import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np

# ============================================================================
# Step 1: Explore the Dataset
# ============================================================================

print("Loading SMS Spam Dataset...")
dataset = load_dataset("sms_spam")

print("\n=== Dataset Info ===")
print(f"Dataset structure: {dataset}")
print(f"\nColumn names: {dataset['train'].column_names}")
print(f"Number of samples: {len(dataset['train'])}")

# Check first few examples
print("\n=== Sample Data ===")
for i in range(3):
    print(f"Example {i+1}:")
    print(f"  Text: {dataset['train'][i]['sms']}")
    print(f"  Label: {dataset['train'][i]['label']}")

# ============================================================================
# Step 2: Create Label Dictionary
# ============================================================================

label_map = {0: 'ham', 1: 'spam'}
id_map = {'ham': 0, 'spam': 1}

print("\n=== Label Mapping ===")
print(f"Label map (ID → Name): {label_map}")
print(f"ID map (Name → ID): {id_map}")

# ============================================================================
# Step 3: Tokenize and Preprocess
# ============================================================================

print("\n=== Loading Tokenizer ===")
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_function(examples):
    """Tokenize the text with padding and truncation"""
    return tokenizer(
        examples['sms'],
        padding='max_length',
        truncation=True,
        max_length=128
    )

print("Tokenizing dataset...")
tokenized_dataset = dataset.map(tokenize_function, batched=True)

# ============================================================================
# Step 4: Split Train & Evaluation Data
# ============================================================================

print("\n=== Preparing Train/Eval Split ===")

# Shuffle the dataset
shuffled_dataset = tokenized_dataset['train'].shuffle(seed=42)

# Get total dataset size
total_size = len(shuffled_dataset)
print(f"Total dataset size: {total_size}")

# Adapt split based on available data
# Use 80% for training and 20% for evaluation (or max 5000/1000 if available)
train_size = min(5000, int(total_size * 0.8))
eval_size = min(1000, total_size - train_size)

print(f"Using {train_size} samples for training and {eval_size} for evaluation")

# Split into train and eval
train_dataset = shuffled_dataset.select(range(train_size))
eval_dataset = shuffled_dataset.select(range(train_size, train_size + eval_size))

print(f"Training samples: {len(train_dataset)}")
print(f"Evaluation samples: {len(eval_dataset)}")

# ============================================================================
# Step 5: Fine-Tune DistilBERT
# ============================================================================

print("\n=== Loading Model ===")
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,
    id2label=label_map,
    label2id=id_map
)

# Define metrics computation
def compute_metrics(pred):
    """Calculate accuracy, precision, recall, and F1 score"""
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average='binary'
    )
    acc = accuracy_score(labels, preds)

    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

# Set up training arguments
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    logging_dir='./logs',
    logging_steps=100,
    save_total_limit=2,
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

print("\n=== Starting Training ===")
trainer.train()

# Evaluate the model
print("\n=== Final Evaluation ===")
eval_results = trainer.evaluate()
print(f"Evaluation Results: {eval_results}")

# ============================================================================
# Step 6: Save Model
# ============================================================================

print("\n=== Saving Model ===")
trainer.save_model("./spam_model")
tokenizer.save_pretrained("./spam_model")
print("Model and tokenizer saved to './spam_model'")

# ============================================================================
# Step 7: Load Model & Make Predictions
# ============================================================================

print("\n=== Loading Saved Model for Inference ===")

# Load the saved model and tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained("./spam_model")
loaded_model = AutoModelForSequenceClassification.from_pretrained("./spam_model")

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model.to(device)
loaded_model.eval()

def predict_with_label(text):
    """
    Predict whether a text is spam or ham

    Args:
        text (str): Input text message

    Returns:
        dict: Prediction results with label and confidence
    """
    # Tokenize input
    inputs = loaded_tokenizer(
        text,
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors="pt"
    )

    # Move inputs to device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get prediction
    with torch.no_grad():
        outputs = loaded_model(**inputs)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)
        predicted_class = torch.argmax(probabilities, dim=-1).item()
        confidence = probabilities[0][predicted_class].item()

    # Get label name
    label_name = label_map[predicted_class]

    return {
        'text': text,
        'label': label_name,
        'confidence': f"{confidence:.4f}",
        'probabilities': {
            'ham': f"{probabilities[0][0].item():.4f}",
            'spam': f"{probabilities[0][1].item():.4f}"
        }
    }

# ============================================================================
# Test Examples
# ============================================================================

print("\n=== Testing Predictions ===")

test_texts = [
    "Congratulations! You've won a free ticket.",
    "Hey, are we meeting tomorrow?",
    "URGENT! You have won $1000000! Click here NOW!",
    "Can you pick up some milk on your way home?",
    "FREE entry to win a brand new iPhone! Text WIN to 12345",
    "Meeting scheduled for 3pm in conference room B"
]

for text in test_texts:
    result = predict_with_label(text)
    print(f"\nText: '{result['text']}'")
    print(f"Prediction: {result['label'].upper()}")
    print(f"Confidence: {result['confidence']}")
    print(f"Probabilities: Ham={result['probabilities']['ham']}, Spam={result['probabilities']['spam']}")

print("\n=== Training Complete! ===")
print("Your spam classifier is ready to use!")
print("\nTo use it later:")
print("1. Load the model: AutoModelForSequenceClassification.from_pretrained('./spam_model')")
print("2. Load the tokenizer: AutoTokenizer.from_pretrained('./spam_model')")
print("3. Use predict_with_label(text) function for predictions")

Loading SMS Spam Dataset...

=== Dataset Info ===
Dataset structure: DatasetDict({
    train: Dataset({
        features: ['sms', 'label'],
        num_rows: 5574
    })
})

Column names: ['sms', 'label']
Number of samples: 5574

=== Sample Data ===
Example 1:
  Text: Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...

  Label: 0
Example 2:
  Text: Ok lar... Joking wif u oni...

  Label: 0
Example 3:
  Text: Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's

  Label: 1

=== Label Mapping ===
Label map (ID → Name): {0: 'ham', 1: 'spam'}
ID map (Name → ID): {'ham': 0, 'spam': 1}

=== Loading Tokenizer ===
Tokenizing dataset...


Map:   0%|          | 0/5574 [00:00<?, ? examples/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



=== Preparing Train/Eval Split ===
Total dataset size: 5574
Using 4459 samples for training and 1000 for evaluation
Training samples: 4459
Evaluation samples: 1000

=== Loading Model ===


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).



=== Starting Training ===


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.0493,0.026105,0.995,0.97479,0.983051,0.978903
2,0.0315,0.028424,0.992,0.958333,0.974576,0.966387



=== Final Evaluation ===


Evaluation Results: {'eval_loss': 0.02610485814511776, 'eval_accuracy': 0.995, 'eval_precision': 0.9747899159663865, 'eval_recall': 0.9830508474576272, 'eval_f1': 0.9789029535864979, 'eval_runtime': 3.5367, 'eval_samples_per_second': 282.75, 'eval_steps_per_second': 17.813, 'epoch': 2.0}

=== Saving Model ===
Model and tokenizer saved to './spam_model'

=== Loading Saved Model for Inference ===

=== Testing Predictions ===

Text: 'Congratulations! You've won a free ticket.'
Prediction: HAM
Confidence: 0.9967
Probabilities: Ham=0.9967, Spam=0.0033

Text: 'Hey, are we meeting tomorrow?'
Prediction: HAM
Confidence: 0.9984
Probabilities: Ham=0.9984, Spam=0.0016

Text: 'URGENT! You have won $1000000! Click here NOW!'
Prediction: SPAM
Confidence: 0.9926
Probabilities: Ham=0.0074, Spam=0.9926

Text: 'Can you pick up some milk on your way home?'
Prediction: HAM
Confidence: 0.9979
Probabilities: Ham=0.9979, Spam=0.0021

Text: 'FREE entry to win a brand new iPhone! Text WIN to 12345'
Prediction: