In [None]:
import os
import pickle
import email_read_util

## Download 2007 TREC Public Spam Corpus
1. Read the "Agreement for use"
   https://plg.uwaterloo.ca/~gvcormac/treccorpus07/

2. Download 255 MB Corpus (trec07p.tgz) and untar into the 'chapter1/datasets' directory

3. Check that the below paths for 'DATA_DIR' and 'LABELS_FILE' exist

In [None]:
DATA_DIR = 'trec07p/data/'
LABELS_FILE = 'trec07p/full/index'
TRAINING_SET_RATIO = 0.7

In [None]:
labels = {}
spam_words = set()
ham_words = set()

In [None]:
# Read the labels
with open(LABELS_FILE) as f:
    for line in f:
        line = line.strip()
        label, key = line.split()
        labels[key.split('/')[-1]] = 1 if label.lower() == 'ham' else 0

In [None]:
# Split corpus into train and test sets
filelist = os.listdir(DATA_DIR)
X_train = filelist[:int(len(filelist)*TRAINING_SET_RATIO)]
X_test = filelist[int(len(filelist)*TRAINING_SET_RATIO):]

In [None]:
import nltk
nltk.download('punkt_tab')

In [None]:
from collections import defaultdict
from nltk import ngrams

In [None]:
# Parameters
NGRAM_SIZE = 2  # Use bigrams (you can change this to 3 for trigrams, etc.)
MIN_FREQ = 5    # Minimum frequency for n-grams to be considered

In [None]:

# Train the model - find frequent n-grams in spam
if not os.path.exists('spam_ngrams.pkl'):
    spam_ngram_counts = defaultdict(int)
    ham_ngram_counts = defaultdict(int)

    for filename in X_train:
        path = os.path.join(DATA_DIR, filename)
        if filename in labels:
            label = labels[filename]
            stems = email_read_util.load(path)
            if not stems:
                continue

            # Generate n-grams
            stem_ngrams = ngrams(stems, NGRAM_SIZE)

            # Count n-grams based on label
            if label == 0:  # Spam
                for ng in stem_ngrams:
                    spam_ngram_counts[ng] += 1
            else:  # Ham
                for ng in stem_ngrams:
                    ham_ngram_counts[ng] += 1

    # Filter n-grams that appear frequently in spam but rarely in ham
    spam_indicative_ngrams = set()
    for ng, count in spam_ngram_counts.items():
        if count >= MIN_FREQ and ham_ngram_counts.get(ng, 0) < count/2:
            spam_indicative_ngrams.add(ng)

    pickle.dump(spam_indicative_ngrams, open('spam_ngrams.pkl', 'wb'))
else:
    spam_indicative_ngrams = pickle.load(open('spam_ngrams.pkl', 'rb'))


print(f'Found {len(spam_indicative_ngrams)} spam-indicative {NGRAM_SIZE}-grams')


In [None]:

# Test the model
fp = 0
tp = 0
fn = 0
tn = 0

for filename in X_test:
    path = os.path.join(DATA_DIR, filename)
    if filename in labels:
        true_label = labels[filename]
        stems = email_read_util.load(path)
        if not stems:
            continue

        # Generate n-grams for this email
        stem_ngrams = set(ngrams(stems, NGRAM_SIZE))

        # Check for spam-indicative n-grams
        spam_score = len(stem_ngrams & spam_indicative_ngrams)

        # Predict spam if any spam-indicative n-grams found
        predicted_label = 0 if spam_score > 0 else 1

        # Update confusion matrix
        if true_label == 1 and predicted_label == 1:
            tn += 1
        elif true_label == 1 and predicted_label == 0:
            fp += 1
        elif true_label == 0 and predicted_label == 1:
            fn += 1
        elif true_label == 0 and predicted_label == 0:
            tp += 1


In [None]:
# Display results
from IPython.display import HTML, display
conf_matrix = [[tn, fp],
               [fn, tp]]
display(HTML('<table><tr>{}</tr></table>'.format(
    '</tr><tr>'.join('<td>{}</td>'.format(
        '</td><td>'.join(str(_) for _ in row))
                     for row in conf_matrix))))

count = tn + tp + fn + fp
percent_matrix = [["{:.1%}".format(tn/count), "{:.1%}".format(fp/count)],
                  ["{:.1%}".format(fn/count), "{:.1%}".format(tp/count)]]
display(HTML('<table><tr>{}</tr></table>'.format(
    '</tr><tr>'.join('<td>{}</td>'.format(
        '</td><td>'.join(str(_) for _ in row))
                     for row in percent_matrix))))

print("Classification accuracy: {}".format("{:.1%}".format((tp+tn)/count)))
print("Precision (spam): {}".format("{:.1%}".format(tp/(tp+fp))))
print("Recall (spam): {}".format("{:.1%}".format(tp/(tp+fn))))