In [22]:
from typing import List

In [23]:
import os
os.getcwd()

'/Users/ze/CS686/cs686-repo/Asg2'

In [24]:
import csv
import math
from math import *
import numpy as np  # numpy==1.19.2

import dt_global

In [25]:
def read_data(file_path: str):
    """
    Reads data from file_path, 

    :param file_path: The name of the data file.
    :type filename: str
    :return: A 2d data array consisting of examples 
    :rtype: List[List[int or float]]
    """
    data_array = []
    with open(file_path, 'r') as csv_file:
        # read csv_file into a 2d array
        reader = csv.reader(csv_file)
        for row in reader:
            data_array.append(row)

        # set global variables
        dt_global.feature_names = data_array[0]
        dt_global.label_index = len(dt_global.feature_names) - 1

        # exclude feature name row
        data_array = data_array[1:]
        dt_global.num_label_values = len(set(np.array(data_array)[:, -1]))

        # change the input feature values to floats
        for example in data_array:
            for i in range(len(dt_global.feature_names) - 1):  # exclude the label column
                example[i] = float(example[i])

        # convert the label values to int
        for example in data_array:
            example[-1] = int(example[-1])

        return data_array

In [26]:
def preprocess(data_array, folds_num=10):
    """
    Divides data_array into folds_num sets for cross validation. 
    Each fold has an approximately equal number of examples.

    :param data_array: a set of examples
    :type data_array: List[List[Any]]
    :param folds_num: the number of folds
    :type folds_num: int, default 10
    :return: a list of sets of length folds_num
    Each set contains the set of data for the corrresponding fold.
    :rtype: List[List[List[Any]]]
    """
    fold_size = math.floor(len(data_array) / folds_num)

    folds = []
    for i in range(folds_num):

        if i == folds_num - 1:
            folds.append(data_array[i * fold_size:])
        else:
            folds.append(data_array[i * fold_size: (i + 1) * fold_size])

    return folds

In [27]:
data = read_data("./data.csv")

In [28]:
data[:5]

[[0.63, 0.56, 0.52, 0.21, 0.5, 0.0, 0.5, 0.22, 0],
 [0.59, 0.6, 0.49, 0.43, 0.5, 0.0, 0.53, 0.31, 6],
 [0.61, 0.39, 0.53, 0.14, 0.5, 0.0, 0.43, 0.26, 7],
 [0.57, 0.52, 0.46, 0.2, 0.5, 0.83, 0.52, 0.41, 8],
 [0.59, 0.48, 0.5, 0.14, 0.5, 0.0, 0.45, 0.25, 0]]

In [29]:
dt_global.feature_names

['mcg', 'gvh', 'alm', 'mit', 'erl', 'pox', 'vac', 'nuc', 'class']

In [30]:
dt_global.label_index

8

In [31]:
dt_global.num_label_values

10

In [32]:
folds = preprocess(data)

In [33]:
len(folds)

10

In [34]:
folds[0][:5]

[[0.63, 0.56, 0.52, 0.21, 0.5, 0.0, 0.5, 0.22, 0],
 [0.59, 0.6, 0.49, 0.43, 0.5, 0.0, 0.53, 0.31, 6],
 [0.61, 0.39, 0.53, 0.14, 0.5, 0.0, 0.43, 0.26, 7],
 [0.57, 0.52, 0.46, 0.2, 0.5, 0.83, 0.52, 0.41, 8],
 [0.59, 0.48, 0.5, 0.14, 0.5, 0.0, 0.45, 0.25, 0]]

In [35]:
import dt_global as G
G.num_label_values

10

In [36]:
G.feature_names.index("mit")

3

In [37]:
from collections import defaultdict
a = defaultdict(set)
a[1].add(1)

In [38]:
a

defaultdict(set, {1: {1}})

In [39]:
a[1] == {1}

True

In [40]:
b = {1}

In [41]:
a[1] == b

True

In [42]:
sorted(a.values())

[{1}]

In [43]:
def get_splits(examples: List, feature: str) -> List[float]:
    """
    Given some examples and a feature, returns a list of potential split point values for the feature.
    
    :param examples: a set of examples
    :type examples: List[List[Any]]
    :param feature: a feature
    :type feature: str
    :return: a list of potential split point values 
    :rtype: List[float]
    """ 
    indFea = G.feature_names.index(feature)
    ret = []

    # regVal, regLabel, mFlag = None, None, False
    # for row in sorted(examples, key=lambda x: (x[indFea], x[G.label_index])):
    #     if row[indFea] == regVal and row[G.label_index] != regLabel:
    #         mFlag = True
    #     elif row[indFea] != regVal:
    #         if mFlag or row[G.label_index] != regLabel:
    #             ret.append((regVal + row[indFea]) / 2)
    #         regVal, regLabel, mFlag = row[indFea], row[G.label_index], False

    table = defaultdict(set)
    for row in examples:
        table[row[indFea]].add(row[G.label_index])
    
    regVal, regLabs = None, set()
    for i, key in enumerate(sorted(table.keys())):
        if i != 0 and (len(regLabs) + len(table[key]) > 2 or table[key] != regLabs):
            ret.append((regVal + key) / 2)
        regVal, regLabs = key, table[key]

    return ret

In [44]:
get_splits(data, "mit")

[0.02,
 0.045,
 0.065,
 0.07500000000000001,
 0.08499999999999999,
 0.095,
 0.10500000000000001,
 0.11499999999999999,
 0.125,
 0.135,
 0.14500000000000002,
 0.155,
 0.165,
 0.175,
 0.185,
 0.195,
 0.20500000000000002,
 0.215,
 0.225,
 0.235,
 0.245,
 0.255,
 0.265,
 0.275,
 0.28500000000000003,
 0.295,
 0.305,
 0.315,
 0.325,
 0.335,
 0.345,
 0.355,
 0.365,
 0.375,
 0.385,
 0.395,
 0.405,
 0.415,
 0.425,
 0.435,
 0.445,
 0.455,
 0.46499999999999997,
 0.475,
 0.485,
 0.495,
 0.505,
 0.515,
 0.525,
 0.535,
 0.545,
 0.555,
 0.565,
 0.575,
 0.595,
 0.605,
 0.625,
 0.635,
 0.645,
 0.655,
 0.665,
 0.6799999999999999,
 0.695,
 0.705,
 0.725,
 0.775,
 0.785,
 0.795,
 0.81,
 0.84,
 0.865,
 0.935]

[]

In [45]:
get_splits(data[:1], "mit")

[]

In [46]:
l = [1,2,3]

In [47]:
[(x, x+1) for x in l]

[(1, 2), (2, 3), (3, 4)]

In [48]:
l.count(lambda x: x < 2)

0

In [123]:
def choose_feature_split(examples: List, features: List[str]) -> (str, float):
    """
    Given some examples and some features,
    returns a feature and a split point value with the max expected information gain.

    If there are no valid split points for the remaining features, return None and -1.

    Tie breaking rules:
    (1) With multiple split points, choose the one with the smallest value. 
    (2) With multiple features with the same info gain, choose the first feature in the list.

    :param examples: a set of examples
    :type examples: List[List[Any]]    
    :param features: a set of features
    :type features: List[str]
    :return: the best feature and the best split value
    :rtype: str, float
    """   
    def __neg_ent(indFea, midWay):
        # num = 0
        # for i in range(len(examples)):
        #     if examples[i][indFea] <= midWay:
        #         num += 1
        # p = num / len(examples)
        # return round(p * log2(p) + (1 - p) * log2((1 - p)), 6)

        countL, countR = defaultdict(lambda: 0), defaultdict(lambda: 0)
        for i in range(len(examples)):
            if examples[i][indFea] <= midWay:
                countL[examples[i][G.label_index]] += 1
            else:
                countR[examples[i][G.label_index]] += 1
        sumL, sumR = sum(countL.values()), sum(countR.values())
        pLeft = sum(countL.values()) / len(examples)
        pListL = [num / sumL for num in countL.values()]
        pListR = [num / sumR for num in countR.values()]
        pListT = [(countL[key] + countR[key]) / (sumL + sumR) for key in set(countL.keys()).union(countR.keys())]
        
        return round(sum(list(map(lambda p: p * log2(p), pListT))) -
                     sum(list(map(lambda p: p * log2(p), pListL))) * pLeft - 
                     sum(list(map(lambda p: p * log2(p), pListR))) * (1 - pLeft), 6)
    

    regFea, regNegEnt, regMidWay = None, 0, -1
    for fea in features:
        indFea = G.feature_names.index(fea)
        tem = [(__neg_ent(indFea, midWay), midWay) for midWay in get_splits(examples, fea)]
        negEnt, midWay  = min(tem) if tem else (0, -1)
        if negEnt < regNegEnt:
            regFea, regNegEnt, regMidWay = fea, negEnt, midWay

    return regFea, regMidWay

In [124]:
choose_feature_split(data, ["mit","erl"])

('mit', 0.325)

In [125]:
get_splits(data, "erl")

[0.75]

In [52]:
for fea in G.feature_names:
    get_splits(data, fea)

In [53]:
# from itertools import combinations as comb
# list(comb(G.feature_names, 3))

In [129]:
choose_feature_split(data, G.feature_names[:-1])

('alm', 0.435)

In [55]:
math.isclose(0.5, 0.50000000)

True

In [56]:
[1,2].copy()

[1, 2]

In [57]:
from collections import Counter
c = Counter([1,2,2,3])
sum([c[x] for x in c.keys() if x < 2.5])

3

In [58]:
Counter([(1,1), (1,-1)])

Counter({(1, 1): 1, (1, -1): 1})

In [59]:
import math
from math import *

In [60]:
math.isclose(1,1)

True

In [61]:
log2(1)

0.0

In [62]:
from typing import List

In [63]:
def split_examples(examples: List, feature: str, split: float) -> (List, List):
    """
    Given some examples, a feature, and a split point,
    splits examples into two lists and return the two lists of examples.

    The first list of examples have their feature value <= split point.
    The second list of examples have their feature value > split point.

    :param examples: a set of examples
    :type examples: List[List[Any]]
    :param feature: a feature
    :type feature: str
    :param split: the split point
    :type split: float
    :return: two lists of examples split by the feature split
    :rtype: List[List[Any]], List[List[Any]]
    """ 
    retLeft, retRight = [], []
    indFea = G.feature_names.index(feature)

    for row in examples:
        if row[indFea] < split:
            retLeft.append(row.copy())
        else:
            retRight.append(row.copy())

    return retLeft, retRight

In [64]:
left, right = split_examples(data, "mcg", 0.485)

In [65]:
len(left)

741

In [66]:
len(right)

743

In [67]:
from anytree import Node
root = Node("root")
s0 = Node("sub0", parent=root)
s0b = Node("sub0B", parent=s0, foo=4, bar=109)
s0a = Node("sub0A", parent=s0)
s1 = Node("sub1", parent=root)
s1a = Node("sub1A", parent=s1)
s1b = Node("sub1B", parent=s1, bar=8)
s1c = Node("sub1C", parent=s1)
s1ca = Node("sub1Ca", parent=s1c)

ModuleNotFoundError: No module named 'anytree'

In [240]:
from anytree import RenderTree

In [241]:
root.children

(Node('/root/sub0'), Node('/root/sub1'))

In [242]:
print(RenderTree(root))

Node('/root')
├── Node('/root/sub0')
│   ├── Node('/root/sub0/sub0B', bar=109, foo=4)
│   └── Node('/root/sub0/sub0A')
└── Node('/root/sub1')
    ├── Node('/root/sub1/sub1A')
    ├── Node('/root/sub1/sub1B', bar=8)
    └── Node('/root/sub1/sub1C')
        └── Node('/root/sub1/sub1C/sub1Ca')


In [243]:
root.children[0].children[0]

Node('/root/sub0/sub0B', bar=109, foo=4)

In [244]:
root.children[0].children[0].bar

109

In [245]:
root.what = 1

In [246]:
root

Node('/root', what=1)

In [247]:
print(RenderTree(root))

Node('/root', what=1)
├── Node('/root/sub0')
│   ├── Node('/root/sub0/sub0B', bar=109, foo=4)
│   └── Node('/root/sub0/sub0A')
└── Node('/root/sub1')
    ├── Node('/root/sub1/sub1A')
    ├── Node('/root/sub1/sub1B', bar=8)
    └── Node('/root/sub1/sub1C')
        └── Node('/root/sub1/sub1C/sub1Ca')


In [248]:
data[0]

[0.63, 0.56, 0.52, 0.21, 0.5, 0.0, 0.5, 0.22, 0]

In [249]:
from collections import defaultdict
count = defaultdict(lambda: 0)
count[1] += 1

In [250]:
dict(count)

{1: 1}

In [251]:
count = [4,5,1,9,0,9]
count.index(max(count))

3

In [252]:
data[:5]

[[0.63, 0.56, 0.52, 0.21, 0.5, 0.0, 0.5, 0.22, 0],
 [0.59, 0.6, 0.49, 0.43, 0.5, 0.0, 0.53, 0.31, 6],
 [0.61, 0.39, 0.53, 0.14, 0.5, 0.0, 0.43, 0.26, 7],
 [0.57, 0.52, 0.46, 0.2, 0.5, 0.83, 0.52, 0.41, 8],
 [0.59, 0.48, 0.5, 0.14, 0.5, 0.0, 0.45, 0.25, 0]]

In [253]:
choose_feature_split(data[:5], ["erl"])

(None, -1)

In [309]:
def get_majority(examples):
    count = [0] * G.num_label_values
    for row in examples:
        count[row[G.label_index]] += 1

    return count.index(max(count))

In [75]:
def split_node(cur_node: Node, examples: List, features: List[str], max_depth=math.inf):
    """
    Given a tree with cur_node as the root, some examples, some features, and the max depth,
    grows a tree to classify the examples using the features by using binary splits.

    If cur_node is at max_depth, makes cur_node a leaf node with majority decision and return.

    This function is recursive.

    :param cur_node: current node
    :type cur_node: Node
    :param examples: a set of examples
    :type examples: List[List[Any]]
    :param features: a set of features
    :type features: List[str]
    :param max_depth: the maximum depth of the tree
    :type max_depth: int
    """ 
    # def __get_majority():
    #     count = [0] * G.num_label_values
    #     for row in examples:
    #         count[row[G.label_index]] += 1

    #     return count.index(max(count))


    major = get_majority(examples)
    if max_depth == 0:
        cur_node.decision = major
        return
    
    splFea, splVal = choose_feature_split(examples, features)
    if not splFea:
        cur_node.decision = major
        return

    cur_node.major, cur_node.numExs = major, len(examples)
    cur_node.feature, cur_node.split = splFea, splVal
    leftExs, rightExs = split_examples(examples, splFea, splVal)
    lchild = Node("l-%s<%.3f"%(splFea, splVal), cur_node, depth=cur_node.depth + 1)
    rchild = Node("r-%s>%.3f"%(splFea, splVal), cur_node, depth=cur_node.depth + 1)
    split_node(lchild, leftExs, features, max_depth - 1)
    split_node(rchild, rightExs, features, max_depth - 1)

In [356]:
# def split_node(cur_node: Node, examples: List, features: List[str], max_depth=math.inf):
#     """
#     Given a tree with cur_node as the root, some examples, some features, and the max depth,
#     grows a tree to classify the examples using the features by using binary splits.

#     If cur_node is at max_depth, makes cur_node a leaf node with majority decision and return.

#     This function is recursive.

#     :param cur_node: current node
#     :type cur_node: Node
#     :param examples: a set of examples
#     :type examples: List[List[Any]]
#     :param features: a set of features
#     :type features: List[str]
#     :param max_depth: the maximum depth of the tree
#     :type max_depth: int
#     """ 
#     def __get_majority():
#         count = [0] * G.num_label_values
#         for row in examples:
#             count[row[G.label_index]] += 1

#         return count.index(max(count))


#     # if not examples:
#     #     cur_node.decision = cur_node.pMajor
#     #     return
#     if max_depth == 0:
#         cur_node.decision = __get_majority()
#         return
    
#     splFea, splVal = choose_feature_split(examples, features)
#     if not splFea:
#         cur_node.decision = __get_majority()
#         return

#     cur_node.feature, cur_node.split = splFea, splVal
#     leftExs, rightExs = split_examples(examples, splFea, splVal)
#     lchild = Node("l-%s<%.3f"%(splFea, splVal), cur_node, depth=cur_node.depth + 1)
#     rchild = Node("r-%s>%.3f"%(splFea, splVal), cur_node, depth=cur_node.depth + 1)
#     split_node(lchild, leftExs, features, max_depth - 1)
#     split_node(rchild, rightExs, features, max_depth - 1)

In [357]:
dumb = Node("root", depth=0)
split_node(dumb, data, G.feature_names[:-1], 5)

In [358]:
print(RenderTree(dumb))

Node('/root', depth=0, feature='mcg', major=0, numExs=1484, split=0.485)
├── Node('/root/l-mcg<0.485', depth=1, feature='mcg', major=7, numExs=741, split=0.405)
│   ├── Node('/root/l-mcg<0.485/l-mcg<0.405', depth=2, feature='gvh', major=7, numExs=368, split=0.425)
│   │   ├── Node('/root/l-mcg<0.485/l-mcg<0.405/l-gvh<0.425', depth=3, feature='mcg', major=7, numExs=184, split=0.345)
│   │   │   ├── Node('/root/l-mcg<0.485/l-mcg<0.405/l-gvh<0.425/l-mcg<0.345', depth=4, feature='mcg', major=7, numExs=92, split=0.295)
│   │   │   │   ├── Node('/root/l-mcg<0.485/l-mcg<0.405/l-gvh<0.425/l-mcg<0.345/l-mcg<0.295', decision=7, depth=5)
│   │   │   │   └── Node('/root/l-mcg<0.485/l-mcg<0.405/l-gvh<0.425/l-mcg<0.345/r-mcg>0.295', decision=7, depth=5)
│   │   │   └── Node('/root/l-mcg<0.485/l-mcg<0.405/l-gvh<0.425/r-mcg>0.345', depth=4, feature='alm', major=7, numExs=92, split=0.535)
│   │   │       ├── Node('/root/l-mcg<0.485/l-mcg<0.405/l-gvh<0.425/r-mcg>0.345/l-alm<0.535', decision=7, depth=5)


In [359]:
full = Node("root", depth=0)
split_node(full, data, G.feature_names[:-1])

In [360]:
print(RenderTree(full))

gvh>0.735/l-mcg<0.775/r-mcg>0.745/l-gvh<0.770/r-mcg>0.755', decision=3, depth=10)
                │       │       └── Node('/root/r-mcg>0.485/r-mcg>0.575/r-mcg>0.655/r-gvh>0.665/l-mit<0.295/r-gvh>0.735/l-mcg<0.775/r-mcg>0.745/r-gvh>0.770', depth=9, feature='gvh', major=2, numExs=3, split=0.825)
                │       │           ├── Node('/root/r-mcg>0.485/r-mcg>0.575/r-mcg>0.655/r-gvh>0.665/l-mit<0.295/r-gvh>0.735/l-mcg<0.775/r-mcg>0.745/r-gvh>0.770/l-gvh<0.825', decision=2, depth=10)
                │       │           └── Node('/root/r-mcg>0.485/r-mcg>0.575/r-mcg>0.655/r-gvh>0.665/l-mit<0.295/r-gvh>0.735/l-mcg<0.775/r-mcg>0.745/r-gvh>0.770/r-gvh>0.825', decision=3, depth=10)
                │       └── Node('/root/r-mcg>0.485/r-mcg>0.575/r-mcg>0.655/r-gvh>0.665/l-mit<0.295/r-gvh>0.735/r-mcg>0.775', depth=7, feature='gvh', major=3, numExs=12, split=0.785)
                │           ├── Node('/root/r-mcg>0.485/r-mcg>0.575/r-mcg>0.655/r-gvh>0.665/l-mit<0.295/r-gvh>0.735/r-mcg>0.775/l

In [361]:
t1 = Node("root", depth=0)
split_node(t1, data, G.feature_names[:-1], 1)

In [362]:
print(RenderTree(t1))

Node('/root', depth=0, feature='mcg', major=0, numExs=1484, split=0.485)
├── Node('/root/l-mcg<0.485', decision=7, depth=1)
└── Node('/root/r-mcg>0.485', decision=0, depth=1)


In [363]:
def learn_dt(examples: List, features: List[str], max_depth=math.inf) -> Node:
    """
    Given some examples, some features, and the max depth,
    creates the root of a decision tree, and
    calls split_node to grow the tree to classify the examples using the features, and
    returns the root node.

    This function is a wrapper for split_node.

    Tie breaking rule:
    If there is a tie for majority voting, always return the label with the smallest value.

    :param examples: a set of examples
    :type examples: List[List[Any]]
    :param features: a set of features
    :type features: List[str]
    :param max_depth: the max depth of the tree
    :type max_depth: int, default math.inf
    :return: the root of the tree
    :rtype: Node
    """ 
    root = Node("root", depth=0)
    split_node(root, examples, features, max_depth)
    return root

In [364]:
t2 = learn_dt(data, G.feature_names[:-1])

In [365]:
print(RenderTree(t2))

gvh>0.735/l-mcg<0.775/r-mcg>0.745/l-gvh<0.770/r-mcg>0.755', decision=3, depth=10)
                │       │       └── Node('/root/r-mcg>0.485/r-mcg>0.575/r-mcg>0.655/r-gvh>0.665/l-mit<0.295/r-gvh>0.735/l-mcg<0.775/r-mcg>0.745/r-gvh>0.770', depth=9, feature='gvh', major=2, numExs=3, split=0.825)
                │       │           ├── Node('/root/r-mcg>0.485/r-mcg>0.575/r-mcg>0.655/r-gvh>0.665/l-mit<0.295/r-gvh>0.735/l-mcg<0.775/r-mcg>0.745/r-gvh>0.770/l-gvh<0.825', decision=2, depth=10)
                │       │           └── Node('/root/r-mcg>0.485/r-mcg>0.575/r-mcg>0.655/r-gvh>0.665/l-mit<0.295/r-gvh>0.735/l-mcg<0.775/r-mcg>0.745/r-gvh>0.770/r-gvh>0.825', decision=3, depth=10)
                │       └── Node('/root/r-mcg>0.485/r-mcg>0.575/r-mcg>0.655/r-gvh>0.665/l-mit<0.295/r-gvh>0.735/r-mcg>0.775', depth=7, feature='gvh', major=3, numExs=12, split=0.785)
                │           ├── Node('/root/r-mcg>0.485/r-mcg>0.575/r-mcg>0.655/r-gvh>0.665/l-mit<0.295/r-gvh>0.735/r-mcg>0.775/l

In [366]:
root.is_leaf

False

In [367]:
def predict(cur_node: Node, example, max_depth=math.inf, \
    min_num_examples=0) -> int:
    """
    Given a tree with cur_node as its root, an example, and optionally a max depth,
    returns a prediction for the example based on the tree.

    If max_depth is provided and we haven't reached a leaf node at the max depth, 
    return the majority decision at this node.

    If min_num_examples is provided and the number of examples at the node is less than min_num_examples, 
    return the majority decision at this node.
    
    This function is recursive.

    Tie breaking rule:
    If there is a tie for majority voting, always return the label with the smallest value.

    :param cur_node: cur_node of a decision tree
    :type cur_node: Node
    :param example: one example
    :type example: List[Any]
    :param max_depth: the max depth
    :type max_depth: int, default math.inf
    :param min_num_examples: the minimum number of examples at a node
    :type min_num_examples: int, default 0
    :return: the decision for the given example
    :rtype: int
    """ 
    if cur_node.is_leaf:
        return cur_node.decision
    elif max_depth <= 0 or cur_node.numExs < min_num_examples:
        return cur_node.major
    
    indFea = G.feature_names.index(cur_node.feature)
    # if indFea >= len(example):
    #     print(cur_node.feature, indFea, example)
    nextNode = cur_node.children[0] if example[indFea] <= cur_node.split else cur_node.children[1]
    return predict(nextNode, example, max_depth - 1, min_num_examples)

In [368]:
predict(full, data[22])

7

In [369]:
predict(full, data[22][:-1], 1)

0

In [370]:
predict(full, data[22], min_num_examples=2)

7

In [371]:
predict(full, data[22], min_num_examples=100)

6

In [372]:
predict(full, data[22], min_num_examples=300)

0

In [373]:
t3 = learn_dt(data, G.feature_names[:-1], 3)
print(RenderTree(t3))

Node('/root', depth=0, feature='mcg', major=0, numExs=1484, split=0.485)
├── Node('/root/l-mcg<0.485', depth=1, feature='mcg', major=7, numExs=741, split=0.405)
│   ├── Node('/root/l-mcg<0.485/l-mcg<0.405', depth=2, feature='gvh', major=7, numExs=368, split=0.425)
│   │   ├── Node('/root/l-mcg<0.485/l-mcg<0.405/l-gvh<0.425', decision=7, depth=3)
│   │   └── Node('/root/l-mcg<0.485/l-mcg<0.405/r-gvh>0.425', decision=0, depth=3)
│   └── Node('/root/l-mcg<0.485/r-mcg>0.405', depth=2, feature='gvh', major=0, numExs=373, split=0.46499999999999997)
│       ├── Node('/root/l-mcg<0.485/r-mcg>0.405/l-gvh<0.465', decision=0, depth=3)
│       └── Node('/root/l-mcg<0.485/r-mcg>0.405/r-gvh>0.465', decision=0, depth=3)
└── Node('/root/r-mcg>0.485', depth=1, feature='mcg', major=0, numExs=743, split=0.575)
    ├── Node('/root/r-mcg>0.485/l-mcg<0.575', depth=2, feature='mcg', major=0, numExs=370, split=0.525)
    │   ├── Node('/root/r-mcg>0.485/l-mcg<0.575/l-mcg<0.525', decision=0, depth=3)
    │   └─

In [374]:
def get_prediction_accuracy(cur_node: Node, examples: List, max_depth=math.inf, \
    min_num_examples=0) -> float:
    """
    Given a tree with cur_node as the root, some examples, 
    and optionally the max depth or the min_num_examples, 
    returns the accuracy by predicting the examples using the tree.

    The tree may be pruned by max_depth or min_num_examples.

    :param cur_node: cur_node of the decision tree
    :type cur_node: Node
    :param examples: the set of examples. 
    :type examples: List[List[Any]]
    :param max_depth: the max depth
    :type max_depth: int, default math.inf
    :param min_num_examples: the minimum number of examples at a node
    :type min_num_examples: int, default 0
    :return: the prediction accuracy for the examples based on the cur_node
    :rtype: float
    """ 
    accNum = 0
    for row in examples:
        if predict(cur_node, row[:-1], max_depth, min_num_examples) == row[-1]:
            accNum += 1

    return accNum / len(examples)

In [375]:
get_prediction_accuracy(full, data)

1.0

In [376]:
G.feature_names

['mcg', 'gvh', 'alm', 'mit', 'erl', 'pox', 'vac', 'nuc', 'class']

In [377]:
get_prediction_accuracy(dumb, data)

0.431266846361186

In [384]:
get_prediction_accuracy(full, data, max_depth=5) == get_prediction_accuracy(dumb, data)

True

In [382]:
get_prediction_accuracy(full, data, max_depth=10)

0.8402964959568733

In [68]:
examples = data

In [118]:
def test_ent(indFea, midWay):
    countL, countR = defaultdict(lambda: 0), defaultdict(lambda: 0)
    for i in range(len(examples)):
        if examples[i][indFea] <= midWay:
            countL[examples[i][G.label_index]] += 1
        else:
            countR[examples[i][G.label_index]] += 1
    sumL, sumR = sum(countL.values()), sum(countR.values())
    pLeft = sum(countL.values()) / len(examples)
    pListL = [num / sumL for num in countL.values()]
    pListR = [num / sumR for num in countR.values()]
    pListT = [(countL[key] + countR[key]) / (sumL + sumR) for key in set(countL.keys()).union(countR.keys())]
    print(pListT, sum(pListT))
    print(pListL, sum(pListL))
    print(pListR, sum(pListR))

    print()
    print(sum(list(map(lambda p: p * log2(p), pListT))),
                    sum(list(map(lambda p: p * log2(p), pListL))) * pLeft,
                    sum(list(map(lambda p: p * log2(p), pListR))) * (1 - pLeft), 6)
    
    return round(sum(list(map(lambda p: p * log2(p), pListT))) -
                    sum(list(map(lambda p: p * log2(p), pListL))) * pLeft - 
                    sum(list(map(lambda p: p * log2(p), pListR))) * (1 - pLeft), 6)

In [121]:
test_ent(3, 0.935)

[0.3119946091644205, 0.0033692722371967657, 0.02358490566037736, 0.029649595687331536, 0.03436657681940701, 0.10983827493261455, 0.16442048517520216, 0.2890835579514825, 0.013477088948787063, 0.02021563342318059] 1.0
[0.3115306810519218, 0.16453135536075522, 0.28927848954821306, 0.013486176668914362, 0.10991233985165205, 0.020229265003371546, 0.023600809170600135, 0.029669588671611596, 0.03438975050573163, 0.0033715441672285905] 1.0
[1.0] 1.0

-2.4904176362327273 -2.4892845635172383 0.0 6


-0.001133

In [122]:
test_ent(2, 0.435)

[0.3119946091644205, 0.0033692722371967657, 0.02358490566037736, 0.029649595687331536, 0.03436657681940701, 0.10983827493261455, 0.16442048517520216, 0.2890835579514825, 0.013477088948787063, 0.02021563342318059] 1.0
[0.5136986301369864, 0.017123287671232876, 0.1232876712328767, 0.11643835616438356, 0.030821917808219176, 0.030821917808219176, 0.08561643835616438, 0.07534246575342465, 0.003424657534246575, 0.003424657534246575] 1.0
[0.3808724832214765, 0.1837248322147651, 0.3414429530201342, 0.015939597315436243, 0.01761744966442953, 0.006711409395973154, 0.025167785234899327, 0.014261744966442953, 0.010906040268456376, 0.003355704697986577] 1.0

-2.4904176362327273 -0.44820939360469064 -1.666666251598837 6


-0.375542

In [127]:
choose_feature_split(data, G.feature_names[:-1])

('alm', 0.435)

In [91]:
set([1,3]).union(set([1,2]))

{1, 2, 3}

In [94]:
b = defaultdict(lambda: 0)
b[1]

0

In [101]:
c = {1:2}
d = {1:3, 2:5}
set(c.keys()).union(d.keys())

{1, 2}