In [None]:
import os
import pickle
import email_read_util
from matplotlib import pyplot as plt


## Download 2007 TREC Public Spam Corpus
1. Read the "Agreement for use"
   https://plg.uwaterloo.ca/~gvcormac/treccorpus07/

2. Download 255 MB Corpus (trec07p.tgz) and untar into the 'chapter1/datasets' directory

3. Check that the below paths for 'DATA_DIR' and 'LABELS_FILE' exist

In [None]:
DATA_DIR = 'trec07p/data/'
LABELS_FILE = 'trec07p/full/index'
TRAINING_SET_RATIO = 0.7

In [None]:
labels = {}
spam_words = set()
ham_words = set()

In [None]:
# Read the labels
with open(LABELS_FILE) as f:
    for line in f:
        line = line.strip()
        label, key = line.split()
        labels[key.split('/')[-1]] = 1 if label.lower() == 'ham' else 0

In [None]:
# Split corpus into train and test sets
filelist = os.listdir(DATA_DIR)
X_train = filelist[:int(len(filelist)*TRAINING_SET_RATIO)]
X_test = filelist[int(len(filelist)*TRAINING_SET_RATIO):]

In [None]:
import nltk
nltk.download('punkt_tab')

In [None]:
# Initialize data structures
spam_word_counts = defaultdict(int)
ham_word_counts = defaultdict(int)
total_spam = 0
total_ham = 0

# Count word occurrences
for filename in X_train:
    path = os.path.join(DATA_DIR, filename)
    if filename in labels:
        label = labels[filename]
        stems = email_read_util.load(path)
        if not stems:
            continue

        if label == 0:  # Spam
            total_spam += 1
            for word in set(stems):
                spam_word_counts[word] += 1
        else:  # Ham
            total_ham += 1
            for word in set(stems):
                ham_word_counts[word] += 1

In [None]:
# Calculate word statistics
word_stats = {}
all_words = set(spam_word_counts.keys()).union(set(ham_word_counts.keys()))
for word in all_words:
    spam_count = spam_word_counts.get(word, 0)
    ham_count = ham_word_counts.get(word, 0)
    spam_percent = (spam_count / total_spam) * 100
    ham_percent = (ham_count / total_ham) * 100
    word_stats[word] = (spam_percent, ham_percent)

In [None]:
# Evaluate different thresholds from 1% to 20%
thresholds = range(1, 21)
results = []

for min_spam_percent in thresholds:
    # Find words that meet the current threshold criteria
    spam_words = set()
    for word, (spam_p, ham_p) in word_stats.items():
        if spam_p >= min_spam_percent:
            spam_words.add(word)

    # Test the model
    tp = fp = fn = tn = 0

    for filename in X_test:
        path = os.path.join(DATA_DIR, filename)
        if filename in labels:
            true_label = labels[filename]
            stems = email_read_util.load(path)
            if not stems:
                continue

            # Check for spam words
            stem_set = set(stems)
            spam_score = len(stem_set & spam_words)

            # Predict spam if any spam words found
            predicted_label = 0 if spam_score > 0 else 1

            # Update confusion matrix
            if true_label == 1 and predicted_label == 1:
                tn += 1
            elif true_label == 1 and predicted_label == 0:
                fp += 1
            elif true_label == 0 and predicted_label == 1:
                fn += 1
            elif true_label == 0 and predicted_label == 0:
                tp += 1

    # Calculate metrics
    accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    results.append({
        'threshold': min_spam_percent,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'spam_words_count': len(spam_words)
    })

In [None]:
# Plot the results
plt.figure(figsize=(12, 8))

# Accuracy
plt.subplot(2, 2, 1)
plt.plot(thresholds, [r['accuracy'] for r in results], marker='o')
plt.title('Accuracy vs Spam Percentage Threshold')
plt.xlabel('Minimum Spam Percentage (%)')
plt.ylabel('Accuracy')
plt.grid(True)

# Precision
plt.subplot(2, 2, 2)
plt.plot(thresholds, [r['precision'] for r in results], marker='o', color='orange')
plt.title('Precision vs Spam Percentage Threshold')
plt.xlabel('Minimum Spam Percentage (%)')
plt.ylabel('Precision')
plt.grid(True)

# Recall
plt.subplot(2, 2, 3)
plt.plot(thresholds, [r['recall'] for r in results], marker='o', color='green')
plt.title('Recall vs Spam Percentage Threshold')
plt.xlabel('Minimum Spam Percentage (%)')
plt.ylabel('Recall')
plt.grid(True)

# F1 Score
plt.subplot(2, 2, 4)
plt.plot(thresholds, [r['f1'] for r in results], marker='o', color='red')
plt.title('F1 Score vs Spam Percentage Threshold')
plt.xlabel('Minimum Spam Percentage Threshold (%)')
plt.ylabel('F1 Score')
plt.grid(True)

plt.tight_layout()
plt.show()

# Print summary table
print("\nPerformance Summary:")
print("Threshold% | Spam Words | Accuracy | Precision | Recall | F1 Score")
print("---------------------------------------------------------------")
for r in results:
    print(f"{r['threshold']:>9}% | {r['spam_words_count']:>10} | {r['accuracy']:.3f} | {r['precision']:.3f} | {r['recall']:.3f} | {r['f1']:.3f}")