In [1]:
import requests  
import random
from collections import Counter

iris_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
r = requests.get(iris_url)
with open('iris.data', 'wb') as f:
  f.write(r.content)

In [2]:
vectors = []
answers = []
with open('iris.data', 'r') as f:
  for line in f:
    if len(line.strip()) == 0:
      break
    items = line.strip().split(",")
    vectors.append(tuple(float(i) for i in items[:4]))
    answers.append(items[4])

zipped = list(zip(vectors,answers))
random.shuffle(zipped)
train_size = int(len(zipped) * 0.8)
train_x, train_y = zip(*zipped[:train_size])
test_x, test_y = zip(*zipped[train_size:])

In [3]:
def gini_score(items):
  counter = Counter(items)
  total_count = sum(counter.values()) # len(items) 제너레이터도 받을 수 있게 
  gini = 1
  for item, count in counter.items():
    gini -= (count/total_count) ** 2

  return gini


def find_split_point_of_a_field(pairs):
 
  S = sorted(pairs, key=lambda x:x[0])

  split_point, min_gini = 0, 99
  for i in range(1,len(S)):
    if S[i-1][0] != S[i][0]:
      prop = i/len(S)
      gini = gini_score(s[1] for s in S[:i]) * prop + gini_score(s[1] for s in S[i:]) * (1-prop)
      if min_gini > gini:
        split_point, min_gini = (S[i-1][0] + S[i][0])/2, gini

  return split_point, min_gini

def split_data(X, Y):
  num_fields = len(X[0])

  min_fid, min_sp, min_gini = -1, -1, 99
  for fid in range(num_fields):
    sp, gini = find_split_point_of_a_field(zip([x[fid] for x in X], Y))
    if min_gini > gini:
      min_fid, min_sp, min_gini = fid, sp, gini
  
  node = {}
  node['sp'] = min_sp
  node['fid'] = min_fid
  node['gini'] = min_gini
  node['left'] = tuple(zip(*[(x,y) for x,y in zip(X, Y) if x[min_fid] < min_sp]))
  node['right'] = tuple(zip(*[(x,y) for x,y in zip(X, Y) if x[min_fid] >= min_sp]))
  return node

def decision_tree(X, Y, threshold):
  original_gini = gini_score(Y)
  node = split_data(X, Y)

  if original_gini - node['gini'] <= threshold:
    # 더이상 나누지 않는다.
    counter = Counter(Y)
    ans, c = counter.most_common(1)[0]
    return (ans, c/sum(counter.values()))

  else:
    node['left'] = decision_tree(*node['left'], threshold)
    node['right'] = decision_tree(*node['right'], threshold)
    # 나눈다.
    return node

def predict(x, tree):
  if 'fid' not in tree:
    return tree

  if x[tree['fid']] < tree['sp']:
    return predict(x,tree['left'])
  else:
    return predict(x,tree['right'])


tree = decision_tree(train_x, train_y, 0)

train_accuracy = 0

for x,y in zip(train_x, train_y):
  if predict(x, tree)[0] == y:
    train_accuracy += 1

train_accuracy /= len(train_x)
print("train_accuracy : ",train_accuracy)

test_accuracy = 0

for x,y in zip(test_x, test_y):
  if predict(x, tree)[0] == y:
    test_accuracy += 1

test_accuracy /= len(test_x)
print("test_accuracy : ",test_accuracy)

train_accuracy :  1.0
test_accuracy :  0.9333333333333333
