In [1]:
from random import choice
import pandas as pd
import numpy as np

In [23]:
data = {
    "age": [25, 32, 47, 51, 62],
    "income": [3000, 4500, 5000, 6200, 7200],
    "married": [1, 0, 1, 1, 0],
    "owns_house": [0, 1, 1, 1, 0],
    "purchased": [0, 1, 1, 1, 0]  # zmienna docelowa
}
test = pd.DataFrame.from_dict(data)

data2 = {
    "age": [21, 25, 32, 34, 47, 50, 51, 54, 57, 60, 62, 65, 67, 71, 74],
    "owns_house": [0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1],
    "married": [0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1],
    "purchased": [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1],
}

test2 = pd.DataFrame.from_dict(data2)

In [587]:
class Tree:
    def __init__(self, split = None, data_left = None, data_right = None):
        self.node  = split
        self.left  = data_left
        self.right = data_right
    def __repr__(self):
        def print_tree(tree,ident = "\t"):
            if isinstance(tree,Tree):
                return f"Node : {tree.node}\n{ident}Leaf 1 :\n{print_tree(tree.left,2*ident)}\n{ident}Leaf 2 :\n{print_tree(tree.right,2*ident)}"
                
            if isinstance(tree, pd.DataFrame):
                if tree.empty: 
                    return ""
                return tree.to_string()
            if tree == None: return ""
            else: return str(tree)
                
                
        return f"Node : {repr(self.node)}\nLeaf 1 :\n\t{print_tree(self.left)}\nLeaf 2 :\n\t{print_tree(self.right)}"

In [856]:
class DecisionTree:
    def __init__(self,dataset : pd.DataFrame,columns : list[int])-> None:
        self.data = dataset.iloc[:,columns+[-1]] 
        self.target = dataset.iloc[:,[-1]]
        self.pred_tree = Tree()
        self.raw_tree = Tree()
        self.split_info = {i:self.continuous_value(i) for i in self.data.columns[:-1]}
        self.acc = 0
        
    def repr(self):
        return pd.DataFrame.to_string(self.data)
    def __str__(self):
        return pd.DataFrame.to_string(self.data)

    def pred(self,data):
        return data.iloc[:,-1:].value_counts().idxmax()[0]

    def continuous_value(self, split,cont_cap = 5):
        values = self.data[split].unique()
        return len(values) > cont_cap
        
    def _raw_to_pred(self,tree):
        if isinstance(tree,Tree):
            return Tree(tree.node,self._raw_to_pred(tree.left),self._raw_to_pred(tree.right))
        elif isinstance(tree, pd.DataFrame):
            return self.pred(tree)

    def _gini_index(self, proportions : list[int]):
        return 1 - sum(i**2 for i in proportions)
        
    def choose_split(self,data,used_splits = []):
        '''
        chooses the best split for data.
        returns split, value to split by, left split, right split
        '''
        splits = data.columns[:-1]
        choose_from = []
        for split in splits:
            values = data[split].unique()
            split_by = self._choose_category(data,split,values,used_splits)
            if split_by == None:
                continue
            choose_from.append(split_by)
        best_split = min(choose_from, key= lambda x : x["gini total"])
        return best_split

        
    def _choose_category(self,data,split,values, used_splits = [],continuous = 5):
        '''
        given split and list of distinct values in that split we choose which split 
        returns the best(smallest) gini index and we return that split with its corresponding gini index, additionally 
        we return both splits(data = split, data != split)
        '''
        if self.continuous_value(split, continuous):
            sorted_values = np.sort(values)
            average_values = [(sorted_values[i]+sorted_values[i+1])/2 for i in range(len(values)-1)]
            splits = [(i,data[data[split] <= i],data[data[split] > i]) for i in average_values if (split,i) not in used_splits]
        else:
            splits = [(i,data[data[split] == i],data[data[split] != i]) for i in values if (split,i) not in used_splits]
            
        if splits == []: return None
        info = []
        for i,data,other in splits:
            gini_data = self._gini_for_split(data)
            gini_other = self._gini_for_split(other)
            total = gini_data[1] + gini_other[1]
            gini_total = gini_data[0] * (gini_data[1]/total) + gini_other[0] * (gini_other[1]/total)
            info.append({"split" : split, "value" : i,"gini total" : gini_total,"left": data,"right" :other})
        return min(info,key= lambda x : x["gini total"])
        
    def _gini_for_split(self, split):
        target = split.iloc[:,-1]
        dist_value_counts = target.value_counts()
        split_count = dist_value_counts.sum()
        proportions = [i/split_count for i in dist_value_counts]
        gini = self._gini_index(proportions)
        return gini,split_count
    
    def create_tree(self, limit = 1,depth_limit = 5)-> None:
        def create_rec(data,depth = 0 ,used_splits = [] ):
            if data.empty: return None
            if depth > depth_limit or data.empty:
                return data
            elif len(data) <= limit:
                return data
            else:
                
                split_dict = self.choose_split(data,used_splits)
                split = {"split" : split_dict["split"], "value" : split_dict["value"]}
                used_splits.append((split["split"],split["value"]))
                left = create_rec(split_dict["left"], depth + 1, used_splits)
                right = create_rec(split_dict["right"], depth + 1, used_splits)
                return Tree(split, left, right)
        self.raw_tree = create_rec(self.data)
        self.pred_tree = self._raw_to_pred(self.raw_tree)

    def pred_new(self,to_pred):
        def go_down_tree(tree):
            if isinstance(tree,Tree):
                split,value = tree.node["split"],tree.node["value"]
                cont = self.split_info[split]
                
                if cont:
                    if (to_pred[split] < value).iloc[0]:
                        return go_down_tree(tree.left)
                    else:
                        return go_down_tree(tree.right)
                else:
                    if (to_pred[split] == value).iloc[0]:
                        return go_down_tree(tree.left)
                    else:
                        return go_down_tree(tree.right)
            else:
                return tree
        return go_down_tree(self.pred_tree)

    def tree_acc(self):
        def helper(tree_preds,tree_raw):
            if isinstance(tree_preds,Tree):
                return helper(tree_preds.left,tree_raw.left) + helper(tree_preds.right,tree_raw.right)
            else:
                return (tree_raw.iloc[:,-1:] == tree_preds).sum()
        return helper(self.pred_tree,self.raw_tree)/len(self.data)

In [858]:
tree2 = DecisionTree(test2,[0,1,2])

In [860]:
tree2.create_tree(depth_limit= 4,limit = 1)

In [862]:
print(tree2.raw_tree)

Node : {'split': 'married', 'value': 0}
Leaf 1 :
	Node : {'split': 'age', 'value': 50.5}
	Leaf 1 :
Node : {'split': 'age', 'value': 28.5}
		Leaf 1 :
Node : {'split': 'age', 'value': 23.0}
				Leaf 1 :
   age  owns_house  married  purchased
0   21           0        0          0
				Leaf 2 :
   age  owns_house  married  purchased
1   25           1        0          0
		Leaf 2 :
Node : {'split': 'age', 'value': 39.5}
				Leaf 1 :
   age  owns_house  married  purchased
2   32           0        0          1
				Leaf 2 :
   age  owns_house  married  purchased
4   47           1        0          1
	Leaf 2 :
Node : {'split': 'age', 'value': 57.0}
		Leaf 1 :
   age  owns_house  married  purchased
7   54           0        0          0
		Leaf 2 :
Node : {'split': 'age', 'value': 63.5}
				Leaf 1 :
   age  owns_house  married  purchased
9   60           0        0          0
				Leaf 2 :
Node : {'split': 'age', 'value': 69.0}
								Leaf 1 :
    age  owns_house  married  purchased
12   67   

In [864]:
# print(tree2.pred_tree)

In [866]:
new_val = {"age" : [32], 
           "owns_house" : [1],
          "married" : [0],
          }
to_guess = pd.DataFrame.from_dict(new_val)
# print(to_guess)
# print(test2)

In [868]:
# tree2.pred_new(to_guess)

In [870]:
# to_guess["purchased"] = tree2.pred_new(to_guess)

In [872]:
# print(tree2.raw_tree)

In [874]:
print(tree2.tree_acc())

purchased    1.0
dtype: float64
