In [111]:
from __future__ import print_function

In [112]:
import csv
from csv import reader
import math
from random import randint, seed 

In [113]:
def readData(filename):
    with open(filename, 'rt') as read_obj:
        csv_reader = csv.reader(read_obj) # Return a reader object which will
                                        # iterate over lines in the given csvfile
        training_data = list(csv_reader)
    return training_data

In [114]:
def convert_to_float(row):  
    row = [float(x.strip()) for x in row]
    return row

In [115]:
def class_counts(data_set):
    counts = {}
    for row in data_set:
        label = row[-1]
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts

In [116]:
def calculate_entropy(data_set):
    counts = class_counts(data_set)
    entropy = 0
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(data_set))
        entropy -= math.log(prob_of_lbl, 2)
    return entropy

In [117]:
class Question:

    def __init__(self, index, value):
        self.index = index
        self.value = value

    def match(self, example):
        val = example[self.index]
        return val >= self.value

In [118]:
def make_partition(data_set, question):
    right, left = [], []
    for row in data_set:
        if question.match(row):
            right.append(row)
        else:
            left.append(row)
    return right, left

In [119]:
def calculate_gain(right, left, current_data_set_entropy):
    p = float(len(left)) / (len(left) + len(right))
    return current_data_set_entropy - p * calculate_entropy(left) - (1 - p) * calculate_entropy(right)

In [120]:
def find_best_split(data_set):
    best_gain = 0
    split_question = None
    current_data_set_entropy = calculate_entropy(data_set)
    number_of_features = len(data_set[0]) - 1

    for column in range(number_of_features):
        column_values = set([row[column] for row in data_set])
        for value in column_values:
            question = Question(column, value)
            right, left = make_partition(data_set, question)
            if len(right) == 0 or len(left) == 0:
                continue

            gain = calculate_gain(right, left, current_data_set_entropy)
            
            if gain > best_gain:
                best_gain, split_question = gain, question
    
    return best_gain, split_question

In [121]:
class leaf_node:
    def __init__(self, data_set):
       self.predictions = class_counts(data_set)

In [122]:
class decision_node:
    def __init__(self, question, right, left):
        self.question = question
        self.right = right
        self.left = left

In [123]:
def build_tree(data_set):
    info_gain, question = find_best_split(data_set)

    if info_gain == 0:
        return leaf_node(data_set)
    
    right_split_data_set, left_split_data_set = make_partition(data_set, question)
    right_branch = build_tree(right_split_data_set)
    left_branch = build_tree(left_split_data_set)
    return decision_node(question, right_branch, left_branch) 


In [124]:
def print_tree(node, spacing=""):
    if isinstance(node, leaf_node):
        print(spacing + "Node class and count: ", node.predictions)
        return
    print(spacing + 'index: ' + str(node.question.index) +
          ' value: ' + str(node.question.value))

    print(spacing + '--> greater than:')
    print_tree(node.right, spacing + "  ")

    print(spacing + '--> less than:')
    print_tree(node.left, spacing + "  ")

In [125]:
def classify(row, node):

    if isinstance(node, leaf_node):
        return node.predictions

    if node.question.match(row):
        return classify(row, node.right)
    else:
        return classify(row, node.left)

In [126]:
#### Main ####
initial_data_set = readData('wine.csv') # here dataset contains data values as strings
#so we convert the string values to floats
data_set = []
for row in initial_data_set:
    row = convert_to_float(row)
    data_set.append(row)
 
# now we will do k-fold 
# for now k=10 
k = 10
folds = []
for i in range(k):
    folds.append([])
for i in range(len(data_set)):
    folds[i % k].append(data_set[i])

# now we will do cross validation
total_accuracy = 0.0
for group in folds:
    # train test splits
    train_data = list(folds)
    train_data.remove(group)
    train_data = sum(train_data, [])
    test_data = group 

    # now build the tree using the training data and print the tree
    d_tree = build_tree(train_data)
    print("\n*********Generated Decision Tree********\n")
    print_tree(d_tree, "")

    # now test begins
    print("\n*********Testing Decision Tree********\n")
    total_row = 0
    total_matched = 0
    for row in test_data:
        total_row += 1
        classified = classify(row, d_tree)
        for lbl in classified.keys():
            if lbl == row[-1]:
                total_matched += 1
    accuracy = total_matched/total_row*100
    total_accuracy += accuracy
    print('accurary: ', accuracy, '%')

print("\n*********Final Accuracy********\n")
average_accuracy = total_accuracy / k
print('average accuracy: ', average_accuracy, "%")


*********Generated Decision Tree********

index: 6 value: 2.03
--> greater than:
  index: 12 value: 1015.0
  --> greater than:
    Node class and count:  {1.0: 37}
  --> less than:
    index: 0 value: 12.99
    --> greater than:
      index: 7 value: 0.32
      --> greater than:
        Node class and count:  {1.0: 6}
      --> less than:
        index: 0 value: 13.9
        --> greater than:
          Node class and count:  {1.0: 4}
        --> less than:
          index: 3 value: 21.5
          --> greater than:
            Node class and count:  {2.0: 3}
          --> less than:
            index: 1 value: 1.77
            --> greater than:
              Node class and count:  {1.0: 4}
            --> less than:
              Node class and count:  {2.0: 2}
    --> less than:
      Node class and count:  {2.0: 29}
--> less than:
  index: 9 value: 5.88
  --> greater than:
    Node class and count:  {3.0: 25}
  --> less than:
    index: 6 value: 1.3
    --> greater than:
      Node c