In [3]:
import numpy as np 
import pandas as pd
from sklearn.datasets import load_iris

In [5]:
iris = load_iris()
x=iris.data
y=iris.target
data = np.c_[x,y]

In [45]:
class Ques:
    """
    this class when given proper params gives us the proper question. (helps in printing the tree too)
    """
    def __init__(self, col, val):
        self.col = col
        self.val = val
    
    def match(self, data):
        val = data[self.col]
        return val >= self.val
    
    def __repr__(self):
        return "Is %s >= %s" %(self.col, self.val)

In [46]:
Ques(0,5)

Is 0 >= 5

In [47]:
class DecisionNode:
    """
    class which represents a node about which we will split i.e. decision making node of the tree.
    """
    def __init__(self,question,left,right):
        self.question = question
        self.left = left
        self.right = right

In [48]:
class Leaf:
    """
    Leaf node representaion of the tree
    """
    def __init__(self,rows):
        self.predictions=count_values(rows)

In [49]:
def count_values(rows):
    """
    count the number of occurrences of the predicting class 
    """
    count = {}
    for row in rows: 
        label = row[-1]
        
        if label not in count: 
            count[label] = 0
        count[label]+=1
    return count

In [50]:
count_values(data)

{0.0: 50, 1.0: 50, 2.0: 50}

In [51]:
def partition(rows,question):
    """
    function that helps in partitioning the data passed to the function.
    """
    left_row,right_row=[],[]
    for row in rows:
        if question.match(row):
            left_row.append(row)
        else:
            right_row.append(row)
    return left_row,right_row

In [52]:
print(Ques(0,5))
t_l,t_r = partition(data,Ques(0,5))
t_r

Is 0 >= 5


[array([4.9, 3. , 1.4, 0.2, 0. ]),
 array([4.7, 3.2, 1.3, 0.2, 0. ]),
 array([4.6, 3.1, 1.5, 0.2, 0. ]),
 array([4.6, 3.4, 1.4, 0.3, 0. ]),
 array([4.4, 2.9, 1.4, 0.2, 0. ]),
 array([4.9, 3.1, 1.5, 0.1, 0. ]),
 array([4.8, 3.4, 1.6, 0.2, 0. ]),
 array([4.8, 3. , 1.4, 0.1, 0. ]),
 array([4.3, 3. , 1.1, 0.1, 0. ]),
 array([4.6, 3.6, 1. , 0.2, 0. ]),
 array([4.8, 3.4, 1.9, 0.2, 0. ]),
 array([4.7, 3.2, 1.6, 0.2, 0. ]),
 array([4.8, 3.1, 1.6, 0.2, 0. ]),
 array([4.9, 3.1, 1.5, 0.1, 0. ]),
 array([4.9, 3.1, 1.5, 0.1, 0. ]),
 array([4.4, 3. , 1.3, 0.2, 0. ]),
 array([4.5, 2.3, 1.3, 0.3, 0. ]),
 array([4.4, 3.2, 1.3, 0.2, 0. ]),
 array([4.8, 3. , 1.4, 0.3, 0. ]),
 array([4.6, 3.2, 1.4, 0.2, 0. ]),
 array([4.9, 2.4, 3.3, 1. , 1. ]),
 array([4.9, 2.5, 4.5, 1.7, 2. ])]

In [87]:
def gini(rows):
    """
    calculates the gini impurity
    """
    count=count_values(rows)
    impurity=1
    for label in count:
        probab_of_label=count[label]/float(len(rows))
        impurity-=probab_of_label**2
    return impurity

def entropy(rows):
    """
    calculates the entropy used in ID3 
    """
    entropy=0
    count=count_values(rows)
    for label in count:
        p=count[label]/float(len(rows))
        entropy-=p*np.log2(p)
    return entropy

def info_gain_gini(current,left,right):
    p =float(len(left))/len(left)+len(right)
    return current-p*gini(left)-(1-p)*gini(right)

def info_gain_entropy(current,left,right):
    p =float(len(left))/len(left)+len(right)
    return current-p*entropy(left)-(1-p)*entropy(right)

In [88]:
def best_split_gini(rows):
    best_gain=0
    best_question=None
    current=gini(rows)
    features=len(rows[0])-1
    for col in range(features):
        values=set([row[col] for row in rows])
        for val in values:
            question=Ques(col,val)
            left,right=partition(rows,question)
            if len(left)==0 or len(right) ==0:
                continue
            gain=info_gain_gini(current,left,right)
            if gain>=best_gain:
                best_gain,best_question=gain,question
    return best_gain,best_question

def best_split_entropy(rows):
    best_gain=0
    best_question=None
    current=entropy(rows)
    features=len(rows[0])-1
    for col in range(features):
        values=set([row[col] for row in rows])
        for val in values:
            question=Ques(col,val)
            left,right=partition(rows,question)
            if len(left)==0 or len(right)==0:
                continue
            gain=info_gain_entropy(current,left,right)
            if(gain>=best_gain):
                best_gain,best_question=gain,question
    return best_gain,best_question

In [96]:
def build_tree(rows):
    gain,question=best_split_gini(rows)
    if gain==0:
        return Leaf(rows)
    
    left, false_rows = partition(rows, question)
    true_branch = build_tree(left)
    false_branch = build_tree(false_rows)
    return DecisionNode(question, true_branch, false_branch)

def print_tree(node,indentation=""):
    if isinstance(node,Leaf):
        print(indentation+"PREDICTION",node.predictions)
        return 
    print(indentation + str(node.question))
    print(indentation+ "Left Branch")
    print_tree(node.left,indentation + " ")
    print(indentation+ "Right Branch")
    print_tree(node.right,indentation + " ")

In [97]:
def fit(X,y,max_depth):
    data = np.c_[X,y] 
    tree = build_tree(data)
    print_tree(tree)

In [98]:
fit(x,y,10)

ValueError: all the input array dimensions except for the concatenation axis must match exactly

In [99]:
ds = pd.read_csv('titanic.csv')
ds.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
PassengerId    891 non-null int64
Survived       891 non-null int64
Pclass         891 non-null int64
Name           891 non-null object
Sex            891 non-null object
Age            714 non-null float64
SibSp          891 non-null int64
Parch          891 non-null int64
Ticket         891 non-null object
Fare           891 non-null float64
Cabin          204 non-null object
Embarked       889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.6+ KB


In [100]:
cols_to_drop = [
    'PassengerId',
    'Name',
    'Ticket',
    'Cabin',
    'Embarked',
]

df = ds.drop(cols_to_drop, axis=1)
def convert_sex_to_num(s):
    if s=='male':
        return 0
    elif s=='female':
        return 1
    else:
        return s

df.Sex = df.Sex.map(convert_sex_to_num)
data = df.dropna()
input_cols = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare']
out_cols = ['Survived']

X = data[input_cols]
y = data[out_cols]

In [101]:
fit(X,y,10)

Is 4 >= 6.0
Left Branch
 PREDICTION {0.0: 1}
Right Branch
 Is 2 >= 80.0
 Left Branch
  PREDICTION {1.0: 1}
 Right Branch
  Is 2 >= 74.0
  Left Branch
   PREDICTION {0.0: 1}
  Right Branch
   Is 2 >= 71.0
   Left Branch
    PREDICTION {0.0: 2}
   Right Branch
    Is 2 >= 70.5
    Left Branch
     PREDICTION {0.0: 1}
    Right Branch
     Is 2 >= 70.0
     Left Branch
      PREDICTION {0.0: 2}
     Right Branch
      Is 2 >= 66.0
      Left Branch
       PREDICTION {0.0: 1}
      Right Branch
       Is 2 >= 65.0
       Left Branch
        PREDICTION {0.0: 3}
       Right Branch
        Is 2 >= 64.0
        Left Branch
         PREDICTION {0.0: 2}
        Right Branch
         Is 2 >= 63.0
         Left Branch
          PREDICTION {1.0: 2}
         Right Branch
          Is 3 >= 5.0
          Left Branch
           PREDICTION {0.0: 5}
          Right Branch
           Is 5 >= 512.3292
           Left Branch
            PREDICTION {1.0: 3}
           Right Branch
            Is 4 >= 4.0
  