In [1]:
import time

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

from keras.models import Sequential
from keras.layers import Dense
from sklearn.preprocessing import LabelEncoder
from keras.utils import np_utils

import numpy as np

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
# fix random seed for reproducibility
seed = 7
np.random.seed(seed)

In [3]:
# function returns the data in the right format
def get_data():
    dataset = np.genfromtxt("connect-4.csv", dtype='str', delimiter=",")
    
    preX = dataset[:,0:42]
    preY = dataset[:,42]
    
    X = np.zeros(preX.shape)
    
    for i, row in enumerate(preX):
        for j, col in enumerate(row):
            if col == 'x':
                X[i,j] = 1.0
            if col == 'o':
                X[i,j] = -1.0
            if col == 'b':
                X[i,j] = 0.0
    
    
    encoder = LabelEncoder()
    # code: 0 - draw; 1 - loss; 2 -win
    encoded_Y = encoder.fit_transform(preY)
    
    
    # splitting the dataset into 80% training and 20% test data set
    train, test, label_train, label_test = \
            train_test_split(X, encoded_Y,test_size = 0.2)
    
    return train, label_train, test, label_test

In [4]:
from sklearn.tree._tree import TREE_LEAF

def prune_index(inner_tree, index, threshold):
    if inner_tree.value[index].min() < threshold:
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
    # if there are shildren, visit them as well
    if inner_tree.children_left[index] != TREE_LEAF:
        prune_index(inner_tree, inner_tree.children_left[index], threshold)
        prune_index(inner_tree, inner_tree.children_right[index], threshold)

In [5]:
# builds the decision tree of depth 12
def decision_tree(train, label):
    dt = DecisionTreeClassifier(max_depth = 12, min_samples_leaf=100)
    dt.fit(train, label)
    prune_index(dt.tree_, 0, 5)
    end = time.time()
    return dt

In [6]:
# builds the neural network for a given class
def neural_network(class_data):
    num_train = []
    num_label = []
    for x in class_data:
        num_train.append(x[0])
        num_label.append(x[1])
    
    num_train = np.array(num_train)
    num_label = np.array(num_label)
    
    # converting categorical variable into numerical values
    encoder = LabelEncoder()
    encoder.fit(num_label)

    # code: 0 - draw; 1 - loss; 2 -win
    encoded_Y = encoder.transform(num_label)
    final_label = np_utils.to_categorical(encoded_Y, 3)
    
    out = final_label.shape[1]
    print(out, final_label.shape, num_train.shape)
    model = Sequential()
    model.add(Dense(8, input_dim=42, activation='relu'))
    model.add(Dense(out, activation='softmax'))
    model.compile(loss='categorical_crossentropy', \
                  optimizer='adam', metrics=['accuracy'])
    model.fit(num_train, final_label, epochs=5, batch_size=5)
    return model 

In [7]:
# builds the neural shrub
def neural_shrubs(tree, train, label):
    train = np.array(train)
    label = np.array(label)
  
    # leave_id: index of the leaf that cantains the instance
    leave_id = tree.apply(train) 

    classes = dict()

    for x in range(len(train)):
        leaf = leave_id[x]
        
        # Gets the class for each leaf
        #.value: contains value of all the tree nodes
        #.value[leaf]: returns the value of the leaf
        #idx = tree.tree_.value[leaf][0][0]
        
        # insert the instance into the class
        if leaf in classes.keys():
            classes[leaf].append([train[x], label[x]])
        else:
            classes[leaf] = [[train[x], label[x]]]
    
    # stores the neural network for each class
    nn_models = dict()
    
    #stores the max time taken to build a neural network
    max_time = 0;
    for key in classes.keys():            
        start = time.time()
        model = neural_network(classes[key])
        end = time.time()
        
        time_taken = end - start
        if max_time < time_taken:
            max_time = time_taken
        
        nn_models[key] = model
   
    # returns a neural network for each class and the max 
    # time taken to build the neural network
    return nn_models, max_time

In [8]:
# The algorithm to build the neural shrub
train, train_label, test, test_label = get_data()

dt_start = time.time()
tree = decision_tree(train, train_label)
dt_end = time.time()

shrubs, max_time = neural_shrubs(tree, train, train_label)

3 (140, 3) (140, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (300, 3) (300, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (146, 3) (146, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (152, 3) (152, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (161, 3) (161, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (140, 3) (140, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (171, 3) (171, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (174, 3) (174, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (127, 3) (127, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (119, 3) (119, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (203, 3) (203, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (155, 3) (155, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (140, 3) (140, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (137, 3) (137, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (159

Epoch 4/5
Epoch 5/5
3 (1173, 3) (1173, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (222, 3) (222, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (153, 3) (153, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (206, 3) (206, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (127, 3) (127, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (105, 3) (105, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (173, 3) (173, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (120, 3) (120, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (200, 3) (200, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (150, 3) (150, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (109, 3) (109, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (133, 3) (133, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (222, 3) (222, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (116, 3) (116, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoc

3 (372, 3) (372, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (211, 3) (211, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (115, 3) (115, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (116, 3) (116, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (265, 3) (265, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (116, 3) (116, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (201, 3) (201, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (106, 3) (106, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (655, 3) (655, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (126, 3) (126, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (489, 3) (489, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (118, 3) (118, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (282, 3) (282, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (686, 3) (686, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (110

Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (160, 3) (160, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (110, 3) (110, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (197, 3) (197, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (161, 3) (161, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (131, 3) (131, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (348, 3) (348, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (1511, 3) (1511, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (1124, 3) (1124, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (104, 3) (104, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (118, 3) (118, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (153, 3) (153, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (185, 3) (185, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (348, 3) (348, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (1885, 3) (1885, 42)
Epoch 1/5
Epoch 2/5


Epoch 4/5
Epoch 5/5
3 (188, 3) (188, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (233, 3) (233, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (128, 3) (128, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (117, 3) (117, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (127, 3) (127, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (155, 3) (155, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (118, 3) (118, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (146, 3) (146, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (167, 3) (167, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (145, 3) (145, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (123, 3) (123, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (101, 3) (101, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (159, 3) (159, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (236, 3) (236, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 

3 (146, 3) (146, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (132, 3) (132, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (223, 3) (223, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (402, 3) (402, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (701, 3) (701, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (1188, 3) (1188, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (1258, 3) (1258, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (273, 3) (273, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (216, 3) (216, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (135, 3) (135, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (108, 3) (108, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (142, 3) (142, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (364, 3) (364, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (188, 3) (188, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 

Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (121, 3) (121, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (146, 3) (146, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (121, 3) (121, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (117, 3) (117, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (203, 3) (203, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (165, 3) (165, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (462, 3) (462, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (372, 3) (372, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (113, 3) (113, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (168, 3) (168, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (162, 3) (162, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (169, 3) (169, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (247, 3) (247, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (185, 3) (185, 42)
Epoch 1/5
Epoch 

Epoch 4/5
Epoch 5/5
3 (116, 3) (116, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (192, 3) (192, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (100, 3) (100, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (103, 3) (103, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (103, 3) (103, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (100, 3) (100, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (124, 3) (124, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (133, 3) (133, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (151, 3) (151, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (132, 3) (132, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (121, 3) (121, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (190, 3) (190, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (117, 3) (117, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (113, 3) (113, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 

3 (258, 3) (258, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (107, 3) (107, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (2720, 3) (2720, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (518, 3) (518, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (408, 3) (408, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (161, 3) (161, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (138, 3) (138, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (197, 3) (197, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (164, 3) (164, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (166, 3) (166, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (230, 3) (230, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (107, 3) (107, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (330, 3) (330, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (190, 3) (190, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (1

Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (130, 3) (130, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (166, 3) (166, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (195, 3) (195, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (127, 3) (127, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (112, 3) (112, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (126, 3) (126, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (137, 3) (137, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (113, 3) (113, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (110, 3) (110, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (197, 3) (197, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (108, 3) (108, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (122, 3) (122, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (149, 3) (149, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (124, 3) (124, 42)
Epoch 1/5
Epoch 2/5
Epoch 

Epoch 5/5
3 (183, 3) (183, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (135, 3) (135, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (146, 3) (146, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (157, 3) (157, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (159, 3) (159, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (102, 3) (102, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (248, 3) (248, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (307, 3) (307, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (143, 3) (143, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (224, 3) (224, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (167, 3) (167, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (203, 3) (203, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (104, 3) (104, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (183, 3) (183, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 

Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (165, 3) (165, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (112, 3) (112, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (164, 3) (164, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (102, 3) (102, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (167, 3) (167, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (102, 3) (102, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (107, 3) (107, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (135, 3) (135, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (190, 3) (190, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (173, 3) (173, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (128, 3) (128, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (227, 3) (227, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (169, 3) (169, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (157, 3) (157, 42)
Epoch 1/5
Epoch 

Epoch 4/5
Epoch 5/5
3 (166, 3) (166, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (162, 3) (162, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (169, 3) (169, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (122, 3) (122, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (166, 3) (166, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (156, 3) (156, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (104, 3) (104, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (201, 3) (201, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (155, 3) (155, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (116, 3) (116, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (193, 3) (193, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (100, 3) (100, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (379, 3) (379, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (156, 3) (156, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 

Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (194, 3) (194, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (270, 3) (270, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (129, 3) (129, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (105, 3) (105, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (378, 3) (378, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (405, 3) (405, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (172, 3) (172, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (232, 3) (232, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (106, 3) (106, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (291, 3) (291, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (226, 3) (226, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (113, 3) (113, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (137, 3) (137, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (126, 3) (126, 42)
Epoch 1/5
Epoch 2/5
Epoch 

3 (195, 3) (195, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (222, 3) (222, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (186, 3) (186, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (133, 3) (133, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (197, 3) (197, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (155, 3) (155, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (172, 3) (172, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (103, 3) (103, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (147, 3) (147, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (152, 3) (152, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (275, 3) (275, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (188, 3) (188, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (152, 3) (152, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (105, 3) (105, 42)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
3 (108

Epoch 4/5
Epoch 5/5


In [9]:
# predicts using the neural shrub
def neural_shrub_predict(tree, nn_model, test, label):
    label_test = np.array(label)
    test = np.array(test)
    
    #row - actual; col - pred
    confusion_matrix = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
    correct = 0
    
    for i in range(len(test)):
        x = test[i]
        pred_class = tree.apply([x])
        x = np.array([x])
        nn_model_class = nn_model[pred_class[0]]
        pred = np.argmax(nn_model_class.predict(x))

        confusion_matrix[label[i]][pred] = \
            confusion_matrix[label[i]][pred] + 1
        if pred == label[i]: correct = correct + 1

    acc_score = correct/len(test)
    
    return confusion_matrix, acc_score

In [10]:
# Predicting
cm, acc_score = neural_shrub_predict(tree, shrubs, test, test_label)
print("Confusion Matrix:\n\n", cm)

Confusion Matrix:

 [[  55  338  852]
 [ 136 1890 1364]
 [ 205 1287 7385]]


In [11]:
# function used to calcultate the metrics for each class
def metrics(cm, cls, size):
    cm = np.array(cm)
    tp = cm[cls][cls]
    fp = sum(cm[x, cls] for x in range(3))-cm[cls][cls]
    fn = sum(cm[cls, x] for x in range(3))-cm[cls][cls]
    tn = size - tp - fp - fn
    precision = tp/(tp+fp)
    recall = tp/(tp+fn)
    fmeasure = 2*(precision*recall)/(precision + recall)
    accuracy = (tp + tn)/size
    
    return precision, recall, fmeasure, accuracy

In [12]:
# metrics for class 0 (draw)
precision0, recall0, f0, acc0 = metrics(cm, 0, len(test))
print("                Precision Recall F-measure Accuracy")
print("Class 0 (draw): ", round(precision0, 3), "  ", round(recall0, 3), \
      " ", round(f0, 3), "   ", round(acc0,3))

                Precision Recall F-measure Accuracy
Class 0 (draw):  0.139    0.044   0.067     0.887


In [13]:
# metrics for class 1 (lose)
precision1, recall1, f1, acc1 = metrics(cm, 1, len(test))
print("                Precision Recall F-measure Accuracy")
print("Class 1 (loss): ", round(precision1, 3), "  ", round(recall1, 3), \
      " ", round(f1, 3), "   ", round(acc1,3))

                Precision Recall F-measure Accuracy
Class 1 (loss):  0.538    0.558   0.547     0.769


In [14]:
# metrics for class 2 (win)
precision2, recall2, f2, acc2 = metrics(cm, 2, len(test))
print("                Precision Recall F-measure Accuracy")
print("Class 2 (win): ", round(precision2, 3), "  ", round(recall2, 3), \
      " ", round(f2, 3), "   ", round(acc2,3))

                Precision Recall F-measure Accuracy
Class 2 (win):  0.769    0.832   0.799     0.726


In [15]:
# average metrics
avg_p = (precision0 + precision1 + precision2)/3.0
avg_r = (recall0 + recall1 + recall2) / 3.0
avg_f = (f0 + f1 + f2) / 3.0
avg_a = (acc0 + acc1 + acc2)/ 3.0
print("        Precision Recall F-measure Accuracy")
print("Average: ", round(avg_p, 3), "  ", round(avg_r, 3), \
      " ", round(avg_f, 3), "   ", round(avg_a,3))

        Precision Recall F-measure Accuracy
Average:  0.482    0.478   0.471     0.794


In [16]:
# training time
total_time_taken = dt_end - dt_start + max_time
print("Training Time: %s sec" % round(total_time_taken, 5))

Training Time: 15.54949 sec


In [17]:
# Number of instances correctly classified
print("Accuracy_score: ", round(acc_score, 5))

Accuracy_score:  0.6905
