In [1]:
import csv
import os
import sys

keys_=["fixed_acidity",
        "volatile_acidity",
        "citric_acid",
        "residual_sugar",
        "chlorides",
        "free_sulfur_dioxide",
        "total_sulfur_dioxide",
        "density",
        "pH",
        "sulphates",
        "alcohol",
        "quality"]
data = {k_:[] for k_ in keys_}
with open('winequality-red.csv') as f:
    csv_file = csv.reader(f,delimiter=',')
    data_set = []
    for idx,row in enumerate(csv_file):
        if idx == 0: continue
        data_row={}
        for i,k in enumerate(keys_):
            data_row[k] = float(row[0].split(';')[i])
        data_set.append(data_row)
        
train_data,test_data = data_set[:int(len(data_set)*.9)],data_set[int(len(data_set)*.9):]


In [20]:
class TreeNode:
    def __init__(self,examples):
        self.examples = examples
        self.left = None
        self.right = None
        self.split_point = None
    
    def split(self):
        if len(self.examples) ==1:
            return
        
        best_split_point = {
            "feature":None,
            "values":None,
            "mse":float("inf"),
            "split_index":None
        }
        # iterate over features, example is a list of dictiona
        for feature in self.examples[0].keys():
            if feature == 'quality':
                continue
            
            self.examples.sort(key=lambda example:example[feature])
            
            for i, _ in enumerate(self.examples[:-1]):
                split_point_value = (self.examples[i][feature] + self.examples[i+1][feature])/2
                mse, split_index =  self.get_split_point_mse(feature,split_point_value)
                if mse is not None and best_split_point["mse"]>mse:
                    best_split_point = {
                        "feature":feature,
                        "values":split_point_value,
                        "mse":mse,
                        "split_index":split_index
                    }
        
        self.split_point = best_split_point
        
        self.examples.sort(key=lambda example: example[self.split_point["feature"]])
        self.left = TreeNode(self.examples[:self.split_point["split_index"]])
        self.left.split()
        self.right = TreeNode(self.examples[self.split_point["split_index"]:])
        self.right.split()
        
    def get_split_point_mse(self,feature,split_point_value):
        left_split_labels = [example['quality'] for example in self.examples if example[feature]<=split_point_value]
        right_split_labels = [example['quality'] for example in self.examples if example[feature]>split_point_value]
         
        if not left_split_labels or not len(right_split_labels):
            return None, None
        
        left_split_mse = get_mse(left_split_labels)
        right_split_mse = get_mse(right_split_labels)
        num_samples = len(left_split_labels) + len(right_split_labels)
        mse = ((len(left_split_labels)*left_split_mse) +(len(right_split_labels)*right_split_mse))/num_samples
        split_index = len(left_split_labels)
        
        return mse,split_index
        
def get_mse(values):
    average = get_average(values)
    return sum([(value-average)**2 for value in values])/len(values)
    
    

def get_average(values):
    return sum(values)/len(values)
    
class RegressionTree:
    def __init__(self,examples):
        self.root = TreeNode(examples)
        self.train()
    
    def train(self):
        self.root.split()
        
    def predict(self,example):
        node = self.root
        
        while node.left and node.right:
            if example[node.split_point['feature']] <= node.split_point["value"]:
                node = node.left
            else:
                node = node.right
                
        leaf_labels = [leaf_example['quality'] for leaf_example in node.examples]
        return sum(leaf_labels)/len(leaf_labels)

In [22]:
# RegressionTree(train_data).predict(test_data[0])

In [16]:
class TreeNode:
    def __init__(self, examples):
        self.examples = examples
        self.left = None
        self.right = None
        self.split_point = None

    def split(self):
        # Write your code here.
        if len(self.examples)==1:
            return
        best_split_point = {
            "feature":None,
            "value":None,
            "mse":float('inf'),
            "split_index":None
            }

        for feature in self.examples[0].keys():
            if feature != 'quality':
#                 continue
                self.examples.sort(key = lambda example:example[feature])
                for i,_ in enumerate(self.examples[:-1]):
                    split_point_value = (self.examples[i][feature] + self.examples[i+1][feature])/2
                    mse, split_index = self.get_split_point_mse(feature,split_point_value)
                    if mse is not None and best_split_point['mse']>mse:
                        best_split_point = {
                            "feature":feature,
                            "value":split_point_value,
                            "mse":mse,
                            "split_index": split_index}
        if best_split_point['feature'] is not None:
            self.split_point = best_split_point  
#             print(self.split_point)
            self.examples.sort(key = lambda example: example[self.split_point['feature']])
            self.left = TreeNode(self.examples[:self.split_point['split_index']])
            self.left.split()
            self.right = TreeNode(self.examples[self.split_point['split_index']:])
            self.right.split()

    def get_split_point_mse(self,feature,split_point_value):
        left_split_labels = [example['quality'] for example in self.examples if example[feature]<=split_point_value]
        right_split_labels = [example['quality'] for example in self.examples if example[feature]>split_point_value]

        # make sure there are right and left examples
        if not len(left_split_labels) or not len(right_split_labels):
            return None, None
        left_split_mse = get_mse(left_split_labels)
        right_split_mse = get_mse(right_split_labels)
        num_samples = len(left_split_labels)+len(right_split_labels)
        mse = ((len(left_split_labels)*left_split_mse) + (len(right_split_labels)*right_split_mse))/num_samples
        split_index = len(left_split_labels)

        return mse,split_index



def get_mse(values):
    average = get_average(values)
    return  sum([(value - average)**2 for value in values])/len(values)

def get_average(values):
    return sum(values)/len(values)


class RegressionTree:
    def __init__(self, examples):
        # Don't change the following two lines of code.
        self.root = TreeNode(examples)
        self.train()

    def train(self):
        # Don't edit this line.
        self.root.split()

    def predict(self, example):
        # Write your code here.
        node = self.root
        while node.left and node.right:
            if example[node.split_point['feature']]<=node.split_point['value']:
                node = node.left
            else:
                node = node.right

        leaf_labels = [leaf_example['quality'] for leaf_example in node.examples]
        return sum(leaf_labels)/len(leaf_labels)


In [None]:
tree=RegressionTree(train_data)

In [15]:
for i in test_data:
    print(f"pred: {tree.predict(i)} actual:{i['quality']}")

pred: 5.0 actual:6.0
pred: 6.0 actual:7.0
pred: 5.0 actual:6.0
pred: 5.0 actual:5.0
pred: 6.0 actual:5.0
pred: 5.0 actual:6.0
pred: 5.0 actual:6.0
pred: 5.0 actual:5.0
pred: 5.0 actual:5.0
pred: 6.0 actual:5.0
pred: 6.0 actual:8.0
pred: 6.0 actual:7.0
pred: 6.0 actual:7.0
pred: 6.0 actual:7.0
pred: 6.0 actual:5.0
pred: 5.0 actual:6.0
pred: 5.0 actual:6.0
pred: 5.0 actual:6.0
pred: 6.0 actual:5.0
pred: 7.0 actual:5.0
pred: 7.0 actual:7.0
pred: 5.0 actual:6.0
pred: 5.0 actual:4.0
pred: 6.0 actual:6.0
pred: 5.0 actual:6.0
pred: 5.0 actual:5.0
pred: 5.0 actual:5.0
pred: 5.0 actual:7.0
pred: 3.0 actual:4.0
pred: 5.0 actual:7.0
pred: 4.0 actual:3.0
pred: 5.0 actual:5.0
pred: 6.0 actual:5.0
pred: 6.0 actual:6.0
pred: 6.0 actual:5.0
pred: 6.0 actual:5.0
pred: 8.0 actual:7.0
pred: 6.0 actual:5.0
pred: 8.0 actual:7.0
pred: 5.0 actual:3.0
pred: 5.0 actual:5.0
pred: 5.0 actual:4.0
pred: 5.0 actual:5.0
pred: 5.0 actual:4.0
pred: 6.0 actual:5.0
pred: 6.0 actual:4.0
pred: 6.0 actual:5.0
pred: 6.0 act