diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 0e31c7ed58df8..deac390130128 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -177,9 +177,10 @@ object GradientBoostedTrees extends Logging { treeStrategy.assertValid() // Cache input - if (input.getStorageLevel == StorageLevel.NONE) { + val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { input.persist(StorageLevel.MEMORY_AND_DISK) - } + true + } else false timer.stop("init") @@ -265,6 +266,9 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + + if (persistedInput) input.unpersist() + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo,