In [49]:
from __future__ import print_function
import numpy as np
import pandas as pd


In [84]:
df = pd.read_csv('winequality-red.csv')
from sklearn.model_selection import train_test_split
X_train, X_test = train_test_split(df, test_size = 0.85, random_state = 0)#splitting dataset into train and test sets

training_data = X_train.values.tolist()
header = ["fixed acidity","volatile acidity","citric acid","residual sugar","chlorides","free sulfur dioxide","total sulfur dioxide","density","pH","sulphates","alcohol","quality"
]

In [85]:
df.head(n=5)

Unnamed: 0,fixed acidity,volatile acidity,citric acid,residual sugar,chlorides,free sulfur dioxide,total sulfur dioxide,density,pH,sulphates,alcohol,quality
0,7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5
1,7.8,0.88,0.0,2.6,0.098,25.0,67.0,0.9968,3.2,0.68,9.8,5
2,7.8,0.76,0.04,2.3,0.092,15.0,54.0,0.997,3.26,0.65,9.8,5
3,11.2,0.28,0.56,1.9,0.075,17.0,60.0,0.998,3.16,0.58,9.8,6
4,7.4,0.7,0.0,1.9,0.076,11.0,34.0,0.9978,3.51,0.56,9.4,5


In [86]:
def unique_vals(rows, col):
    return set([row[col] for row in rows])

In [87]:
def class_counts(rows):
    counts = {}
    type(counts)

    for row in rows:
        # label is last column
        label = row[-1] 
        if label not in counts:
            counts[label] = 0
        counts[label] += 1
    return counts

In [88]:
def is_numeric(value):
    return isinstance(value, int) or isinstance(value, float)

In [89]:
class Question:
    # Question used to partition a dataset.
    def __init__(self, column, value):
        self.column = column
        self.value = value

    def match(self, example):
        val = example[self.column]
        if is_numeric(val):
            return val >= self.value
        else:
            return val == self.value

    def __repr__(self):
        condition = "=="
        if is_numeric(self.value):
            condition = ">="
        return "Is %s %s %s?" % (
            header[self.column], condition, str(self.value))

In [90]:
q = Question(0,'7')
print(q)

Is fixed acidity == 7?


In [91]:
def partition(rows, question):
    true_rows, false_rows = [], []
    for row in rows:
        if question.match(row):
            true_rows.append(row)
        else:
            false_rows.append(row)
    return true_rows, false_rows

In [92]:
def gini(rows):
    counts = class_counts(rows)
    impurity = 1
    for lbl in counts:
        prob_of_lbl = counts[lbl] / float(len(rows))
        impurity -= prob_of_lbl**2
    return impurity

In [93]:
def info_gain(left, right, current_uncertainty):
    
    p = float(len(left)) / (len(left) + len(right))
    return current_uncertainty - p * gini(left) - (1 - p) * gini(right)

In [94]:
def find_best_split(rows):
    best_gain = 0  # keep track of the best information gain
    best_question = None  # keep train of the feature / value that produced it
    current_uncertainty = gini(rows)
    n_features = len(rows[0]) - 1  # number of columns

    for col in range(n_features):  # for each feature

        values = set([row[col] for row in rows])  # unique values in the column

        for val in values:  # for each value

            question = Question(col, val)
            true_rows, false_rows = partition(rows, question)
            
            if len(true_rows) == 0 or len(false_rows) == 0:
                continue

            gain = info_gain(true_rows, false_rows, current_uncertainty)

            if gain >= best_gain:
                best_gain, best_question = gain, question

    return best_gain, best_question

In [95]:
class Leaf:

    def __init__(self, rows):
        self.predictions = class_counts(rows)

In [96]:
class Decision_Node:
    def __init__(self,
                 question,
                 true_branch,
                 false_branch):
        self.question = question
        self.true_branch = true_branch
        self.false_branch = false_branch

In [97]:
def build_tree(rows):

    gain, question = find_best_split(rows)
    if gain == 0:
        return Leaf(rows)

    true_rows, false_rows = partition(rows, question)

    # Recursively build the true branch.
    true_branch = build_tree(true_rows)

    # Recursively build the false branch.
    false_branch = build_tree(false_rows)

    return Decision_Node(question, true_branch, false_branch)

In [100]:
def print_tree(node, spacing=""):

    if isinstance(node, Leaf):
        print (spacing + "Predict", node.predictions)
        return

    print (spacing + str(node.question))
    print (spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")
    print (spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")

In [101]:
my_tree=build_tree(training_data)
print_tree(my_tree)

Is alcohol >= 10.1?
--> True:
  Is sulphates >= 0.7?
  --> True:
    Is alcohol >= 10.8?
    --> True:
      Is citric acid >= 0.34?
      --> True:
        Is pH >= 3.02?
        --> True:
          Is chlorides >= 0.132?
          --> True:
            Predict {6.0: 1}
          --> False:
            Is citric acid >= 0.6?
            --> True:
              Predict {6.0: 1}
            --> False:
              Is volatile acidity >= 0.42?
              --> True:
                Is alcohol >= 12.3?
                --> True:
                  Predict {6.0: 1}
                --> False:
                  Predict {7.0: 1}
              --> False:
                Predict {7.0: 22}
        --> False:
          Predict {5.0: 1}
      --> False:
        Is alcohol >= 12.5?
        --> True:
          Predict {7.0: 4}
        --> False:
          Is fixed acidity >= 8.6?
          --> True:
            Is pH >= 3.18?
            --> True:
              Predict {5.0: 2}
            --> False

In [102]:
testing_data = X_test.values.tolist()

In [103]:
def classify(row, node):
    
    if isinstance(node, Leaf):
        return node.predictions

    if node.question.match(row):
        return classify(row, node.true_branch)
    else:
        return classify(row, node.false_branch)

In [106]:
y_pred = []
if __name__ == '__main__':

    my_tree = build_tree(training_data)

    print_tree(my_tree)

Is alcohol >= 10.1?
--> True:
  Is sulphates >= 0.7?
  --> True:
    Is alcohol >= 10.8?
    --> True:
      Is citric acid >= 0.34?
      --> True:
        Is pH >= 3.02?
        --> True:
          Is chlorides >= 0.132?
          --> True:
            Predict {6.0: 1}
          --> False:
            Is citric acid >= 0.6?
            --> True:
              Predict {6.0: 1}
            --> False:
              Is volatile acidity >= 0.42?
              --> True:
                Is alcohol >= 12.3?
                --> True:
                  Predict {6.0: 1}
                --> False:
                  Predict {7.0: 1}
              --> False:
                Predict {7.0: 22}
        --> False:
          Predict {5.0: 1}
      --> False:
        Is alcohol >= 12.5?
        --> True:
          Predict {7.0: 4}
        --> False:
          Is fixed acidity >= 8.6?
          --> True:
            Is pH >= 3.18?
            --> True:
              Predict {5.0: 2}
            --> False