In [25]:
# import required libraries
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split

In [2]:
# read in data - https://www.kaggle.com/uciml/zoo-animal-classification
df=pd.read_csv("zoo.csv")
# set 'animal name' to index
df.set_index('animal_name',inplace=True)
df['class_type'].replace({1:'Mammal',2:'Bird',3:'Reptile',4:'Fish',5:'Amphibian',6:'Bug',7:'Invertebrate'},inplace=True)

In [32]:
# snapshot of our data
df.head()

Unnamed: 0_level_0,hair,feathers,eggs,milk,airborne,aquatic,predator,toothed,backbone,breathes,venomous,fins,legs,tail,domestic,catsize,class_type
animal_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
aardvark,1,0,0,1,0,0,1,1,1,1,0,0,4,0,0,1,Mammal
antelope,1,0,0,1,0,0,0,1,1,1,0,0,4,1,0,1,Mammal
bass,0,0,1,0,0,1,1,1,1,0,0,1,0,1,0,0,Fish
bear,1,0,0,1,0,0,1,1,1,1,0,0,4,0,0,1,Mammal
boar,1,0,0,1,0,0,1,1,1,1,0,0,4,1,0,1,Mammal


In [4]:
# creating our train and test
X_train,X_test,y_train,y_test=train_test_split(df.drop('class_type',axis=1),df['class_type'],random_state=0,test_size=0.3)

In [5]:
# create and fit the model
forest=RandomForestClassifier(max_depth=4)
forest.fit(X_train,y_train)
y_pred=forest.predict(X_test)
final_predict_df=pd.DataFrame({'Name of animal':X_test.index,'Predicted Type':y_pred,'Actual Type':y_test})

In [6]:
# prediction test data
y_pred_test = forest.predict(X_test)
# accuracy_score
accuracy_score(y_test, y_pred_test)

0.967741935483871

In [7]:
# table of predictions
final_predict_df

Unnamed: 0_level_0,Name of animal,Predicted Type,Actual Type
animal_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
frog,frog,Amphibian,Amphibian
stingray,stingray,Fish,Fish
bass,bass,Fish,Fish
oryx,oryx,Mammal,Mammal
sealion,sealion,Mammal,Mammal
vole,vole,Mammal,Mammal
crow,crow,Bird,Bird
seahorse,seahorse,Fish,Fish
opossum,opossum,Mammal,Mammal
wallaby,wallaby,Mammal,Mammal


In [8]:
# view confusion matrix for test data and predictions
confusion_matrix(y_test, y_pred_test)

array([[ 1,  0,  0,  0,  0,  0,  0],
       [ 0,  6,  0,  0,  0,  0,  0],
       [ 0,  0,  2,  0,  0,  0,  0],
       [ 0,  0,  0,  7,  0,  0,  0],
       [ 0,  0,  0,  0,  2,  0,  0],
       [ 0,  0,  0,  0,  0, 11,  0],
       [ 0,  0,  0,  1,  0,  0,  1]], dtype=int64)

In [9]:
for x in range(2,8):
    for y in [100,50,10]:
        forest=RandomForestClassifier(max_depth=x,n_estimators=y)
        forest.fit(X_train,y_train)
        y_pred=forest.predict(X_test)
        print("depth= ",x,' estimators= ',y,' accuracy= ',accuracy_score(y_test,y_pred))

depth=  2  estimators=  100  accuracy=  0.8709677419354839
depth=  2  estimators=  50  accuracy=  0.8709677419354839
depth=  2  estimators=  10  accuracy=  0.8709677419354839
depth=  3  estimators=  100  accuracy=  0.9032258064516129
depth=  3  estimators=  50  accuracy=  0.9032258064516129
depth=  3  estimators=  10  accuracy=  0.967741935483871
depth=  4  estimators=  100  accuracy=  0.9354838709677419
depth=  4  estimators=  50  accuracy=  0.967741935483871
depth=  4  estimators=  10  accuracy=  0.967741935483871
depth=  5  estimators=  100  accuracy=  0.967741935483871
depth=  5  estimators=  50  accuracy=  0.967741935483871
depth=  5  estimators=  10  accuracy=  0.967741935483871
depth=  6  estimators=  100  accuracy=  0.967741935483871
depth=  6  estimators=  50  accuracy=  0.967741935483871
depth=  6  estimators=  10  accuracy=  0.967741935483871
depth=  7  estimators=  100  accuracy=  0.967741935483871
depth=  7  estimators=  50  accuracy=  0.967741935483871
depth=  7  estimato

In [11]:
# view features of the one incorrectly predicted animal type
df.loc['seasnake']

hair                0
feathers            0
eggs                0
milk                0
airborne            0
aquatic             1
predator            1
toothed             1
backbone            1
breathes            0
venomous            1
fins                0
legs                0
tail                1
domestic            0
catsize             0
class_type    Reptile
Name: seasnake, dtype: object

In [12]:
tree = forest.estimators_[5]
from sklearn.tree import export_graphviz
import pydot
export_graphviz(tree, 
                out_file='tree.dot', 
                feature_names = df.columns[1:],
                class_names = df.index,
                rounded = True, 
                precision = 1)

In [13]:
(graph, ) = pydot.graph_from_dot_file('tree.dot')

In [14]:
graph.write_png('tree.png')

In [15]:
# let's identify feature importance
feature_list = list(df.columns)
importances = list(forest.feature_importances_)
# List of tuples with variable and importance
feature_importances = [(feature, round(importance, 2)) for feature, importance in zip(feature_list, importances)]
# Sort the feature importances by most important first
feature_importances = sorted(feature_importances, key = lambda x: x[1], reverse = True)
# Print out the feature and importances 
[print('Variable: {:20} Importance: {}'.format(*pair)) for pair in feature_importances];


Variable: milk                 Importance: 0.19
Variable: toothed              Importance: 0.13
Variable: feathers             Importance: 0.12
Variable: backbone             Importance: 0.12
Variable: hair                 Importance: 0.07
Variable: eggs                 Importance: 0.07
Variable: breathes             Importance: 0.06
Variable: fins                 Importance: 0.06
Variable: airborne             Importance: 0.05
Variable: legs                 Importance: 0.05
Variable: aquatic              Importance: 0.03
Variable: tail                 Importance: 0.03
Variable: predator             Importance: 0.02
Variable: venomous             Importance: 0.01
Variable: catsize              Importance: 0.01
Variable: domestic             Importance: 0.0
