In [1]:
import numpy as np

In [2]:
class DecisionTreeLearner:
    def __init__(self, leaf_size = 1):
        self.LEAF = -1
        self.leaf_size = leaf_size
        self.tree = np.empty()

    def build_tree(self, data):
        if data.shape[0] <= self.leaf_size or np.all(data[:, -1] == data[:, -1]):
            return np.array([self.LEAF, np.mean(data[:,-1]), self.LEAF, self.LEAF], dtype=np.float64).reshape((1,4))
        else:
            i = self.find_split(data)
            SplitVal = np.median(data[:,i])
            if np.all(data[:,i]<=SplitVal):
                return np.array([self.LEAF, data[-1, 1], self.LEAF, self.LEAF], dtype=np.float64).reshape((1,4))
            elif np.all(data[:,i]>SplitVal):
                return np.array([self.LEAF, data[-1, 1], self.LEAF, self.LEAF], dtype=np.float64).reshape((1,4))
            lefttree = self.build_tree(data[data[:,i] <= SplitVal])
            righttree = self.build_tree(data[data[:,i] > SplitVal])
            root = np.array([i, SplitVal, 1, lefttree.shape[0] + 1], dtype=np.float64)
            return np.vstack(np.vstack(root, lefttree), righttree)

    def find_split(self, data):
        coef_vals = np.zeros(data.shape[1]-1)
        for k in range(data.shape[1]-1):
            coef_vals[k] = np.abs(np.corrcoef(data[:,k], data[:,-1])[0,1])
        i = np.argmax(coef_vals)
        return i

    def train(self, data):
        self.tree = self.build_tree(data)

    def query(self, points):
        ans = np.array([])
        for point in points:
            current_count = 0
            current_layer = self.tree[current_count]
            while current_layer[3] != self.LEAF:
                if point[current_layer[0]] > current_layer[2]:
                    current_count += current_layer[3]
                    current_layer = self.tree[current_count]
                else:
                    current_count += current_layer[2]
                    current_layer = self.tree[current_count]
            if current_layer[3] == self.LEAF:
                ans.append(current_layer[1])
        return ans