Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
59 lines (51 sloc) 2.38 KB
#*******************************************************************************
# Copyright 2014-2019 Intel Corporation
# All Rights Reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License"), the following terms apply:
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# See the License for the specific language governing permissions and
# limitations under the License.
#*******************************************************************************
# daal4py Decision Forest Regression Tree Traversal example
import math
import daal4py as d4p
from numpy import loadtxt, allclose
from decision_forest_regression_batch import main as df_regression
def printTree(nodes, values):
def printNodes(node_id, nodes, values, level):
node = nodes[node_id]
value = values[node_id]
if not math.isnan(node["threshold"]):
print(" " * level + "Level " + str(level) + ": Feature = " + str(node["feature"]) + ", Threshold = " + str(node["threshold"]))
else:
print(" " * level + "Level " + str(level) + ", Value = " + str(value).replace(" ", ""))
if node["left_child"] != -1:
printNodes(node["left_child"], nodes, values, level + 1)
if node["right_child"] != -1:
printNodes(node["right_child"], nodes, values, level + 1)
return
printNodes(0, nodes, values, 0)
return
if __name__ == "__main__":
from daal4py import __daal_link_version__ as dv
daal_version = tuple(map(int, (dv[0:4], dv[4:8])))
if daal_version < (2019, 1):
print("Need Intel(R) DAAL 2019.1 or later")
else:
# First get our result and model
(train_result, _, _) = df_regression()
# Retrieve and print all trees; encoded as in sklearn.ensamble.tree_.Tree
for treeId in range(train_result.model.NumberOfTrees):
treeState = d4p.getTreeState(train_result.model, treeId)
printTree(treeState.node_ar, treeState.value_ar)
print('Traversed {} trees.'.format(train_result.model.NumberOfTrees))
print('All looks good!')
You can’t perform that action at this time.