In [51]:
import numpy as np
import pandas as pd
import math
from platform import python_version
from IPython.core.debugger import set_trace

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import random
from pprint import pprint

In [5]:
df = pd.read_csv("iris.csv")
df = df.rename(columns={"species":"label"})

In [6]:
df.head()

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,label
0,6.7,3.3,5.7,2.1,virginica
1,4.3,3.0,1.1,0.1,setosa
2,4.8,3.4,1.6,0.2,setosa
3,5.6,2.8,4.9,2.0,virginica
4,6.2,2.8,4.8,1.8,virginica


In [7]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150 entries, 0 to 149
Data columns (total 5 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   sepal_length  150 non-null    float64
 1   sepal_width   150 non-null    float64
 2   petal_length  150 non-null    float64
 3   petal_width   150 non-null    float64
 4   label         150 non-null    object 
dtypes: float64(4), object(1)
memory usage: 6.0+ KB


In [8]:
def train_test_split(df, test_size):
    indices = df.index.tolist()
    test_indices = random.sample(population=indices, k=round(test_size*len(indices))) #k - number of samples required
    test_df = df.loc[test_indices]
    train_df = df.drop(test_indices)
    return train_df, test_df

In [9]:
train_df, test_df = train_test_split(df, 0.1)
print(len(test_df), len(train_df))

15 135


In [10]:
# Convert dataframe to numpy to make it run much faster since numpy uses SIMD feature
data = train_df.values #converts to a 2D array

In [11]:
def check_purity(data):
    
    label_column = data[:, -1]
    unique_classes = np.unique(label_column)

    if len(unique_classes) == 1:
        return True
    else:
        return False

In [12]:
def classify_data(data):
    label = data[:, -1]
    classes, counts = np.unique(label, return_counts=True)
#     print(classes, counts)
    classification = classes[counts.argmax()]
    return classification

In [13]:
classify_data(train_df[train_df.petal_width < 0.8].values)

'setosa'

# Potential Split

In [14]:
def get_potential_splits(data):
    
    potential_splits = {}
    _, n_columns = data.shape
    for column_index in range(n_columns - 1):        # excluding the last column which is the label
        potential_splits[column_index] = []
        values = data[:, column_index]
        unique_values = np.unique(values)

        potential_splits[column_index] = unique_values
    
    return potential_splits

In [15]:
def split_data(data, split_column, split_value):  # split value will be 1 split_col will be split_attr
    split_column_val = data[:, split_column]
    data_below = data[split_column_val <= split_value]
    data_above = data[split_column_val > split_value]
    return data_below, data_above     # data below will be data false, data above-true

In [16]:
split_column = 3
split_value = 0.8
data_below, data_above = split_data(data, split_column, split_value)

In [17]:
def calculate_entropy(data):
    label_column = data[:, -1]
    _, counts = np.unique(label_column, return_counts=True)
    probabilities = counts / counts.sum()
    entropy = sum(probabilities * -np.log2(probabilities))
    return entropy

In [18]:
def calculate_overall_entropy(data_below, data_above):
    n_datapoints = len(data_below) + len(data_above)
    p_data_below = len(data_below) / n_datapoints
    p_data_above = len(data_above) / n_datapoints
    overall_entropy = (p_data_below * calculate_entropy(data_below) 
                       + p_data_above * calculate_entropy(data_above))
    return overall_entropy

In [19]:
calculate_overall_entropy(data_below, data_above)

0.6496642534954947

In [26]:
def determine_best_split(data, potential_splits):
    best_info_gain = None
    for column_index in potential_splits:
        for value in potential_splits[column_index]:
            
            data_below, data_above = split_data(data, split_column = column_index, split_value = value)
            
            # Entropy for data below
            Info_gain = (calculate_entropy(data) - 
                        ((len(data_below)/len(data)) * calculate_entropy(data_below)) - 
                        ((len(data_above)/len(data)) * calculate_entropy(data_above)))
         
            if best_info_gain is None:
                best_info_gain = Info_gain
            elif Info_gain >= best_info_gain:
                best_info_gain = Info_gain
                best_split_column = column_index
                best_split_value = value
    return best_split_column, best_split_value

In [27]:
split_col, split_val = determine_best_split(data, get_potential_splits(data))
print(split_col, split_val)

3 0.5


In [28]:
determine_best_split(data, get_potential_splits(data))

(3, 0.5)

## Main algorithm

In [32]:
def decision_tree_algorithm(df, counter=0):
    data = df.values if counter == 0 else df

    #base case
    if check_purity(data) == True:
        classification = classify_data(data)
        return classification
        
    #recursion
    else:
        counter+=1
        
        potential_splits = get_potential_splits(data)
        split_column, split_value = determine_best_split(data, potential_splits)
        data_below, data_above = split_data(data, split_column, split_value)
        
        if len(data_below) == 0 or len(data_above) == 0:
            classification = classify_data(data)
            return classification
        
        question = f"if {train_df.columns[split_column]} <= {split_value}"
                
        #instantiate sub-tree
        sub_tree = {question: []}

        yes_answer = decision_tree_algorithm(data_below, counter)
        no_answer = decision_tree_algorithm(data_above, counter)
        
        sub_tree[question].append(yes_answer)
        sub_tree[question].append(no_answer)
        
        return sub_tree    

In [33]:
tree = decision_tree_algorithm(train_df[train_df.label != "virginica"])
tree

{'if petal_width <= 0.5': ['setosa', 'versicolor']}

In [53]:
tree = decision_tree_algorithm(train_df)
tree

True

## Classification

In [36]:
example = test_df.iloc[0]
print(example)

sepal_length           6.6
sepal_width              3
petal_length           4.4
petal_width            1.4
label           versicolor
Name: 148, dtype: object


In [39]:
def classify_example(example, tree, print_path=0):
    question = list(tree.keys())[0]
    if print_path != 0:
        print(question)
    _, feature_name, comparison_operator, value = question.split()
    
    if example[feature_name] <= float(value):
        answer = tree[question][0]
        if print_path!=0:
            print("yes")
    else:
        answer = tree[question][1]
        if print_path!=0:
            print("no")
    
    #base case
    if not isinstance(answer, dict):
        return answer
    else:
        return classify_example(example, answer, print_path)    

In [40]:
classify_example(example, tree, 1)

if petal_width <= 0.5
no
if petal_length <= 4.7
yes


'versicolor'

### Accuracy

In [43]:
def calculate_accuracy(df, tree):
    df["classification"] = df.apply(classify_example, axis=1, args=(tree, ))
    df["classification_correct"] = df.classification == df.label
    accuracy = df.classification_correct.mean()
    return accuracy

In [44]:
calculate_accuracy(train_df, tree)

1.0

In [45]:
print(train_df)

     sepal_length  sepal_width  petal_length  petal_width      label  \
0             6.7          3.3           5.7          2.1  virginica   
1             4.3          3.0           1.1          0.1     setosa   
2             4.8          3.4           1.6          0.2     setosa   
3             5.6          2.8           4.9          2.0  virginica   
4             6.2          2.8           4.8          1.8  virginica   
..            ...          ...           ...          ...        ...   
144           4.9          3.1           1.5          0.1     setosa   
145           7.7          3.0           6.1          2.3  virginica   
146           5.1          3.7           1.5          0.4     setosa   
147           5.1          3.5           1.4          0.2     setosa   
149           5.1          3.8           1.9          0.4     setosa   

    classification  classification_correct  
0        virginica                    True  
1           setosa                    True  


In [52]:
print(python_version())

3.6.10
