#### Environment Setup

In [1]:
import os

workspace_dir = '/content/spam-detection'
branch = 'feature/bilstm-training'
current_dir = os.getcwd()
if not os.path.exists(workspace_dir) and current_dir != workspace_dir:
    !git clone https://github.com/RationalEar/spam-detection.git
    os.chdir(workspace_dir)
    !git checkout $branch
    !ls -al
    !pip install -q transformers==4.48.0 scikit-learn pandas numpy
    !pip install -q torch --index-url https://download.pytorch.org/whl/cu126
else:
    os.chdir(workspace_dir)
    !git pull origin $branch

Cloning into 'spam-detection'...
remote: Enumerating objects: 116, done.[K
remote: Counting objects: 100% (116/116), done.[K
remote: Compressing objects: 100% (78/78), done.[K
remote: Total 116 (delta 43), reused 101 (delta 29), pack-reused 0 (from 0)[K
Receiving objects: 100% (116/116), 135.78 KiB | 739.00 KiB/s, done.
Resolving deltas: 100% (43/43), done.
Branch 'feature/cnn-model' set up to track remote branch 'feature/cnn-model' from 'origin'.
Switched to a new branch 'feature/cnn-model'
total 48
drwxr-xr-x 10 root root 4096 Jun  8 16:26 .
drwxr-xr-x  1 root root 4096 Jun  8 16:26 ..
drwxr-xr-x  2 root root 4096 Jun  8 16:26 docs
drwxr-xr-x  8 root root 4096 Jun  8 16:26 .git
-rw-r--r--  1 root root   30 Jun  8 16:26 .gitignore
drwxr-xr-x  2 root root 4096 Jun  8 16:26 integrations
drwxr-xr-x  2 root root 4096 Jun  8 16:26 metrics
drwxr-xr-x  2 root root 4096 Jun  8 16:26 models
drwxr-xr-x  2 root root 4096 Jun  8 16:26 preprocess
-rw-r--r--  1 root root 2441 Jun  8 16:26 requi

In [2]:
### If running on Google Colab, mount Google Drive
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import pandas as pd
from utils.functions import set_seed
from utils.constants import DATA_PATH, MODEL_SAVE_PATH

set_seed(42)

DATA_PATH: /content/drive/MyDrive/Projects/spam-detection-data
WORKSPACE_DIR: /content/spam-detection


#### Load the preprocessed data

In [4]:
train_df = pd.read_pickle(open(f"{DATA_PATH}/data/processed/train.pkl", "rb"))
val_df = pd.read_pickle(open(f"{DATA_PATH}/data/processed/val.pkl", "rb"))
test_df = pd.read_pickle(open(f"{DATA_PATH}/data/processed/test.pkl", "rb"))
train_df.head()

Unnamed: 0,subject,text,label,source,sender_hash,reply_to_hash,date
0,Personal Finance: Resolutions You Can Keep,personal finance resolutions you keep motley f...,0,hard_ham,bb339a04eb35de16f6386c5ca0d57fd88b20916663bd84...,3d0448fc6a4d02914e3adf6812ede7310a82838909afac...,"Wed, 02 Jan 2002 13:55:00 -0500"
1,Please help a newbie compile mplayer :-),please help newbie compile mplayer hello i jus...,0,easy_ham,2f890790e67625bdfd8e3c7cca018bf511c2cbca431554...,492368811b79453838d5e7e3692f607adee8d7e71ddd2e...,"Thu, 31 Jan 2002 22:44:14 -0700"
2,Re: Please help a newbie compile mplayer :-),re please help newbie compile mplayer make sur...,0,easy_ham,d83f5738686fa88436e12f3710c15b270666e3061ba627...,492368811b79453838d5e7e3692f607adee8d7e71ddd2e...,"Fri, 01 Feb 2002 00:53:41 -0600"
3,Re: Please help a newbie compile mplayer :-),re please help newbie compile mplayer lance wr...,0,easy_ham,2f890790e67625bdfd8e3c7cca018bf511c2cbca431554...,492368811b79453838d5e7e3692f607adee8d7e71ddd2e...,"Fri, 01 Feb 2002 02:01:44 -0700"
4,Re: Please help a newbie compile mplayer :-),re please help newbie compile mplayer once upo...,0,easy_ham,f9579e33dbc2d625e2ba35d53c611b8c3bd09cca4c7760...,492368811b79453838d5e7e3692f607adee8d7e71ddd2e...,"Fri, 01 Feb 2002 10:29:23 +0100"


In [5]:
# Build vocabulary from training data
from utils.functions import build_vocab

word2idx, idx2word = build_vocab(train_df['text'])

In [6]:
from preprocess.data_loader import load_glove_embeddings

# Load GloVe embeddings
GLOVE_PATH = os.path.join(DATA_PATH, 'data/raw/glove.6B/glove.6B.300d.txt')
embedding_dim = 300
max_len = 200
pretrained_embeddings = load_glove_embeddings(GLOVE_PATH, word2idx, embedding_dim)

#### Train the BiLSTM model

In [None]:
from training.trainer import train_model
start_time = pd.Timestamp.now()
model = train_model('bilstm', train_df, val_df, test_df, embedding_dim=embedding_dim, pretrained_embeddings=pretrained_embeddings,
                model_save_path=MODEL_SAVE_PATH, max_len=max_len, evaluate=True)
end_time = pd.Timestamp.now()

In [None]:
training_time = end_time - start_time
print(f"Training completed in: {training_time}")

In [None]:
# Display metrics in a nice DataFrame format
from metrics.cnn_metrics import compute_metrics
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Prepare test data and make predictions
from utils.functions import encode
import torch

X_test_tensor = torch.tensor([encode(t, word2idx, max_len) for t in test_df['text']])
y_test_tensor = torch.tensor(test_df['label'].values, dtype=torch.float32)

# Get device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move data to device
X_test_tensor = X_test_tensor.to(device)
y_test_tensor = y_test_tensor.to(device)

# Get predictions
with torch.no_grad():
    y_pred_probs = model(X_test_tensor)
    y_pred = (y_pred_probs > 0.5).float()

# Compute metrics using our enhanced function that returns a DataFrame
metrics_df = compute_metrics(y_test_tensor.cpu().numpy(), y_pred.cpu().numpy(), y_pred_probs.cpu().numpy())

# Display metrics
metrics_df

In [None]:
# Visualize the confusion matrix
confusion_matrix = metrics_df.attrs['confusion_matrix']

# Create a visual confusion matrix
cm_values = np.array([
    [confusion_matrix['tn'], confusion_matrix['fp']],
    [confusion_matrix['fn'], confusion_matrix['tp']]
])

plt.figure(figsize=(8, 6))
sns.heatmap(cm_values, annot=True, fmt='d', cmap='Blues',
           xticklabels=['Predicted Negative (Ham)', 'Predicted Positive (Spam)'],
           yticklabels=['Actual Negative (Ham)', 'Actual Positive (Spam)'])
plt.title('Confusion Matrix')
plt.ylabel('Actual Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

# Calculate and display key performance indicators
accuracy = (confusion_matrix['tp'] + confusion_matrix['tn']) / sum(confusion_matrix.values())
precision = confusion_matrix['tp'] / (confusion_matrix['tp'] + confusion_matrix['fp']) if (confusion_matrix['tp'] + confusion_matrix['fp']) > 0 else 0
recall = confusion_matrix['tp'] / (confusion_matrix['tp'] + confusion_matrix['fn']) if (confusion_matrix['tp'] + confusion_matrix['fn']) > 0 else 0

print(f"Model Performance Summary:")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall (Spam Catch Rate): {recall:.4f}")
print(f"F1-Score: {2 * precision * recall / (precision + recall):.4f} if precision and recall are non-zero else 0")