In [25]:
from datasets import load_dataset
from transformers import (
    DistilBertTokenizerFast,
    DistilBertForSequenceClassification,
    Trainer,
    TrainingArguments
)
import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support


In [26]:
# 1.Load Dataset
dataset = load_dataset("SetFit/enron_spam")

Repo card metadata block was not found. Setting CardData to empty.


In [27]:
# Expand to contextual labels
def map_label(example):
    if example["label"] == 1:
        example["category"] = "Promotional"
    else:
        text = example["text"].lower()
        if "meeting" in text or "project" in text or "deadline" in text:
            example["category"] = "Work-related"
        elif "urgent" in text or "immediately" in text:
            example["category"] = "Urgent"
        else:
            example["category"] = "Informational"
    return example



In [28]:
dataset = dataset.map(map_label)

In [29]:
# Select a smaller subset for speed
train_ds = dataset["train"].shuffle(seed=42).select(range(400))
test_ds = dataset["test"].shuffle(seed=42).select(range(150))

# Encode label categories
unique_labels = list(set(train_ds["category"]))
label2id = {label: i for i, label in enumerate(unique_labels)}
id2label = {i: label for label, i in label2id.items()}

def encode_labels(example):
    example["label"] = label2id[example["category"]]
    return example

train_ds = train_ds.map(encode_labels)
test_ds = test_ds.map(encode_labels)


In [30]:
# 2 Tokenization

tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=128)

train_enc = train_ds.map(tokenize, batched=True)
test_enc = test_ds.map(tokenize, batched=True)
train_enc.set_format("torch", columns=["input_ids", "attention_mask", "label"])
test_enc.set_format("torch", columns=["input_ids", "attention_mask", "label"])


Map:   0%|          | 0/150 [00:00<?, ? examples/s]

In [37]:
# 3️ Model Setup

model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=len(unique_labels),
    id2label=id2label,
    label2id=label2id
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
# 4️ Training Setup

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=2,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_dir="./logs"
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average="weighted"
    )
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_enc,
    eval_dataset=test_enc,
    compute_metrics=compute_metrics
)


In [33]:

# 5 Train Model

trainer.train()




Step,Training Loss


TrainOutput(global_step=100, training_loss=0.6030052947998047, metrics={'train_runtime': 500.52, 'train_samples_per_second': 1.598, 'train_steps_per_second': 0.2, 'total_flos': 26494424678400.0, 'train_loss': 0.6030052947998047, 'epoch': 2.0})

In [34]:
# 6️ Evaluate Model

results = trainer.evaluate()
print("\n📊 Evaluation Results:", results)





📊 Evaluation Results: {'eval_loss': 0.33723804354667664, 'eval_accuracy': 0.8866666666666667, 'eval_f1': 0.8486454910551295, 'eval_precision': 0.816031746031746, 'eval_recall': 0.8866666666666667, 'eval_runtime': 21.0302, 'eval_samples_per_second': 7.133, 'eval_steps_per_second': 0.903, 'epoch': 2.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [36]:
# 8️ Custom Email Prediction

def predict_email_context(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    outputs = model(**inputs)
    pred = torch.argmax(outputs.logits).item()
    return id2label[pred]

sample_email = "We have an urgent client meeting tomorrow. Please prepare the report."
print(f"\n📧 Email: {sample_email}")
print(f"Predicted Context: {predict_email_context(sample_email)}")


📧 Email: We have an urgent client meeting tomorrow. Please prepare the report.
Predicted Context: Informational


In [35]:
# 7️ Save Model

model.save_pretrained("context_email_classifier")
tokenizer.save_pretrained("context_email_classifier")


('context_email_classifier\\tokenizer_config.json',
 'context_email_classifier\\special_tokens_map.json',
 'context_email_classifier\\vocab.txt',
 'context_email_classifier\\added_tokens.json',
 'context_email_classifier\\tokenizer.json')