Skip to content

Commit be1e11d

Browse files
committed
Finalize Hunt's algorithm script
1 parent 454807d commit be1e11d

File tree

5 files changed

+293
-4
lines changed

5 files changed

+293
-4
lines changed

huntsDT/README.md

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,152 @@
11
# Hunt's algorithm to build Decision Trees
22

3-
Build a Decision Tree by splitting based on GINI index. Also outputs the GINI values to explain why the algo chose some attribute.
3+
Build a Decision Tree by splitting based on GINI index. Also outputs the GINI values to explain why the algo chose some attribute.
4+
5+
## Example
6+
7+
Input:
8+
9+
*Note*: The last column must be the class identifier
10+
11+
```bash
12+
python3 huntDT_driver.py -v --table_csv table.csv
13+
```
14+
15+
Output:
16+
17+
*Note*: Each node in the tree is represented as (attribute column number that this node was previously splitted: attribute value). That's why the root node is always (0: None).
18+
19+
```
20+
Gini split for 0: 0.48
21+
Gini split for 1: 0.16250000000000003
22+
Gini split for 2: 0.4789285714285715
23+
Choose att 1
24+
Gini split for 0: 0.375
25+
Gini split for 2: 0.0
26+
Choose att 2
27+
Gini split for 0: 0.0
28+
Choose att 0
29+
Gini split for 0: 0.0
30+
Choose att 0
31+
Gini split for 0: 0.0
32+
Choose att 0
33+
Gini split for 0: 0.21428571428571433
34+
Gini split for 2: 0.16666666666666666
35+
Choose att 2
36+
Gini split for 0: 0.3333333333333333
37+
Choose att 0
38+
Gini split for 0: 0.0
39+
Choose att 0
40+
Gini split for 0: 0.0
41+
Choose att 0
42+
Gini split for 0: 0.0
43+
Gini split for 2: 0.0
44+
Choose att 0
45+
Gini split for 2: 0.0
46+
Choose att 2
47+
Gini split for 2: 0.0
48+
Choose att 2
49+
|-> (0: None)
50+
51+
|-> (1: High)
52+
53+
|-> (2: Poor)
54+
55+
|-> (0: F)
56+
57+
|-> (0: M)
58+
59+
|-> (B: 1)
60+
61+
|-> (A: 0)
62+
63+
|-> (2: Fair)
64+
65+
|-> (0: F)
66+
67+
|-> (0: M)
68+
69+
|-> (B: 0)
70+
71+
|-> (A: 1)
72+
73+
|-> (2: Excellent)
74+
75+
|-> (0: F)
76+
77+
|-> (0: M)
78+
79+
|-> (B: 2)
80+
81+
|-> (A: 0)
82+
83+
|-> (1: Low)
84+
85+
|-> (2: Poor)
86+
87+
|-> (0: F)
88+
89+
|-> (B: 1)
90+
91+
|-> (A: 1)
92+
93+
|-> (0: M)
94+
95+
|-> (B: 1)
96+
97+
|-> (A: 0)
98+
99+
|-> (2: Fair)
100+
101+
|-> (0: F)
102+
103+
|-> (B: 2)
104+
105+
|-> (A: 0)
106+
107+
|-> (0: M)
108+
109+
|-> (2: Excellent)
110+
111+
|-> (0: F)
112+
113+
|-> (B: 3)
114+
115+
|-> (A: 0)
116+
117+
|-> (0: M)
118+
119+
|-> (1: Medium)
120+
121+
|-> (0: F)
122+
123+
|-> (2: Poor)
124+
125+
|-> (2: Fair)
126+
127+
|-> (B: 0)
128+
129+
|-> (A: 2)
130+
131+
|-> (2: Excellent)
132+
133+
|-> (B: 0)
134+
135+
|-> (A: 1)
136+
137+
|-> (0: M)
138+
139+
|-> (2: Poor)
140+
141+
|-> (B: 0)
142+
143+
|-> (A: 3)
144+
145+
|-> (2: Fair)
146+
147+
|-> (2: Excellent)
148+
149+
|-> (B: 0)
150+
151+
|-> (A: 2)
152+
```

huntsDT/huntDS.py

Whitespace-only changes.

huntsDT/huntDT.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import numpy as np
2+
3+
4+
class Node:
5+
def __init__(self, records, visited=None):
6+
self.records = records
7+
self.attribute_id = 0
8+
self.attribute_val = None
9+
self.children = []
10+
self.visited = visited or [False for i in range(len(records)-1)]
11+
12+
def set_attribute(self, att_id, att_val):
13+
self.attribute_id = att_id
14+
self.attribute_val = att_val
15+
16+
def add_child(self, node):
17+
self.children.append(node)
18+
19+
def display(self, level=0):
20+
out = level * 5 * ' '
21+
out += f"|-> ({self.attribute_id}: {self.attribute_val})"
22+
out += "\n"
23+
print(out)
24+
for child in self.children:
25+
child.display(level=level+1)
26+
27+
28+
class HuntDT:
29+
def __init__(self, records=[], verbose=False):
30+
self.records = records
31+
self.verbose = verbose
32+
# build domain of each attribute: att id -> list of str
33+
self.domains = [self.domain(i) for i in range(len(records[0]))]
34+
35+
# return set of different values the attribute can take
36+
def domain(self, att_id):
37+
s = set()
38+
for r in self.records:
39+
s.add(r[att_id])
40+
return s
41+
42+
# build tree
43+
def build(self):
44+
root = Node(self.records)
45+
root = self.build_tree(root)
46+
return root
47+
48+
def build_tree(self, node, level=0):
49+
split_att = self.select(node, level)
50+
if split_att == -1: return node
51+
visited = node.visited.copy()
52+
visited[split_att] = True
53+
for val in self.domains[split_att]:
54+
group = filter(lambda r: r[split_att]==val, node.records)
55+
child = Node(list(group), visited=visited)
56+
child.set_attribute(split_att, val)
57+
node.add_child(self.build_tree(child, level+1))
58+
return node
59+
60+
# select based on gini split index
61+
def select(self, node, level=0):
62+
records = node.records
63+
if len(records) == 0: return -1
64+
min_gini = np.inf
65+
chosen = -1
66+
for att_id in range(len(self.domains)-1):
67+
if node.visited[att_id]: continue
68+
gini = self.gini_index(records, att_id)
69+
if self.verbose:
70+
margin = level * 5 * ' '
71+
# print(f"{margin}At state {node.visited}")
72+
print(f"{margin}Gini split for {att_id}: {gini}")
73+
if gini < min_gini:
74+
min_gini = gini
75+
chosen = att_id
76+
if chosen != -1 and self.verbose: print(f"{level*5*' '}Choose att {chosen}")
77+
else:
78+
# find the frequency of classes
79+
for class_id in self.domains[-1]:
80+
count = len(list(filter(lambda r: r[-1]==class_id, records)))
81+
child = Node(records)
82+
child.set_attribute(class_id, count)
83+
node.add_child(child)
84+
return chosen
85+
86+
def gini_index(self, records, split_att_id):
87+
n_records = len(records)
88+
gini = 0.0
89+
# split records by attribute
90+
for val in self.domains[split_att_id]:
91+
group = list(filter(lambda r: r[split_att_id]==val, records))
92+
if len(group) == 0:
93+
continue
94+
squared_sum = 0.0
95+
classes = [r[-1] for r in group]
96+
for class_id in self.domains[-1]:
97+
p = classes.count(class_id) / len(group)
98+
squared_sum += p * p
99+
gini += (1.0 - squared_sum) * (len(group)/n_records)
100+
return gini
101+

huntsDT/huntDT_driver.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,31 @@
11
import argparse
2-
from huntDS import *
2+
import csv
3+
from huntDT import *
34

45
def main():
56
parser = argparse.ArgumentParser(description="Hunt's algorithm walkthrough. By Roundofthree.")
67
parser.add_argument("-v", action='store_true', help="Print the intermediate GINI indexes.")
7-
parser.add_argument("--table_csv", type=str, help="File path to a .csv file with the records.")
8+
parser.add_argument("--table_csv", type=str, required=True, help="File path to a .csv file with the records.")
89
arg = parser.parse_args()
9-
verbose = arg.v
10+
verbose = arg.v
11+
table_csv = arg.table_csv
12+
records = []
13+
14+
with open(table_csv, 'r') as f:
15+
f = csv.reader(f, delimiter=',')
16+
for line in f:
17+
r = [i for i in line]
18+
records.append(r)
19+
20+
# n_records = len(records)
21+
# n_attributes = len(records[0]) - 1 # minus class column
22+
23+
huntDT = HuntDT(records=records, verbose=verbose)
24+
root = huntDT.build()
25+
root.display() # display a tree with
26+
# |-> (att_id: att_val)
27+
# |-> (att_id: att_val)
28+
1029

1130
if __name__ == '__main__':
1231
main()

huntsDT/table.csv

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
M,High,Fair,A
2+
M,Medium,Excellent,A
3+
M,Medium,Excellent,A
4+
M,Medium,Poor,A
5+
M,Medium,Poor,A
6+
M,Medium,Poor,A
7+
F,Medium,Fair,A
8+
F,Medium,Fair,A
9+
F,Medium,Excellent,A
10+
F,Low,Poor,A
11+
M,High,Poor,B
12+
M,High,Excellent,B
13+
M,High,Excellent,B
14+
M,Low,Poor,B
15+
F,Low,Fair,B
16+
F,Low,Fair,B
17+
F,Low,Excellent,B
18+
F,Low,Excellent,B
19+
F,Low,Excellent,B
20+
F,Low,Poor,B

0 commit comments

Comments
 (0)