Skip to content
Fetching contributors…
Cannot retrieve contributors at this time
59 lines (51 sloc) 2.38 KB
 #******************************************************************************* # Copyright 2014-2019 Intel Corporation # All Rights Reserved. # # This software is licensed under the Apache License, Version 2.0 (the # "License"), the following terms apply: # # You may not use this file except in compliance with the License. You may # obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # # See the License for the specific language governing permissions and # limitations under the License. #******************************************************************************* # daal4py Decision Forest Regression Tree Traversal example import math import daal4py as d4p from numpy import loadtxt, allclose from decision_forest_regression_batch import main as df_regression def printTree(nodes, values): def printNodes(node_id, nodes, values, level): node = nodes[node_id] value = values[node_id] if not math.isnan(node["threshold"]): print(" " * level + "Level " + str(level) + ": Feature = " + str(node["feature"]) + ", Threshold = " + str(node["threshold"])) else: print(" " * level + "Level " + str(level) + ", Value = " + str(value).replace(" ", "")) if node["left_child"] != -1: printNodes(node["left_child"], nodes, values, level + 1) if node["right_child"] != -1: printNodes(node["right_child"], nodes, values, level + 1) return printNodes(0, nodes, values, 0) return if __name__ == "__main__": from daal4py import __daal_link_version__ as dv daal_version = tuple(map(int, (dv[0:4], dv[4:8]))) if daal_version < (2019, 1): print("Need Intel(R) DAAL 2019.1 or later") else: # First get our result and model (train_result, _, _) = df_regression() # Retrieve and print all trees; encoded as in sklearn.ensamble.tree_.Tree for treeId in range(train_result.model.NumberOfTrees): treeState = d4p.getTreeState(train_result.model, treeId) printTree(treeState.node_ar, treeState.value_ar) print('Traversed {} trees.'.format(train_result.model.NumberOfTrees)) print('All looks good!')
You can’t perform that action at this time.