In [1]:
from math import log

In [2]:
class Node:
    def __init__(self,dataset,attrList):
        self.dataset = dataset
        self.attrList = attrList
        self.name = str(attrList)
        self.isLeaf = False
        self.className = None
        self.splitAttr = None
        self.edgeList = []
        self.size = len(dataset)
        self.classGroups = self.dataset.groupby('class')
        self.nClass = self.dataset['class'].unique()
        self.split()
        
    def split(self):
        if len(self.nClass) == 1:
            self.isLeaf = True
            self.className = self.nClass[0]
            return
        if len(self.attrList) == 0:
            self.isLeaf = True
            self.className = self.dataset.groupby('class').size().idxmax()
            return
        self.splitAttr = self.selectionAttribute()
        self.newList = self.attrList[:]
        self.newList.remove(self.splitAttr)
        
        for i,j in self.dataset.groupby(self.splitAttr):
            newNode = Node(j.drop(self.splitAttr,axis=1),self.newList)
            self.edgeList.append(Edge(newNode,i))
        
    def selectionAttribute(self):
        hOfx = 0
        for name in self.nClass:
            x = float(self.classGroups.size()[name]/float(self.size))
            hOfx+= - x * log(x,2)
        
        gain = -1
        splitChoice = None
        for attr in self.attrList:
            temp = hOfx - self.entropy(attr)
            if temp > gain:
                gain = temp
                splitChoice = attr
        return splitChoice
                
        
    def entropy(self,column):
        groupCol = self.dataset.groupby(column)
        groupCom = self.dataset.groupby([column,'class'])
        columnVal = groupCol.groups.keys()
        pairClassColumn = groupCom.groups.keys()
        
        result = 0
        for value in columnVal:
            hOfy = 0
            size = groupCol.size()[value]
            for name in self.nClass:
                if (value,name) in pairClassColumn:
                    y = float(groupCom.size()[value][name]/float(size))
                    hOfy+= - y * log(y,2)
            result += (size/self.size) * hOfy
        
        return result

In [3]:
class Edge:
    def __init__(self,name,value):
        self.name = name
        self.value = value
        self.size = self.name.size
        
    def move(self,value):
        if value == self.value:
            return self.name
        else:
            return None

In [4]:
class Tree:
    def __init__(self,dataset,attrList):
        self.dataset = dataset
        self.attrList = attrList
        self.root =  Node(self.dataset,self.attrList)
        
    def test(self,row):
        current = self.root
        while True:
            if current.isLeaf:
                return current.className
            else:
                found = False
                for edge in current.edgeList:
                    newNode = edge.move(row[current.splitAttr])
                    if newNode is not None:
                        current = newNode
                        found = True
                        break
                if not found:
                    bestChoice = None
                    maxPP = 0
                    for edge in current.edgeList:
                        if edge.size > maxPP:
                            bestChoice = edge.name
                            max_dataset = edge.size
                    current = bestChoice

In [5]:
import pandas
from sklearn.model_selection import KFold
import random,time

random.seed(time.time())

car = "car.txt"
car_names = ['buying','maint','doors','persons','lug_boot','safety','class']

chess = "kr-vs-kp.data.txt"
chess_names = ['bkblk', 'bknwy', 'bkon8', 'bkona', 'bkspr', 'bkxbq', 'bkxcr', 'bkxwp', 'blxwp', 'bxqsq', 'cntxt', 'dsopp', 'dwipd', 'hdchk', 'katri', 'mulch', 'qxmsq', 'r2ar8', 'reskd', 'reskr', 'rimmx', 'rkxwp', 'rxmsq', 'simpl', 'skach', 'skewr', 'skrxp', 'spcop', 'stlmt', 'thrsk', 'wkcti', 'wkna8', 'wknck', 'wkovl', 'wkpos', 'wtoeg', 'class']

connect = "connect-4.txt"
connect_names = ['a1','a2','a3','a4','a5','a6','b1','b2','b3','b4','b5','b6','c1','c2','c3','c4','c5','c6','d1','d2','d3','d4','d5','d6','e1','e2','e3','e4','e5','e6','f1','f2','f3','f4','f5','f6','g1','g2','g3','g4','g5','g6','class']

mushroom = "agaricus-lepiota.txt"
mushroom_names = ['class','cap-shape','cap-surface','cap-color','bruises?','odor','gill-attachment','gill-spacing','gill-size','gill-color','stalk-shape','stalk-root','stalk-surface-above-ring','stalk-surface-below-ring','stalk-color-above-ring ','stalk-color-below-ring ','veil-type','veil-color','ring-number','ring-type','spore-print-color','population','habitat']

nursery = "nursery.txt"
nursery_names = ['parents','has_nurs','form','children','housing','finance','social','health','class']

abalone = "abalone.txt"
abalone_names = ['class','Length','Diameter','Height','Whole weight','Shucked weight','Viscera weight','Shell weight','Rings']

balance = "balance.txt"
balance_names = ['class','Left-Weight','Left-Distance','Right-Weight','Right-Distance']

dataset = connect
nameSet = connect_names

dataset = pandas.read_csv(dataset, names = nameSet)
print(len(dataset))
# print(dataset)
dataset = dataset.head(1000)
dataset_split = KFold(n_splits=10,random_state=random.randint(1,100),shuffle=True)
count = 1

avg = 0
for train, test in dataset_split.split(dataset):
    newNames = nameSet[:]
    newNames.remove('class')
    tree = Tree(dataset.iloc[train], newNames)
    testSet = dataset.iloc[test]
    valid, invalid, total = 0,0,0
    for index, row in testSet.iterrows():
        total += 1
        if(tree.test(row) == row['class']):
            valid += 1
        else:
            invalid += 1
    
    print(str(count)+".Accuracy: "+str((valid/float(total))*100))
    avg += ((valid/float(total))*100)
    count += 1
    
avg /= count-1
print("@@@@@@@@@@@@@@@@@@@@@@@@@\nAverage Accuracy: "+str(avg))



67557
1.Accuracy: 80.0
2.Accuracy: 80.0
3.Accuracy: 73.0
4.Accuracy: 83.0
5.Accuracy: 80.0
6.Accuracy: 80.0
7.Accuracy: 70.0
8.Accuracy: 77.0
9.Accuracy: 67.0
10.Accuracy: 86.0
@@@@@@@@@@@@@@@@@@@@@@@@@
Average Accuracy: 77.6
