In [19]:
import numpy as np
import matplotlib.pyplot as plt
import pprint

from anytree import Node, RenderTree
from anytree.exporter import DotExporter

In [20]:
X = np.array(
    [['youth', 'high', 'no', 'fair'],
    ['youth', 'high', 'no', 'excellent'],
    ['middle_aged', 'high', 'no', 'fair'],
    ['senior', 'medium', 'no', 'fair'],
    ['senior', 'low', 'yes', 'fair'],
    ['senior', 'low', 'yes', 'excellent'],
    ['middle_aged', 'low', 'yes', 'excellent'],
    ['youth', 'medium', 'no', 'fair'],
    ['youth', 'low', 'yes', 'fair'],
    ['senior', 'medium', 'yes', 'fair'],
    ['youth', 'medium', 'yes', 'excellent'],
    ['middle_aged', 'medium', 'no', 'excellent'],
    ['middle_aged', 'high', 'yes', 'fair'],
    ['senior', 'medium', 'no', 'excellent']]
)

Y = np.array(['no', 'no', 'yes', 'yes', 'yes', 'no', 'yes', 'no', 'yes', 'yes', 'yes', 'yes', 'yes', 'no'])

In [21]:
def create_node(prop_name, left_value, right_value, left_node, right_node, gini_score, belongs_to_class = None):
  return {
      'prop_name': prop_name,
      'left_value': left_value,
      'right_value': right_value,
      'left_node': left_node,
      'right_node': right_node,
      'gini_score': gini_score,
      'belongs_to_class': belongs_to_class
  }

def calculate_gini_index(Y):
  classes = np.unique(Y)
  result = 1
  for c in classes:
    p = len(Y[Y == c]) / len(Y)
    result -= p**2
  return result

def split_data(X, Y, node):
  chosen_X = None
  min_gini = 1
  left_X_class = None
  right_X_class = None
  X_t = X.transpose()
  for i in range(len(X_t)):
    X_i = X_t[i]
    classes = np.unique(X_i)
    local_min = 1
    local_left_class = None
    local_right_class = None

    for c in classes:
      left_Y = Y[X_i != c]
      right_Y = Y[X_i == c]
      left_gini = calculate_gini_index(left_Y)
      right_gini = calculate_gini_index(right_Y)
      local_gini = (len(left_Y) / len(Y)) * left_gini + (len(right_Y) / len(Y)) * right_gini
      if local_gini < local_min:
        local_min = local_gini
        local_left_class = np.unique(X_t[i, X_i != c])
        local_right_class = [c]

    if local_min < min_gini:
      min_gini = local_min
      chosen_X = i
      left_X_class = local_left_class
      right_X_class = local_right_class

  # these are for when Y is already clear for some branch
  left_class = None
  right_class = None
  if min_gini == 0:
    unique_Y_left = np.unique(Y[np.isin(X_t[chosen_X], left_X_class)])
    unique_Y_right = np.unique(Y[np.isin(X_t[chosen_X], right_X_class)])

    left_class = unique_Y_left[0] if unique_Y_left.size == 1 else None
    right_class = unique_Y_right[0] if unique_Y_right.size == 1 else None


  return min_gini, chosen_X, left_X_class, right_X_class, left_class, right_class

safeguard = 0
def create_tree(X, Y, root = create_node(None, None, None, None, None, None, None)):
  global safeguard
  safeguard += 1

  if root['belongs_to_class'] != None or X.size == 0:
    return root

  min_gini, chosen_X, left_X_class, right_X_class, left_class, right_class = split_data(X, Y, root)

  if left_class != None:
    #prune X and Y from left_X_class
    mask = np.isin(X[:, chosen_X], left_X_class)
    X = X[~mask]
    Y = Y[~mask]


  if right_class != None:
    #prune X and Y from left_X_class
    mask = np.isin(X[:, chosen_X], right_X_class)
    X = X[~mask]
    Y = Y[~mask]


  


  right_mask = np.isin(X[:, chosen_X], right_X_class)
  left_X = X[~right_mask]
  right_X = X[right_mask]
  left_Y = Y[~right_mask]
  right_Y = Y[right_mask]

  # if chosen_X != None:
  left_X = np.delete(left_X, chosen_X, 1)
  right_X = np.delete(right_X, chosen_X, 1)

  left_node = create_node(None, None, None, None, None, None, left_class)
  right_node = create_node(None, None, None, None, None, None, right_class)



  root['left_node'] = create_tree(left_X, left_Y, left_node)
  root['right_node'] = create_tree(right_X, right_Y, right_node)
  root['gini_score'] = min_gini
  # root['chosen_X'] = chosen_X
  root['left_X_class'] = left_X_class
  root['right_X_class'] = right_X_class


  return root

In [22]:
root =  create_tree(X, Y)

In [24]:
i = 0
def create_tree_nodes(data, parent_node=None, edge_label=None):
    """Recursively creates anytree Nodes from nested dictionary."""
    if data is None:
        return None
    global i
    i += 1
    node = Node(i, parent=parent_node,
                  edgeattr={'label': str(edge_label)} if edge_label is not None else None)

    # Create nodes for left and right subtrees, specifying edge labels
    create_tree_nodes(data.get('left_node'), parent_node=node, edge_label=data.get('left_X_class'))  
    create_tree_nodes(data.get('right_node'), parent_node=node, edge_label=data.get('right_X_class'))

    return node

# Assuming 'root' is your nested dictionary object
root_node = create_tree_nodes(root)

# Print the tree to the console
print(RenderTree(root_node))

# Export to a Graphviz dot file and then to a PNG image
# dot_exporter = DotExporter(root_node)
# dot_exporter.to_picture(r"C:\Users\vaheh\aua\Fall_2024\ML\code\CS340B\tree_visualization.png")


dot_path = "tree_visualization.dot"
DotExporter(
    root_node,
    nodenamefunc=lambda node: node.name,
    edgeattrfunc=lambda parent, child: 'label="%s"' % (child.edgeattr.get('label') if child.edgeattr is not None else '')
).to_dotfile(dot_path)

png_path = "tree_visualization.png"
subprocess.check_call(["dot", dot_path, "-T", "png", "-o", png_path])

Node('/1', edgeattr=None)
├── Node('/1/2', edgeattr={'label': "['senior' 'youth']"})
│   ├── Node('/1/2/3', edgeattr={'label': "['yes']"})
│   │   ├── Node('/1/2/3/4', edgeattr={'label': "['fair']"})
│   │   │   ├── Node('/1/2/3/4/5', edgeattr={'label': "['medium']"})
│   │   │   └── Node('/1/2/3/4/6', edgeattr={'label': "[np.str_('low')]"})
│   │   └── Node('/1/2/3/7', edgeattr={'label': "[np.str_('excellent')]"})
│   │       ├── Node('/1/2/3/7/8', edgeattr={'label': "['medium']"})
│   │       └── Node('/1/2/3/7/9', edgeattr={'label': "[np.str_('low')]"})
│   └── Node('/1/2/10', edgeattr={'label': "[np.str_('no')]"})
│       ├── Node('/1/2/10/11', edgeattr={'label': "['medium']"})
│       │   ├── Node('/1/2/10/11/12', edgeattr={'label': "['fair']"})
│       │   └── Node('/1/2/10/11/13', edgeattr={'label': "[np.str_('excellent')]"})
│       └── Node('/1/2/10/14', edgeattr={'label': "[np.str_('high')]"})
│           ├── Node('/1/2/10/14/15', edgeattr={'label': "['fair']"})
│           └

0