In [1]:
from sklearn import model_selection
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pylab import *
%matplotlib inline 
%config InlineBackend.figure_format="retina" 

In [2]:
# 判断当前数据集中是否类别全部相同
def same_category(D):
    classList=D[:,-1].tolist()

    res=-1
    for c in classList:
        if classList.count(c)==len(classList):
            res=c
            break         
    
    if res != -1:
        flag = True
    else:
        flag = False
    return flag, c

# 找出数据集D中样本数最多的类别
def findMost(D):
    good=D.loc[D['好瓜']==1,:].shape[0]
    bad=D.loc[D['好瓜']==0,:].shape[0]
    if good>=bad:
        return '好瓜'
    else:
        return '坏瓜'



# 根据beta划分数据集
def getDv_b(D,beta):
    dot = np.dot(D[:, :-1], beta).flatten()
    dataSet1 = D[dot <= 0]
    dataSet2 = D[dot > 0]
    return dataSet1, dataSet2



In [3]:
# 数据处理    
def ReadData(pathname):
    data=pd.read_csv(pathname)
    data.loc[data['好瓜']=='是', '好瓜'] = 1
    data.loc[data['好瓜']=='否', '好瓜'] = 0
    data = data.iloc[:,1:]
    
    attribute=data.columns[:-1]
    label=data['好瓜']
    
    data=oneHotData(data,label)

     
    return data,attribute

def oneHotData(df,classLabel):

    # 色泽
    color = pd.get_dummies(df.色泽, prefix="色泽")
    # 根蒂
    root = pd.get_dummies(df.根蒂, prefix="根蒂")
    # 敲声
    knocks = pd.get_dummies(df.敲声, prefix="敲声")
    # 纹理
    texture = pd.get_dummies(df.纹理, prefix="纹理")
    # 脐部
    navel = pd.get_dummies(df.脐部, prefix="脐部")
    # 触感
    touch = pd.get_dummies(df.触感, prefix="触感")
    # 密度和含糖量
    densityAndsugar = pd.DataFrame()
    densityAndsugar["密度"] = df.密度
    densityAndsugar["含糖率"] = df.含糖率
    # 融合
    newData = pd.concat([color, root, knocks, texture, navel, touch, densityAndsugar], axis=1)
    newFeatures = list(newData.columns)
    newData = np.asarray(newData, dtype="float64")
    classLabel = np.asarray(classLabel, dtype="int").reshape(-1, 1)
    # 新的特征数据和类融合
    newData = np.concatenate((newData, classLabel), axis=1)
    # 在第一列添加1
    newData = np.insert(newData, 0,
                           np.ones(df.shape[0]),
                           axis=1)
    return newData

In [4]:
def sigmoid(Z):
    return 1.0/(1+np.exp(-Z))

def gradDescent(data,label,eta=0.1,n_iters=500):
    m,n=data.shape
    label=label.reshape(-1,1)
    
    beta=np.ones((n,1))
    
    for i in range(n_iters):
        y_sig=sigmoid(data.dot(beta))
        m=y_sig-label  #计算误差值
        beta=beta-data.transpose().dot(m)*eta   #误差反传更新参数
        
    return beta


In [5]:
def treeGenerate(D,root,lastNode,lastBeta):
    flag, category = same_category(D)
    if flag:
        if category==1:
            lastNode[lastBeta] = '好瓜'  
        else:
            lastNode[lastBeta] = '坏瓜'
        return
    
    if len(D[0])==1:
        lastNode[lastBeta]=findMost(D)
        return
    
    bestBeta=gradDescent(D[:, :-1], D[:, -1])


    nodeTxt=""

    for i in range(len(bestBeta)):
        if i==0:
            continue
        else:
            nodeTxt+="w"+str(i)+" "+str(bestBeta[i][0])+' \n '
            
    nodeTxt+="<=" + str(-bestBeta[0][0])
    
    root[nodeTxt]={nodeTxt:{}}
    #print(root[nodeTxt],'\n')
    
    Dv_b1,Dv_b2=getDv_b(D,bestBeta)
    class1="是"
    class2="否"
    # 根据beta进行数据集分割
    root[nodeTxt][class1] = {}
    root[nodeTxt][class2] = {}
    treeGenerate(Dv_b1,root[nodeTxt][class1],root[nodeTxt],class1)
    
    treeGenerate(Dv_b2,root[nodeTxt][class2],root[nodeTxt],class2)

In [6]:
if __name__ == '__main__':
    data,attribute=ReadData('../../data/watermelon_3.csv')
    root={}
    

    lastNode=None
    lastA=None

    treeGenerate(data,root,lastNode,lastA)
    print(root)


{'w1 1.111893106641747 \n w2 -1.1959732358553834 \n w3 -0.05708423182290375 \n w4 -0.832401400006815 \n w5 -1.8906991065001257 \n w6 2.581936145470397 \n w7 -1.500237135644533 \n w8 2.191474174614811 \n w9 -0.832401400006815 \n w10 -2.1785770731647216 \n w11 2.176078711661651 \n w12 -0.1386659995334692 \n w13 2.2607505797621132 \n w14 -3.599223476240943 \n w15 1.197308535442288 \n w16 -1.0268188239240008 \n w17 -0.11434553711253821 \n w18 0.38751251737701736 \n w19 -3.1495818501742896 \n <=2.141164361036537': {'w1 1.111893106641747 \n w2 -1.1959732358553834 \n w3 -0.05708423182290375 \n w4 -0.832401400006815 \n w5 -1.8906991065001257 \n w6 2.581936145470397 \n w7 -1.500237135644533 \n w8 2.191474174614811 \n w9 -0.832401400006815 \n w10 -2.1785770731647216 \n w11 2.176078711661651 \n w12 -0.1386659995334692 \n w13 2.2607505797621132 \n w14 -3.599223476240943 \n w15 1.197308535442288 \n w16 -1.0268188239240008 \n w17 -0.11434553711253821 \n w18 0.38751251737701736 \n w19 -3.149581850174