<a href="https://colab.research.google.com/github/Mohamed-Devp/spam_detection_with_lstms/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Overview
This notebook demonstrates the training of a deep learning model capable of detecting spam messages with high accuracy.

# Data loading
The dataset used in this notebook is the [SMS Spam Collection Dataset](https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset/data) from kaggle.

## Import necessary modules

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, confusion_matrix

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import re
from collections import Counter

import nltk
from nltk.corpus import stopwords

sns.set_style('whitegrid')
nltk.download('stopwords')

## Load and read the dataset

In [None]:
import kagglehub # kaggle api

path = kagglehub.dataset_download("uciml/sms-spam-collection-dataset")
%ls "{path}"

In [9]:
# Load the dataset into dataframe
df = pd.read_csv(f"{path}/spam.csv", encoding = 'latin-1')
df = df[['v1', 'v2']]
df.columns = ['Label', 'Message']

df.head()

Unnamed: 0,Label,Message
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."


In [None]:
df = df.dropna().drop_duplicates().reset_index(drop = True)

print(f"Number of Missing values:\n{df.isna().sum()}",)
print(f"Number of duplicate rows:{df.duplicated().sum()}")

# Data exploration

In [None]:
df.info()

In [None]:
# Plot the label distribution
sns.countplot(df, x = 'Label')
plt.show()

# Data preprocessing

## Tokenization

In [None]:
stop_words = set(stopwords.words('english'))

max_len = 248

sequences = []
labels = []

for row in range(df.shape[0]):
  sentence = df.iloc[row]['Message']
  label = df.iloc[row]['Label']

  # Remove punctuation & special characters
  tokenized = re.sub(r"[^\w\s]", " ", sentence.lower()).split()

  # Remove stop words
  tokenized = [token for token in tokenized if token not in stop_words]
  if len(tokenized) == 0:
    continue

  chars = [char for token in tokenized for char in token]
  sequences.append(chars if len(chars) <= max_len else chars[:max_len])
  labels.append(1 if label == 'spam' else 0)

print(f"Sequence: {sequences[0][:16]} - Label: {labels[0]}")

## Convert tokens into indices

In [43]:
chars = [char for sequence in sequences for char in sequence]
counts = Counter(chars)

char_to_index = {char: index for index, (char, _) in enumerate(counts.items(), start = 2)}
char_to_index['<PAD>'] = 0
char_to_index['<OOV>'] = 1

inputs = torch.empty((len(sequences), max_len), dtype=torch.int32)
targets = torch.tensor(labels, dtype = torch.float32)

for i, sequence in enumerate(sequences):
  sample = [char_to_index.get(char, char_to_index['<OOV>']) for char in sequence]

  if len(sample) < max_len:
    sample += [char_to_index['<PAD>']] * (max_len - len(sample))

  inputs[i] = torch.tensor(sample)

## Split the data

In [44]:
def split(tensor, perc = .8):
  idx = int(len(tensor) * perc)
  return tensor[:idx], tensor[idx:]

X_train, X_test = split(inputs)
y_train, y_test = split(targets)

X_valid, X_test = split(X_test, perc = .5)
y_valid, y_test = split(y_test, perc = .5)

# Build and Train the classifier

## Define the model architecture




In [45]:
class CharCNN(nn.Module):
  def __init__(
      self,
      vocab_size,
      embedding_dim,
      num_kernels
  ):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, embedding_dim)

    self.conv = nn.Conv1d(in_channels=embedding_dim, out_channels=num_kernels, kernel_size=5)

    self.output = nn.Linear(num_kernels, 1)

    self.criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(5.0))
    self.optimizer = SGD(self.parameters(), lr = 0.01)

  def forward(self, inputs, probs = False):
    embedded = self.embedding(inputs).permute(0, 2, 1)

    conv_out = F.relu(self.conv(embedded))
    conv_out = F.max_pool1d(conv_out, kernel_size=conv_out.shape[-1]).squeeze(-1)

    output = self.output(conv_out).squeeze(-1)

    return F.sigmoid(output) if probs else output

  def backward(self, inputs, targets):
    self.optimizer.zero_grad()

    outputs = self.forward(inputs)
    loss = self.criterion(outputs, targets)

    loss.backward()
    self.optimizer.step()

    return loss.item()

model = CharCNN(
    vocab_size=len(char_to_index),
    embedding_dim=32,
    num_kernels=64
)

losses = {
    'train': [],
    'valid': []
}

## Train the model

In [None]:
# Fit the model to the data
batch_size = 64
num_epochs = 125

for epoch in range(num_epochs):
  epoch_loss = num_batchs = 0

  # Shuffle the data before each epoch
  shuffled = torch.randperm(len(X_train))

  for start in range(0, len(X_train), batch_size):
    end = min(start + batch_size, len(X_train))

    X_batch = X_train[shuffled][start:end]
    y_batch = y_train[shuffled][start:end]

    loss = model.backward(X_batch, y_batch)
    epoch_loss += loss
    num_batchs += 1

  # Compute the validation loss
  with torch.no_grad():
    outputs = model.forward(X_valid)
    valid_loss = model.criterion(outputs, y_valid)

  losses['train'].append(epoch_loss / num_batchs)
  losses['valid'].append(valid_loss.item())

  if epoch == 0 or (epoch + 1) % 5 == 0:
    print(f"Epoch: {epoch + 1} - Training: {losses['train'][-1]:.2f} - Validation: {losses['valid'][-1]:.2f}")

## Evaluate the model

In [None]:
# Plot the loss curve
sns.lineplot(losses['train'], label = 'Training')
sns.lineplot(losses['valid'], label = 'Validation')
plt.show()

In [None]:
probs = model.forward(X_test, probs = True)

threshold = .5
preds = (probs >= threshold).type(torch.int32)

print(f"Accuracy score: {accuracy_score(y_test, preds):.2f}")
print(f"Recall score: {recall_score(y_test, preds):.2f}")
print(f"Precision score: {precision_score(y_test, preds):.2f}")
print(f"F1 score: {f1_score(y_test, preds):.2f}")

In [None]:
# plot the confusion matrix
conf_matrix = confusion_matrix(y_test, preds)

plt.figure(figsize=(6, 4))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", cbar=False,
            xticklabels=['ham', 'spam'], yticklabels=['ham', 'spam'])
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix Heatmap")
plt.show()