# DecisionTree Classifier
This is basic implementation of a decision tree classifier.

In [1]:
import numpy as np
import pandas as pd

In [2]:
from MyDecisionTree import build_tree

## Datasets
Some standard datasets are loaded in order to test the algorithm.

In [3]:
iris_df = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv')

In [4]:
mpg_df = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/mpg.csv')

In [5]:
flights_df = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/flights.csv')

In [6]:
# Here I perform the shuffling since some datasets are arranged according to the labels

iris_df = iris_df.sample(frac=1).reset_index(drop=True)
mpg_df = mpg_df.sample(frac=1).reset_index(drop=True)
flights_df = flights_df.sample(frac=1).reset_index(drop=True)

## Building The Tree
The module is very straightforward to use: it is sufficient to call the `build_tree` function, and to pass a dataframe containing the dataset as a first argument and the name of the column that contains the labels as a second argument. The function will return a `Node` object that is the root of the tree.

In [7]:
tree_1 = build_tree(iris_df[:-5], iris_df.columns[-1])

In [8]:
#In this case the dataset is quite complex for the algorithm, so I use just a small amount of data for testing purposes
tree_2 = build_tree(mpg_df[:100], mpg_df.columns[-1])

In [9]:
tree_3 = build_tree(flights_df[:-5], flights_df.columns[-1])

## Making Predictions
In order to make predictions, once a tree has been built, it is sufficient to call the `predict` method, passing as argument a dataframe containing the rows one wants to get predictions for. For each input row, the method will return the predicted class, and the value of gini impurity associated with the node that led to that prediction.

In [10]:
iris_df[-10:]

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
140,5.4,3.9,1.3,0.4,setosa
141,6.8,2.8,4.8,1.4,versicolor
142,5.1,3.8,1.6,0.2,setosa
143,6.2,3.4,5.4,2.3,virginica
144,4.4,3.2,1.3,0.2,setosa
145,5.6,3.0,4.5,1.5,versicolor
146,6.4,3.2,5.3,2.3,virginica
147,5.5,2.5,4.0,1.3,versicolor
148,5.7,2.8,4.5,1.3,versicolor
149,4.8,3.1,1.6,0.2,setosa


In [11]:
tree_1.predict(iris_df[iris_df.columns[:-1]][-10:])

[('setosa', 0.0),
 ('versicolor', 0.0),
 ('setosa', 0.0),
 ('virginica', 0.0),
 ('setosa', 0.0),
 ('versicolor', 0.0),
 ('virginica', 0.0),
 ('versicolor', 0.0),
 ('versicolor', 0.0),
 ('setosa', 0.0)]

In [12]:
mpg_df[95:105]

Unnamed: 0,mpg,cylinders,displacement,horsepower,weight,acceleration,model_year,origin,name
95,11.0,8,318.0,210.0,4382,13.5,70,usa,dodge d200
96,25.0,4,140.0,75.0,2542,17.0,74,usa,chevrolet vega
97,23.0,4,97.0,54.0,2254,23.5,72,europe,volkswagen type 3
98,36.0,4,135.0,84.0,2370,13.0,82,usa,dodge charger 2.2
99,24.5,4,151.0,88.0,2740,16.0,77,usa,pontiac sunbird coupe
100,25.0,6,181.0,110.0,2945,16.4,82,usa,buick century limited
101,24.0,4,119.0,97.0,2545,17.0,75,japan,datsun 710
102,14.0,8,455.0,225.0,3086,10.0,70,usa,buick estate wagon (sw)
103,12.0,8,455.0,225.0,4951,11.0,73,usa,buick electra 225 custom
104,21.0,4,122.0,86.0,2226,16.5,72,usa,ford pinto runabout


In [13]:
tree_2.predict(mpg_df[mpg_df.columns[:-1]][95:105])

[('dodge d200', 0.0),
 ('chevrolet vega', 0.0),
 ('volkswagen type 3', 0.0),
 ('dodge charger 2.2', 0.0),
 ('pontiac sunbird coupe', 0.0),
 ('buick century', 0.0),
 ('toyota corolla tercel', 0.0),
 ('ford galaxie 500', 0.0),
 ('ford country', 0.0),
 ('plymouth reliant', 0.0)]

In [14]:
flights_df[-10:]

Unnamed: 0,year,month,passengers
134,1950,June,149
135,1949,September,136
136,1957,June,422
137,1952,October,191
138,1959,September,463
139,1956,April,313
140,1954,May,234
141,1955,July,364
142,1957,May,355
143,1958,September,404


In [15]:
tree_3.predict(flights_df[["year", "month"]][-10:])

[(149, 0.0),
 (136, 0.0),
 (422, 0.0),
 (191, 0.0),
 (463, 0.0),
 (348, 0.0),
 (229, 0.0),
 (548, 0.0),
 (270, 0.0),
 (463, 0.0)]

## Visualising The Tree
It is possible to have a visual representation of the tree by calling the `print_tree` method. For each non-leaf node, an attribute and a value are specified. These values are the ones selected by the algorithm to split the data, and, in particular, the left child of each node represents the data for which the value "is attribute == value" is true, and the right child represents the data for which that value is false.

In [16]:
tree_1.print_tree()

 (R) attribute: petal_width value: 0.2 gini: 0.66653983353151
       (l) gini: 0.0 decision: setosa (Leaf)
       (r) attribute: petal_width value: 0.3 gini: 0.6310176053765798
             (l) gini: 0.0 decision: setosa (Leaf)
             (r) attribute: petal_width value: 0.4 gini: 0.6028099173553719
                   (l) gini: 0.0 decision: setosa (Leaf)
                   (r) attribute: petal_width value: 0.1 gini: 0.5608445659345838
                         (l) gini: 0.0 decision: setosa (Leaf)
                         (r) attribute: petal_width value: 1.3 gini: 0.5195751770095793
                               (l) gini: 0.0 decision: versicolor (Leaf)
                               (r) attribute: petal_width value: 1.0 gini: 0.5110318404016383
                                     (l) gini: 0.0 decision: versicolor (Leaf)
                                     (r) attribute: petal_width value: 1.5 gini: 0.4928125
                                           (l) attribute: petal_lengt

In [17]:
tree_2.print_tree()

 (R) attribute: acceleration value: 19.2 gini: 0.9890000000000012
       (l) gini: 0.0 decision: datsun 210 (Leaf)
       (r) attribute: displacement value: 232.0 gini: 0.9889629321116198
             (l) attribute: acceleration value: 16.0 gini: 0.625
                   (l) gini: 0.0 decision: amc hornet (Leaf)
                   (r) attribute: mpg value: 18.0 gini: 0.5
                         (l) gini: 0.0 decision: amc matador (Leaf)
                         (r) gini: 0.0 decision: amc pacer (Leaf)
             (r) attribute: origin value: europe gini: 0.9889090086011767
                   (l) attribute: cylinders value: 4 gini: 0.9462809917355364
                         (l) attribute: displacement value: 97.0 gini: 0.9273356401384076
                               (l) attribute: mpg value: 34.3 gini: 0.6666666666666665
                                     (l) gini: 0.0 decision: audi 4000 (Leaf)
                                     (r) attribute: mpg value: 27.0 gini: 0.5
       

In [None]:
tree_3.print_tree()