# Implementation of Decision Tree Algorithm
## Imports

In [93]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

from src.util import *

## Sample data 

In [94]:
X = pd.read_csv('data/wine_dataset.csv')
y = X['type']
# remove column type from X
X = X.drop('type', axis=1)

## Helper Functions

In [95]:
def learn(X, y, impurity_measure='entropy'):
    # check if the impurity_measure is valid
    if impurity_measure not in ['entropy', 'gini']:
        raise ValueError('impurity_measure must be either "entropy" or "gini"')
    # check if the data is valid
    if len(X) != len(y):
        raise ValueError('X and y must have the same length')


In [96]:
# check if data is pure, meaning that all samples belong to the same class
def check_purity(y):
    # create a distinct list of classes
    unique_classes = np.unique(y)
    # check if there is only one class in the list
    return len(unique_classes) == 1

In [97]:
def classify_data(y):
    # create a distinct list of classes and their counts
    unique_classes, counts_unique_classes = np.unique(y, return_counts=True)
    # get the index of the most frequent class
    index = counts_unique_classes.argmax()
    # return the most frequent class
    classification = unique_classes[index]
    return classification

In [98]:
def get_potential_splits(X):
    potential_splits = {}
    # get the number of rows and columns
    _, n_columns = X.shape
    # iterate over each column to get the potential splits
    for column_index in range(n_columns):
        # get the values of the column
        values = X[:, column_index]
        # get the distinct values of the column because this the information is higher
        unique_values = np.unique(values)

        for index in range(len(unique_values)):
            if index != 0:
                current_value = unique_values[index]
                previous_value = unique_values[index - 1]
                # calculate the potential split -> calculate the average between the current and previous value
                potential_split = (current_value + previous_value) / 2
                # if no array exists for the current column, create one
                if column_index not in potential_splits:
                    potential_splits[column_index] = []
                potential_splits[column_index].append(potential_split)
    return potential_splits

In [99]:
def split_data(X, y, split_column, split_value):
    # get the values of the column
    split_column_values = X[:, split_column]
    # get the indices of the rows that are less than the split value
    less_than_split_value = np.where(split_column_values <= split_value)[0]
    # get the indices of the rows that are greater than the split value
    greater_than_split_value = np.where(split_column_values > split_value)[0]
    # split the data
    X_less_than = X[less_than_split_value]
    y_less_than = y[less_than_split_value]
    X_greater_than = X[greater_than_split_value]
    y_greater_than = y[greater_than_split_value]
    return X_less_than, y_less_than, X_greater_than, y_greater_than

### Entropy

Measurement of the tendency to differ to the same class. The lower the entropy, the more pure the data is.

In [100]:
def calculate_entropy(y):
    # get the number of samples
    number_of_samples = len(y)
    # get the distinct classes and their counts
    _, counts = np.unique(y, return_counts=True)
    # calculate the probabilities, this will run for each y because counts is an array
    probabilities = counts / number_of_samples
    # calculate the overall entropy
    entropy = sum(probabilities * -np.log2(probabilities))
    return entropy

In [101]:
def calculate_overall_entropy(y_less_than, y_greater_than):
    # get the number of samples
    number_of_samples = len(y_less_than) + len(y_greater_than)
    # calculate the probabilities
    probabilities_less_than = len(y_less_than) / number_of_samples
    probabilities_greater_than = len(y_greater_than) / number_of_samples
    # calculate the overall entropy
    overall_entropy = (
        probabilities_less_than * calculate_entropy(y_less_than)
        + probabilities_greater_than * calculate_entropy(y_greater_than)
    )
    return overall_entropy

In [102]:
def determine_best_split(X, y, potential_splits, impurity_measure='entropy'):
    overall_entropy = 999
    for column_index in potential_splits:
        for value in potential_splits[column_index]:
            # split the data with current split
            X_less_than, y_less_than, X_greater_than, y_greater_than = split_data(
                X, y, split_column=column_index, split_value=value
            )
            if impurity_measure == 'entropy':
                current_overall_entropy = calculate_overall_entropy(
                    y_less_than, y_greater_than
                )
            # elif impurity_measure == 'gini':
                # current_overall_entropy = calculate_overall_gini(
                #    y_less_than, y_greater_than
                #)
            if current_overall_entropy <= overall_entropy:
                overall_entropy = current_overall_entropy
                best_split_column = column_index
                best_split_value = value
                
    return best_split_column, best_split_value

## Decision Tree Algorithm

In [119]:
def decision_tree_algorithm(X, y, impurity_measure='entropy'):
    # check if the data is pure
    if check_purity(y):
        # classify the data as the most frequent class
        classification = classify_data(y)
        return classification
    else:

        # get the potential splits
        potential_splits = get_potential_splits(X)
        
        # get the best split column and value
        split_column, split_value = determine_best_split(
            X, y, potential_splits, impurity_measure
        )
        # split the data according to the best split
        X_less_than, y_less_than, X_greater_than, y_greater_than = split_data(
            X, y, split_column, split_value
        )
        
        question = "{} <= {}".format(split_column, split_value)
        sub_tree = {question: []}
        # find answers (recursion)
        yes_answer = decision_tree_algorithm(
            X_less_than, y_less_than, impurity_measure
        )
        no_answer = decision_tree_algorithm(
            X_greater_than, y_greater_than, impurity_measure
        )
        # if the answers are the same, then there is no point in asking the question
        if yes_answer == no_answer:
            sub_tree = yes_answer
        else:
            sub_tree[question].append(yes_answer)
            sub_tree[question].append(no_answer)
        return sub_tree

print(decision_tree_algorithm(X.values, y.values))


Classification:  0
Classification:  1
Classification:  0
Classification:  1
Classification:  1
Classification:  0
Classification:  0
Classification:  1
Classification:  0
Classification:  1
Classification:  1
Classification:  0
Classification:  1
Classification:  0
Classification:  1
Classification:  0
Classification:  0
Classification:  1
Classification:  1
Classification:  0
Classification:  1
Classification:  0
Classification:  1
Classification:  0
Classification:  1
Classification:  0
Classification:  1
Classification:  0
Classification:  0
Classification:  1
Classification:  0
Classification:  0
Classification:  0
Classification:  1
Classification:  1
Classification:  0
Classification:  1
Classification:  0
Classification:  0
Classification:  1
Classification:  0
Classification:  1
Classification:  0
Classification:  1
Classification:  1
Classification:  0
Classification:  0
Classification:  1
Classification:  0
Classification:  1
Classification:  0
Classification:  0
Classificati

## Prediction

In [104]:
def predict(x, tree):
    return

## Execution

In [105]:
X_train, X_val, y_train, y_val = split_dataset(
    X, y, test_size=0.2, random_state=42
)

print("Training data:")
print(X_train)
print(y_train)

print("\nValidation data:")
print(X_val)
print(y_val)


Training data:
      citric acid  residual sugar    pH  sulphates  alcohol
3034         0.38             8.1  3.30       0.54      9.8
2576         0.23             6.2  3.34       0.43      9.6
533          0.04             2.5  3.53       0.55      9.5
1061         0.20             3.0  3.23       0.59      9.5
2626         0.32            16.2  3.17       0.37     11.2
...           ...             ...   ...        ...      ...
1095         0.59            11.8  3.17       0.46      8.9
1130         0.30             1.2  2.96       0.36     12.5
1294         0.00             2.2  3.40       0.58     10.9
860          0.14             2.4  3.66       0.65      9.8
3174         0.39             3.2  3.37       0.71     11.5

[2558 rows x 5 columns]
3034    0
2576    0
533     1
1061    1
2626    0
       ..
1095    0
1130    0
1294    1
860     1
3174    1
Name: type, Length: 2558, dtype: int64

Validation data:
      citric acid  residual sugar    pH  sulphates  alcohol
2440         

In [106]:
learn(X, y, 'entropy')