# Spam Classification with BERT + LoRA
Parameter-efficient fine-tuning of BERT for SMS spam detection.

In [None]:
!pip install torchmetrics

In [2]:
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np
import math

from tabulate import tabulate
from tqdm import trange
import random
from torchmetrics.classification import Recall, Accuracy, AUROC, Precision

In [None]:
!wget 'https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip'
!unzip -o smsspamcollection.zip

In [None]:
!unzip -o smsspamcollection.zip

In [None]:
!head -10 SMSSpamCollection

In [None]:
!wget 'https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip'
!unzip -o smsspamcollection.zip

In [8]:
file_path = 'SMSSpamCollection'
df = pd.DataFrame({'label':int(), 'text':str()}, index = [])
with open(file_path) as f:
    for line in f.readlines():
        split = line.split('\t')
        df = pd.concat([
                df,
                pd.DataFrame.from_dict({
                    'label': [1 if split[0] == 'spam' else 0],
                    'text': [split[1]]
                })
            ],
            ignore_index=True
        )
df.head()

Unnamed: 0,label,text
0,0,"Go until jurong point, crazy.. Available only ..."
1,0,Ok lar... Joking wif u oni...\n
2,1,Free entry in 2 a wkly comp to win FA Cup fina...
3,0,U dun say so early hor... U c already then say...
4,0,"Nah I don't think he goes to usf, he lives aro..."


In [9]:
text = df.text.values
labels = df.label.values

In [None]:
# Initialize BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [11]:
def print_rand_sentence():
    '''Displays the tokens and respective IDs of a random text sample'''
    index = random.randint(0, len(text)-1)
    table = np.array([tokenizer.tokenize(text[index]),
                    tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text[index]))]).T
    print(tabulate(table,
                 headers = ['Tokens', 'Token IDs'],
                 tablefmt = 'fancy_grid'))

print_rand_sentence()

╒═══════════╤═════════════╕
│ Tokens    │   Token IDs │
╞═══════════╪═════════════╡
│ do        │        2079 │
├───────────┼─────────────┤
│ you       │        2017 │
├───────────┼─────────────┤
│ realize   │        5382 │
├───────────┼─────────────┤
│ that      │        2008 │
├───────────┼─────────────┤
│ in        │        1999 │
├───────────┼─────────────┤
│ about     │        2055 │
├───────────┼─────────────┤
│ 40        │        2871 │
├───────────┼─────────────┤
│ years     │        2086 │
├───────────┼─────────────┤
│ ,         │        1010 │
├───────────┼─────────────┤
│ we        │        2057 │
├───────────┼─────────────┤
│ '         │        1005 │
├───────────┼─────────────┤
│ ll        │        2222 │
├───────────┼─────────────┤
│ have      │        2031 │
├───────────┼─────────────┤
│ thousands │        5190 │
├───────────┼─────────────┤
│ of        │        1997 │
├───────────┼─────────────┤
│ old       │        2214 │
├───────────┼─────────────┤
│ ladies    │       

In [None]:
token_id = []
attention_masks = []

def preprocessing(input_text, tokenizer):
    '''
    Tokenize a text string and return input IDs + attention mask (padded to max length 128).
    '''

  return tokenizer.encode_plus(
        input_text,
        add_special_tokens=True,
        max_length=128,
        pad_to_max_length=True,
        return_attention_mask=True,
        truncation=True
    )


for sample in text:
    encoding_dict = preprocessing(sample, tokenizer)
    token_id.append(encoding_dict['input_ids'])
    attention_masks.append(encoding_dict['attention_mask'])


token_id = torch.tensor(token_id)
attention_masks = torch.tensor(attention_masks)
labels = torch.tensor(labels)

In [13]:
def print_rand_sentence_encoding():
    '''Displays tokens, token IDs and attention mask of a random text sample'''
    index = random.randint(0, len(text) - 1)
    tokens = tokenizer.tokenize(tokenizer.decode(token_id[index]))
    token_ids = [i.numpy() for i in token_id[index]]
    attention = [i.numpy() for i in attention_masks[index]]
    table = np.array([tokens, token_ids, attention]).T
    print(
        tabulate(
            table,
            headers = ['Tokens', 'Token IDs', 'Attention Mask'],
            tablefmt = 'fancy_grid')
    )

print_rand_sentence_encoding()

╒══════════╤═════════════╤══════════════════╕
│ Tokens   │   Token IDs │   Attention Mask │
╞══════════╪═════════════╪══════════════════╡
│ [CLS]    │         101 │                1 │
├──────────┼─────────────┼──────────────────┤
│ not      │        2025 │                1 │
├──────────┼─────────────┼──────────────────┤
│ yet      │        2664 │                1 │
├──────────┼─────────────┼──────────────────┤
│ chi      │        9610 │                1 │
├──────────┼─────────────┼──────────────────┤
│ ##kk     │       19658 │                1 │
├──────────┼─────────────┼──────────────────┤
│ ##u      │        2226 │                1 │
├──────────┼─────────────┼──────────────────┤
│ .        │        1012 │                1 │
├──────────┼─────────────┼──────────────────┤
│ .        │        1012 │                1 │
├──────────┼─────────────┼──────────────────┤
│ wat      │       28194 │                1 │
├──────────┼─────────────┼──────────────────┤
│ ab       │       11113 │        

In [14]:
val_ratio = 0.2
# Batch size for training
batch_size = 16

# Stratified split into train/validation indices
train_idx, val_idx = train_test_split(np.arange(len(labels)), test_size=val_ratio, stratify=labels, random_state=42)

# Build TensorDatasets for train/validation
train_set = TensorDataset(token_id[train_idx], attention_masks[train_idx], labels[train_idx])

val_set = TensorDataset(token_id[val_idx], attention_masks[val_idx], labels[val_idx])

# Prepare DataLoader
train_dataloader = DataLoader(train_set, sampler=RandomSampler(train_set), batch_size=batch_size)

validation_dataloader = DataLoader(val_set, sampler=SequentialSampler(val_set), batch_size=batch_size)

Define the LoRA specific layers.

In [15]:
# Define a LoRA Layer which has A, B and alpha parameters
class LoRALayer(torch.nn.Module):
  def __init__(self, in_dim, out_dim, rank, alpha):
    super().__init__()

    self.A = torch.nn.Parameter(torch.empty(in_dim, rank))

    torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))

    self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
    self.alpha = alpha
    self.rank = rank

  def forward(self, x):
    x = x.matmul(self.A).matmul(self.B)
    return x * (self.alpha / self.rank)

# Linear layer wrapped with a LoRA adapter (residual add)
class LoRALinear(torch.nn.Module):
  def __init__(self, linear, rank, alpha):
    super().__init__()
    self.linear = linear
    self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)

  def forward(self, x):
    # Base linear + LoRA adaptation
    return self.linear(x) + self.lora(x)

def lora_linear_replace(model, rank, alpha):
  # Recursively replace nn.Linear with LoRALinear
  for name, module in model.named_children():
    if isinstance(module, torch.nn.Linear):
      setattr(model, name, LoRALinear(module, rank, alpha))
    else:
      # Continue traversal into submodules
      lora_linear_replace(module, rank, alpha)

### Load specific versions of the model

In [16]:
# Load BERT for binary classification (suppress extra outputs)
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', num_labels=2, output_attentions=False, output_hidden_states=False
)

# Turn off all gradients of the model to start
for param in model.parameters():
    param.requires_grad = False

# Toggle LoRA or full/partial fine-tuning
use_lora  = True
# If this is False, turn off gradients
fine_tune = False
# Track the number of trainable parameters
total_parameters = 0

if use_lora:
  # Replace Linear layers with LoRA-augmented versions
  lora_linear_replace(model, rank=8, alpha=8)
  # Count LoRA trainable parameters
  for param in model.parameters():
        if param.requires_grad:
            total_parameters += param.numel()
else:
  # If fine_tune is off, turn off gradients for all layers other than classifier
  if not fine_tune:
    # Train only the classifier layer
    for param in model.classifier.parameters():
            param.requires_grad = True
            total_parameters += param.numel()
  else:
    # Fine-tune the entire model
    for param in model.parameters():
            param.requires_grad = True
            total_parameters += param.numel()

print(total_parameters)

if use_lora:
  assert(total_parameters == 1345552)
else:
  if fine_tune:
    assert(total_parameters == 109483778)
  else:
    assert(total_parameters == 1538)




Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


1345552


### Set the model to the right device

In [17]:
import platform

# Select available device
if platform.system() == 'Darwin':
    device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
elif platform.system() == 'Linux':
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
else:
    device = torch.device('cpu')
print(device)

cuda


In [18]:
_ = model.to(device)

# Number of training epochs
epochs = 2

In [19]:
# AdamW optimizer
optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr = 5e-5,
    eps = 1e-08
)

### Train the model

In [20]:
# Define evaluation metrics
accuracy = Accuracy(task="binary")
recall = Recall(task="binary")
precision = Precision(task="binary")
auroc = AUROC(task="binary")

In [21]:
model.device

device(type='cuda', index=0)

In [22]:
# Training and validation loop
import tqdm
for _ in tqdm.tqdm(trange(epochs, desc = 'Epoch')):

    # ----- Training -----
    model.train()

    # Tracking variables
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0

    for step, batch in enumerate(train_dataloader):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        optimizer.zero_grad()

        # Forward pass
        outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
        loss = outputs.loss

        # Backward pass
        loss.backward()
        optimizer.step()

        # Update tracking variables
        tr_loss += loss.item()
        nb_tr_examples += b_labels.size(0)
        nb_tr_steps += 1

    # ----- Validation -----

    # Set model to evaluation mode
    model.eval()

    # Tracking variables
    val_accuracy = []
    val_precision = []
    val_recall = []
    val_auroc = []

    for batch in validation_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        with torch.no_grad():
          # Forward pass
            eval_output = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)

        # Calculate validation metrics
        labels = b_labels.cpu()
        predicted_labels = torch.argmax(eval_output.logits, dim=1).cpu()

        val_accuracy.append(accuracy(predicted_labels, labels).item())
        val_recall.append(recall(predicted_labels, labels).item())
        val_precision.append(precision(predicted_labels, labels).item())
        val_auroc.append(auroc(torch.softmax(eval_output.logits, dim=1)[:, 1].cpu(), labels).item())


    print('\n\t - Train loss: {:.4f}'.format(tr_loss / nb_tr_steps))
    print('\t - Validation Accuracy: {:.4f}'.format(sum(val_accuracy)/len(val_accuracy)))
    print('\t - Validation Precision: {:.4f}'.format(sum(val_precision)/len(val_precision)))
    print('\t - Validation Recall: {:.4f}'.format(sum(val_recall)/len(val_recall)))
    print('\t - Validation AUROC: {:.4f}\n'.format(sum(val_auroc)/len(val_auroc)))

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch:  50%|█████     | 1/2 [00:20<00:20, 20.68s/it]


	 - Train loss: 0.1959
	 - Validation Accuracy: 0.9875
	 - Validation Precision: 0.8560
	 - Validation Recall: 0.8643
	 - Validation AUROC: 0.9050




Epoch: 100%|██████████| 2/2 [00:41<00:00, 20.55s/it]
100%|██████████| 2/2 [00:41<00:00, 20.55s/it]


	 - Train loss: 0.0260
	 - Validation Accuracy: 0.9920
	 - Validation Precision: 0.8702
	 - Validation Recall: 0.8643
	 - Validation AUROC: 0.9063






### Test on a specific sentence, see the outcome

In [24]:
new_sentence = 'WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.'

# Tokenize and convert to tensors
test_ids = []
test_attention_mask = []

# Apply the tokenizer
encoding = preprocessing(new_sentence, tokenizer)

# Convert input_ids and attention_mask to tensors
encoding['input_ids'] = torch.tensor(encoding['input_ids']).unsqueeze(0)
encoding['attention_mask'] = torch.tensor(encoding['attention_mask']).unsqueeze(0)

# Extract IDs and Attention Mask
test_ids.append(encoding['input_ids'])
test_attention_mask.append(encoding['attention_mask'])
test_ids = torch.cat(test_ids, dim = 0)
test_attention_mask = torch.cat(test_attention_mask, dim = 0)

# Forward pass, calculate logit predictions
with torch.no_grad():
    output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device))

prediction = 'Spam' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Ham'

print('Input Sentence: ', new_sentence)
print('Predicted Class: ', prediction)

Input Sentence:  WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.
Predicted Class:  Spam


## Experiment: Full Fine-Tuning vs Frozen BERT vs LoRA

We trained three configurations on the SMS Spam dataset to compare adaptation strategies:

1) **Frozen BERT (feature extractor; train classifier only)**
- Val Acc: **0.8657**
- Precision: **0.0000**
- Recall: **0.0000**
- AUROC: **0.8466**
- Misclassified the spam example as **Ham**.
- Takeaway: features stay generic; poor recall for spam.

2) **Full Fine-Tuning (BERT + classifier)**
- Val Acc: **0.9893**
- Precision: **0.8476**
- Recall: **0.8210**
- AUROC: **0.9089**
- Correctly classified the spam example.
- Takeaway: task-specific adaptation boosts all metrics.

3) **LoRA Fine-Tuning (parameter-efficient)**
- Val Acc: **0.9920**
- Precision: **0.8702**
- Recall: **0.8543**
- AUROC: **0.9063**
- Correctly classified the spam example.
- Takeaway: near full-FT quality while updating far fewer params.

### Summary Table

| Config            | Accuracy | Precision | Recall | AUROC  |
|-------------------|---------:|----------:|-------:|:------:|
| Frozen BERT       |  0.8657  |   0.0000  | 0.0000 | 0.8466 |
| Full Fine-Tuning  |  0.9893  |   0.8476  | 0.8210 | 0.9089 |
| LoRA              |  0.9920  |   0.8702  | 0.8543 | 0.9063 |

**Conclusion.** Freezing BERT underperforms due to lack of task adaptation. Full fine-tuning yields strong improvements. LoRA matches (or slightly exceeds) full fine-tuning on this task with a fraction of trainable parameters, making it a practical default.
