Fine tuning sentence-transformers/all-MiniLM-L12-v2 with  enron email dataset and LoRa with peft


In [None]:
%pip install transformers datasets torch peft pandas scikit-learn safetensors sentence-transformers
%pip install "numpy<2.0.0"
%pip install kaggle kagglehub
%pip install seaborn matplotlib


In [1]:
import pandas as pd
import torch
import re
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.metrics import classification_report, accuracy_score, f1_score, confusion_matrix
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset, Features, Value, ClassLabel
import safetensors # Required for use_safetensors=True and safe_serialization=True
from peft import LoraConfig, get_peft_model
import kagglehub
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import defaultdict


# Force CPU
torch.cuda.is_available = lambda: False
device = torch.device("cpu")

In [2]:
# download dataset, kaggle will download to cache 
path = kagglehub.dataset_download("wcukierski/enron-email-dataset")
print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/wcukierski/enron-email-dataset?dataset_version_number=2...


100%|██████████| 358M/358M [01:25<00:00, 4.41MB/s] 

Extracting files...





Path to dataset files: /Users/rodneyfinkel/.cache/kagglehub/datasets/wcukierski/enron-email-dataset/versions/2


In [3]:
# move kaggle dataset to current working directory
import shutil
import os

for filename in os.listdir(path):
    src = os.path.join(path, filename)
    dst = os.path.join(".", filename)
    if os.path.isfile(src):
        shutil.copy(src, dst)
        print(f"copied {filename} to current working directory.")
        # Remove the orginal cached dataset
        cache_path = os.path.expanduser("~/.cache/kagglehub/datasets/wcukierski/enron-email-dataset")
        shutil.rmtree(cache_path)
        print("Original cached dataset removed.")
        

copied emails.csv to current working directory.
Original cached dataset removed.


In [4]:
# Load and preprocess dataset
df = pd.read_csv('emails.csv')
df = df.sample(n=20000, random_state=42)

In [5]:
pd.set_option('display.max_columns', None)
df.head()

Unnamed: 0,file,message
427616,shackleton-s/sent/1912.,Message-ID: <21013688.1075844564560.JavaMail.e...
108773,farmer-d/logistics/1066.,Message-ID: <22688499.1075854130303.JavaMail.e...
355471,parks-j/deleted_items/202.,Message-ID: <27817771.1075841359502.JavaMail.e...
457837,stokley-c/chris_stokley/iso/client_rep/41.,Message-ID: <10695160.1075858510449.JavaMail.e...
124910,germany-c/all_documents/1174.,Message-ID: <27819143.1075853689038.JavaMail.e...


In [6]:
# Clean emails and extract subject and delete dates
def clean_email(text):
    # extract subject if available
    subject_match = re.search(r'Subject: (.*?)\n', text, re.IGNORECASE)
    subject = subject_match.group(1) if subject_match else ''
    # Remove headers, signatures and dates
    text = re.sub(r'From:.*\n|To:.*\n|Subject:.*\n|Message-ID:.*\n|Date:.*\n', '', text)
    text = re.sub(r'-{2,}.*?-{2,}', '', text)
    text = re.sub(r'http[s]?://\S+', '', text)
    return text.strip(), subject

df['message'], df['subject'] = zip(*df['message'].apply(clean_email))


In [7]:
# Heursitic spam detection
def is_spam(email_text, subject):
    full_text = f"{subject.lower()} {email_text.lower()}"
    spam_keywords = {'offer': 2, 'free': 2, 'win': 3, 'lottery': 3, 'click here': 3}
    legit_keywords = {'budget': -2, 'contract': -2, 'meeting': -2}
    score = 0
    for kw, weight in spam_keywords.items():
        if re.search(r'\b' + re.escape(kw) + r'\b', full_text):
            score += weight
    for kw, weight in legit_keywords.items():
        if re.search(r'\b' + re.escape(kw) + r'\b', full_text):
            score += weight
    return score > 4

df['is_spam'] = df.apply(lambda x: is_spam(x['message'], x['subject']), axis=1)
print("Spam distribution:", df['is_spam'].value_counts())
df = df[~df['is_spam']]

Spam distribution: is_spam
False    19654
True       346
Name: count, dtype: int64


In [None]:
df.head()

In [8]:
# Prepare text for embedding (combine subject and body)
df['full_text'] = df.apply(lambda x: f"[SUBJECT] {x['subject']} [BODY] {x['message']}", axis=1)

# Load lightweight model (L6-v2 for efficiency, ~80MB, fast on CPU)
model = SentenceTransformer('all-MiniLM-L6-v2')

# Generate embeddings (batch_size=32 to manage memory ~2GB for 20k samples)
embeddings = model.encode(df['full_text'].tolist(), batch_size=32, show_progress_bar=True)

# Cluster for categories (k=4) - adjust k if needed
kmeans_cat = KMeans(n_clusters=4, random_state=42)
df['category_cluster'] = kmeans_cat.fit_predict(embeddings)

# Cluster for priorities (k=3, or try more for sub-groups)
kmeans_prio = KMeans(n_clusters=3, random_state=42)
df['priority_cluster'] = kmeans_prio.fit_predict(embeddings)



Batches:   0%|          | 0/615 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



In [None]:
# Extract top keywords per cluster for heuristic refinement/very noisy data at this point
def get_top_keywords(texts, n=10):
    vectorizer = TfidfVectorizer(max_features=n, stop_words='english')
    tfidf = vectorizer.fit_transform(texts)
    return vectorizer.get_feature_names_out()

# For category clusters
for cluster in range(4):
    cluster_text = df[df['category_cluster'] == cluster]['full_text']
    print(f"Category Cluster {cluster} top keywords: {get_top_keywords(cluster_text)}")

# For priority clusters
for cluster in range(3):
    cluster_text = df[df['priority_cluster'] == cluster]['full_text']
    print(f"Priority Cluster {cluster} top keywords: {get_top_keywords(cluster_text)}")

Category Cluster 0 top keywords: ['20' '2001' 'bcc' 'cc' 'com' 'content' 'ect' 'enron' 'subject' 'version']
Category Cluster 1 top keywords: ['20' 'bcc' 'cc' 'com' 'content' 'ect' 'enron' 'subject' 'type' 'version']
Category Cluster 2 top keywords: ['20' 'cc' 'cn' 'com' 'content' 'ect' 'enron' 'na' 'ou' 'recipients']
Category Cluster 3 top keywords: ['bcc' 'cc' 'charset' 'content' 'folder' 'mime' 'subject' 'text' 'type'
 'version']
Priority Cluster 0 top keywords: ['20' '2001' 'bcc' 'cc' 'com' 'content' 'ect' 'enron' 'subject' 'version']
Priority Cluster 1 top keywords: ['bcc' 'cc' 'com' 'content' 'ect' 'enron' 'folder' 'text' 'type' 'version']
Priority Cluster 2 top keywords: ['20' 'cc' 'cn' 'com' 'content' 'ect' 'enron' 'na' 'ou' 'recipients']


In [10]:
# Example: Map clusters to labels (manual or heuristic-based on keywords)
cluster_to_category = {0: 0, 1: 1, 2: 2, 3: 3}  # Adjust based on keyword inspection
df['category'] = df['category_cluster'].map(cluster_to_category)

# Similarly for priority
cluster_to_priority = {0: 0, 1: 1, 2: 2}
df['priority'] = df['priority_cluster'].map(cluster_to_priority)

print("Pseudo-labeled category distribution:")
print(df['category'].value_counts())
print("Pseudo-labeled priority distribution:")
print(df['priority'].value_counts())

Pseudo-labeled category distribution:
category
0    9006
1    4606
3    3719
2    2323
Name: count, dtype: int64
Pseudo-labeled priority distribution:
priority
0    10777
1     6486
2     2391
Name: count, dtype: int64


In [None]:
# Heuristic spam labeling
# def is_potential_spam(text, subject):
#     spam_keywords = ['offer', 'free', 'win', 'click here', 'urgent', 'limited time', 'act now']
#     return any(keyword in text.lower() or keyword in subject.lower() for keyword in spam_keywords)

# df['label'] = df.apply(lambda x: is_potential_spam(x['message'], x['subject']), axis=1).astype(int)

# spam = df[df['label'] == 1]
# non_spam = df[df['label'] == 0]
# target_spam_count = int(len(df) * 0.05)

# if len(spam) > target_spam_count:
#     print(f"Reducing spam emails from {len(spam)} to {target_spam_count}")
#     spam = resample(spam, n_samples=target_spam_count, random_state=42)
# else:
#     non_spam = resample(non_spam, n_samples=int(target_spam_count * 3), random_state=42)

# df = pd.concat([spam, non_spam]).sample(frac=1, random_state=42).reset_index(drop=True)
# df.drop(columns=['label'], inplace=True)

In [11]:
print(df.shape)
df.head()

(19654, 9)


Unnamed: 0,file,message,subject,is_spam,full_text,category_cluster,priority_cluster,category,priority
427616,shackleton-s/sent/1912.,Mime-Version: 1.0\nContent-Type: text/plain; c...,Re: Credit Derivatives,False,[SUBJECT] Re: Credit Derivatives [BODY] Mime-V...,0,0,0,0
108773,farmer-d/logistics/1066.,Cc: daren.farmer@enron.com\nMime-Version: 1.0\...,Meter #1591 Lamay Gaslift,False,[SUBJECT] Meter #1591 Lamay Gaslift [BODY] Cc:...,0,0,0,0
355471,parks-j/deleted_items/202.,"wollam.erik@enron.com, corrier.brad@enron.com\...",Re: man night again?,False,[SUBJECT] Re: man night again? [BODY] wollam.e...,1,1,1,1
457837,stokley-c/chris_stokley/iso/client_rep/41.,Mime-Version: 1.0\nContent-Type: text/plain; c...,"Enron 480, 1480 charges",False,"[SUBJECT] Enron 480, 1480 charges [BODY] Mime-...",0,0,0,0
124910,germany-c/all_documents/1174.,Mime-Version: 1.0\nContent-Type: text/plain; c...,Transport Deal,False,[SUBJECT] Transport Deal [BODY] Mime-Version: ...,3,1,3,1


In [None]:
# Heuristic category labeling
# def assign_category(text, subject, file):
#     text = text.lower()
#     subject = subject.lower()
#     file = file.lower() 
#     if any(kw in text or kw in subject for kw in ['budget', 'invoice', 'payment', 'financial']) or 'finance' in file or 'accounting' in file:
#         return 0  # Finance
#     elif any(kw in text or kw in subject for kw in ['hiring', 'employee', 'benefits', 'payroll']) or 'hr' in file or 'personnel' in file:
#         return 1  # HR
#     elif any(kw in text or kw in subject for kw in ['contract', 'legal', 'compliance', 'attorney']) or 'legal' in file:
#         return 2  # Legal
#     elif any(kw in text or kw in subject for kw in ['meeting', 'schedule', 'calendar', 'agenda']) or 'meetings' in file or 'admin' in file:
#         return 3  # Admin
#     return 0  # Default to finance (most common in Enron)

# df['category'] = df.apply(lambda x: assign_category(x['message'], x['subject'], x['file']), axis=1)

In [None]:
df.head()

In [None]:
# Heuristic priority labeling
# def assign_priority(text, subject):
#     text = text.lower()
#     subject = subject.lower()
#     if any(kw in text or kw in subject for kw in ['urgent', 'asap', 'deadline', 'immediate', 'action required']) or '!' in text:
#         return 0  # High
#     elif any(kw in text or kw in subject for kw in ['fyi', 'thanks', 'no rush']):
#         return 2  # Low
#     return 1  # Medium

# df['priority'] = df.apply(lambda x: assign_priority(x['message'], x['subject']), axis=1)

In [None]:
df.head()

In [12]:
# Balance categories (~25% each)
print("Original category distribution:")
print(df['category'].value_counts())
category_counts = df['category'].value_counts()
target_count = int(len(df) * 0.25)
df_cat_balanced = pd.DataFrame()
for cat in range(4):  # 0=finance, 1=hr, 2=legal, 3=admin
    cat_df = df[df['category'] == cat]
    print(f"Category {cat}: Original size = {len(cat_df)}, Target = {target_count}")
    if len(cat_df) > target_count:
        cat_df = resample(cat_df, n_samples=target_count, random_state=42, replace=False)
    elif len(cat_df) < target_count:
        cat_df = resample(cat_df, n_samples=target_count, random_state=42, replace=True)
        if len(cat_df) < target_count / 2:
            print(f"Warning: Category {cat} has very few samples ({len(cat_df)}), upsampling may cause overfitting")
    df_cat_balanced = pd.concat([df_cat_balanced, cat_df])
print("Balanced category distribution:")
print(df_cat_balanced['category'].value_counts())

Original category distribution:
category
0    9006
1    4606
3    3719
2    2323
Name: count, dtype: int64
Category 0: Original size = 9006, Target = 4913
Category 1: Original size = 4606, Target = 4913
Category 2: Original size = 2323, Target = 4913
Category 3: Original size = 3719, Target = 4913
Balanced category distribution:
category
0    4913
1    4913
2    4913
3    4913
Name: count, dtype: int64


In [13]:
# Balance priorities (~33% each)
print("Original priority distribution:")
print(df['priority'].value_counts())
priority_counts = df['priority'].value_counts()
target_count = int(len(df) * 0.33)
df_prio_balanced = pd.DataFrame()
for prio in range(3):  # 0=high, 1=medium, 2=low
    prio_df = df[df['priority'] == prio]
    print(f"Priority {prio}: Original size = {len(prio_df)}, Target = {target_count}")
    if len(prio_df) > target_count:
        prio_df = resample(prio_df, n_samples=target_count, random_state=42, replace=False)
    elif len(prio_df) < target_count:
        prio_df = resample(prio_df, n_samples=target_count, random_state=42, replace=True)
        if len(prio_df) < target_count / 2:
            print(f"Warning: Priority {prio} has very few samples ({len(prio_df)}), upsampling may cause overfitting")
    df_prio_balanced = pd.concat([df_prio_balanced, prio_df])
print("Balanced priority distribution:")
print(df_prio_balanced['priority'].value_counts())
    

Original priority distribution:
priority
0    10777
1     6486
2     2391
Name: count, dtype: int64
Priority 0: Original size = 10777, Target = 6485
Priority 1: Original size = 6486, Target = 6485
Priority 2: Original size = 2391, Target = 6485
Balanced priority distribution:
priority
0    6485
1    6485
2    6485
Name: count, dtype: int64


In [None]:
print(df.shape)
df.head()

In [None]:
# Train model function
def train_model(dataset, model_name, num_labels, output_dir, label_column, label_names):
    # Combine subject and message for training
    dataset['text'] = dataset.apply(lambda x: f"[SUBJECT] {x['subject']} [BODY] {x['message']}", axis=1)
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        dataset['text'], dataset[label_column], test_size=0.2, random_state=42
    )
    
     # Verify labels are integers
    train_labels = train_labels.astype(int)
    val_labels = val_labels.astype(int)
    
    # Define dataset features with ClassLabel
    features = Features({
        'text': Value('string'),
        'label': ClassLabel(num_classes=num_labels, names=label_names)
    })
    
    train_dataset = Dataset.from_dict({'text': train_texts, 'label': train_labels}, features=features)
    val_dataset = Dataset.from_dict({'text': val_texts, 'label': val_labels}, features=features)
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Define id2label and label2id
    id2label = {i: label for i, label in enumerate(label_names)}
    label2id = {label: i for i, label in enumerate(label_names)}    
    
    # changed to use_safetensors = True
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, id2label=id2label, label2id=label2id, use_safetensors=True)
    
    # Ensure model is in training mode and freeze base model parameters
    model.train()
    for param in model.parameters():
        param.requires_grad_(False) # Freeze base model
    lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=[
        'query', 'key', 'value',  # Attention layers
        'dense',  # Feed-forward layers
        'classifier'  # Sequence classification head
        
        ], 
        lora_dropout=0.1, 
        bias="none", 
        task_type="SEQ_CLS",
        modules_to_save=['classifier']  # Ensure classifier is saved
        ) 
    model = get_peft_model(model, lora_config)
    
    # Enable gradient for both LoRA and classifier parameters
    trainable_params = []
    for name, param in model.named_parameters():
        if any(substr in name for substr in ["lora", "classifier"]):
            param.requires_grad = True
            trainable_params.append(name)
    
    # Debug: Print trainable parameters to verify
    print(f"Trainable parameters: {trainable_params}")
    
    def tokenize_function(examples):
        return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128) # May need to lower max_length to 64 becuase of memory
    
    train_dataset = train_dataset.map(tokenize_function, batched=True)
    val_dataset = val_dataset.map(tokenize_function, batched=True)
    
    train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
    val_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])  
    
    # Evaluation metrics
    def compute_metrics(p):
        preds = p.predictions.argmax(axis=-1)
        labels = p.label_ids
        return {
            "accuracy": accuracy_score(labels, preds),
            "f1": f1_score(labels, preds, average="weighted")
        }
    
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        warmup_steps=50,
        weight_decay=0.01,
        logging_dir=f'{output_dir}/logs',
        logging_steps=10,
        eval_strategy='steps',
        eval_steps=100,  # Evaluate every 100 steps
        max_grad_norm=1.0,
        #early_stopping_patience=3, 
        early_stopping_threshold=0.01, 
        metric_for_best_model='f1',
        save_strategy='steps',
        save_steps=200,
        gradient_checkpointing=False, # set to False since LoRA handles memory optimization
        fp16=False,
        load_best_model_at_end=True,
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics
    )
    
     # Debug: Print training dataset features
    print(f"Training dataset features: {train_dataset.features}")
    
    # Train and evaluate
    trainer.train()
    metrics = trainer.evaluate()
    print(f"Evaluation metrics for {output_dir}: {metrics}")
    
    # Confusion Matrix
    preds = trainer.predict(val_dataset)
    cm = confusion_matrix(preds.label_ids, preds.predictions.argmax(axis=1))
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=label_names, yticklabels=label_names)
    plt.title(f'Confusion Matrix - {output_dir}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()
    
    model.save_pretrained(output_dir, safe_serialization=True)
    tokenizer.save_pretrained(output_dir)
    return model, tokenizer

In [17]:
# Train category classifier
# changed model to sentence-transformers/all-MiniLM-L12-v2 from microsoft/minilm-l12-h384-uncased
category_label_names = ["Finance", "HR", "Legal", "Admin"]
model_cat, tokenizer_cat = train_model(
    df_cat_balanced, 'sentence-transformers/all-MiniLM-L12-v2', 4, './fine_tuned_minilm_category', 'category', category_label_names) 

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L12-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters: ['base_model.model.bert.encoder.layer.0.attention.self.query.lora_A.default.weight', 'base_model.model.bert.encoder.layer.0.attention.self.query.lora_B.default.weight', 'base_model.model.bert.encoder.layer.0.attention.self.key.lora_A.default.weight', 'base_model.model.bert.encoder.layer.0.attention.self.key.lora_B.default.weight', 'base_model.model.bert.encoder.layer.0.attention.self.value.lora_A.default.weight', 'base_model.model.bert.encoder.layer.0.attention.self.value.lora_B.default.weight', 'base_model.model.bert.encoder.layer.0.attention.output.dense.lora_A.default.weight', 'base_model.model.bert.encoder.layer.0.attention.output.dense.lora_B.default.weight', 'base_model.model.bert.encoder.layer.0.intermediate.dense.lora_A.default.weight', 'base_model.model.bert.encoder.layer.0.intermediate.dense.lora_B.default.weight', 'base_model.model.bert.encoder.layer.0.output.dense.lora_A.default.weight', 'base_model.model.bert.encoder.layer.0.output.dense.lora_B.defaul

Map:   0%|          | 0/15721 [00:00<?, ? examples/s]

Map:   0%|          | 0/3931 [00:00<?, ? examples/s]

TypeError: TrainingArguments.__init__() got an unexpected keyword argument 'early_stopping_patience'

In [None]:
# Train priority classifier
# changed model to sentence-transformers/all-MiniLM-L12-v2 from microsoft/minilm-l12-h384-uncased
priority_label_names = ["High", "Medium", "Low"]
model_prio, tokenizer_prio = train_model(
    df_prio_balanced, 'sentence-transformers/all-MiniLM-L12-v2', 3, './fine_tuned_minilm_priority', 'priority', priority_label_names
)

In [None]:
# Inference functions
def predict_category(email_text, model=model_cat, tokenizer=tokenizer_cat):
    text, subject = clean_email(email_text)
    input_text = f"[SUBJECT] {subject} [BODY] {text}"
    inputs = tokenizer(input_text, return_tensors='pt', truncation=True, padding=True, max_length=64)
    with torch.no_grad():
        outputs = model(**inputs)
    predicted_class = torch.argmax(outputs.logits, dim=1).item()
    return category_label_names[predicted_class]

def predict_priority(email_text, model=model_prio, tokenizer=tokenizer_prio):
    text, subject = clean_email(email_text)
    input_text = f"[SUBJECT] {subject} [BODY] {text}"
    inputs = tokenizer(input_text, return_tensors='pt', truncation=True, padding=True, max_length=64)
    with torch.no_grad():
        outputs = model(**inputs)
    predicted_class = torch.argmax(outputs.logits, dim=1).item()
    return priority_label_names[predicted_class]

# Test inference
sample_email = "Subject: Budget Review\nDate: Wed, 29 Nov 2000 05:40:00 -0800\nPlease review the attached budget for Q3."
print(f"Category: {predict_category(sample_email)}")
print(f"Priority: {predict_priority(sample_email)}")