Build a spam classifier (a more challenging exercise):

 • Download examples of spam and ham from Apache SpamAssassin’s public
 datasets.
 
 • Unzip the datasets and familiarize yourself with the data format.
 
 • Split the datasets into a training set and a test set.
 
 • Write a data preparation pipeline to convert each email into a feature vector.
 
 Your preparation pipeline should transform an email into a (sparse) vector
 indicating the presence or absence of each possible word. For example, if all
 emails only ever contain four words, “Hello,” “how,” “are,” “you,” then the email
 “Hello you Hello Hello you” would be converted into a vector [1, 0, 0, 1]
 (meaning [“Hello” is present, “how” is absent, “are” is absent, “you” is
 present]), or [3, 0, 0, 2] if you prefer to count the number of occurrences of
 each word.
 
 • You may want to add hyperparameters to your preparation pipeline to control
 whether or not to strip off email headers, convert each email to lowercase,
 remove punctuation, replace all URLs with “URL,” replace all numbers with
 “NUMBER,” or even perform stemming (i.e., trim off word endings; there are
 Python libraries available to do this).
 
 • Then try out several classifiers and see if you can build a great spam classifier,
 with both high recall and high precision.

In [33]:
import os
import tarfile
import urllib.request

DOWNLOAD_ROOT = "http://spamassassin.apache.org/old/publiccorpus/"
HAM_URL = DOWNLOAD_ROOT + "20030228_easy_ham.tar.bz2"
SPAM_URL = DOWNLOAD_ROOT + "20030228_spam.tar.bz2"
SPAM_PATH = os.path.join("datasets", "spam")

def fetch_spam_data(spam_url=SPAM_URL, spam_path=SPAM_PATH):
    if not os.path.isdir(spam_path):
        os.makedirs(spam_path)
    for filename, url in (("ham.tar.bz2", HAM_URL), ("spam.tar.bz2", SPAM_URL)):
        path = os.path.join(spam_path, filename)
        if not os.path.isfile(path):
            urllib.request.urlretrieve(url, path)
        tar_bz2_file = tarfile.open(path)
        tar_bz2_file.extractall(path=SPAM_PATH)
        tar_bz2_file.close()

The dataset folder exists.

The datasets are downloaded only if they are not already present.

The .tar.bz2 files are extracted for further processing.

In [36]:
fetch_spam_data()

  tar_bz2_file.extractall(path=SPAM_PATH)


In [37]:
HAM_DIR = os.path.join(SPAM_PATH, "easy_ham")
SPAM_DIR = os.path.join(SPAM_PATH, "spam")
ham_filenames = [name for name in sorted(os.listdir(HAM_DIR)) if len(name) > 20]
spam_filenames = [name for name in sorted(os.listdir(SPAM_DIR)) if len(name) > 20]

To remove unwanted files (e.g., hidden system files or metadata files that may have short names).

Email files in the dataset likely have long, unique names (e.g., message IDs).

Ensures only real email files are processed, avoiding errors when parsing email contents.

In [39]:
len(ham_filenames)

2500

In [43]:
len(spam_filenames)

500

In [45]:
import email
import email.policy

def load_email(is_spam, filename, spam_path=SPAM_PATH):
    directory = "spam" if is_spam else "easy_ham"
    with open(os.path.join(spam_path, directory, filename), "rb") as f:
        return email.parser.BytesParser(policy=email.policy.default).parse(f)

Reads a spam or ham email file from the dataset.

Parses the email using Python's email module.

Returns an EmailMessage object, which contains the email body, headers, and attachments.

In [48]:
ham_emails = [load_email(is_spam=False, filename=name) for name in ham_filenames]
spam_emails = [load_email(is_spam=True, filename=name) for name in spam_filenames]

In [49]:
print(ham_emails[1].get_content().strip())

Martin A posted:
Tassos Papadopoulos, the Greek sculptor behind the plan, judged that the
 limestone of Mount Kerdylio, 70 miles east of Salonika and not far from the
 Mount Athos monastic community, was ideal for the patriotic sculpture. 
 
 As well as Alexander's granite features, 240 ft high and 170 ft wide, a
 museum, a restored amphitheatre and car park for admiring crowds are
planned
---------------------
So is this mountain limestone or granite?
If it's limestone, it'll weather pretty fast.

------------------------ Yahoo! Groups Sponsor ---------------------~-->
4 DVDs Free +s&p Join Now
http://us.click.yahoo.com/pt6YBB/NXiEAA/mG3HAA/7gSolB/TM
---------------------------------------------------------------------~->

To unsubscribe from this group, send an email to:
forteana-unsubscribe@egroups.com

 

Your use of Yahoo! Groups is subject to http://docs.yahoo.com/info/terms/


In [52]:
print(spam_emails[6].get_content().strip())

Help wanted.  We are a 14 year old fortune 500 company, that is
growing at a tremendous rate.  We are looking for individuals who
want to work from home.

This is an opportunity to make an excellent income.  No experience
is required.  We will train you.

So if you are looking to be employed from home with a career that has
vast opportunities, then go:

http://www.basetel.com/wealthnow

We are looking for energetic and self motivated people.  If that is you
than click on the link and fill out the form, and one of our
employement specialist will contact you.

To be removed from our link simple go to:

http://www.basetel.com/remove.html


4139vOLW7-758DoDY1425FRhM1-764SMFc8513fCsLl40


In [54]:
def get_email_structure(email):
    if isinstance(email, str):
        return email
    payload = email.get_payload()
    if isinstance(payload, list):
        return "multipart({})".format(", ".join([
            get_email_structure(sub_email)
            for sub_email in payload
        ]))
    else:
        return email.get_content_type()

If email is already a string, it means the function was called incorrectly.
    
In that case, it simply returns the string as-is.

email.get_payload() extracts the content (body) of the email.
    
The payload can be:

A string → A simple email with plain text or HTML content.
    
A list → A multipart email (e.g., emails with attachments, HTML+text versions).

If the payload is a list, it means the email is multipart (contains multiple parts like text, HTML, attachments).

The function recursively checks the structure of each sub-part.

It processes each part of the multipart email using recursion.

The get_email_structure(sub_email) function is called on each subpart.

The results are joined into a single string, describing the structure of the email.

If the email is not multipart, it returns the content type (e.g., "text/plain", "text/html", "application/pdf", etc.).

In [57]:
from collections import Counter

def structures_counter(emails):
    structures = Counter()
    for email in emails:
        structure = get_email_structure(email)
        structures[structure] += 1
    return structures

Counter from the collections module is used to count occurrences of different email structures.

A Counter dictionary is created to store the frequency of each email structure.
The keys in this Counter will be email structures (e.g., "text/plain", "multipart(text/plain, text/html)"), and the values will be counts.

This loops through each email object in the emails list.

The Counter dictionary is updated to increase the count for this particular email structure.

After processing all emails, the function returns the Counter dictionary containing the frequency of each structure.

In [60]:
structures_counter(ham_emails).most_common()

[('text/plain', 2408),
 ('multipart(text/plain, application/pgp-signature)', 66),
 ('multipart(text/plain, text/html)', 8),
 ('multipart(text/plain, text/plain)', 4),
 ('multipart(text/plain)', 3),
 ('multipart(text/plain, application/octet-stream)', 2),
 ('multipart(text/plain, text/enriched)', 1),
 ('multipart(text/plain, application/ms-tnef, text/plain)', 1),
 ('multipart(multipart(text/plain, text/plain, text/plain), application/pgp-signature)',
  1),
 ('multipart(text/plain, video/mng)', 1),
 ('multipart(text/plain, multipart(text/plain))', 1),
 ('multipart(text/plain, application/x-pkcs7-signature)', 1),
 ('multipart(text/plain, multipart(text/plain, text/plain), text/rfc822-headers)',
  1),
 ('multipart(text/plain, multipart(text/plain, text/plain), multipart(multipart(text/plain, application/x-pkcs7-signature)))',
  1),
 ('multipart(text/plain, application/x-java-applet)', 1)]

This method sorts the Counter dictionary from most to least frequent and returns a list of tuples.

In [63]:
structures_counter(spam_emails).most_common()

[('text/plain', 218),
 ('text/html', 183),
 ('multipart(text/plain, text/html)', 45),
 ('multipart(text/html)', 20),
 ('multipart(text/plain)', 19),
 ('multipart(multipart(text/html))', 5),
 ('multipart(text/plain, image/jpeg)', 3),
 ('multipart(text/html, application/octet-stream)', 2),
 ('multipart(text/plain, application/octet-stream)', 1),
 ('multipart(text/html, text/plain)', 1),
 ('multipart(multipart(text/html), application/octet-stream, image/jpeg)', 1),
 ('multipart(multipart(text/plain, text/html), image/gif)', 1),
 ('multipart/alternative', 1)]

In [65]:
for header, value in spam_emails[0].items():
    print(header,":",value)

Return-Path : <12a1mailbot1@web.de>
Delivered-To : zzzz@localhost.spamassassin.taint.org
Received : from localhost (localhost [127.0.0.1])	by phobos.labs.spamassassin.taint.org (Postfix) with ESMTP id 136B943C32	for <zzzz@localhost>; Thu, 22 Aug 2002 08:17:21 -0400 (EDT)
Received : from mail.webnote.net [193.120.211.219]	by localhost with POP3 (fetchmail-5.9.0)	for zzzz@localhost (single-drop); Thu, 22 Aug 2002 13:17:21 +0100 (IST)
Received : from dd_it7 ([210.97.77.167])	by webnote.net (8.9.3/8.9.3) with ESMTP id NAA04623	for <zzzz@spamassassin.taint.org>; Thu, 22 Aug 2002 13:09:41 +0100
From : 12a1mailbot1@web.de
Received : from r-smtp.korea.com - 203.122.2.197 by dd_it7  with Microsoft SMTPSVC(5.5.1775.675.6);	 Sat, 24 Aug 2002 09:42:10 +0900
To : dcek1a1@netsgo.com
Subject : Life Insurance - Why Pay More?
Date : Wed, 21 Aug 2002 20:31:57 -1600
MIME-Version : 1.0
Message-ID : <0103c1042001882DD_IT7@dd_it7>
Content-Type : text/html; charset="iso-8859-1"
Content-Transfer-Encoding : qu

spam_emails[0]

Accesses the first email in the list spam_emails.
spam_emails contains email objects parsed using the email library.
.items()

spam_emails[0].items() returns all headers as (header, value) pairs.
Loop through headers

Iterates over the headers and prints them in "Header: Value" format.

In [68]:
spam_emails[0]["Subject"]

'Life Insurance - Why Pay More?'

Okay, before we learn too much about the data, let's not forget to split it into a training set and a test set:

In [71]:
import numpy as np
from sklearn.model_selection import train_test_split

X = np.array(ham_emails + spam_emails, dtype=object)
y = np.array([0] * len(ham_emails) + [1] * len(spam_emails))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Okay, let's start writing the preprocessing functions. First, we will need a function to convert HTML to plain text. Arguably the best way to do this would be to use the great BeautifulSoup library, but I would like to avoid adding another dependency to this project, so let's hack a quick & dirty solution using regular expressions. The following function first drops the head section, then converts all a tags to the word Hyperlink, then it gets rid of all HTML tags, leaving only the plain text. For readability, it also replaces multiple newlines with single newlines, and finally it unescapes html entities (such as &gt; or &nbsp;):

In [82]:
import re
from html import unescape

def html_to_plain_text(html):
    # Removing the <head>...</head> section (metadata, scripts, etc.).
    text = re.sub('<head.*?>.*?</head>', '', html, flags=re.M | re.S | re.I)
    # Replacing <a> (anchor) tags with "HYPERLINK" to keep track of links.
    text = re.sub('<a\\s.*?>', ' HYPERLINK ', text, flags=re.M | re.S | re.I)
    # Removing all other HTML tags to extract pure text.
    text = re.sub('<.*?>', '', text, flags=re.M | re.S)
    # Cleaning up excessive newlines (\n) for better formatting.
    text = re.sub(r'(\\s*\\n)+', '\n', text, flags=re.M | re.S)
    # Decoding HTML entities (like &lt; → <, &amp; → &).
    return unescape(text)
    

1-------------------------------------------------

Pattern	Meaning

<head	Matches the opening <head tag.
                                
.*?	Matches anything (non-greedy mode).

>	Matches the closing > of the <head> tag.
>
.*?	Matches everything inside the <head> section.

</head>	Matches the </head> closing tag.

Flags Used:	

re.M (Multiline)	Allows matching across multiple lines.

re.S (DotAll)	Allows . to match newlines (\n).

re.I (IgnoreCase)	Makes it case-insensitive (e.g., <HEAD> still matches

 Effect: Removes everything inside <head>...</head>.

2-----------------------------------------------


Pattern	Meaning

<a	Matches the opening <a tag.

\s	Matches any whitespace (space, tab, newline).

.*?	Matches everything inside the tag (non-greedy).

>	Matches the closing > of the anchor tag.

Effect: Replaces <a href="..."> and similar tags with "HYPERLINK".

3----------------------------------------------

Pattern	Meaning
<	Matches the opening <.

.*?	Matches everything inside (non-greedy).

>	Matches the closing >.

4-----------------------------------------------
Pattern	Meaning

\s*	Matches optional spaces before a newline.

\n	Matches a newline.

(\s*\n)+	Matches multiple blank lines together.

 Effect: Removes all HTML tags, leaving only plain text.

5-----------------------------------------------

 Effect: Converts special HTML entities like:

&lt; → <

&amp; → &

&quot; → "



Let's see if it works. This is HTML spam:

In [94]:
html_spam_emails = [email for email in X_train[y_train==1]
                    if get_email_structure(email) == "text/html"]
sample_html_spam = html_spam_emails[7]
print(sample_html_spam.get_content().strip()[:1000], "...")


<HTML><HEAD><TITLE></TITLE><META http-equiv="Content-Type" content="text/html; charset=windows-1252"><STYLE>A:link {TEX-DECORATION: none}A:active {TEXT-DECORATION: none}A:visited {TEXT-DECORATION: none}A:hover {COLOR: #0033ff; TEXT-DECORATION: underline}</STYLE><META content="MSHTML 6.00.2713.1100" name="GENERATOR"></HEAD>
<BODY text="#000000" vLink="#0033ff" link="#0033ff" bgColor="#CCCC99"><TABLE borderColor="#660000" cellSpacing="0" cellPadding="0" border="0" width="100%"><TR><TD bgColor="#CCCC99" valign="top" colspan="2" height="27">
<font size="6" face="Arial, Helvetica, sans-serif" color="#660000">
<b>OTC</b></font></TD></TR><TR><TD height="2" bgcolor="#6a694f">
<font size="5" face="Times New Roman, Times, serif" color="#FFFFFF">
<b>&nbsp;Newsletter</b></font></TD><TD height="2" bgcolor="#6a694f"><div align="right"><font color="#FFFFFF">
<b>Discover Tomorrow's Winners&nbsp;</b></font></div></TD></TR><TR><TD height="25" colspan="2" bgcolor="#CCCC99"><table width="100%" border="0" 


And this is the resulting plain text:

In [97]:
print(html_to_plain_text(sample_html_spam.get_content())[:1000], "...")




OTC

 Newsletter
Discover Tomorrow's Winners 

For Immediate Release

Cal-Bay (Stock Symbol: CBYI)
Watch for analyst "Strong Buy Recommendations" and several advisory newsletters picking CBYI.  CBYI has filed to be traded on the OTCBB, share prices historically INCREASE when companies get listed on this larger trading exchange. CBYI is trading around 25 cents and should skyrocket to $2.66 - $3.25 a share in the near future.
Put CBYI on your watch list, acquire a position TODAY.

REASONS TO INVEST IN CBYI

A profitable company and is on track to beat ALL earnings estimates!

One of the FASTEST growing distributors in environmental & safety equipment instruments.

Excellent management team, several EXCLUSIVE contracts.  IMPRESSIVE client list including the U.S. Air Force, Anheuser-Busch, Chevron Refining and Mitsubishi Heavy Industries, GE-Energy & Environmental Research.

RAPIDLY GROWING INDUSTRY
Industry revenues exceed $900 million, estimates indicate that there could be as much as

In [99]:

def email_to_text(email):
    # A variable html is initialized to store HTML content
    html = None
    # email.walk() is used to iterate over all parts of the email.
    for part in email.walk():
        # This retrieves the MIME type of the part.
        ctype = part.get_content_type()
        # Skips attachments or other non-text parts (like images, PDFs, etc.).
        # Continues only if the part is "text/plain" or "text/html".
        if not ctype in ("text/plain", "text/html"):
            continue
        # part.get_content() extracts the email body.
        # If there's an encoding issue, get_payload() is used as a fallback.
        try:
            content = part.get_content()
        except: # in case of encoding issues
            content = str(part.get_payload())
        # If the email contains plain text, return it immediately
        if ctype == "text/plain":
            return content
        # If the email contains HTML, store it in html for later conversion.
        else:
            html = content
    # If only an HTML version exists, convert it to plain text
    if html:
        return html_to_plain_text(html)

In [102]:
print(email_to_text(sample_html_spam)[:100], "...")




OTC

 Newsletter
Discover Tomorrow's Winners 

For Immediate Release

Cal-Bay (Stock Symbol: CBYI ...


$ pip3 install urlextract

In [105]:
%pip install urlextract

Note: you may need to restart the kernel to use updated packages.


In [107]:
try:
    import urlextract # may require an Internet connection to download root domain names
    
    url_extractor = urlextract.URLExtract()
    print(url_extractor.find_urls("Will it detect github.com and https://youtu.be/7Pq-S557XQU?t=3m32s"))
except ImportError:
    print("Error: replacing URLs requires the urlextract module.")
    url_extractor = None

['github.com', 'https://youtu.be/7Pq-S557XQU?t=3m32s']


In [109]:
# from sklearn.base import BaseEstimator, TransformerMixin

# class EmailToWordCounterTransformer(BaseEstimator, TransformerMixin):
#     def __init__(self, strip_headers=True, lower_case=True, remove_punctuation=True,
#                  replace_urls=True, replace_numbers=True, stemming=True):
#         self.strip_headers = strip_headers
#         self.lower_case = lower_case
#         self.remove_punctuation = remove_punctuation
#         self.replace_urls = replace_urls
#         self.replace_numbers = replace_numbers
#         self.stemming = stemming
#     def fit(self, X, y=None):
#         return self
#     def transform(self, X, y=None):
#         X_transformed = []
#         for email in X:
#             text = email_to_text(email) or ""
#             if self.lower_case:
#                 text = text.lower()
#             if self.replace_urls and url_extractor is not None:
#                 urls = list(set(url_extractor.find_urls(text)))
#                 urls.sort(key=lambda url: len(url), reverse=True)
#                 for url in urls:
#                     text = text.replace(url, " URL ")
#             if self.replace_numbers:
#                 text = re.sub(r'\d+(?:\.\d*)?(?:[eE][+-]?\d+)?', 'NUMBER', text)
#             if self.remove_punctuation:
#                 text = re.sub(r'\W+', ' ', text, flags=re.M)
#             word_counts = Counter(text.split())
#             if self.stemming and stemmer is not None:
#                 stemmed_word_counts = Counter()
#                 for word, count in word_counts.items():
#                     stemmed_word = stemmer.stem(word)
#                     stemmed_word_counts[stemmed_word] += count
#                 word_counts = stemmed_word_counts
#             X_transformed.append(word_counts)
#         return np.array(X_transformed)

In [131]:
from sklearn.base import BaseEstimator, TransformerMixin
# BaseEstimator, TransformerMixin: Allows this transformer to be used in sklearn pipelines.
import re
# re: Used for regex-based text cleaning.
import numpy as np
from collections import Counter
# Used to count word occurrences.
from nltk.stem import PorterStemmer
# Used for stemming words.

class EmailToWordCounterTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, strip_headers=True, lower_case=True, remove_punctuation=True,
                 replace_urls=True, replace_numbers=True, stemming=True):
        self.strip_headers = strip_headers
        self.lower_case = lower_case
        self.remove_punctuation = remove_punctuation
        self.replace_urls = replace_urls
        self.replace_numbers = replace_numbers
        self.stemming = stemming
        self.stemmer = PorterStemmer()  # Initialize Porter Stemmer

    def fit(self, X, y=None):
        return self
        # Since transformers in scikit-learn require fit(), this function simply returns self without modifying data.
    
    def transform(self, X, y=None):
        #  X is a list of emails.
        # We iterate over each email and apply text preprocessing.
        X_transformed = []
        for email in X:
            text = email_to_text(email) or ""
            # Converts the email into plain text (either "text/plain" or "text/html").
            #  Uses email_to_text(email), which extracts text from email objects.
            if self.lower_case:
                text = text.lower()
             # Ensures case insensitivity by converting all text to lowercase.
            if self.replace_urls:
                text = re.sub(r'https?://\S+', ' URL ', text)
             # Uses regex (re.sub()) to replace all links (http:// or https://) with "URL".
            if self.replace_numbers:
                text = re.sub(r'\d+', ' NUMBER ', text)
             # Uses regex (\d+) to replace all numbers with "NUMBER".
            if self.remove_punctuation:
                text = re.sub(r'\W+', ' ', text)
             # Removes all non-word characters (\W+), leaving only letters and spaces.
            
            word_counts = Counter(text.split())
             # Splits text into words and counts occurrences using Counter.

            if self.stemming:
                stemmed_word_counts = Counter()
                for word, count in word_counts.items():
                    stemmed_word = self.stemmer.stem(word)  # Use self.stemmer
                    stemmed_word_counts[stemmed_word] += count
                word_counts = stemmed_word_counts

            # If stemming=True, applies Porter Stemming (e.g., running → run).
            #  Stores stemmed words and their counts in stemmed_word_counts.

            X_transformed.append(word_counts)
             # Stores processed text as a list of word count dictionaries.
        
        return np.array(X_transformed)




Regex	Meaning

\d+	Matches any number (e.g., 1234)

https?://\S+	Matches URLs (e.g., https://example.com)

\W+	Matches all non-word characters (punctuation, symbols)

\s+	Matches extra spaces or newlines
    
<.*?>	Matches HTML tags (e.g., <b>bold</b>)

Let's try this transformer on a few emails:

In [138]:
X_few = X_train[:3]
X_few_wordcounts = EmailToWordCounterTransformer().fit_transform(X_few)
X_few_wordcounts

array([Counter({'chuck': 1, 'murcko': 1, 'wrote': 1, 'stuff': 1, 'yawn': 1, 'r': 1}),
       Counter({'the': 11, 'of': 9, 'and': 8, 'all': 3, 'christian': 3, 'to': 3, 'by': 3, 'jefferson': 2, 'i': 2, 'have': 2, 'superstit': 2, 'one': 2, 'on': 2, 'been': 2, 'ha': 2, 'half': 2, 'rogueri': 2, 'teach': 2, 'jesu': 2, 'some': 1, 'interest': 1, 'quot': 1, 'url': 1, 'thoma': 1, 'examin': 1, 'known': 1, 'word': 1, 'do': 1, 'not': 1, 'find': 1, 'in': 1, 'our': 1, 'particular': 1, 'redeem': 1, 'featur': 1, 'they': 1, 'are': 1, 'alik': 1, 'found': 1, 'fabl': 1, 'mytholog': 1, 'million': 1, 'innoc': 1, 'men': 1, 'women': 1, 'children': 1, 'sinc': 1, 'introduct': 1, 'burnt': 1, 'tortur': 1, 'fine': 1, 'imprison': 1, 'what': 1, 'effect': 1, 'thi': 1, 'coercion': 1, 'make': 1, 'world': 1, 'fool': 1, 'other': 1, 'hypocrit': 1, 'support': 1, 'error': 1, 'over': 1, 'earth': 1, 'six': 1, 'histor': 1, 'american': 1, 'john': 1, 'e': 1, 'remsburg': 1, 'letter': 1, 'william': 1, 'short': 1, 'again': 1, 'becom

This looks about right!

Now we have the word counts, and we need to convert them to vectors. For this, we will build another transformer whose fit() method will build the vocabulary (an ordered list of the most common words) and whose transform() method will use the vocabulary to convert word counts to vectors. The output is a sparse matrix.

In [141]:
from scipy.sparse import csr_matrix
# Efficient sparse matrix representation (saves memory by storing only nonzero values).
class WordCounterToVectorTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, vocabulary_size=1000):
        self.vocabulary_size = vocabulary_size
    # Sets vocabulary_size (default 1000), which limits the number of words used.
    def fit(self, X, y=None):
        total_count = Counter()
        for word_count in X:
            for word, count in word_count.items():
                total_count[word] += min(count, 10)
        #  Aggregates word frequencies from all emails in X.
        # Limits word count contribution per email to 10 (avoiding bias from frequent words).
        most_common = total_count.most_common()[:self.vocabulary_size]
        self.vocabulary_ = {word: index + 1 for index, (word, count) in enumerate(most_common)}
        # Selects the top vocabulary_size words and assigns them an index (starting from 1).
        return self
    def transform(self, X, y=None):
        rows = []
        cols = []
        data = []
        # Initializes lists to store nonzero values for the sparse matrix.
        for row, word_count in enumerate(X):
            # Loops through each email
            for word, count in word_count.items():
                 # For each word:
                rows.append(row)
                 # Stores email index
                cols.append(self.vocabulary_.get(word, 0))
                # Gets word index (0 for out-of-vocab words).
                data.append(count)
                # Stores word frequency.
        return csr_matrix((data, (rows, cols)), shape=(len(X), self.vocabulary_size + 1))
        # Converts the lists into a compressed sparse row (CSR) matrix.
        #  The shape is (number_of_emails, vocabulary_size + 1).
        # Saves memory by storing only nonzero values.

In [121]:

vocab_transformer = WordCounterToVectorTransformer(vocabulary_size=10)
# This initializes the transformer with vocabulary_size=10, meaning it will keep only the top 10 most common words.
X_few_vectors = vocab_transformer.fit_transform(X_few_wordcounts)
X_few_vectors

<3x11 sparse matrix of type '<class 'numpy.int32'>'
	with 20 stored elements in Compressed Sparse Row format>

What does this matrix mean? Well, the 99 in the second row, first column, means that the second email contains 99 words that are not part of the vocabulary. The 11 next to it means that the first word in the vocabulary is present 11 times in this email. The 9 next to it means that the second word is present 9 times, and so on. You can look at the vocabulary to know which words we are talking about. The first word is "the", the second word is "of", etc.



In [124]:
vocab_transformer.vocabulary_

{'the': 1,
 'of': 2,
 'and': 3,
 'to': 4,
 'url': 5,
 'all': 6,
 'in': 7,
 'christian': 8,
 'on': 9,
 'by': 10}

We are now ready to train our first spam classifier! Let's transform the whole dataset:

In [127]:
from sklearn.pipeline import Pipeline

preprocess_pipeline = Pipeline([
    ("email_to_wordcount", EmailToWordCounterTransformer()),
    ("wordcount_to_vector", WordCounterToVectorTransformer()),
])
# EmailToWordCounterTransformer() → Converts emails into word frequency dictionaries.
# WordCounterToVectorTransformer() → Converts word frequency dictionaries into sparse vectors
X_train_transformed = preprocess_pipeline.fit_transform(X_train)

In [143]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

log_clf = LogisticRegression(solver="lbfgs", max_iter=1000, random_state=42)
score = cross_val_score(log_clf, X_train_transformed, y_train, cv=3, verbose=3)
score.mean()

# solver="lbfgs" → Uses the LBFGS optimization algorithm, suitable for small to medium datasets.
# max_iter=1000 → Increases the number of iterations (default is 100), ensuring the model converges.
# random_state=42 → Ensures reproducibility.

[CV] END ................................ score: (test=0.981) total time=   0.0s
[CV] END ................................ score: (test=0.985) total time=   0.0s
[CV] END ................................ score: (test=0.990) total time=   0.0s


0.9854166666666666

Over 98.5%, not bad for a first try! :) However, remember that we are using the "easy" dataset. You can try with the harder datasets, the results won't be so amazing. You would have to try multiple models, select the best ones and fine-tune them using cross-validation, and so on.

But you get the picture, so let's stop now, and just print out the precision/recall we get on the test set:

In [130]:

from sklearn.metrics import precision_score, recall_score

X_test_transformed = preprocess_pipeline.transform(X_test)

log_clf = LogisticRegression(solver="lbfgs", max_iter=1000, random_state=42)
log_clf.fit(X_train_transformed, y_train)

y_pred = log_clf.predict(X_test_transformed)

print("Precision: {:.2f}%".format(100 * precision_score(y_test, y_pred)))
print("Recall: {:.2f}%".format(100 * recall_score(y_test, y_pred)))

Precision: 93.94%
Recall: 97.89%


precision_score → Measures the proportion of correctly predicted spam emails.

recall_score → Measures the proportion of actual spam emails that were correctly identified.