<img align="right" src="notebook_resources/decision_tree_header_image.png" width=300 height=300>

# How Do Decision Trees Work?

### by Matt Britton

#### The goal of this notebook is to teach you about Decision Trees, a common model type in machine learning.
#### We'll do this through a combination of thought experiments, data visualizations, and a mini-exercise. 

#### By the end of this tutorial, you should be able to:

- Describe the structure and purpose of a decision tree
- Describe how the decision trees we make in the real world are similar to and different from machine learning models.
- Construct a small decision tree with locally optimal splits using the concepts of _Entropy_ and _Information Gain_.

### Accessing This Notebook

You can access this notebook at its GitHub repository, and download it to run on your own machine. To run the notebook and interact with the charts without downloading, you need to go to the NBViewer version.

| Source        | Link          | QR Code|
| ------------- | --------------- | ----- |
| GitHub     | https://github.com/MattJBritton/InteractiveDecisionTrees | <img src="notebook_resources/GitHub_QR_code.png" width=100 height=100> |
| NBViewer     | https://nbviewer.jupyter.org/github/MattJBritton/InteractiveDecisionTrees/blob/master/InteractiveDecisionTrees.ipynb      |   <img src="notebook_resources/NBViewer_QR_code.png" width=100 height=100> |

### Background Knowledge

This notebook builds on basic knowledge of machine learning, such as:
- Supervised Learning 
- Classification
- Model Evaluation

It's also helpful to understand how at least one other classification model type works, such as:
- Logistic Regression
- K-Nearest Neighbors

Lastly, this tutorial leverages several techniques used in EDA for understanding feature distributions, including bar charts, line charts, and kernel density estimate (KDE) plots. Knowledge of how to read these charts would be helpful.

## What is a Decision Tree?

Think back to other supervised classification algorithms. A decision tree is another type of predictive model, and so it works like Logistic Regression does. As a data scientist, you will:
-  __Train__ it based on data.
-  __Interpret__ the mathematical model of the relationships of predictors to the target
- __Predict__ what will hapen to a new data point
- __Evaluate__ its accuracy on known data.

Each type of model represents a way of thinking about the world. Whereas Logistic Regression is focused on how fast Y changes as X increases, decision trees are best at capturing a different kind of phenomena: if/then choices.

#### Decision Trees in the Real World

You’ve undoubtedly made a decision tree at some point in your life. Let’s take the example of packing for a vacation. Should you bring a bathing suit?
Your decision process might look a little like this:
- Is there a body of water where you’re going?
    - If no, then don’t.
    - If yes, is the temperature there above 60 degrees?

## Imports

In [194]:
import pandas as pd
import numpy as np
import altair as alt

from scipy.stats import entropy
from sklearn.tree import DecisionTreeClassifier, export_text

### Data Loading and Cleaning

In [2]:
titanic_df = pd.read_csv("data/titanic_train.csv")
titanic_df["Family_Size"] = titanic_df["SibSp"] + titanic_df["Parch"]
titanic_df["Class"] = titanic_df["Pclass"].replace({1:"First", 2:"Second", 3: "Third"})
titanic_df["Age"] = titanic_df["Age"].round(0)
titanic_df = titanic_df.drop(["PassengerId", "Name", "Ticket", "Cabin", "Fare", "SibSp", "Parch", "Pclass"], axis=1)
titanic_df = titanic_df.dropna(axis=0, how="any")

In [4]:
categorical_features = ["Class", "Sex", "Embarked"]
target = "Survived"

### Define Functions Used in Vis

In [184]:
def get_split_entropy(data, feature, split_val, side):
    if feature in categorical_features:
        if side == "left":
            operator = "=="
        else:
            operator = "!="
    else:
        if side == "left":
            operator = "<="
        else:
            operator = ">"
    
    side_data = data.query(feature + operator + "@split_val")
    side_percent_target = np.array(side_data[target]).sum()/len(side_data)
    side_entropy = entropy([side_percent_target, 1.-side_percent_target], base=2)
    subset_size = len(side_data)
    return side_entropy, subset_size

In [262]:
def build_IG_Table(data,feature):
    percent_target = np.array(data[target]).sum()/len(np.array(data[target]))
    dataset_entropy = entropy([percent_target, 1.-percent_target], base=2)
    entropies = []
    if feature in categorical_features:
        split_val_range = np.unique(data[feature])
    else:
        split_val_range = np.arange(np.min(data[feature]).astype(int), np.max(data[feature]).astype(int))
    for split_val in split_val_range:
        
        left_entropy, left_size = get_split_entropy(data, feature, split_val, "left")
        right_entropy, right_size = get_split_entropy(data, feature, split_val, "right")
        
        information_gain = dataset_entropy - ((left_size*left_entropy) + (right_size*right_entropy))/len(data)
        entropies.append((split_val, (100.*left_entropy).round(0), (100.*right_entropy).round(0), (100.*information_gain).round(2)))

    entropy_df = pd.DataFrame(entropies, columns = ["Split_Val", "Left_Entropy", "Right_Entropy", "Information_Gain"])
    entropy_df["Feature"] = feature
    return entropy_df

In [328]:
def build_interactive_decision_tree(train_dataset, show_information_gain=False, show_splits=False, subquery=None):
    
    chart = alt.hconcat()
    if subquery:
        data = train_dataset.query(subquery).copy()
    else:
        data = train_dataset.copy()
    
    ig_df_list = []
    # get all information gain calculations into a single table
    features_to_use = [x for x in data.columns if x != target and np.unique(data[x]).shape[0] > 1]
    for feature in features_to_use:
        ig_df_list.append(build_IG_Table(data, feature))
    
    master_ig_df = pd.concat(ig_df_list)
    
    CATEGORICAL_WIDTH = 250
    QUANTITATIVE_WIDTH = 250
    SPLIT_WIDTH = 100
    TEXT_WIDTH = 120
    MAIN_CHART_HEIGHT = 200
    IG_LINE_HEIGHT = 75
    for feature in features_to_use:
        if feature in categorical_features:
            
            component_stack = []

            chart_data = data.loc[:,[feature, target]]
            selector = alt.selection_single(fields=[feature])

            # build IG table
            ig_df = master_ig_df.query("Feature == @feature")
            
            # Top Component

            feature_component = alt.Chart(chart_data).mark_bar().encode(
                x = alt.X(feature, axis=alt.Axis(labelAngle=0), title=None),
                y = "count()",
                color = alt.Color(target, type="nominal", sort=[1,0]),
                opacity = alt.condition(selector, alt.value(1), alt.value(0.3))
            ).properties(
                width = CATEGORICAL_WIDTH,
                height = MAIN_CHART_HEIGHT
            ).add_selection(
                selector
            )
            
            component_stack.append(feature_component)
            
            # End Top Component
            
            # Middle Component
            
            if show_information_gain:
            
                ig_line = alt.Chart(ig_df).mark_line().encode(
                    x = alt.X("Split_Val", title=None),
                    y = alt.Y("Information_Gain", scale = alt.Scale(domain = [0, np.max(master_ig_df["Information_Gain"])]))
                ).properties(
                    width = QUANTITATIVE_WIDTH,
                    height = IG_LINE_HEIGHT
                )
                
                component_stack.append(ig_line)
            
            # End Middle Component
            
            # Bottom Component
            
            if show_splits:
                split_component = alt.Chart(chart_data).mark_bar().encode(
                    y = "Split:N",
                    x = alt.X("count()", scale=alt.Scale(domain = [0, len(data)])),
                    color = alt.Color(target, type="nominal", sort=[1,0])
                ).transform_calculate(
                    split_val = selector[feature]
                ).transform_calculate(
                    Split = "datum." + feature + " == datum.split_val?'Left':'Right'"
                ).properties(
                    width = SPLIT_WIDTH
                )

                split_entropies = alt.Chart(chart_data).mark_text(dx=10).encode(
                    y = "Split:N",
                    x = alt.X("count()", scale=alt.Scale(domain = [0, len(data)])),
                    text = "Entropy:Q"
                ).transform_calculate(
                    split_val = selector[feature]
                ).transform_calculate(
                    Split = "datum." + feature + " == datum.split_val?'Left':'Right'"
                ).transform_lookup(
                    lookup = "split_val",
                    from_ = alt.LookupData(
                        data = ig_df,
                        key = "Split_Val",
                        fields = ["Left_Entropy", "Right_Entropy"]
                    )
                ).transform_calculate(
                    Entropy = "datum.Split == 'Left'?datum.Left_Entropy:datum.Right_Entropy"
                ) 

                split_text = alt.Chart(ig_df).mark_text().encode(
                    text = "chart_text:N",
                ).transform_calculate(
                    selected_val = selector[feature]
                ).transform_filter(
                    (alt.datum.Split_Val <= alt.datum.selected_val) & (alt.datum.Split_Val >= alt.datum.selected_val)
                ).transform_calculate(
                    chart_text = "For " + feature + "=" + alt.datum.selected_val
                ).properties(
                    width = TEXT_WIDTH
                )

                ig_text = alt.Chart(ig_df).mark_text().encode(
                    text = "chart_text:N",
                ).transform_calculate(
                    selected_val = selector[feature]
                ).transform_filter(
                    (alt.datum.Split_Val <= alt.datum.selected_val) & (alt.datum.Split_Val >= alt.datum.selected_val)
                ).transform_calculate(
                    chart_text = "Information Gain: " + alt.datum.Information_Gain
                ).properties(
                    width = TEXT_WIDTH
                )
                
                component_stack.append((split_component + split_entropies) | (split_text & ig_text))
            
            # End Bottom Component

            chart |= alt.vconcat(*component_stack).properties(title=feature)
        else:
            
            component_stack = []

            chart_data = data.loc[:,[feature, target]].melt(id_vars=target)
            selector = alt.selection_single(fields = [feature], on="mouseover", nearest=True)

            # build IG table
            ig_df = master_ig_df.query("Feature == @feature")

            # Top Component
            
            feature_component = alt.Chart(chart_data).transform_density(
                density='value',
                bandwidth=0.5,
                groupby=['variable', target],
                steps=20
            ).transform_joinaggregate(
                max_density = 'max(density)'
            ).mark_line().encode(
                alt.X('value:Q', title=None),
                alt.Y('density:Q', axis=alt.Axis(labelAngle=0, titleAngle=0)),
                alt.Color(target, type="nominal")
            ).properties(
                width=QUANTITATIVE_WIDTH,
                height = MAIN_CHART_HEIGHT
            )

            selection_bar_df = pd.DataFrame(
                np.arange(np.min(data[feature]).astype(int), np.max(data[feature]).astype(int)),
                columns = [feature]
            )
            selection_bar_df["Height"] = 1

            selection_bars = alt.Chart(selection_bar_df).mark_bar(
                size=.8*QUANTITATIVE_WIDTH/len(selection_bar_df),
                binSpacing=0,
                align="left"
            ).transform_calculate(
                key=selector[feature]
            ).transform_calculate(
                val='datum.'+feature
            ).encode(
                x = alt.X(feature, title = None),
                y = alt.Y("Height", axis=None),
                color = alt.value("lightgrey"),
                opacity = alt.condition(alt.datum.val <= alt.datum.key, alt.value(0.3), alt.value(0))
            ).properties(
                height = MAIN_CHART_HEIGHT
            ).add_selection(
                selector
            )    
            
            component_stack.append(
                (selection_bars + feature_component).resolve_scale(y="independent")
            )
            
            # End Top Component

            # Middle Component
            if show_information_gain:
                ig_line = alt.Chart(ig_df).mark_line().encode(
                    x = alt.X("Split_Val", title=None),
                    y = alt.Y("Information_Gain", scale = alt.Scale(domain = [0, np.max(master_ig_df["Information_Gain"])]))
                ).properties(
                    width = QUANTITATIVE_WIDTH,
                    height = IG_LINE_HEIGHT
                )
                component_stack.append(ig_line)
                
            # End Middle Component
            
            # Bottom Component
            
            if show_splits:
                split_component = alt.Chart(chart_data).mark_bar().encode(
                    y = "Split:N",
                    x = alt.X("count()", scale=alt.Scale(domain = [0, len(data)])),
                    color = alt.Color(target, type="nominal", sort=[1,0])
                ).transform_calculate(
                    split_val = selector[feature]
                ).transform_calculate(
                    Split = "datum.value <= datum.split_val?'Left':'Right'"
                ).properties(
                    width = SPLIT_WIDTH
                )   

                split_entropies = alt.Chart(chart_data).mark_text(dx=10).encode(
                    y = "Split:N",
                    x = alt.X("count()", scale=alt.Scale(domain = [0, len(data)])),
                    text = "Entropy:Q"
                ).transform_calculate(
                    split_val = selector[feature]
                ).transform_calculate(
                    Split = "datum.value <= datum.split_val?'Left':'Right'"
                ).transform_lookup(
                    lookup = "split_val",
                    from_ = alt.LookupData(
                        data = ig_df,
                        key = "Split_Val",
                        fields = ["Left_Entropy", "Right_Entropy"]
                    )
                ).transform_calculate(
                    Entropy = "datum.Split == 'Left'?datum.Left_Entropy:datum.Right_Entropy"
                )           

                split_text = alt.Chart(ig_df).mark_text().encode(
                    text = "chart_text:N",
                ).transform_calculate(
                    selected_val = selector[feature]
                ).transform_filter(
                    (alt.datum.Split_Val <= alt.datum.selected_val) & (alt.datum.Split_Val >= alt.datum.selected_val)
                ).transform_calculate(
                    chart_text = "For " + feature + "<=" + alt.datum.selected_val
                ).properties(
                    width = TEXT_WIDTH
                )

                ig_text = alt.Chart(ig_df).mark_text().encode(
                    text = "chart_text:N",
                ).transform_calculate(
                    selected_val = selector[feature]
                ).transform_filter(
                    (alt.datum.Split_Val <= alt.datum.selected_val) & (alt.datum.Split_Val >= alt.datum.selected_val)
                ).transform_calculate(
                    chart_text = "Information Gain: " + alt.datum.Information_Gain
                ).properties(
                    width = TEXT_WIDTH
                )
                
                component_stack.append((split_component) | (split_text & ig_text))
                
                #End Bottom Component

            chart |= alt.vconcat(*component_stack).resolve_scale(x="shared").properties(title=feature)

    return chart.configure_title(anchor="middle")

In [329]:
build_interactive_decision_tree(titanic_df, True, True)

In [347]:
data_for_ml = pd.get_dummies(titanic_df, drop_first = True)
X = data_for_ml.drop(target, axis=1)
y = data_for_ml[target]
sklearn_dt = DecisionTreeClassifier(criterion="entropy", max_depth=2)
sklearn_dt.fit(X, y)
print(f"Model Accuracy: {sklearn_dt.score(X, y).round(3)}")

Model Accuracy: 0.798


In [348]:
print("Structure of Decision Tree")
print(
    export_text(
        sklearn_dt,
        feature_names = list(X.columns)
    ).replace(
        "class: 1", "Survivor"
    ).replace(
        "class: 0", "Non-Survivor"
    ).replace(
        "<= 0.50", "is False"
    ).replace(
        ">  0.50", "is True"
    )
)

Structure of Decision Tree
|--- Sex_male is False
|   |--- Class_Third is False
|   |   |--- Survivor
|   |--- Class_Third is True
|   |   |--- Non-Survivor
|--- Sex_male is True
|   |--- Age <= 13.00
|   |   |--- Survivor
|   |--- Age >  13.00
|   |   |--- Non-Survivor

