### Installing and importing the required modules

In [None]:
import os
import sys
import torch
import random
import kagglehub
import numpy as np
import pandas as pd
from pathlib import Path
from evaluate import load
from typing import Dict, Any
from datasets import Dataset
from sklearn.preprocessing import LabelEncoder
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForSequenceClassification, TrainingArguments, Trainer, EvalPrediction

# Add the parent directory to the system path
sys.path.append(str(Path().resolve().parent.parent))

# Import local dependencies
from src.utils import get_device, set_seed

### Constants, hyperparameters and model configurations

In [None]:
# Configuration parameters
seed = 42 # Seed for reproducibility
test_size = 0.2 # Train-test split percentage
validation_size = 0.1 # Train-validation split percentage
max_length = 64 # The maximum length of the input sequences
model_id = "bert-base-uncased" # The model ID of the BERT model
save_trained_model = False # Whether to save the model after training
model_path = Path().resolve().parent.parent / "saved_models" / "spam_mails_classifier" # Path to save the trained model to

In [3]:
# Set the seed for reproducibility
set_seed(seed)

In [4]:
# Get the device available on the system
device = get_device()

# Print the detected device
print(f"Detected device: {device}")

Detected device: mps


### Data loading

In [5]:
# Download the dataset
path = kagglehub.dataset_download("venky73/spam-mails-dataset")

In [6]:
# Load the dataset into a pandas dataframe
dataset = pd.read_csv(os.path.join(path, "spam_ham_dataset.csv"))

In [7]:
# Drop null values
dataset.dropna(inplace=True)

In [8]:
# Show a subset of the samples
dataset.head()

Unnamed: 0.1,Unnamed: 0,label,text,label_num
0,605,ham,Subject: enron methanol ; meter # : 988291\r\n...,0
1,2349,ham,"Subject: hpl nom for january 9 , 2001\r\n( see...",0
2,3624,ham,"Subject: neon retreat\r\nho ho ho , we ' re ar...",0
3,4685,spam,"Subject: photoshop , windows , office . cheap ...",1
4,2030,ham,Subject: re : indian springs\r\nthis deal is t...,0


### Tokenizer

In [9]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")

### Preprocess data

In [10]:
# Instantiate  the label encoder
label_encoder = LabelEncoder()

# Encode the target column (category_description) into numeric labels
dataset["labels"] = label_encoder.fit_transform(dataset["label"])
dataset["labels"] = dataset["labels"].astype("int64")

# Extract and print the total number of classes
num_classes = len(label_encoder.classes_)
print(f"Total number of classes: {num_classes}")

Total number of classes: 2


In [11]:
# Convert the Pandas DataFrame to a Hugging Face Dataset
hf_dataset = Dataset.from_pandas(dataset)

# Train-valid-test split
train_dataset, test_dataset = hf_dataset.train_test_split(test_size=test_size, seed=seed).values()
train_dataset, valid_dataset = train_dataset.train_test_split(test_size=validation_size, seed=seed).values()

# Print the number of training and test samples
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(valid_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

Number of training samples: 3722
Number of validation samples: 414
Number of test samples: 1035


In [12]:
# Preprocess the dataset
def preprocess(examples: Dict[str, Any]) -> Dict[str, Any]:
    # Tokenize the input sequences
    return tokenizer(
        examples["text"],
        truncation = True,
        padding = "max_length",
        max_length = max_length
    )

# Tokenize the dataset
tokenized_train = train_dataset.map(preprocess, batched=True, remove_columns=['Unnamed: 0', 'label', 'label_num'])
tokenized_valid = valid_dataset.map(preprocess, batched=True, remove_columns=['Unnamed: 0', 'label', 'label_num'])
tokenized_test = test_dataset.map(preprocess, batched=True, remove_columns=['Unnamed: 0', 'label', 'label_num'])

# Display the sequence length
print(f"Sequence length: {len(tokenized_train[0]['input_ids'])}")

Map:   0%|          | 0/3722 [00:00<?, ? examples/s]

Map:   0%|          | 0/414 [00:00<?, ? examples/s]

Map:   0%|          | 0/1035 [00:00<?, ? examples/s]

Sequence length: 64


In [13]:
# Print a sample sequence
print(tokenizer.decode(random.choice(tokenized_train)['input_ids']))

[CLS] subject : calpine daily and monthly nominations > > ricky a. archer fuel supply 700 louisiana, suite 2700 houston, texas 77002 713 - 830 - 8659 direct 713 - 830 - 8722 fax - calpine daily gas nomination 1. doc - calpine monthly gas nomination [SEP]


### Building the model

In [14]:
# Load the model
model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    num_labels = num_classes
).to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
# Display the model
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [16]:
# Print the number of parameters in the model
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in the model: {total_params}")

Total number of parameters in the model: 109483778


### Trainig the model

In [17]:
# Load the accuracy metric
accuracy_metric = load("accuracy")

# Define the compute_metrics function
def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
    # Extract the logits and labels from the EvalPrediction object
    logits = eval_pred.predictions
    labels = eval_pred.label_ids
    
    # Handle the case where logits is a tuple
    if isinstance(logits, tuple):
        logits = logits[0]
        
    # Get the predicted class labels and compute the accuracy
    preds = np.argmax(logits, axis=-1)
    out = accuracy_metric.compute(predictions=preds, references=labels)
    
    # Safety check
    assert out is not None, "Metrics computation failed."
    
    # Convert all metric values to float
    return {k: float(v) for k, v in out.items()}  

In [None]:
# Mixed precision settings
use_cuda = torch.cuda.is_available() and "cuda" in str(device).lower()
use_pin_memory = bool(use_cuda)
bf16 = bool(use_cuda and torch.cuda.is_bf16_supported())

# Define the training arguments
training_args = TrainingArguments(
    output_dir = "./checkpoints/spam_mails_classifier",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate = 3e-5,
    per_device_train_batch_size = 16,
    per_device_eval_batch_size = 16,
    num_train_epochs = 10,
    weight_decay = 0.01,
    logging_dir = "./logs",
    logging_strategy = "steps",
    logging_steps = 50,
    save_total_limit = 2,
    load_best_model_at_end = True,
    metric_for_best_model = "eval_accuracy",
    greater_is_better = True,
    report_to = "none",
    dataloader_pin_memory = use_pin_memory,
    bf16 = bf16
)

In [19]:
# Instantiate the trainer to train the model
trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = tokenized_train,
    eval_dataset = tokenized_valid,
    compute_metrics = compute_metrics
)

# Training the model
trainer_output = trainer.train()

# Pretty print the training results
print(trainer_output)

Epoch,Training Loss,Validation Loss,Accuracy
1,0.0803,0.087466,0.980676
2,0.0484,0.086129,0.987923
3,0.0001,0.129474,0.978261
4,0.0001,0.087549,0.990338
5,0.0,0.094123,0.987923
6,0.0,0.096318,0.987923
7,0.0,0.098471,0.987923
8,0.0,0.100189,0.987923
9,0.0,0.101088,0.987923
10,0.0,0.10127,0.987923


TrainOutput(global_step=2330, training_loss=0.018922314213118175, metrics={'train_runtime': 1086.1842, 'train_samples_per_second': 34.267, 'train_steps_per_second': 2.145, 'total_flos': 1224124185062400.0, 'train_loss': 0.018922314213118175, 'epoch': 10.0})


### Save the model

In [None]:
# Save the trained model
if save_trained_model:
    # Saving the adapter to the destination path
    model.save_pretrained(model_path)

    # Define the quantization configurations of the model (only for CUDA devices)
    quantization_config = None
    if use_cuda:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit = True,
            bnb_4bit_quant_type = "nf4",
            bnb_4bit_compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
            bnb_4bit_use_double_quant = True
        )
        
    # Reload the fine-tuned model 
    reload_kwargs = {}
    if quantization_config is not None:
        reload_kwargs.update(dict(quantization_config=quantization_config, device_map="auto"))

    # Reload the fine-tuned model
    model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        **reload_kwargs
    ).to(device)

### Evaluation

In [None]:
# Clear GPU cache
if torch.cuda.is_available():
	torch.cuda.empty_cache()

In [None]:

# Set the model to evaluation mode
model.eval();

In [35]:
# Evaluate the model on the test set
predictions = []

# Iterate over the test dataset in batches
for i in range(0, len(tokenized_test), training_args.per_device_eval_batch_size):
    # Prepare the batch
    batch = tokenized_test[i : i + training_args.per_device_eval_batch_size]
    inputs = {k: torch.tensor(v).to(device) for k, v in batch.items() if k not in ["text", "labels"]}
    
    # Forward pass
    with torch.no_grad():
        # Predict the outputs
        outputs = model(**inputs)
        
    # Extract the predictions
    logits = outputs.logits
    preds = logits.argmax(dim=-1)
    
    # Append the predictions to the list
    predictions.extend(preds.cpu().numpy())

In [36]:
# Display metrics
eval_metrics = accuracy_metric.compute(predictions=predictions, references=tokenized_test["labels"])

# Display the test accuracy
assert eval_metrics is not None, "Evaluation metrics are not available."
print(f"Test Accuracy: {eval_metrics['accuracy']:.4f}")

Test Accuracy: 0.9816


### Inference

In [None]:
# Inference samples
inference_inputs = [
    """
    Subject: urgent action required ; account suspension notice
    This is a final warning regarding your account ending in #19872.
    Please override your account settings to avoid deactivation by clicking the secure link provided. Failure to act within 24 hours will result in the suspension of services.
    """, # ham
    """
    Subject: exclusive investment opportunity ; guaranteed profits
    Dear Customer,
    We’ve identified a high-yield opportunity in cryptocurrency trading. To override your financial status, deposit $500 to start earning 300% daily profits. This offer is available for a limited time only. Act now!
    """, # spam
    """
    Subject: overdue payment ; meter #892134
    Please note that your payment for account #892134 is overdue.
    To avoid service interruption, override the pending charges by clicking here and submitting your details. Our records show this must be resolved within 12 hours.
    """, # ham
    """
    Subject: quarterly reporting update ; meter # : 772839
    As part of our quarterly review, please override the system to include the corrected readings from meter #772839. Forward the updated numbers to the finance team before 5 PM today for accurate reporting.
    """, # ham
    """
    Subject: project progress review ; data consolidation
    Hi Team,
    This is a follow-up to Monday’s meeting regarding data consolidation for the project. Kindly override any outdated entries in the shared dashboard with the updated metrics shared earlier. Let’s finalize by EOD for the client review.
    """ # ham
]

# Tokenize a sample input
inputs = tokenizer(
    inference_inputs,
    truncation = True,
    padding = "max_length",
    max_length = max_length,
    return_tensors = "pt"
  ).to(device)

In [38]:
# Perform inference
with torch.no_grad():
    # Compute the output of the model
    outputs = model(**inputs)

    # Extract the logits
    logits = outputs.logits

    # Extract probabilities and predictions
    probabilities = torch.softmax(logits, dim=-1).cpu().numpy()
    predictions = torch.argmax(logits, dim=-1).cpu().numpy()

# Convert the predicted labels to the corresponding categories
predicted_categories = label_encoder.inverse_transform(predictions)

In [39]:
# Display the predictions
for i, (input_text, category, probs) in enumerate(zip(inference_inputs, predicted_categories, probabilities)):
    print(f"Sample {i + 1} --> Input: {input_text.strip()} | Predicted label: {category} | Probability: {np.max(probs)}")

Sample 1 --> Input: Subject: urgent action required ; account suspension notice
    Please override your account settings to avoid deactivation by clicking the secure link provided. Failure to act within 24 hours will result in the suspension of services. | Predicted label: ham | Probability: 0.999392032623291
Sample 2 --> Input: Subject: exclusive investment opportunity ; guaranteed profits
    Dear Customer,
    We’ve identified a high-yield opportunity in cryptocurrency trading. To override your financial status, deposit $500 to start earning 300% daily profits. This offer is available for a limited time only. Act now! | Predicted label: spam | Probability: 0.9999643564224243
Sample 3 --> Input: Subject: overdue payment ; meter #892134
    Please note that your payment for account #892134 is overdue.
    To avoid service interruption, override the pending charges by clicking here and submitting your details. Our records show this must be resolved within 12 hours. | Predicted label: 