In [1]:
#Data set loading,the data is preliminarily processed
import math
import numpy as np
import re
from sklearn.metrics import accuracy_score, classification_report, precision_score, recall_score

In [2]:
def getDateSet(dataPath=''):
    with open(dataPath, encoding='utf-8') as f:
        txt_data = f.readlines()
    data=[]
    classTag=[]
    for line in txt_data:
        line_split = line.strip('\n').split('\t')
        if line_split[0]=='ham':
            data.append(line_split[1])
            classTag.append(1)
        elif line_split[0]=='spam':
            data.append(line_split[1])
            classTag.append(0)
    return data, classTag

In [3]:
class NaiveBates:
    def __init__(self):
        self.__ham_count=0
        self.__spam_count=0

        self.__ham_words_count=0
        self.__spam_words_count=0

        self.__ham_words=list()
        self.__spam_words=list()

        self.__word_dictionary_set=set()
        self.__word_dictionary_size=0

        self.__ham_map=dict()
        self.__spam_map=dict()

        self.__ham_probability=0.0
        self.__spam_probability=0.0

    
    def data_preprocess(self, sentence):
        temp_info=re.sub('\W',' ',sentence.lower())
        words=re.split(r'\s+',temp_info)
        return list(filter(lambda x:len(x)>=3,words)) 

    def fit(self, X_train, Y_train):
        words_line=[]
        for sentence in X_train:
            words_line.append(self.data_preprocess(sentence))
        self.build_word_set(words_line, Y_train)
        self.word_count()

    def build_word_set(self, X_train, y_train):
        for words, y in zip(X_train, y_train):
            if y == 0:
                self.__ham_count += 1
                self.__ham_words_count += len(words)
                for word in words:
                    self.__ham_words.append(word)
                    self.__word_dictionary_set.add(word)
            if y == 1:
                self.__spam_count += 1
                self.__spam_words_count += len(words)
                for word in words:
                    self.__spam_words.append(word)
                    self.__word_dictionary_set.add(word)

        self.__word_dictionary_size = len(self.__word_dictionary_set)

    def word_count(self):
        for word in self.__ham_words:
            self.__ham_map[word] = self.__ham_map.setdefault(word, 0) + 1

        for word in self.__spam_words:
            self.__spam_map[word] = self.__spam_map.setdefault(word, 0) + 1

        self.__ham_probability = self.__ham_count / (self.__ham_count + self.__spam_count)
        self.__spam_probability = self.__spam_count / (self.__ham_count + self.__spam_count)
        print("正常短信词频：{}".format(self.__ham_map))
        print("垃圾短信词频：{}".format(self.__spam_map))
    
    def predict(self, X_test):
        return [self.predict_one(sentence) for sentence in X_test]

    def predict_one(self, sentence):
        ham_pro = 0
        spam_pro = 0
        words = self.data_preprocess(sentence)
        for word in words:
            ham_pro += math.log(
                (self.__ham_map.get(word, 0) + 1) / 
                (self.__ham_count + self.__word_dictionary_size))

            spam_pro += math.log(
                (self.__spam_map.get(word, 0) + 1) / 
                (self.__spam_count + self.__word_dictionary_size))

        ham_pro += math.log(self.__ham_probability)
        spam_pro += math.log(self.__spam_probability)
        return int(spam_pro >= ham_pro)

In [4]:
if __name__ == '__main__':
    data, classTag = getDateSet(dataPath="./SMSSpamCollection")
    train_size = 300
    train_x, trian_y = data[:train_size], classTag[:train_size]
    test_x = data[train_size:]
    test_y = classTag[train_size:]
    nb_model =  NaiveBates()
    nb_model.fit(train_x, trian_y)
    pre_y = nb_model.predict(test_x)

    accuracy_score_value = accuracy_score(test_y, pre_y)
    recall_score_value = recall_score(test_y, pre_y)
    precision_score_value = precision_score(test_y, pre_y)
    classification_report_value = classification_report(test_y, pre_y)
    print("准确率:", accuracy_score_value)
    print("召回率:", recall_score_value)
    print("精确率:", precision_score_value)
    print(classification_report_value)

正常短信词频：{'free': 14, 'entry': 3, 'wkly': 3, 'comp': 1, 'win': 3, 'cup': 1, 'final': 2, 'tkts': 1, '21st': 1, 'may': 2, '2005': 1, 'text': 9, '87121': 2, 'receive': 5, 'question': 1, 'std': 2, 'txt': 11, 'rate': 1, 'apply': 3, '08452810075over18': 1, 'freemsg': 2, 'hey': 2, 'there': 1, 'darling': 1, 'been': 4, 'week': 3, 'now': 8, 'and': 6, 'word': 2, 'back': 1, 'like': 1, 'some': 1, 'fun': 1, 'you': 26, 'for': 10, 'still': 1, 'xxx': 1, 'chgs': 1, 'send': 4, 'rcv': 2, 'winner': 3, 'valued': 2, 'network': 3, 'customer': 6, 'have': 10, 'selected': 4, 'receivea': 1, '900': 2, 'prize': 8, 'reward': 1, 'claim': 11, 'call': 17, '09061701461': 1, 'code': 5, 'kl341': 1, 'valid': 3, 'hours': 1, 'only': 4, 'had': 1, 'your': 13, 'mobile': 8, 'months': 1, 'more': 3, 'entitled': 1, 'update': 2, 'the': 15, 'latest': 1, 'colour': 1, 'mobiles': 2, 'with': 6, 'camera': 2, '08002986030': 1, 'six': 1, 'chances': 1, 'cash': 4, 'from': 7, '100': 3, '000': 2, 'pounds': 2, 'csh11': 1, '87575': 1, 'cost': 1, '1