In [15]:
import re
import html
import pandas as pd
import nltk
import networkx as nx
from collections import defaultdict
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from bs4 import BeautifulSoup

lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))

class GraphBasedSpamFilter:
    def __init__(self, similarity_threshold=0.2,
                 min_df=5,
                 max_df_ratio=0.8,
                 top_k=10):
        self.graph = nx.Graph()
        self.messages = []               # list of (full_text, label)
        self.message_tokens = []         # list of precomputed token sets
        self.token_index = defaultdict(set)
        self.similarity_threshold = similarity_threshold
        self.min_df = min_df
        self.max_df_ratio = max_df_ratio
        self.top_k = top_k

    def preprocess(self, text):
        try:
            text = "" if text is None else str(text)
            text = html.unescape(text.lower())
            if "<" in text and ">" in text:
                text = BeautifulSoup(text, "html.parser").get_text(" ")
            text = re.sub(r"<[^>]+>", " ", text)
            text = re.sub(r"\S+@\S+", " ", text)
            text = re.sub(r"http\S+|www\.\S+", " ", text)
            text = re.sub(r"[^a-z0-9\s]", " ", text)
            text = re.sub(r"\s+", " ", text).strip()
            tokens = nltk.word_tokenize(text)
            clean = [
                lemmatizer.lemmatize(w)
                for w in tokens
                if w not in stop_words and len(w) > 1
            ]
            return set(clean)
        except Exception as e:
            print("Preprocess error:", e)
            return set()

    def _meta_tokens(self, meta: dict):
        toks = set()
        # domain token
        dom = meta.get("from_domain")
        if dom:
            toks.add(f"dom_{dom}")
        # attachment
        if meta.get("has_attachment", False):
            toks.add("has_attachment")
        # subject length bucket
        sl = meta.get("subject_length", 0)
        toks.add(f"subj_len_{(sl//50)*50}")
        # to_count bucket
        tc = meta.get("to_count", 0)
        toks.add(f"to_cnt_{min(tc,10)}")
        # url_count bucket
        uc = meta.get("url_count", 0)
        toks.add(f"url_cnt_{min(uc,5)}")
        # hour/weekday
        hr = meta.get("hour")
        if hr is not None and hr >= 0:
            toks.add(f"hour_{hr}")
            
        wd = meta.get("weekday")
        if wd is not None and wd >= 0:
            toks.add(f"wkday_{wd}")
        return toks

    def cosine_similarity(self, set1, set2):
        inter = len(set1 & set2)
        return inter / ((len(set1)*len(set2))**0.5 + 1e-9)

    def add_message(self, subject, body, label, metadata=None):
        # text tokens
        subj_toks = self.preprocess(subject)
        body_toks = self.preprocess(body)
        tokens = subj_toks | body_toks
        # metadata tokens
        if metadata:
            tokens |= self._meta_tokens(metadata)

        idx = len(self.messages)
        self.messages.append((f"{subject} {body}", label))
        self.message_tokens.append(tokens)
        self.graph.add_node(idx, subject=subject, body=body, label=label)

        # connect to prior messages
        candidates = {i for t in tokens for i in self.token_index[t]}
        for i in candidates:
            sim = self.cosine_similarity(tokens, self.message_tokens[i])
            if sim > self.similarity_threshold:
                self.graph.add_edge(idx, i, weight=sim)
        for t in tokens:
            self.token_index[t].add(idx)

    def predict_spam_subject_body(self, subject, body, metadata=None):
        subj_toks = self.preprocess(subject)
        body_toks = self.preprocess(body)
        tokens = subj_toks | body_toks
        if metadata:
            tokens |= self._meta_tokens(metadata)

        scores = []
        candidates = {i for t in tokens for i in self.token_index.get(t, [])}
        for i in candidates:
            sim = self.cosine_similarity(tokens, self.message_tokens[i])
            if sim > self.similarity_threshold:
                scores.append(self.messages[i][1])
        return 1.0 if sum(scores)/max(1, len(scores)) > 0.5 else 0.0

    def train_with_dataframe(self, df):
        # Expect df columns: from, to, date, subject, body, label
        for _, row in df.iterrows():
            meta = {
                "from_domain": row["from"].split("@")[-1] if pd.notnull(row["from"]) and "@" in row["from"] else None,
                "to_count": len(str(row["to"]).split(",")) if pd.notnull(row["to"]) else 0,
                "subject_length": len(str(row["subject"] or "")),
                "has_attachment": False,
                "hour": pd.to_datetime(row["date"], errors="coerce").hour if pd.notnull(row["date"]) else None,
                "weekday": pd.to_datetime(row["date"], errors="coerce").weekday() if pd.notnull(row["date"]) else None,
                "url_count": len(re.findall(r"http[s]?://\S+", str(row["body"] or "")))
            }
            self.add_message(row["subject"], row["body"], row["label"], metadata=meta)

        N = len(self.messages)
        # filter tokens by df
        df_counts = {t: len(idxs) for t, idxs in self.token_index.items()}
        valid = {t for t, c in df_counts.items()
                 if c >= self.min_df and c <= self.max_df_ratio * N}

        # rebuild index & tokens
        self.token_index = defaultdict(set)
        for i, toks in enumerate(self.message_tokens):
            filt = toks & valid
            self.message_tokens[i] = filt
            for t in filt:
                self.token_index[t].add(i)

        # prune edges to top_k
        for u in list(self.graph.nodes()):
            nbrs = [(v, self.graph[u][v]['weight']) for v in self.graph[u]]
            nbrs.sort(key=lambda x: -x[1])
            keep = {v for v, _ in nbrs[:self.top_k]}
            for v in list(self.graph[u]):
                if v not in keep:
                    self.graph.remove_edge(u, v)

In [16]:
import pandas as pd
df = pd.read_csv('emails_dataset.csv')

In [17]:
df.sample(5)

Unnamed: 0,subject,from,to,date,body,label
7956,Re: RH 8 no DMA for DVD drive,Chris Kloiber <ckloiber@ckloiber.com>,rpm-zzzlist@freshrpms.net,"Tue, 08 Oct 2002 23:23:19 -0400","On Tue, 2002-10-08 at 04:48, Panu Matilainen w...",0
16752,Bosnia goes it alone for first full poll,guardian <rssfeeds@example.com>,yyyy@example.com,"Sat, 05 Oct 2002 08:00:50 -0000","URL: http://www.newsisfree.com/click/-6,857278...",0
12172,try viagra for free,,,,always wanted to try the drug the world has be...,1
1681,[ILUG] Join the Web's Fastest Growing Singles ...,"""RankMyPix.com"" <marjani@email2.qves.net>",ilug@linux.ie,"Fri, 23 Aug 2002 04:10:14 -0600",1) Fight The Risk of Cancer!\nhttp://www.adcli...,1
7489,[Spambayes] Current histograms,tim.one@comcast.net,,"Mon, 09 Sep 2002 23:18:25 -0400",We've not only reduced the f-p and f-n rates i...,0


In [18]:
df = df.dropna(subset=['body','label','subject'])

In [19]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(df,random_state=42)

In [22]:
from sklearn.metrics import confusion_matrix,accuracy_score,f1_score,recall_score,precision_score

graph_filter = GraphBasedSpamFilter(
    similarity_threshold=0.2,
    min_df=5,
    max_df_ratio=0.7,
    top_k=10
)
graph_filter.train_with_dataframe(train)

# 4) Helper to extract metadata from a row
def build_meta(row):
    return {
        "from_domain": row['from'].split('@')[-1] if pd.notnull(row['from']) and '@' in row['from'] else None,
        "to_count": len(str(row['to']).split(',')) if pd.notnull(row['to']) else 0,
        "subject_length": len(str(row['subject'] or "")),
        "has_attachment": False,
        "hour": pd.to_datetime(row['date'], errors='coerce').hour if pd.notnull(row['date']) else None,
        "weekday": pd.to_datetime(row['date'], errors='coerce').weekday() if pd.notnull(row['date']) else None,
        "url_count": len(re.findall(r'http[s]?://\S+', str(row['body'] or "")))
    }

# 5) Predict on test set
y_true = test['label'].tolist()
y_pred = []
for _, row in test.iterrows():
    meta = build_meta(row)
    y_pred.append(
        graph_filter.predict_spam_subject_body(
            row['subject'],
            row['body'],
            metadata=meta
        )
    )

# 6) Compute and print metrics
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
acc  = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, zero_division=0)
rec  = recall_score(y_true, y_pred, zero_division=0)
f1   = f1_score(y_true, y_pred, zero_division=0)
spec = tn / (tn + fp) if (tn + fp)>0 else 0

print(f"Acc={acc:.3f}, Prec={prec:.3f}, Rec={rec:.3f}, F1={f1:.3f}, Spec={spec:.3f}")

Acc=0.950, Prec=0.928, Rec=0.914, F1=0.921, Spec=0.967
