In [7]:
#training data
training_data = [
    ['Green', 3, 'Apple'],
    ['Yellow', 3, 'Apple'],
    ['Red', 1, 'Grape'],
    ['Red', 1, 'Grape'],
    ['Yellow', 3, 'Lemon'],
]

In [8]:
header = ['color','diameter','label']

In [9]:
def uniques_values(rows,col):
  return set([row[col] for row in rows])

In [10]:
def is_numeric(value):
  return isinstance(value,int) or isinstance(value,float)

In [26]:
class Question:
  def __init__(self,col,val):
    self.col = col
    self.val = val

  def match(self,example):
    v = example[self.col]
    if is_numeric(v):
      return v>=self.val
    else:
      return v==self.val
    
  def __repr__(self):
    condition = '=='
    if is_numeric(self.val):
      condition = '>='
    return 'Is %s %s %s?'%(header[self.col],condition,str(self.val))

In [12]:
def partition(rows,question):
  true_rows,false_rows = [],[]
  for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
  return true_rows, false_rows

In [14]:
def info_gain(left,right,current_uncertainity):
  p = float(len(left)/(len(left)+len(right)))
  return current_uncertainity - p*gini(left)- (1-p)*gini(right)

In [15]:
class Leaf:
  def __init__(self,rows):
    self.predictions = class_count(rows)

In [23]:
class Decision_Node:
  def __init__(self,question,true_branch,false_branch):

    self.question = question
    self.true_branch = true_branch
    self.false_branch = false_branch

In [17]:
def class_count(rows):
  counts = {}
  for row in rows:
    label = row[-1]
    if label not in counts:
      counts[label]=1
    else:
      counts[label]+=1
  return counts


def gini(rows):
  counts = class_count(rows)
  impurity = 1
  probs = 0
  for label in counts:
    probs += (counts[label]/len(rows))**2
  gini_index = impurity - probs
  return gini_index



def find_best_split(rows):
  best_gain=0
  best_question=None
  current_uncertainity = gini(rows)
  n_features = len(rows[0])-1

  for col in range(n_features):

    values = set([row[col] for row in rows])
    for val in values:
      question = Question(col,val)

      true_rows, false_rows = partition(rows,question)
      if len(true_rows)==0 or len(false_rows)==0:
        continue
      gain = info_gain(true_rows,false_rows,current_uncertainity)

      if gain>=best_gain:
        best_gain,best_question = gain,question
  return best_gain,best_question

In [18]:
#building tree function
def build_tree(rows):
  gain,question = find_best_split(rows)

  if gain==0:
    return Leaf(rows)

  true_rows,false_rows = partition(rows,question)
  true_branch = build_tree(true_rows)
  false_branch = build_tree(false_rows)
  return Decision_Node(question,true_branch,false_branch)

In [19]:
def print_tree(node, spacing=""):
    """World's most elegant tree printing function."""

    # Base case: we've reached a leaf
    if isinstance(node, Leaf):
        print (spacing + "Predict", node.predictions)
        return

    # Print the question at this node
    print (spacing + str(node.question))

    # Call this function recursively on the true branch
    print (spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")

    # Call this function recursively on the false branch
    print (spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")

In [27]:
my_tree = build_tree(training_data)

In [28]:
print_tree(my_tree)

Is diameter >= 3?
--> True:
  Is color == Yellow?
  --> True:
    Predict {'Apple': 1, 'Lemon': 1}
  --> False:
    Predict {'Apple': 1}
--> False:
  Predict {'Grape': 2}
