In [5]:
import pandas as pd
import ssl
import numpy as np
ssl._create_default_https_context = ssl._create_unverified_context
df = pd.read_csv("https://raw.githubusercontent.com/cwhitz/git_intro_to_ml/master/heart.csv")

train = df.head(250)
test = df.tail(13)

In [6]:

class Gini: 
    def gini_impurity(self,node):
        return 1-sum([(np.count_nonzero(node==k)/len(node))**2 for k in self.types])
    def gini_impurity_score(self,node_left,node_right):
        m = len(node_left)+len(node_right)
        i = len(node_left)
        return i/m*self.gini_impurity(node_left) + (m-i)/m*self.gini_impurity(node_right)
class Node:
    def __init__(self,col,dtype,split,gini):
        self.col = col
        self.dtype = dtype
        self.split = split
        self.gini = gini
    def getCol(self):
        return self.col
    def getSplit(self):
        return self.split

In [9]:
class DecisionTree(Gini):
    def __init__(self):
        self.df = pd.read_csv("heart_disease.csv")
        self.types = pd.unique(self.df.target)
        cols = {col : list(self.df[col].unique()) for col in ["age"]}
        node = df
        minGini = {"gini":.5,"col":None,"dtype": None, "split": None}

        for col,vals in cols.items():
            newGini = self.findGiniImpurityFromCol(node, col,vals)
            if newGini["gini"] <= minGini["gini"]:
                minGini = newGini
                
        print(minGini)
    def findGiniImpurityFromCol(self,node, col,vals):
        if len(vals)==2:
            dtype = "binary"
            gini = self.findBinaryImpurity(node, col,vals)
            split = None
        if len(vals)>10:
            dtype= "numerical"
            gini,split = self.findNumericalImpurity(node, col,vals)
        else:
            dtype = "categorical"
            gini, split = self.findCategoricalImpurity(node, col,vals)
        return {"gini":gini,"col":col,"dtype": dtype, "split": split}

    def findBinaryImpurity(self, node, col, vals):
        node_left = df[df[col]==vals[0]].target
        node_right = df[df[col]==vals[1]].target
        gini_impurity = self.gini_impurity_score(node_left,node_right)
        return gini_impurity
    
    def findNumericalImpurity(self,node, col,vals):
        nodes = [[node[node[col]<=split].target, node[node[col]>split].target] for split in sorted(vals)[:-1]]
        impurities = [self.gini_impurity_score(node_left,node_right) for node_left, node_right in  nodes]
        impurities = dict(zip(range(len(vals[:-1])),impurities))
        split = sorted(impurities, key=impurities.get)[0]
        gini_impurity = impurities[split] 
        split = vals[:-1][split]
        return gini_impurity, split
    
    def findCategoricalImpurity(self, node, col,vals):
        combos = self.findCombos(vals)
        nodes = [[node[node[col].isin(combo1)].target, node[node[col].isin(combo2)].target] for combo1,combo2 in combos]
        impurities = [self.gini_impurity_score(node_left,node_right) for node_left, node_right in  nodes]
        impurities = dict(zip(range(len(combos)),impurities))
        split = sorted(impurities, key=impurities.get)[0]
        gini_impurity = impurities[split]
        split = combos[split]
        return gini_impurity, split
    
    def findOptions(self,values):
        if len(values) == 0:
            return [values]
        smaller = findOptions(values[1:])
        return smaller + [[values[0]] + x for x in smaller]

    def findCombos(self,vals):
        options = []
        duplicate_options = [([x for x in vals if x not in option], option) for option in self.findOptions(vals)]
        for option in duplicate_options:
            if (option[1],option[0]) not in options and len(option[0]) != len(vals) and len(option[1]) != len(vals):
                options.append(option)
        return options

In [10]:
DecisionTree()

{'gini': 0.4553348416602667, 'col': 'age', 'dtype': 'numerical', 'split': 71}


<__main__.DecisionTree at 0x7fc2dab91d60>

In [19]:
{x:3for x in range(3)}

{0: 3, 1: 3, 2: 3}

In [175]:
sorted([3,2,1])

[1, 2, 3]

In [179]:
pd.read_csv("https://raw.githubusercontent.com/cwhitz/git_intro_to_ml/master/heart.csv")

{1: 1, 2: 2}

In [91]:
hey = 0 if True else 3

In [203]:
df

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,63,1,3,145,233,1,0,150,0,2.3,0,0,1,1
1,37,1,2,130,250,0,1,187,0,3.5,0,0,2,1
2,41,0,1,130,204,0,0,172,0,1.4,2,0,2,1
3,56,1,1,120,236,0,1,178,0,0.8,2,0,2,1
4,57,0,0,120,354,0,1,163,1,0.6,2,0,2,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
298,57,0,0,140,241,0,1,123,1,0.2,1,0,3,0
299,45,1,3,110,264,0,1,132,0,1.2,1,0,3,0
300,68,1,0,144,193,1,1,141,0,3.4,1,2,3,0
301,57,1,0,130,131,0,1,115,1,1.2,1,1,3,0


In [245]:
max([1,2,2])

2

In [195]:
df.to_csv("heart_disease.csv",index=False)

In [None]:
if len(vals)>10:
                find_gini = [.5,vals[0]]
                for val in vals[1:]:
                    node_left = df[df[col]<val].target
                    node_right = df[df[col]>=val].target
                    impurity = self.gini_impurity_score(node_left,node_right)
                    if impurity < find_gini[0]: find_gini=[impurity,val]