In [45]:
import numpy as np
import subprocess
import pprint

from anytree import Node, RenderTree
from anytree.exporter import DotExporter

In [46]:
X = np.array(
    [['youth', 'high', 'no', 'fair'],
    ['youth', 'high', 'no', 'excellent'],
    ['middle_aged', 'high', 'no', 'fair'],
    ['senior', 'medium', 'no', 'fair'],
    ['senior', 'low', 'yes', 'fair'],
    ['senior', 'low', 'yes', 'excellent'],
    ['middle_aged', 'low', 'yes', 'excellent'],
    ['youth', 'medium', 'no', 'fair'],
    ['youth', 'low', 'yes', 'fair'],
    ['senior', 'medium', 'yes', 'fair'],
    ['youth', 'medium', 'yes', 'excellent'],
    ['middle_aged', 'medium', 'no', 'excellent'],
    ['middle_aged', 'high', 'yes', 'fair'],
    ['senior', 'medium', 'no', 'excellent']]
)

Y = np.array(['no', 'no', 'yes', 'yes', 'yes', 'no', 'yes', 'no', 'yes', 'yes', 'yes', 'yes', 'yes', 'no'])

column_labels = ['age', 'income', 'student', 'credit_rating']
df = pd.DataFrame(X, columns=column_labels)

In [47]:
def create_node(prop_name, left_value, right_value, left_node, right_node, gini_score, belongs_to_class = None):
  return {
      'prop_name': prop_name,
      'left_value': left_value,
      'right_value': right_value,
      'left_node': left_node,
      'right_node': right_node,
      'gini_score': gini_score,
      'belongs_to_class': belongs_to_class
  }

def calculate_gini_index(Y):
  classes = np.unique(Y)
  result = 1
  for c in classes:
    p = Y[Y == c].size / Y.size
    result -= p**2
  return result

def split_data(X, Y, node):
    chosen_X = None
    min_gini = 1
    left_X_class = None
    right_X_class = None
    for column in X.columns:
        X_i = X[column]
        classes = np.unique(X_i)

        local_min = 1
        local_left_class = None
        local_right_class = None
    
        for c in classes:
            left_Y = Y[X_i != c]
            right_Y = Y[X_i == c]
            left_gini = calculate_gini_index(left_Y)
            right_gini = calculate_gini_index(right_Y)
            local_gini = (left_Y.size / Y.size) * left_gini + (right_Y.size / Y.size) * right_gini
            if local_gini < local_min:
                local_min = local_gini
                local_left_class = np.unique(X[X[column] != c][column])
                local_right_class = [c]
        
        if local_min < min_gini:
          min_gini = local_min
          chosen_X = column
          left_X_class = local_left_class
          right_X_class = local_right_class
    
    # these are for when Y is already clear for some branch
    left_class = None
    right_class = None

    
    unique_Y_left = np.unique(Y[np.isin(X[chosen_X], left_X_class)])
    unique_Y_right = np.unique(Y[np.isin(X[chosen_X], right_X_class)])

    left_class = unique_Y_left[0] if unique_Y_left.size == 1 else None
    right_class = unique_Y_right[0] if unique_Y_right.size == 1 else None
    
    
    return min_gini, chosen_X, left_X_class, right_X_class, left_class, right_class

safeguard = 0
def create_tree(X, Y, root = create_node(None, None, None, None, None, None, None)):
    global safeguard
    safeguard += 1
    
    if root['belongs_to_class'] != None or X.size == 0:
        return root
    
    min_gini, chosen_X, left_X_class, right_X_class, left_class, right_class = split_data(X, Y, root)

    if left_class != None:
        mask = np.isin(X[chosen_X], left_X_class)
        X = X[~mask]
        Y = Y[~mask]


    if right_class != None:
        mask = np.isin(X[chosen_X], right_X_class)
        X = X[~mask]
        Y = Y[~mask]

    right_mask = np.isin(X[chosen_X], right_X_class)
    left_X = X[~right_mask]
    right_X = X[right_mask]
    left_Y = Y[~right_mask]
    right_Y = Y[right_mask]
    
    # left_X = left_X.drop(columns=[chosen_X])
    # right_X = right_X.drop(columns=[chosen_X])

    left_node = create_node(None, None, None, None, None, None, left_class)
    right_node = create_node(None, None, None, None, None, None, right_class)
    
    root['left_node'] = create_tree(left_X, left_Y, left_node)
    root['right_node'] = create_tree(right_X, right_Y, right_node)
    root['gini_score'] = min_gini
    root['prop_name'] = chosen_X
    root['left_X_class'] = left_X_class
    root['right_X_class'] = right_X_class
    
    
    return root

In [48]:
root =  create_tree(df, Y)

In [49]:
i = 0
def create_tree_nodes(data, parent_node=None, edge_label=None):
    """Recursively creates anytree Nodes from nested dictionary."""
    if data is None:
        return None
    global i
    i += 1
    if data.get('belongs_to_class') is not None:
        prop_name = data['belongs_to_class']
    else:
        prop_name = data.get('prop_name', 'not defined')

    node = Node(str(prop_name) + " " + str(i), parent=parent_node,
                  edgeattr={'label': str(edge_label)} if edge_label is not None else None)

    create_tree_nodes(data.get('left_node'), parent_node=node, edge_label=data.get('left_X_class'))  
    create_tree_nodes(data.get('right_node'), parent_node=node, edge_label=data.get('right_X_class'))

    return node

root_node = create_tree_nodes(root)

print(RenderTree(root_node))

dot_path = "tree_visualization.dot"
DotExporter(
    root_node,
    nodenamefunc=lambda node: node.name,
    edgeattrfunc=lambda parent, child: 'label="%s"' % (child.edgeattr.get('label') if child.edgeattr is not None else '')
).to_dotfile(dot_path)

png_path = "tree_visualization.png"
subprocess.check_call(["dot", dot_path, "-T", "png", "-o", png_path])

Node('/age 1', edgeattr=None)
├── Node('/age 1/student 2', edgeattr={'label': "['senior' 'youth']"})
│   ├── Node('/age 1/student 2/credit_rating 3', edgeattr={'label': "['yes']"})
│   │   ├── Node('/age 1/student 2/credit_rating 3/yes 4', edgeattr={'label': "['fair']"})
│   │   └── Node('/age 1/student 2/credit_rating 3/age 5', edgeattr={'label': "['excellent']"})
│   │       ├── Node('/age 1/student 2/credit_rating 3/age 5/yes 6', edgeattr={'label': "['youth']"})
│   │       └── Node('/age 1/student 2/credit_rating 3/age 5/no 7', edgeattr={'label': "['senior']"})
│   └── Node('/age 1/student 2/age 8', edgeattr={'label': "['no']"})
│       ├── Node('/age 1/student 2/age 8/no 9', edgeattr={'label': "['youth']"})
│       └── Node('/age 1/student 2/age 8/credit_rating 10', edgeattr={'label': "['senior']"})
│           ├── Node('/age 1/student 2/age 8/credit_rating 10/yes 11', edgeattr={'label': "['fair']"})
│           └── Node('/age 1/student 2/age 8/credit_rating 10/no 12', edgeattr={'

0