In [31]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.metrics import recall_score, precision_score, f1_score, accuracy_score
from _utils import u_constant
path = u_constant.PATH_ROOT + "for learn/Python/NLP_in_Action/chapter-9/classification/"
ham_path = path + "data/ham_data.txt"
spam_path = path + "data/spam_data.txt"

In [23]:
def load_data_with_label(file_path, label):
    """
     加载数据，并赋予标签
    :param file_path: 文件路径
    :param label: 待赋予的标签
    """
    corpus = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if len(line) > 0:
                corpus.append(line)
        f.close()
    data = pd.DataFrame(corpus, columns=["text"])
    data["label"] = label
    return data

In [9]:
def get_data():
    """
    分别load正例样本和负例样本，组装输出
    """
    ham_data = load_data_with_label(ham_path, 1)
    spam_data = load_data_with_label(spam_path, 0)
    data = pd.concat([ham_data, spam_data], axis=0, ignore_index=True)
    return data        

In [35]:
class Treat:
    import string
    import re
    import jieba
    from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
    
    def __init__(self, feature_extract_method="tfidf", min_df=1, ngram_range=(1, 1)):
        self.stopword_list = self.__load_stop_word()
        self.invalid_pattern = re.compile("[{}]".format(re.escape(string.punctuation)))
        self.extract_method = self.__identify_method(feature_extract_method.lower())
        self.min_df = min_df
        self.ngram_range = ngram_range
    
    def __identify_method(self, s):
        """
        识别抽取方法
        """
        if s in ["tfidf", "tf-idf", "tf_idf"]:
            return "tfidf"
        return "bow"
        
    def __load_stop_word(self):
        with open(path + "dict/stop_words.utf8", encoding="utf-8") as f:
            stopword_list = f.readlines()
            f.close()
        return stopword_list
    
    def preprocess(self, text):
        # 去除特殊符号
        text = self.invalid_pattern.sub("", text)
        # 分词 & 去停用词
        tokens = [token.strip() for token in jieba.cut(text) \
                  if token.strip() not in self.stopword_list]
        return " ".join(tokens)
    
    def fit_transform(self, corpus):
        """
        处理输入的语料，输出向量化矩阵
        """
        normed_data = list(map(self.preprocess, corpus))
        if self.extract_method == "bow":
            self.vec = CountVectorizer(min_df=self.min_df, ngram_range=self.ngram_range)
        elif self.extract_method == "tfidf":
            self.vec = TfidfVectorizer(min_df=self.min_df, ngram_range=self.ngram_range)
        features = self.vec.fit_transform(normed_data)
        return features
    
    def transform(self, corpus):
        normed_data = list(map(self.preprocess, corpus))
        return self.vec.transform(normed_data)
        

In [33]:
data = get_data()
X = data["text"].values
y = data["label"].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

In [38]:
extract_methods = ["bow", "tfidf"]
models = ["mnb", "svm", "lr"]
min_df = 1
ngram_range = (1, 1)
for extract_method in extract_methods:
    treat = Treat(extract_method, min_df, ngram_range)
    train = treat.fit_transform(X_train)
    test = treat.transform(X_test)
    print("-------------------------------------------")
    print("--------------%s based-------------" % extract_method)
    for model in models:
        if model == "mnb":
            m = MultinomialNB()
        elif model == "svm":
            m = SGDClassifier(loss="hinge", n_iter=100)
        elif model == "lr":
            m = LogisticRegression()
        m.fit(train, y_train)
        y_pred = m.predict(test)
        acc = accuracy_score(y_test, y_pred)
        recall = recall_score(y_test, y_pred)
        precision = precision_score(y_test, y_pred)
        f1 = f1_score(y_test, y_pred)
        print("%s model" % model)
        print("Acc: %.2f\tPrecision: %.2f\tRecall: %.2f\tF1: %.2f" \
              % (acc, precision, recall, f1))

-------------------------------------------
--------------bow based-------------
mnb model
Acc: 0.99	Precision: 0.98	Recall: 0.99	F1: 0.99
svm model
Acc: 0.98	Precision: 0.99	Recall: 0.97	F1: 0.98
lr model
Acc: 0.99	Precision: 1.00	Recall: 0.98	F1: 0.99
-------------------------------------------
--------------tfidf based-------------
mnb model
Acc: 0.99	Precision: 0.98	Recall: 0.99	F1: 0.99
svm model
Acc: 0.99	Precision: 0.99	Recall: 0.99	F1: 0.99
lr model
Acc: 0.99	Precision: 0.98	Recall: 0.99	F1: 0.99
