In [1]:
from sklearn.svm import SVC
from gensim.models.doc2vec import Doc2Vec
import re
import jieba
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.externals import joblib 
import random



In [2]:
divorce_model=Doc2Vec.load("divorce/divorce.model")
loan_model=Doc2Vec.load("loan/loan.model")
labor_model=Doc2Vec.load("labor/labor.model")

In [3]:
data_input = {1: "divorce/data.txt", 2: "labor/data.txt", 3: "loan/data.txt"}
models={1:divorce_model,2:loan_model,3:labor_model}
data_type = {1: "divorce", 2: "labor", 3: "loan"}
label_size = 20
test_ratio = 0.3
punction = "！？。＂＃＄％＆＇（）＊＋，－／：；＜＝＞＠［＼］＾＿｀｛｜｝～｟｠｢｣､、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."

In [9]:
def line_processing(line):  # 提取每行数据的文本内容
    line = line.strip().split('\t')
    sentence = line[1]
    sentence = re.sub(r'[{}]'.format(punction),' ',sentence).split(' ')
    sent=[]
    for sub_sentence in sentence:
        if sub_sentence!='':
            sent.extend(list(jieba.cut(sub_sentence)))
    return line[0], sent, line[2]

In [10]:
def constructDataSet(data_path,model_tag): #构建X，Y的数据集
    data_file = open(data_path,'r',encoding='utf-8')
    lines = data_file.read().splitlines()
    X=[]
    Y=[]
    d2v = models[model_tag]
    for line in lines:
        _,x,y = line_processing(line)
        x=d2v.infer_vector(x)
        X.append(x)
        y = list(map(int,y.split()))
        Y.append(y)
    Y = np.array(Y).transpose()
    return X,Y

In [11]:
def splitDataSet(X,Y): #构建训练集和测试集
    X_train,X_test,Y_train,Y_test =  train_test_split(X,Y,test_size=test_ratio,shuffle=True)
    return X_train,X_test,Y_train,Y_test

In [12]:
def trainSVM(X,Y,model_path): #训练单个分类器
    X_train,X_test,Y_train,Y_test = splitDataSet(X,Y)
    classifier = SVC(gamma='auto')
    classifier.fit(X_train,Y_train)
    accuracy = classifier.score(X_test,Y_test)
    joblib.dump(classifier,model_path)
    print(model_path,accuracy)

In [13]:
def beginTrain():
    for i in range(1,4):
        print(data_input[i].split("/")[0])
        X,Y = constructDataSet(data_input[i],i)
        tag = list(range(len(Y)))
        random.shuffle(tag)
        for j in tag:
            model_path = data_input[i].split("/")[0]+'/label'+str(j+1)+".model"
            print(model_path)
            trainSVM(X,Y[j],model_path)

In [14]:
def predictSingleLabel(x,model_type,label_tag):
    model_path = model_type+'/'+'label'+str(label_tag+1)+'.model'
    model = joblib.load(model_path)
    return model.predict([x])

In [15]:
def predict(text,model_type):
    text = re.sub(r'[{}]'.format(punction),' ',text).split(' ')
    words=[]
    for word in text:
        words.extend(list(jieba.cut(word)))
    d2v = models[model_type]
    x = d2v.infer_vector(words)
    res=[]
    for i in range(label_size):
        res.append(predictSingleLabel(x,data_type[model_type],i)[0])
    return res

In [16]:
beginTrain()

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\RPJ\AppData\Local\Temp\jieba.cache


divorce


Loading model cost 0.852 seconds.
Prefix dict has been built succesfully.


divorce/label10.model
divorce/label10.model 0.9704400722364794
divorce/label12.model
divorce/label12.model 0.9840319361277445
divorce/label11.model
divorce/label11.model 0.9884991920920064
divorce/label6.model
divorce/label6.model 0.9625510882995912
divorce/label17.model
divorce/label17.model 0.9963881760288946
divorce/label8.model
divorce/label8.model 0.9684440642524474
divorce/label3.model
divorce/label3.model 0.8966828248265374
divorce/label20.model
divorce/label20.model 0.9942971200456231
divorce/label19.model
divorce/label19.model 0.9940119760479041
divorce/label9.model
divorce/label9.model 0.9672084402623324
divorce/label14.model
divorce/label14.model 0.9907803440737573
divorce/label5.model
divorce/label5.model 0.9564680163482558
divorce/label18.model
divorce/label18.model 0.9944872160441023
divorce/label13.model
divorce/label13.model 0.9870734721034122
divorce/label2.model
divorce/label2.model 0.862370497101036
divorce/label4.model
divorce/label4.model 0.9449672084402624
divorce