# functions of ID3

## calc shannon entropy

In [4]:
from math import log

def calcShannonEnt(dataSet):
    numEntries=len(dataSet) # 数据集大小
    
    # get freq count for each class, using dict
    labelCounts={}
    for featVec in dataSet:
        currentLabel=featVec[-1] # the last item is the class name
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    
    # get shannon index for each class
    shannonEnt=0.0
    for key in labelCounts:
        prob=float(labelCounts[key])/numEntries
        shannonEnt -= prob*log(prob, 2) # -p*log2(p)
    return shannonEnt

# test
def createDataSet():
    dataSet=[[1,1,'yes'],
            [1,1,'yes'],
            [1,0,'no'],
            [0,1,'no'],
            [0,1,'no']]
    labels=['no surfacing', 'flippers']
    return dataSet, labels
myDat, labels=createDataSet()
print(myDat)
print(labels)

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
['no surfacing', 'flippers']


In [5]:
calcShannonEnt(myDat)

0.9709505944546686

In [6]:
# change a value, and get shannon again
myDat[0][-1]='maybe'
print(myDat)
calcShannonEnt(myDat)

[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]


1.3709505944546687

## split dataSet

In [13]:
def splitDataSet(dataSet, axis, value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            subVec=featVec[:axis] + featVec[axis+1:] # 相当于去掉了该位置的数字
            #subVec.extend()
            retDataSet.append(subVec)
    return retDataSet
# test
print(myDat)
print( splitDataSet(myDat, 0, 1) )
splitDataSet(myDat, 0, 0)

[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'maybe'], [1, 'yes'], [0, 'no']]


[[1, 'no'], [1, 'no']]

## choose best feature to split

In [25]:
def chooseBestFeatureToSplit(dataSet):
    numFeatures=len(dataSet[0])-1
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0.0; bestFeature=-1
    for i in range(numFeatures):
        #1. get uniq label list of each feature
        valueList=[example[i] for example in dataSet]
        uniqValues=set(valueList)
        
        #2. calc entropy for each split
        newEntropy=0.0
        for value in uniqValues:
            subDataSet=splitDataSet(dataSet, i, value)
            prob=len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain=baseEntropy-newEntropy
        print('=======> i=',i,infoGain,baseEntropy,newEntropy)
        
        #record the best infoGain
        if infoGain > bestInfoGain:
            bestInfoGain=infoGain
            bestFeature=i
    return bestFeature
# test
myDat, labels=createDataSet()
chooseBestFeatureToSplit(myDat)



0

In [18]:
[example[0] for example in myDat]

[1, 1, 1, 0, 0]

In [19]:
set([example[0] for example in myDat])

{0, 1}

## majority class

In [27]:
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount:
            classCount[vote]=0
        classCount[vote]+=1
    # get max
    maxK=0; maxV=0;
    for k,v in classCount.items():
        if v>maxV:
            maxK=k
    return maxK
# test
test_input=[e[-1] for e in myDat]
print(test_input)
majorityCnt(test_input)

['yes', 'yes', 'no', 'no', 'no']


'no'

## get tree

In [35]:
def createTree(dataSet, labels):
    print('>>>>>> enter fn:')
    classList=[e[-1] for e in dataSet]
    
    #1. 类别相同则停止划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    
    #2. 使用完特征时，返回出现最多的类别
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    
    #3. 建树
    bestFeat=chooseBestFeatureToSplit(dataSet) #选择最佳分类的feature编号
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])
    
    featValues=[e[bestFeat] for e in dataSet]
    uniqueValues=set(featValues)
    for value in uniqueValues:
        subLabels = labels[:]
        subDataSet=splitDataSet(dataSet, bestFeat, value)
        myTree[bestFeatLabel][value]=createTree( subDataSet, subLabels )
    return myTree
#test
myDat, labels=createDataSet()
createTree(myDat, labels)

>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:


{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

In [28]:
a=[0,1,2,3,0,4]
a.count(0)

2

In [30]:
len(myDat[0])

3

## plot tree //todo

# test on iris data

In [43]:
import pandas as pd
iris=pd.read_csv('G://ML_MachineLearning//iris_data/iris.csv', index_col =0)
iris.head()

Unnamed: 0,Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species
1,5.1,3.5,1.4,0.2,setosa
2,4.9,3.0,1.4,0.2,setosa
3,4.7,3.2,1.3,0.2,setosa
4,4.6,3.1,1.5,0.2,setosa
5,5.0,3.6,1.4,0.2,setosa


In [86]:
# get data
iris.iloc[:,:].head()

Unnamed: 0,Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species
1,5.1,3.5,1.4,0.2,setosa
2,4.9,3.0,1.4,0.2,setosa
3,4.7,3.2,1.3,0.2,setosa
4,4.6,3.1,1.5,0.2,setosa
5,5.0,3.6,1.4,0.2,setosa


In [80]:
#
def get2DArray(npDat):
    myDat=[]
    for i in range(npDat.shape[0]):
        arr_row=[]
        for j in range(npDat.shape[1]):
            arr_row.append(npDat.iloc[i, j])
        myDat.append(arr_row)
    return myDat
myDat2=get2DArray(iris.iloc[:,:])
myDat2[0:4]

[[5.1, 3.5, 1.4, 0.2, 'setosa'],
 [4.9, 3.0, 1.4, 0.2, 'setosa'],
 [4.7, 3.2, 1.3, 0.2, 'setosa'],
 [4.6, 3.1, 1.5, 0.2, 'setosa']]

In [82]:
# get labels
labels2=[]
for i in iris['Species']:
    labels2.append(i)
labels2[0:5]

['setosa', 'setosa', 'setosa', 'setosa', 'setosa']

In [84]:
# run decision tree
myTree2=createTree(myDat2, labels2)

>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter fn:
>>>>>> enter f

## ID3仅适用于分类数据

结果惨不忍睹！分类点太多！

定量数据不合适

In [85]:
myTree2

{'setosa': {1.0: 'setosa',
  1.1: 'setosa',
  1.2: 'setosa',
  1.3: 'setosa',
  1.4: {'setosa': {2.9: 'setosa',
    3.0: 'setosa',
    3.2: 'setosa',
    3.3: 'versicolor',
    3.4: 'setosa',
    3.5: 'setosa',
    3.6: 'setosa',
    4.2: 'setosa'}},
  1.5: 'setosa',
  1.6: 'setosa',
  1.7: 'setosa',
  1.9: 'setosa',
  3.0: 'versicolor',
  3.3: 'versicolor',
  3.5: 'versicolor',
  3.6: 'versicolor',
  3.7: 'versicolor',
  3.8: 'versicolor',
  3.9: 'versicolor',
  4.0: 'versicolor',
  4.1: {'setosa': {5.6: 'versicolor', 5.7: 'virginica', 5.8: 'versicolor'}},
  4.2: 'versicolor',
  4.3: 'versicolor',
  4.4: 'versicolor',
  4.5: {'setosa': {4.9: 'virginica',
    5.4: 'versicolor',
    5.6: 'versicolor',
    5.7: 'versicolor',
    6.0: 'versicolor',
    6.2: 'versicolor',
    6.4: 'versicolor'}},
  4.6: 'versicolor',
  4.7: 'versicolor',
  4.8: {'setosa': {5.9: 'versicolor',
    6.0: 'virginica',
    6.2: 'virginica',
    6.8: 'versicolor'}},
  4.9: {'setosa': {2.5: 'versicolor',
    2.7: 