# [Workshop] Knowledge Discoervy by Decision Tree

# 0. Installation (one time job)

In [1]:
# !pip install scikit-learn==0.23.1
# !pip install pandas
# !pip install matplotlib

# 1. Import Library

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor
import graphviz
from sklearn.tree import plot_tree, export_graphviz
from sklearn import tree

# 2. Import ASD Data

In [3]:
# Read data
ASD_data = pd.read_csv('./Toddler Autism dataset.csv')
ASD_data.head()

Unnamed: 0,Case_No,A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,Age_Mons,Qchat-10-Score,Sex,Ethnicity,Jaundice,Family_mem_with_ASD,Who completed the test,Class/ASD Traits
0,1,0,0,0,0,0,0,1,1,0,1,28,3,f,middle eastern,yes,no,family member,No
1,2,1,1,0,0,0,1,1,0,0,0,36,4,m,White European,yes,no,family member,Yes
2,3,1,0,0,0,0,0,1,1,0,1,36,4,m,middle eastern,yes,no,family member,Yes
3,4,1,1,1,1,1,1,1,1,1,1,24,10,m,Hispanic,no,no,family member,Yes
4,5,1,1,0,1,1,1,1,1,1,1,20,9,f,White European,no,yes,family member,Yes


In [4]:
# Replace special 
ASD_data.columns=ASD_data.columns.str.replace('-','_')
ASD_data.columns=ASD_data.columns.str.replace('/','_')
ASD_data.columns=ASD_data.columns.str.replace(' ','_')
ASD_data.head()

Unnamed: 0,Case_No,A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,Age_Mons,Qchat_10_Score,Sex,Ethnicity,Jaundice,Family_mem_with_ASD,Who_completed_the_test,Class_ASD_Traits_
0,1,0,0,0,0,0,0,1,1,0,1,28,3,f,middle eastern,yes,no,family member,No
1,2,1,1,0,0,0,1,1,0,0,0,36,4,m,White European,yes,no,family member,Yes
2,3,1,0,0,0,0,0,1,1,0,1,36,4,m,middle eastern,yes,no,family member,Yes
3,4,1,1,1,1,1,1,1,1,1,1,24,10,m,Hispanic,no,no,family member,Yes
4,5,1,1,0,1,1,1,1,1,1,1,20,9,f,White European,no,yes,family member,Yes


In [5]:
# observing the shape of the data
ASD_data.shape

(1054, 19)

# 3. Data Preprocessing

## 3.1. Choose appropriate features 

In [6]:
Sex = pd.Categorical(pd.Categorical(ASD_data['Sex']).codes)
Ethnicity = pd.Categorical(pd.Categorical(ASD_data['Ethnicity']).codes)
Jaundice = pd.Categorical(pd.Categorical(ASD_data['Jaundice']).codes)
Family_mem_with_ASD = pd.Categorical(pd.Categorical(ASD_data['Family_mem_with_ASD']).codes)
Class_ASD_Traits_ = pd.Categorical(pd.Categorical(ASD_data['Class_ASD_Traits_']).codes)

ASD_data['Sex'] = Sex
ASD_data['Ethnicity'] = Ethnicity
ASD_data['Jaundice'] = Jaundice
ASD_data['Family_mem_with_ASD'] = Family_mem_with_ASD
ASD_data['Class_ASD_Traits_'] = Class_ASD_Traits_

# Check the data and think why we drop these variables?
# X = ASD_data.drop(['Case_No', 'Who_completed_the_test', 'Score', 'ASD_Traits', 'A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10'], axis=1)
X = ASD_data.drop(['Case_No', 'Who_completed_the_test', 'Qchat_10_Score', 'Class_ASD_Traits_'], axis=1)

In [7]:
X.columns

Index(['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', 'Age_Mons',
       'Sex', 'Ethnicity', 'Jaundice', 'Family_mem_with_ASD'],
      dtype='object')

## 3.2 Choose target

In [8]:
Y_classification = ASD_data.Class_ASD_Traits_

# 4. Build Classification Tree

## 4.1 Split the dataset into training set and test set

In [9]:
X_train, X_test, y_train, y_test = train_test_split(X, Y_classification, test_size=1 / 3, random_state=42,
                                                    stratify=Y_classification)
print(X_train.shape)
print(X_test.shape)

(702, 15)
(352, 15)


In [10]:
X_train.head()

Unnamed: 0,A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,Age_Mons,Sex,Ethnicity,Jaundice,Family_mem_with_ASD
89,1,1,1,1,1,1,1,1,1,0,34,1,6,0,0
388,0,0,0,0,1,0,1,0,0,1,36,1,6,1,0
198,1,0,1,1,1,1,1,0,1,1,20,1,6,0,0
323,1,1,1,1,1,1,1,1,1,1,32,1,5,1,0
298,1,1,1,1,1,1,1,0,1,1,17,1,0,1,1


In [11]:
y_train.head()

89     1
388    0
198    1
323    1
298    1
Name: Class_ASD_Traits_, dtype: category
Categories (2, int64): [0, 1]

## 4.2 Build classification tree

In [12]:
dt = DecisionTreeClassifier(criterion='gini',random_state=0,max_depth=8)
dt.fit(X_train, y_train)

print("Accuracy on training set: {:.3f}".format(dt.score(X_train, y_train)))
print("Accuracy on test set: {:.3f}".format(dt.score(X_test, y_test)))
dt

Accuracy on training set: 0.996
Accuracy on test set: 0.932


DecisionTreeClassifier(max_depth=8, random_state=0)

In [13]:
dot_data = export_graphviz(dt, out_file=None, 
                      feature_names=X.columns,  
                      class_names=['Not ASD', 'ASD'], # "0": Not ASD ; "1": ASD
                      filled=True, rounded=True,  
                      special_characters=True)  
graph = graphviz.Source(dot_data)

In [14]:
graph.render("tree_rules_image", view=True)
f = open("tree_rules.txt","w+")
f.write(dot_data)
f.close()

In [15]:
# # Visualize the tree
# from IPython.display import display
# display(graph)

# 6. Extract all the rules (Decsion Tree's level = 8) from the classification tree

In [None]:
# if (A5 > 0.5) and (A6 > 0.5) and (A9 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 <= 0.5) and (A2 <= 0.5) and (A3 <= 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 > 0.5) and (A6 > 0.5) and (Family_mem_with_ASD <= 0.5) then class: 1 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 > 0.5) and (Age_Mons > 13.0) then class: 1 
# if (A5 > 0.5) and (A6 > 0.5) and (A9 <= 0.5) and (A1 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 <= 0.5) and (A4 <= 0.5) and (A8 <= 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 > 0.5) and (A6 <= 0.5) and (A2 > 0.5) and (Ethnicity > 0.5) then class: 1 
# if (A5 > 0.5) and (A6 > 0.5) and (A9 <= 0.5) and (A1 <= 0.5) and (A8 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 <= 0.5) and (A2 <= 0.5) and (A3 > 0.5) and (A8 <= 0.5) then class: 0 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 <= 0.5) and (A8 > 0.5) and (Ethnicity > 2.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 <= 0.5) and (A2 > 0.5) and (A8 <= 0.5) and (A6 <= 0.5) and (Ethnicity > 3.0) then class: 0 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 <= 0.5) and (A8 <= 0.5) and (A2 <= 0.5) and (Ethnicity > 4.0) and (Age_Mons > 15.5) then class: 0 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 <= 0.5) and (A4 > 0.5) and (A3 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 > 0.5) and (A6 <= 0.5) and (A2 <= 0.5) and (A8 <= 0.5) and (A3 <= 0.5) and (Jaundice <= 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 <= 0.5) and (A4 <= 0.5) and (A8 > 0.5) and (A2 <= 0.5) and (A6 <= 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 > 0.5) and (A2 <= 0.5) and (A8 <= 0.5) and (A3 <= 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 > 0.5) and (A6 > 0.5) and (Family_mem_with_ASD > 0.5) and (A2 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 > 0.5) and (A2 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 <= 0.5) and (A4 > 0.5) and (A3 <= 0.5) and (A8 > 0.5) and (Ethnicity > 2.5) then class: 1
# if (A5 > 0.5) and (A6 > 0.5) and (A9 <= 0.5) and (A1 <= 0.5) and (A8 <= 0.5) and (A4 > 0.5) then class: 1 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 <= 0.5) and (A8 <= 0.5) and (A2 > 0.5) and (A1 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 <= 0.5) and (A2 > 0.5) and (A8 > 0.5) and (A10 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 > 0.5) and (A6 <= 0.5) and (A2 <= 0.5) and (A8 > 0.5) then class: 1 
# if (A5 > 0.5) and (A6 > 0.5) and (A9 <= 0.5) and (A1 <= 0.5) and (A8 <= 0.5) and (A4 <= 0.5) and (A10 <= 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 > 0.5) and (A2 <= 0.5) and (A8 > 0.5) then class: 1 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 > 0.5) and (Age_Mons <= 13.0) and (A9 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 <= 0.5) and (A4 <= 0.5) and (A8 > 0.5) and (A2 <= 0.5) and (A6 > 0.5) and (Age_Mons > 18.0) then class: 1 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 <= 0.5) and (A4 > 0.5) and (A3 <= 0.5) and (A8 <= 0.5) and (A2 <= 0.5) and (A6 <= 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 <= 0.5) and (A4 <= 0.5) and (A8 > 0.5) and (A2 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 <= 0.5) and (A2 <= 0.5) and (A3 > 0.5) and (A8 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 <= 0.5) and (A4 > 0.5) and (A3 <= 0.5) and (A8 <= 0.5) and (A2 <= 0.5) and (A6 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 > 0.5) and (A6 <= 0.5) and (A2 <= 0.5) and (A8 <= 0.5) and (A3 <= 0.5) and (Jaundice > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 > 0.5) and (A2 <= 0.5) and (A8 <= 0.5) and (A3 > 0.5) and (Age_Mons > 19.0) and (Sex > 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 > 0.5) and (A6 <= 0.5) and (A2 > 0.5) and (Ethnicity <= 0.5) and (A10 > 0.5) then class: 1 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 <= 0.5) and (A8 > 0.5) and (Ethnicity <= 2.5) and (A2 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 <= 0.5) and (A2 > 0.5) and (A8 <= 0.5) and (A6 > 0.5) and (A10 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 > 0.5) and (A2 <= 0.5) and (A8 <= 0.5) and (A3 > 0.5) and (Age_Mons <= 19.0) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 <= 0.5) and (A2 > 0.5) and (A8 <= 0.5) and (A6 > 0.5) and (A10 <= 0.5) then class: 0 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 <= 0.5) and (A8 <= 0.5) and (A2 <= 0.5) and (Ethnicity > 4.0) and (Age_Mons <= 15.5) and (Jaundice <= 0.5) then class: 0 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 <= 0.5) and (A8 <= 0.5) and (A2 > 0.5) and (A1 <= 0.5) and (A7 <= 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 > 0.5) and (A6 <= 0.5) and (A2 <= 0.5) and (A8 <= 0.5) and (A3 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 <= 0.5) and (A2 > 0.5) and (A8 > 0.5) and (A10 <= 0.5) and (A3 <= 0.5) then class: 0 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 <= 0.5) and (A8 <= 0.5) and (A2 <= 0.5) and (Ethnicity > 4.0) and (Age_Mons <= 15.5) and (Jaundice > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 <= 0.5) and (A2 > 0.5) and (A8 <= 0.5) and (A6 <= 0.5) and (Ethnicity <= 3.0) and (A4 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 <= 0.5) and (A2 > 0.5) and (A8 <= 0.5) and (A6 <= 0.5) and (Ethnicity <= 3.0) and (A4 <= 0.5) then class: 0 
# if (A5 > 0.5) and (A6 > 0.5) and (A9 <= 0.5) and (A1 <= 0.5) and (A8 <= 0.5) and (A4 <= 0.5) and (A10 > 0.5) then class: 1 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 > 0.5) and (Age_Mons <= 13.0) and (A9 <= 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 > 0.5) and (A2 <= 0.5) and (A8 <= 0.5) and (A3 > 0.5) and (Age_Mons > 19.0) and (Sex <= 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 <= 0.5) and (A9 <= 0.5) and (A2 > 0.5) and (A8 > 0.5) and (A10 <= 0.5) and (A3 > 0.5) then class: 1 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 <= 0.5) and (A8 <= 0.5) and (A2 <= 0.5) and (Ethnicity <= 4.0) then class: 1 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 > 0.5) and (A6 > 0.5) and (Family_mem_with_ASD > 0.5) and (A2 <= 0.5) then class: 0 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 <= 0.5) and (A8 > 0.5) and (Ethnicity <= 2.5) and (A2 <= 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 <= 0.5) and (A4 <= 0.5) and (A8 > 0.5) and (A2 <= 0.5) and (A6 > 0.5) and (Age_Mons <= 18.0) then class: 0 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 > 0.5) and (A6 <= 0.5) and (A2 > 0.5) and (Ethnicity <= 0.5) and (A10 <= 0.5) then class: 0 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 <= 0.5) and (A4 > 0.5) and (A3 <= 0.5) and (A8 <= 0.5) and (A2 > 0.5) then class: 1 
# if (A5 <= 0.5) and (A7 > 0.5) and (A1 <= 0.5) and (A4 > 0.5) and (A3 <= 0.5) and (A8 > 0.5) and (Ethnicity <= 2.5) then class: 0 
# if (A5 > 0.5) and (A6 <= 0.5) and (A4 <= 0.5) and (A8 <= 0.5) and (A2 > 0.5) and (A1 <= 0.5) and (A7 > 0.5) then class: 1 