In [20]:
import numpy as np
import pandas as pd
import sys
import random

In [36]:
def dTree(X, Y):
    if stop_condition(X, Y):
        node = Node(None, None, Y[0])
        return node
    b, index = decision_stump(X, Y) # one value return, not an array
    # question 19
    # prune the tree to have only 1 branch - make it only a decision-stump
    pos1 = X[:, index] < b; pos2 = X[:, index] >= b
    leftX = X[pos1,:]; leftY = Y[pos1, 0:1]
    rightX = X[pos2,:]; rightY = Y[pos2, 0:1]
    node = Node(b, index)
    
    # question 19
    leftvalue = 1 if np.sum(leftY==1)>=np.sum(leftY==-1) else -1
    rightvalue = 1 if np.sum(rightY==1)>=np.sum(rightY==-1) else -1
    node.leftNode = Node(None, None, leftvalue)
    node.rightNode = Node(None, None, rightvalue)
    #node.leftNode = dTree(leftX, leftY)
    #node.rightNode = dTree(rightX, rightY)
    return node

def predict(X, node):
    # 用divide and conquer处理树
    if node.value is not None:
        return node.value
    theta = node.theta; index = node.index
    if X[index] < theta:
        return predict(X, node.leftNode)
    else:
        return predict(X, node.rightNode)
    
def calculate_error(X, Y, nodes, k):
    row, col = X.shape
    l = len(nodes)
    Yhat = np.zeros((row, k))
    Yhat_final = np.zeros((row, 1))
    for i in range(row):
        # put every piece of data in
        # random forest, uniformly using all the Y
        for j in range(k):
            Yhat[i, j] = predict(X[i,:], nodes[j])
        Yhat_final[i] = 1 if np.sum(Yhat[i]==1)>np.sum(Yhat[i]==-1) else -1
    return Yhat, np.sum(Yhat_final!=Y)/row
    
def stop_condition(X, Y):
    if np.sum(X!=X[0])==0 or X.shape[0] == 1 or np.sum(Y!=Y[0])==0:
        return True
    return False

def Gini(Y):
    N = len(Y)
    if N == 0:
        return 1
    index = 1 - (np.sum(Y==-1)/len(Y))**2 - (np.sum(Y==1)/len(Y))**2
    return index
    
def decision_stump(X, Y):
    # branch by purifying 
    # use Gini
    row, col = X.shape
    Xsort = np.sort(X, 0)
    theta = (np.r_[Xsort[0:1,:]-0.1, Xsort[0:row-1,:]] + Xsort)/2
    
    best_b = 0; best_index = 0; best_impurity = sys.float_info.max
    for i in range(col):
        # 根据求到的theta，把数据分成D1, D2, 然后求Gini
        for j in range(row):
            b = theta[j,i]
            pos1 = X[:,i] < b
            pos2 = X[:,i] >= b
            Y1 = Y[pos1]
            Y2 = Y[pos2]
            impurity = Y1.shape[0]*Gini(Y1) + Y2.shape[0]*Gini(Y2)
            if impurity < best_impurity:
                best_impurity = impurity
                best_b = b
                best_index = i
        
    return best_b, best_index
    
class Node:
    def __init__(self, theta, index, value = None):
        self.theta = theta     # 阈值
        self.index = index     # 维度
        self.value = value     # 根节点值
        self.leftNode = None
        self.rightNode = None
        
def read_data(input_file):
    data = pd.read_csv(input_file, sep = '\s+', header = None)
    data = data.as_matrix()
    row,col = data.shape
    X = data[:,0:col - 1]
    Y = data[:, col-1:col]
    return X, Y
    
def bootstrap_sample(X, Y, N):
    # the random.choices method is not until version 3.6
    index = [random.choice(range(N)) for i in range(N)]
    X_sample = X[index,:]
    Y_sample = Y[index,:]
    return X_sample, Y_sample
    
def main():
    TRAIN_DATA = 'train_data.dat'
    TEST_DATA = 'test_data.dat'
    X_train, Y_train = read_data(TRAIN_DATA)
    X_test, Y_test = read_data(TEST_DATA)
    # random forrest
    # implement the bagging algorithm
    # 30000 will take forever
    T = 3000        # iteration time
    N = X_train.shape[0]   # dataset number
    nodes = []
    k = 3000   # random forest with the first k trees
    for t in range(T):
        X_sample, Y_sample = bootstrap_sample(X_train, Y_train, N)
        node = dTree(X_sample, Y_sample)
        nodes.append(node)
    
    Y_in_predict, Ein = calculate_error(X_train, Y_train, nodes, k)
    Y_out_predict, Eout = calculate_error(X_test, Y_test, nodes, k)
    print(Ein, Eout)
    
if __name__ == '__main__':
    main()

0.11 0.149
