## test simple_tree visualization with 
1. Titanic dataset: Binary classification Decision Tree
2. Iris dataset: Multiclass classification Decision Tree

## Titanic: Binary classifier

In [1]:
import pandas as pd
import numpy as np

from sklearn.tree import DecisionTreeClassifier

In [2]:
# read dataset
titanic = pd.read_csv('data/titanic_train.csv')

# impute null values
titanic["Age"] = titanic["Age"].fillna(titanic["Age"].dropna().median())
titanic["Embarked"] = titanic["Embarked"].fillna("S")

# handle categrical features
titanic['Sex'] = titanic['Sex'].apply(lambda x : 1 if x == 'male' else 0)
titanic = pd.get_dummies(titanic, columns=['Embarked'])

# features to used
features = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked_C', 'Embarked_Q', 'Embarked_S']

In [3]:
dt = DecisionTreeClassifier(random_state=24, max_leaf_nodes=20)
%time dt.fit(titanic[features], titanic['Survived'])

Wall time: 6 ms


DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=20,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=24,
            splitter='best')

### visualize the tree model

In [6]:
import sys
sys.path.insert(0, 'src/')
import simple_tree

In [7]:
%%time
simple_tree.generate_simple_tree(tree_title='Titanic_Tree', tree_model=dt, X=titanic[features], 
                                 target_names=['Not Survived', 'Survived'], target_colors = None,
                                 color_map=None, width=1500, height=1000)

The output is in simple_tree_output/simple_tree_Titanic_Tree.html. Enjoy!
Wall time: 45 ms




## Iris: multiclass classifier

In [10]:
from sklearn.datasets import load_iris
import pandas as pd

iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

In [12]:
%%time
simple_tree.generate_simple_tree(tree_title='Iris_Tree', tree_model=clf, 
                                 X=pd.DataFrame(iris.data, columns=iris.feature_names), 
                                 target_names=list(iris.target_names), target_colors = None,
                                 color_map='Vega10', width=1200, height=1000)

The output is in simple_tree_output/simple_tree_Iris_Tree.html. Enjoy!
Wall time: 13 ms
