# 1. Spam and Ham dataset

In [1]:
import os
import shutil
import codecs

In [2]:
train_spam_path = "./data/train/spam"
train_ham_path = "./data/train/ham"
test_spam_path = "./data/test/spam"
test_ham_path = "./data/test/ham"

if not os.path.exists(train_spam_path):
    os.makedirs(train_spam_path)
if not os.path.exists(train_ham_path):
    os.makedirs(train_ham_path)
if not os.path.exists(test_spam_path):
    os.makedirs(test_spam_path)
if not os.path.exists(test_ham_path):
    os.makedirs(test_ham_path)

# 2. File Encoding Format Conversion 

In [3]:
def file_encoding_format_conversion_save(file_in, file_out, encoding_in='gb2312', encoding_out='utf-8'):
    with codecs.open(filename=file_in, mode='r', encoding=encoding_in, errors='replace') as fi:
        data = fi.read()
        with open(file_out, mode='w', encoding=encoding_out) as fo:
            fo.write(data)

In [5]:
train_spam_path_list = []
train_ham_path_list = []
test_spam_path_list = []
test_ham_path_list = []


label = "./data/trec06c/full/index"
index = 0
            
for line in open(label):
    info = line.split()
    
    if index <= 50000:
        src_file_path = './data/trec06c/data/' + info[1][-8:]
        dst_file_path = "./data/train/" + info[0] + "/" + str(index)
        # shutil.copyfile(src_file_path, dst_file_path)
        
    
        file_encoding_format_conversion_save(file_in=src_file_path, 
                                             file_out=dst_file_path, 
                                             encoding_in='gb2312', 
                                             encoding_out='utf-8')
        
        if info[0] == "spam":
            train_spam_path_list.append(dst_file_path)
        else:
            train_ham_path_list.append(dst_file_path)
            
    else:
        src_file_path = './data/trec06c/data/' + info[1][-8:]
        dst_file_path = "./data/test/" + info[0] + "/" + str(index)
        # shutil.copyfile(src_file_path, dst_file_path)
        
        file_encoding_format_conversion_save(file_in=src_file_path, 
                                             file_out=dst_file_path, 
                                             encoding_in='gb2312', 
                                             encoding_out='utf-8')
        
        if info[0] == "spam":
            test_spam_path_list.append(dst_file_path)
        else:
            test_ham_path_list.append(dst_file_path)
            
    index += 1
    if index % 5000 == 0:
        print("No. {} completed...".format(index))

No. 5000 completed...
No. 10000 completed...
No. 15000 completed...
No. 20000 completed...
No. 25000 completed...
No. 30000 completed...
No. 35000 completed...
No. 40000 completed...
No. 45000 completed...
No. 50000 completed...
No. 55000 completed...
No. 60000 completed...


In [6]:
train_spam_path_list[0]

'./data/train/spam/0'

# 3. Stop List

In [29]:
def get_stop_words(stop_list_path):
    stop_words_list = []
    for line in open(stop_list_path):
        stop_words_list.append(line[:len(line)-1])
    return stop_words_list

stop_words_list = get_stop_words(stop_list_path="./data/stop_words.txt")

In [380]:
stop_words_list[:5]

['--', '?', '“', '”', '》']

# 3. jieba Library

In [232]:
def get_the_word_set_of_single_email(email_path):
    word_list = list()
    for line in open(email_path):
        rule = re.compile(r"[^\u4e00-\u9fa5]")
        line = rule.sub("", line)
        word_list += list(jieba.cut(line))
    
    word_set = set()
    for item in word_list:
        if item not in stop_words_list and item.strip() != '' and item != None:
            word_set.add(item)
            
    return word_set

def get_the_dict_by_emails(email_paths_list):
    index = 0
    
    all_emails_word_dict = dict()
    all_emails_word_set = set()
    
    for path in tqdm(email_paths_list):
        current_email_set = get_the_word_set_of_single_email(path)
        all_emails_word_dict[index] = current_email_set
        all_emails_word_set = all_emails_word_set | current_email_set
        index += 1

    return (all_emails_word_dict, all_emails_word_set)

In [236]:
train_spam_word_dict, train_spam_word_set = get_the_dict_by_emails(train_spam_path_list)

100% 33262/33262 [06:09<00:00, 89.90it/s] 


In [237]:
train_ham_word_dict, train_ham_word_set = get_the_dict_by_emails(train_ham_path_list)

100% 16739/16739 [03:08<00:00, 88.58it/s] 


In [241]:
def get_word_frequency_dict(all_word_dict, all_word_set):
    word_frequency_dict = dict()
    for item in tqdm(all_word_set):
        index = 0
        for email_index, single_word_set in all_word_dict.items():
            if item in single_word_set:
                index += 1
        word_frequency_dict[item] = index
    return word_frequency_dict

In [249]:
train_spam_word_frequency_dict = get_word_frequency_dict(train_spam_word_dict, train_spam_word_set)

100% 78721/78721 [07:52<00:00, 166.55it/s]


In [244]:
train_ham_word_frequency_dict = get_word_frequency_dict(train_ham_word_dict, train_ham_word_set)

100% 99681/99681 [04:04<00:00, 407.95it/s]


In [253]:
train_spam_number = len(train_spam_path_list)
train_ham_number = len(train_ham_path_list)
train_spam_number, train_ham_number

(33262, 16739)

In [252]:
p_prior_spam = train_spam_number / (train_spam_number + train_ham_number)
p_prior_ham = train_ham_number / (train_spam_number + train_ham_number)
p_prior_spam, p_prior_ham

(0.6652266954660907, 0.3347733045339093)

In [323]:
def get_test_word_prob(test_email_word_set,
                       spam_word_frequency_dict=train_spam_word_frequency_dict,
                       ham_word_frequency_dict=train_ham_word_frequency_dict,
                       spam_number=train_spam_number,
                       ham_number=train_ham_number
                      ):
    
    word_prob_dict = dict()
    
    for word in test_email_word_set:
        if word in spam_word_frequency_dict.keys() and word in ham_word_frequency_dict.keys():
            p_word_spam = spam_word_frequency_dict[word] / spam_number
            p_word_ham = ham_word_frequency_dict[word] / ham_number
            word_prob_dict.setdefault(word, (p_word_spam, p_word_ham))
        elif word in spam_word_frequency_dict.keys() and word not in ham_word_frequency_dict.keys():
            p_word_spam = spam_word_frequency_dict[word] / spam_number
            p_word_ham = 0.01
            word_prob_dict.setdefault(word, (p_word_spam, p_word_ham))
        elif word not in spam_word_frequency_dict.keys() and word in ham_word_frequency_dict.keys():
            p_word_spam = 0.01
            p_word_ham = ham_word_frequency_dict[word] / ham_number
            word_prob_dict.setdefault(word, (p_word_spam, p_word_ham))
        elif word not in spam_word_frequency_dict.keys() and word not in ham_word_frequency_dict.keys():
            #若该词不在脏词词典中，概率设为0.4
            p_word_spam = 0.01
            p_word_ham = 0.01
            word_prob_dict.setdefault(word, (p_word_spam, p_word_ham))
            
    # print(sorted(word_prob_dict.items(), key=lambda d:d[1], reverse=True)[0:15])
    return word_prob_dict

In [357]:
def get_test_single_email_bayes(email_path, p_prior_spam=p_prior_spam, p_prior_ham=p_prior_ham):
    test_email_word_set = get_the_word_set_of_single_email(email_path)
    email_word_prob_dict = get_test_word_prob(test_email_word_set)
    
    email_word_prob_list = sorted(email_word_prob_dict.items(), key=lambda d:d[1], reverse=True)[0:15]
    
    p_word_spam = float(p_prior_spam) * 1e6
    p_word_ham = float(p_prior_ham) * 1e6
    
    """
    for word, (spam_prob, ham_prob) in email_word_prob_dict.items():
        p_word_spam *= (spam_prob)
        p_word_ham *= (ham_prob)
    """
    for word, (spam_prob, ham_prob) in email_word_prob_list:
        p_word_spam *= (spam_prob)
        p_word_ham *= (ham_prob)
        
    p = p_word_spam / (p_word_spam + p_word_ham + 1e-12)
    
    return p

In [372]:
def get_acc(email_paths, label="spam"):
    
    num = len(email_paths)
    
    if label == "spam":
        index = True
    else:
        index = False
        
    correct_conuts = 0
    wrong_counts = 0
    for path in tqdm(email_paths):
        p = get_test_single_email_bayes(path)
        if p >= 0.9:
            correct_conuts = correct_conuts + 1
        else:
            wrong_counts = wrong_counts + 1
    
    return (correct_conuts, wrong_counts) if index else (wrong_counts, correct_conuts)

In [373]:
get_acc(test_spam_path_list, label="spam")

100% 9592/9592 [01:21<00:00, 118.28it/s]


(8431, 1161)

In [374]:
get_acc(test_ham_path_list, label="ham")

100% 5027/5027 [00:39<00:00, 128.36it/s]


(4353, 674)

In [376]:
8431 / 9592, 1161 / 9592

(0.8789616346955796, 0.12103836530442035)

In [377]:
4353 / 5027, 674 / 5027

(0.8659240103441417, 0.13407598965585837)

In [378]:
(8473 + 4353) / (9592 + 5027)

0.8773513920240783