In [29]:
from random import seed
from random import randrange
from csv import reader

# Load a CSV file
def load_csv(filename):
    file = open(filename, "r")
    lines = reader(file)
    dataset = list(lines)
    return dataset

# Convert string column to float
def str_column_to_float(dataset, column):
    for row in dataset:
        row[column] = float(row[column].strip())


# Calculate the Gini index for a split dataset
def gini_index(groups, classes):
    # count all samples at split point
    n_instances = float(sum([len(group) for group in groups]))
    # sum weighted Gini index for each group
    gini = 0.0
    for group in groups:
        size = float(len(group))
        # avoid divide by zero
        if size == 0:
            continue
        score = 0.0
        # score the group based on the score for each class
        for class_val in classes:
            p = [row[-1] for row in group].count(class_val) / size
            score += p * p
        # weight the group score by its relative size
        gini += (1.0 - score) * (size / n_instances)
    return gini

# Split a dataset based on an attribute and an attribute value
def split_into_groups(index, value, dataset):# Индекс столбца(атрибута), значение разделения для этого атрибута, датасет
    left, right = list(), list()
    for row in dataset:
        if row[index] < value: #Правая группа содержит все строки со значением по индексу выше или равным значению разделения.
            left.append(row)
        else:
            right.append(row)
    return left, right

# Select the best split point for a dataset
def get_split(dataset):
    class_values = list(set(row[-1] for row in dataset))
    b_index, b_value, b_score, b_groups = 999, 999, 999, None
    for index in range(len(dataset[0])-1):
        for row in dataset:
            groups = split_into_groups(index, row[index], dataset)
            gini = gini_index(groups, class_values)
            #print('X%d < %.3f Gini=%.3f' % ((index+1), row[index], gini))
            if gini < b_score:
                b_index, b_value, b_score, b_groups = index, row[index], gini, groups
    return {'index':b_index, 'value':b_value, 'groups':b_groups,'count_value_class':[len(g) for g in b_groups],"class":to_terminal([g for g in groups])}

#Выбор наиболее распространенного значения класса в группе. Используется в листе дерева для прогнозирования
def to_terminal(group):#get_class_max_of_objects
    outcomes = [row[-1] for row in group]
    return max(list(outcomes), key=outcomes.count)

# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth): #максимальная глубина, минимальное количество шаблонов в узле и текущая глубина узла.
    left, right = node['groups']
    del(node['groups'])
    # check for a no split
    if not left or not right:
        node['left'] = node['right'] = to_terminal(left + right)
        return
    # check for max depth
    if depth >= max_depth:
        node['left'], node['right'] = to_terminal(left), to_terminal(right)
        return
    # process left child
    if len(left) <= min_size:
        node['left'] = to_terminal(left)
    else:
        node['left'] = get_split(left)
        split(node['left'], max_depth, min_size, depth+1)
    # process right child
    if len(right) <= min_size:
        node['right'] = to_terminal(right)
    else:
        node['right'] = get_split(right)
        split(node['right'], max_depth, min_size, depth+1)    
        
# Build a decision tree
def build_tree(train, max_depth, min_size):
    root = get_split(train)
    split(root, max_depth, min_size, 1)
    return root

# Print a decision tree
def print_tree(node, depth=0):
	if isinstance(node, dict):
		print('%s[X%d < %.3f]' % ((depth*'   ', (node['index']+1), node['value'])),node["count_value_class"])
		print_tree(node['left'], depth+1)
		print_tree(node['right'], depth+1)
	else:
		print('%s[%s]' % ((depth*'   ', node)))

dataset = [[2.771244718,1.784783929,0],
	[1.728571309,1.169761413,0],
	[3.678319846,2.81281357,0],
	[3.961043357,2.61995032,0],
	[2.999208922,2.209014212,0],
	[7.497545867,3.162953546,1],
	[9.00220326,3.339047188,1],
	[7.444542326,0.476683375,1],
	[10.12493903,3.234550982,1],
	[6.642287351,3.319983761,1]]



tree = build_tree(dataset, 5, 1)
print_tree(tree)

[X1 < 6.642] [5, 5]
   [X1 < 2.771] [1, 4]
      [0]
      [X1 < 2.771] [0, 4]
         [0]
         [0]
   [X1 < 7.498] [2, 3]
      [X1 < 7.445] [1, 1]
         [1]
         [1]
      [X1 < 7.498] [0, 3]
         [1]
         [1]


In [30]:
import re

class BinaryTree:
    def __init__(self, rootObj):
        self.root = rootObj
        self.leftChild = None
        self.rightChild = None

    def insertLeft(self, newNode):
        if self.leftChild == None:
            self.leftChild = BinaryTree(newNode)
        else:
            t = BinaryTree(newNode)
            t.leftChild = self.leftChild
            self.leftChild = t

    def insertRight(self,newNode):
        if self.rightChild == None:
            self.rightChild = BinaryTree(newNode)
        else:
            t = BinaryTree(newNode)
            t.rightChild = self.rightChild
            self.rightChild = t

    def getRightChild(self):
        return self.rightChild

    def getLeftChild(self):
        return self.leftChild

    def setRootVal(self,obj):
        self.root = obj

    def getRootVal(self):
        return self.root
    
    def count(self):
        return 2 if self.getRightChild() is not None and self.getLeftChild() is not None else 0 if self.getRightChild() is None and self.getLeftChild() is None else 1;
    
    nodesWriter = set()
        
    def toStringUnsafe(self):
        lines = (self.getRootVal()
            + (':' if self.count() > 0 else '') 
            + '\n<NEEDTAB_TreeNode>'
            + self.ChildrenToString()
            + '\n</NEEDTAB_TreeNode>').split('\n');
        count = 0; 
        regOpen = re.compile(r'<NEEDTAB_TreeNode>')
        regClose = re.compile(r'</NEEDTAB_TreeNode>')
        for i in range(len(lines)):
            countOpen = len(regOpen.findall(lines[i]));
            countClose = len(regClose.findall(lines[i]));
            count += countOpen - countClose;
            if count > 0:
                lines[i] = '\t' * count + lines[i];
            if countOpen != 0:
                lines[i] = lines[i].replace('<NEEDTAB_TreeNode>', '');
            if countClose != 0:
                lines[i] = lines[i].replace('</NEEDTAB_TreeNode>', '');
        needRemove = re.compile(r'^[\\t| ]+$');
        LinesWithoutSpace = list(lines);
        for item in LinesWithoutSpace[:]:
            if needRemove.match(item) is not None or len(item) == 0:
                LinesWithoutSpace.remove(item);
        return '\n'.join(LinesWithoutSpace);
    
    def __str__(self):
        if self is None:
            return 'None'
        if self in BinaryTree.nodesWriter:
            return 'deep...'
        BinaryTree.nodesWriter.add(self)
        try:
            return self.toStringUnsafe();
        finally:
            BinaryTree.nodesWriter.remove(self);
        
    def ChildrenToString(self):
        output = '';
        nodes = [];
        if self.getRightChild() is not None:
            nodes += [self.getRightChild()];
        if self.getLeftChild() is not None:
            nodes += [self.getLeftChild()];
        for n in nodes:
            output += str(n) + '\n';
        if len(output) >= len('\n'):
            output = output[:(len(output)-1)];
        return output;
                
                


#      a
#    /   \
#   b     c
#  / \   / \
# d   e f   g

In [31]:
r = BinaryTree('a')
r.insertLeft('b')
r.insertRight('c')
#print('Root:', r.getRootVal())
#print('Left child:', r.getLeftChild())
#print('Tree:', r)
r.getLeftChild().insertLeft('d')
r.getLeftChild().insertRight('e')
#print('Tree:', r)
r.getRightChild().insertLeft('f')
r.getRightChild().insertRight('g')
print('Tree:', r)

Tree: a:
	c:
		g
			
		f
			
	b:
		e
			
		d
			
