In [None]:
import pandas as pd
import math
import matplotlib.pyplot as plt

In [None]:
## Find best threshold
def get_entropy(data):
    """
    Calculate entropy
    """
    if data.empty:
        return 0
    Y = data.y
    n0 = len(data[data.y == 0])
    n1 = len(data[data.y == 1])
    p0 = n0/(n1+n0)
    p1 = n1/(n1+n0)
    if n0 == 0:
        hp0 = 0
    else:
        hp0 = - p0*math.log(p0,2)
    if n1 == 0:
        hp1 = 0
    else:
        hp1 = - p1*math.log(p1,2)
    hp = hp1+hp0
    return hp

def info_gain(data, threshold, feature):
    """
    calculate information gain
    """
    # threshold -- integer
    # feature -- string, the feature thta we want to calculate info gain on
    X = data[[feature, 'y']]
    l_data = X[X[feature] >= threshold]  # get left partition where feature's value >= threshold
    r_data = X[X[feature] < threshold]   # get right partition  
    l_p_theta = len(l_data)/len(data)
    r_p_theta = len(r_data)/len(data)
    hy = get_entropy(data)
    hy_theta = l_p_theta*get_entropy(l_data) + r_p_theta*get_entropy(r_data)
    return hy-hy_theta

def find_best_threshold(data):
    largest_gain = 0
    threshold = 0
    feature = None
    features = list(data.drop(['y'],axis = 1).columns)
    for f in features:
        sorted_data = data.sort_values(by=[f])
        for r in range(len(sorted_data)):
            gain = info_gain(data, sorted_data.iloc[r][f], f)
            if gain > largest_gain:
                largest_gain = gain
                threshold = sorted_data.iloc[r][f]
                feature = f
    return largest_gain, threshold, feature


In [None]:
## Building tree
class Leaf:
    """
    A Leaf node, containing final classification (0 or 1)

    """

    def __init__(self, data):
        y1 = len(data[data.y == 1])
        y0 = len(data[data.y == 0])
        if y1 >= y0:
            self.prediction = 1
        else:
            self.prediction = 0
            
class Leaf:
    """
    A Leaf node, containing final classification (0 or 1)

    """

    def __init__(self, data):
        y1 = len(data[data.y == 1])
        y0 = len(data[data.y == 0])
        if y1 >= y0:
            self.prediction = 1
        else:
            self.prediction = 0
            
def build_tree(data):
    """
    Builds the tree.
    """

    # Base case 1: empty node
    if data.empty:
        return Leaf(data)
    
    gain, threshold, feature = find_best_threshold(data)
    if gain == 0:
        return Leaf(data)

    left_data = data[data[feature] >= threshold].reset_index(drop=True)
    right_data = data[data[feature] < threshold].reset_index(drop=True)

    # Recursively build the true branch.
    left_branch = build_tree(left_data)

    # Recursively build the false branch.
    right_branch = build_tree(right_data)

    return Stump(threshold, feature, left_branch, right_branch)

In [None]:
## predict and evaluate

def predict_from_single_data(d_tree, data):
    """
    predict every data's y based on its features
    """
    if isinstance(d_tree, Leaf):
        return d_tree.prediction
    if 'y' in list(data.columns):
        X = list(data.drop(['y'],axis = 1).columns)
    else:
        X = list(data.columns)
    for i in range(len(data)):
        for x in X:
            if x == d_tree.feature:
                if data[x].iloc[0] >= d_tree.threshold:
                    return predict_from_single_data(d_tree.left_branch, data)
                else:
                    return predict_from_single_data(d_tree.right_branch, data)

def predict(d_tree, data):
    predictions = []
    for i in range(len(data)):
        row = pd.DataFrame(data.iloc[i]).transpose()
        row = row.reset_index(drop=True)
        predictions.append(predict_from_single_data(d_tree,row))
    return predictions

def evaluate(d_tree, data):
    pred = predict(d_tree,data)
    true = data.y.tolist()
    correct = 0
    for i in range(len(pred)):
        if pred[i] == true[i]:
            correct +=1
    return correct/len(pred)

In [None]:
## visualization

def print_tree(stump, spacing=""):
    """Weka plaintext style."""

    # Base case: we've reached a leaf
    if isinstance(stump, Leaf):
        
        print (spacing + "Predict ", stump.prediction)
        return

    # Print the question at this node
    print (spacing + str(stump.feature)+" >= "+str(stump.threshold) + "?")

    # Call this function recursively on the true branch
    print (spacing + 'Then:')
    print_tree(stump.left_branch, spacing + "  |")

    # Call this function recursively on the false branch
    print (spacing + 'Else:')
    print_tree(stump.right_branch, spacing + "  |")
    
def print_boundry(stump, fig,ax,x1min,x1max,x2min,x2max):
    """
    Plot decision boundry
    """
    if isinstance(stump, Leaf):
        ax.hlines(y = x2min, xmin = x1min, xmax = x1max,color='black')
        ax.hlines(y = x2max, xmin = x1min, xmax = x1max,color='black')
        ax.vlines(x=x1min, ymin=x2min, ymax=x2max,color='black')
        ax.vlines(x=x1max, ymin=x2min, ymax=x2max,color='black')
        return ax

    # Call this function recursively on the true branch
    if stump.feature == 'x1':
        print_boundry(stump.left_branch, fig,ax,stump.threshold,x1max,x2min,x2max)
    elif stump.feature == 'x2':
        print_boundry(stump.left_branch, fig,ax,x1min,x1max,stump.threshold,x2max)

    # Call this function recursively on the false branch
    if stump.feature == 'x1':
        print_boundry(stump.right_branch, fig,ax,x1min,stump.threshold,x2min,x2max)
    elif stump.feature == 'x2':
        print_boundry(stump.right_branch, fig,ax,x1min,x1max,x2min,stump.threshold)
        
    return ax

def get_nodesnumber(stump):
    """
    return three's nodes number
    """
    cnt = 1
    # Base case: we've reached a leaf
    if isinstance(stump, Leaf):
        return 1

    # Call this function recursively on the true branch
    cnt += tree_traversal(stump.left_branch)

    # Call this function recursively on the false branch
    cnt += tree_traversal(stump.right_branch)
    
    return cnt    
    
def read_txt(filename):
    data = pd.read_csv(filename,sep=" ", header = None)
    data.columns = ["x1","x2","y"]
    return data