From b855f177fe4a017d167c43de1b1606cd891a58ef Mon Sep 17 00:00:00 2001 From: Sachin Goel Date: Sat, 16 May 2015 19:11:42 +0530 Subject: [PATCH 1/5] Histogram implementation done with tests --- .../org/apache/flink/ml/math/Histogram.scala | 82 +++++ .../flink/ml/math/OnlineHistogram.scala | 287 ++++++++++++++++++ .../flink/ml/math/OnlineHistogramSuite.scala | 156 ++++++++++ 3 files changed, 525 insertions(+) create mode 100644 flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Histogram.scala create mode 100644 flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala create mode 100644 flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Histogram.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Histogram.scala new file mode 100644 index 0000000000000..360c1119023c5 --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Histogram.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +/** Base trait for Histogram + * + */ +trait Histogram { + /** Number of bins in the histogram + * + * @return number of bins + */ + def bins: Int + + /** Returns the lower limit on values of the histogram + * + * @return lower limit on values + */ + def lower: Double + + /** Returns the upper limit on values of the histogram + * + * @return upper limit on values + */ + def upper: Double + + /** Bin value access function + * + * @param bin bin number to access + * @return `v_bin` = value of bin + */ + def getValue(bin: Int): Double + + /** Bin counter access function + * + * @param bin bin number to access + * @return `m_bin` = counter of bin + */ + def getCounter(bin: Int): Int + + /** Adds a new instance to this histogram and appropriately updates counters and values + * + * @param value value to be added + */ + def add(value: Double): Unit + + /** Returns the estimated number of points in the interval `(-\infty,b]` + * + * @return Number of values in the interval `(-\infty,b]` + */ + def sum(b: Double): Int + + /** Merges the histogram with h and returns a histogram with B bins + * + * @param h histogram to be merged + * @param B final size of the merged histogram + */ + def merge(h: Histogram, B: Int): Histogram + + /** Returns a list `u_1,u_2,\ldots,u_{B-1}` such that the number of points in + * `(-\infty,u_1],[u_1,u_2],\ldots,[u_{B-1},\infty)` is `\frac_{1}{B} \sum_{i=0}^{bins-1} m_i`. + * + * @param B number of intervals required + */ + def uniform(B: Int): Array[Double] +} diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala new file mode 100644 index 0000000000000..eaae05e3d6fe5 --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import org.apache.flink.api.java.tuple.Tuple2 + +import java.util + +class OnlineHistogram( + capacity: Int, + min: Double = java.lang.Double.MIN_VALUE, + max: Double = java.lang.Double.MIN_VALUE, + data: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]] + ) extends Histogram with Serializable { + require(checkSanity, "Invalid data provided") + + /** Number of bins in the histogram + * + * @return number of bins + */ + def bins: Int = { + data.size() + } + + /** Returns the lower limit on values of the histogram + * + * @return lower limit on values + */ + def lower: Double = { + min + } + + /** Returns the upper limit on values of the histogram + * + * @return upper limit on values + */ + def upper: Double = { + max + } + + /** Bin value access function + * + * @param bin bin number to access + * @return `v_bin` = value of bin + */ + def getValue(bin: Int): Double = { + require(0 <= bin && bin < bins, bin + " not in [0, " + bins + ")") + data.get(bin).getField(0).asInstanceOf[Double] + } + + /** Bin counter access function + * + * @param bin bin number to access + * @return `m_bin` = counter of bin + */ + def getCounter(bin: Int): Int = { + require(0 <= bin && bin < bins, bin + " not in [0, " + bins + ")") + data.get(bin).getField(1).asInstanceOf[Int] + } + + /** Adds a new instance to this histogram and appropriately updates counters and values + * + * @param value value to be added + */ + def add(value: Double): Unit = { + require(value > lower && value < upper, value + " not in (" + lower + "," + upper + ")") + val search = find(value) + data.add(search, new Tuple2[Double, Int](value, 1)) + if (bins > capacity) mergeElements() + } + + /** Returns the estimated number of points in the interval `(-\infty,b]` + * + * @return Number of values in the interval `(-\infty,b]` + */ + def sum(b: Double): Int = { + require(bins > 0, "Histogram is empty") + if (b < lower) return 0 + if (b > upper) return sum(upper) + val index = find(b) - 1 + var m_b, s: Double = 0 + if (index == -1) { + m_b = getCounter(index + 1) * (b - lower) / (getValue(index + 1) - lower) + s = m_b * (b - lower) / (2 * (getValue(index + 1) - lower)) + return s.toInt + } else if (index == bins - 1) { + m_b = getCounter(index) + (-getCounter(index)) * (b - getValue(index)) / (upper - getValue(index)) + s = (getCounter(index) + m_b) * (b - getValue(index)) / (2 * (upper - getValue(index))) + } else { + m_b = getCounter(index) + (getCounter(index + 1) - getCounter(index)) * (b - getValue(index)) / (getValue(index + 1) - getValue(index)) + s = (getCounter(index) + m_b) * (b - getValue(index)) / (2 * (getValue(index + 1) - getValue(index))) + } + for (i <- 0 to index - 1) { + s = s + getCounter(i) + } + s = s + getCounter(index) / 2 + s.toInt + } + + /** Merges the histogram with h and returns a histogram with B bins + * + * @param h histogram to be merged + * @param B final size of the merged histogram + */ + def merge(h: Histogram, B: Int): Histogram = { + val m: Int = bins + val n: Int = h.bins + var i, j: Int = 0 + val tmp_list: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() + while (i < m || j < n) { + if (i >= m) { + tmp_list.add(new Tuple2[Double, Int](h.getValue(j), h.getCounter(j))) + j = j + 1 + } else if (j >= n) { + tmp_list.add(data.get(i)) + i = i + 1 + } + else if (getValue(i) <= h.getValue(j)) { + tmp_list.add(data.get(i)) + i = i + 1 + } + else { + tmp_list.add(new Tuple2[Double, Int](h.getValue(j), h.getCounter(j))) + j = j + 1 + } + } + new OnlineHistogram(B, Math.min(lower, h.lower), Math.max(upper, h.upper), tmp_list) + } + + /** Returns a list `u_1,u_2,\ldots,u_{B-1}` such that the number of points in + * `(-\infty,u_1],[u_1,u_2],\ldots,[u_{B-1},\infty)` is `\frac_{1}{B} \sum_{i=0}^{bins-1} m_i`. + * + * @param B number of intervals required + */ + def uniform(B: Int): Array[Double] = { + require(bins > 0, "Histogram is empty") + require(B > 1, "Cannot equalize in less than two intervals") + val ret: Array[Double] = new Array[Double](B - 1) + val total: Int = sum(upper) + for (j <- 0 to B - 2) { + val s: Int = (j + 1) * total / B + val search: Tuple2[Int, Int] = searchSum(s) + val i: Int = search.getField(1).asInstanceOf[Int] + val d: Int = s - search.getField(0).asInstanceOf[Int] + var a, b, c: Double = 0 + if (i == -1) { + a = getCounter(i + 1) + b = 0 + c = -2 * d + val z: Double = (-b + Math.sqrt(b * b - 4 * a * c)) / (2 * a) + ret(j) = lower + (getValue(i + 1) - lower) * z + } else if (i == bins - 1) { + a = -getCounter(i) + b = 2 * getCounter(i) + c = -2 * d + val z: Double = (-b + Math.sqrt(b * b - 4 * a * c)) / (2 * a) + ret(j) = getValue(i) + (upper - getValue(i)) * z + } else { + a = getCounter(i + 1) - getCounter(i) + b = 2 * getCounter(i) + c = -2 * d + val z: Double = (-b + Math.sqrt(b * b - 4 * a * c)) / (2 * a) + ret(j) = getValue(i) + (getValue(i + 1) - getValue(i)) * z + } + } + ret + } + + /** Returns the string representation of the histogram. + * + */ + override def toString: String = { + s"Size:" + bins + " " + data.toString + } + + /** Updates the given bin with the provided value and counter. Sets `v_bin`=value and `m_bin`=counter + * + * @param bin bin to be updated + * @param value value to be set at bin + * @param counter counter to be set at bin + */ + private def set(bin: Int, value: Double, counter: Int): Unit = { + require(0 <= bin && bin < bins, bin + " not in [0, " + bins + ")") + require(value > lower && value < upper, value + " not in (" + lower + "," + upper + ")") + data.set(bin, new Tuple2[Double, Int](value, counter)) + } + + /** Searches for an index i such that sum(v_i) < s < sum(v_{i+1}) + * + * *@param s + * @return a tuple of sum(v_i) and index i + */ + private def searchSum(s: Double): Tuple2[Int, Int] = { + val size: Int = bins + var curr_sum: Int = sum(getValue(0)) + for (i <- 0 to size - 1) { + var tmp_sum: Int = 0 + if (i + 1 < size) tmp_sum = sum(getValue(i + 1)) + if (s >= curr_sum && (i + 1 >= size || s < tmp_sum)) { + return new Tuple2[Int, Int](curr_sum, i) + } + curr_sum = tmp_sum + } + new Tuple2[Int, Int](0, -1) + } + + /** Searches for value in the histogram + * + * @param p value to search for + * @return the bin with value just greater than p. If `p > m_{bins-1}`, return bins + */ + private def find(p: Double): Int = { + val size: Int = bins + for (i <- 0 to size - 1) { + if (p >= getValue(i) && (i + 1 >= size || p < getValue(i + 1))) { + return i + 1 + } + } + 0 + } + + /** Merges the closest two elements in the histogram + * + */ + private def mergeElements(): Unit = { + var index: Int = -1 + val size: Int = bins + var diff: Double = java.lang.Double.MAX_VALUE + for (i <- 0 to size - 2) { + val curr_diff: Double = getValue(i + 1) - getValue(i) + if (curr_diff < diff) { + diff = curr_diff + index = i + } + } + val merged_tuple: Tuple2[Double, Int] = mergeBins(index) + set(index, merged_tuple.getField(0).asInstanceOf[Double] / merged_tuple.getField(1).asInstanceOf[Int], merged_tuple.getField(1)) + data.remove(index + 1) + } + + /** Returns the merging of the bin b and its next bin + * + * *@param b the bin to be merging with bin b+1 + * @return the tuple (`v_b.m_b + v_{b+1}.m_{b+1}`,`m_b+m_{b+1}`) + */ + private def mergeBins(b: Int): Tuple2[Double, Int] = { + val ret: Tuple2[Double, Int] = new Tuple2[Double, Int]() + ret.setField(getValue(b + 1) * getCounter(b + 1) + getValue(b) * getCounter(b), 0) + ret.setField(getCounter(b + 1) + getCounter(b), 1) + ret + } + + /** Checks whether the arraylist provided is properly sorted or not. + * + */ + private def checkSanity: Boolean = { + if (lower >= upper) return false + if (capacity <= 0) return false + if (data.size() == 0) return true + if (lower >= getValue(0)) return false + if (upper <= getValue(bins - 1)) return false + for (i <- 0 to bins - 2) { + if (getValue(i + 1) <= getValue(i)) return false + if (getCounter(i) <= 0) return false + } + if (getCounter(bins - 1) <= 0) return false + while (bins > capacity) + mergeElements() + true + } +} diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala new file mode 100644 index 0000000000000..924e88be6e933 --- /dev/null +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.math + +import java.util +import org.apache.flink.api.java.tuple.Tuple2 + +import org.scalatest.{Matchers, FlatSpec} + +class OnlineHistogramSuite extends FlatSpec with Matchers { + + behavior of "Flink's OnlineHistogram" + + it should "fail if capacity is non-positive" in { + intercept[IllegalArgumentException] { + val h = new OnlineHistogram(0) + } + } + + it should "fail if min>=max" in { + intercept[IllegalArgumentException] { + val h = new OnlineHistogram(2, 1, 1) + } + } + + it should "fail if list isn't sorted" in { + intercept[IllegalArgumentException] { + val l: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() + l.add(new Tuple2[Double, Int](4, 2)) + l.add(new Tuple2[Double, Int](2, 2)) + val h = new OnlineHistogram(2, 1, 5, l) + } + } + + it should "fail if the list has any zero counters" in { + intercept[IllegalArgumentException] { + val l: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() + l.add(new Tuple2[Double, Int](4, 2)) + l.add(new Tuple2[Double, Int](5, 0)) + val h = new OnlineHistogram(2, 1, 5, l) + } + } + + it should "succeed if the data is okay and access proper parameters, values and counters" in { + val l: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() + + var h = new OnlineHistogram(2, 1, 10, l) + h.bins should equal(0) + + l.add(new Tuple2[Double, Int](4, 2)) + l.add(new Tuple2[Double, Int](5, 6)) + h = new OnlineHistogram(2, 1, 10, l) + h.bins should equal(2) + h.getValue(0) should equal(4) + h.getCounter(1) should equal(6) + + l.clear() + l.add(new Tuple2[Double, Int](1, 3)) + l.add(new Tuple2[Double, Int](3, 7)) + l.add(new Tuple2[Double, Int](7, 1)) + l.add(new Tuple2[Double, Int](10, 2)) + h = new OnlineHistogram(3, 0, 11, l) + h.bins should equal(3) + h.getValue(0) should equal(2.4) + h.getCounter(1) should equal(1) + h.getCounter(2) should equal(2) + h.add(5) + h.getValue(1) should equal(6) + + l.clear() + h = new OnlineHistogram(2, 0, 11, l) + h.add(1) + h.add(2) + h.add(3) + h.add(5) + h.add(1) + h.bins should equal(2) + h.getValue(1) should equal(5) + h.getCounter(0) should equal(4) + } + + it should "fail in accessing out of bound values and counters" in { + val l: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() + l.add(new Tuple2[Double, Int](1, 3)) + l.add(new Tuple2[Double, Int](3, 7)) + l.add(new Tuple2[Double, Int](7, 1)) + l.add(new Tuple2[Double, Int](10, 2)) + val h = new OnlineHistogram(3, 0, 11, l) + intercept[IllegalArgumentException] { + h.getValue(3) + } + intercept[IllegalArgumentException] { + h.getCounter(3) + } + } + + it should "succeed in computing sum values" in { + val l: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() + l.add(new Tuple2[Double, Int](1, 3)) + l.add(new Tuple2[Double, Int](3, 7)) + l.add(new Tuple2[Double, Int](7, 1)) + l.add(new Tuple2[Double, Int](10, 2)) + val h = new OnlineHistogram(4, 0, 11, l) + + h.sum(1) should equal(1) + h.sum(3) should equal(6) + h.sum(11) should equal(13) + h.sum(4) should equal(7) + } + + it should "succeed in merging two histograms" in { + val l: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() + l.add(new Tuple2[Double, Int](1, 3)) + l.add(new Tuple2[Double, Int](3, 7)) + l.add(new Tuple2[Double, Int](7, 1)) + l.add(new Tuple2[Double, Int](10, 2)) + val h = new OnlineHistogram(3, 0, 11, l) + + val l1: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() + l1.add(new Tuple2[Double, Int](4, 2)) + l1.add(new Tuple2[Double, Int](5, 6)) + val h1 = new OnlineHistogram(2, 1, 10, l1) + + val h3 = h1.merge(h, 3) + h3.getValue(1) should equal(5) + h3.getCounter(2) should equal(2) + } + + it should "succeed in generating an equalization list" in { + val l: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() + l.add(new Tuple2[Double, Int](1, 3)) + l.add(new Tuple2[Double, Int](3, 7)) + l.add(new Tuple2[Double, Int](7, 1)) + l.add(new Tuple2[Double, Int](10, 2)) + val h = new OnlineHistogram(3, 0, 11, l) + val eqArr = h.uniform(3) + h.sum(eqArr(0)) should equal(4) + h.sum(eqArr(1)) should equal(8) + } +} From c64ea0452be867e5fcb0bcb3f8401dbc74ce8fa6 Mon Sep 17 00:00:00 2001 From: Sachin Goel Date: Thu, 21 May 2015 19:15:24 +0530 Subject: [PATCH 2/5] Decision tree implemented. For continuous data. Only Gini. Tested on Iris --- .../ml/classification/DecisionTree.scala | 603 ++++++++++++++++++ .../org/apache/flink/ml/math/Histogram.scala | 18 +- .../flink/ml/math/OnlineHistogram.scala | 177 +++-- .../org/apache/flink/ml/tree/FieldStats.scala | 41 ++ .../scala/org/apache/flink/ml/tree/Node.scala | 46 ++ .../org/apache/flink/ml/tree/SplitValue.scala | 58 ++ .../scala/org/apache/flink/ml/tree/Tree.scala | 58 ++ .../flink/ml/tree/TreeConfiguration.scala | 75 +++ .../ml/classification/Classification.scala | 158 +++++ .../ml/classification/DecisionTreeSuite.scala | 42 ++ .../flink/ml/math/OnlineHistogramSuite.scala | 35 +- 11 files changed, 1183 insertions(+), 128 deletions(-) create mode 100644 flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala create mode 100644 flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala create mode 100644 flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala create mode 100644 flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/SplitValue.scala create mode 100644 flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala create mode 100644 flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala create mode 100644 flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/DecisionTreeSuite.scala diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala new file mode 100644 index 0000000000000..b6c459e76e053 --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala @@ -0,0 +1,603 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification + +import java.util + +import org.apache.flink.api.common.functions.RichMapFunction +import org.apache.flink.api.scala._ +import org.apache.flink.configuration.Configuration +import org.apache.flink.ml.common.FlinkTools.ModuloKeyPartitioner +import org.apache.flink.ml.common._ +import org.apache.flink.ml.math.{Histogram, OnlineHistogram, Vector} +import org.apache.flink.ml.tree._ + +import scala.collection.mutable + +/** Companion object of Decision Tree. Contains convenience functions and the parameter type definitions + * of the algorithm. + * + */ +object DecisionTree { + val DECISION_TREE = "decision_tree" + val DECISION_TREE_CONFIG = "decision_tree_configuration" + + def apply(): DecisionTree = { + new DecisionTree() + } + + case object Depth extends Parameter[Int] { + val defaultValue: Option[Int] = Some(30) + } + + case object SplitStrategy extends Parameter[String] { + val defaultValue: Option[String] = Some("Gini") + } + + case object MinInstancesPerNode extends Parameter[Int] { + val defaultValue: Option[Int] = Some(1) + } + + case object Pruning extends Parameter[Boolean] { + val defaultValue: Option[Boolean] = Some(false) + } + + case object MaxBins extends Parameter[Int] { + val defaultValue: Option[Int] = Some(100) + } + + case object Dimension extends Parameter[Int] { + val defaultValue: Option[Int] = Some(2) + } + + case object Category extends Parameter[Array[Int]] { + val defaultValue: Option[Array[Int]] = Some(Array.ofDim(0)) + } + + case object Classes extends Parameter[Int] { + val defaultValue: Option[Int] = Some(2) + } + +} + +class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serializable { + + import DecisionTree._ + + /** Sets the maximum allowed depth of the tree. + * Currently only allowed values up to 30 + * + * *@param depth + * @return itself + */ + def setDepth(depth: Int): DecisionTree = { + require(depth <= 30, "Maximum depth allowed: 30") + parameters.add(Depth, depth) + this + } + + /** Sets minimum number of instances that must be present at a node for its parent to split + * + * *@param minInstancesPerNode + * @return itself + */ + def setMinInstancePerNode(minInstancesPerNode: Int): DecisionTree = { + require(minInstancesPerNode >= 1, "Every node must have at least one instance associated with it") + parameters.add(MinInstancesPerNode, minInstancesPerNode) + this + } + + /** Sets whether or not to prune the tree after building + * + * *@param prune + * @return itself + */ + def setPruning(prune: Boolean): DecisionTree = { + parameters.add(Pruning, prune) + this + } + + /** Sets maximum number of bins to be used for calculating splits. + * + * *@param maxBins + * @return itself + */ + def setMaxBins(maxBins: Int): DecisionTree = { + require(maxBins >= 1, "Maximum bins used must be at least one") + parameters.add(MaxBins, maxBins) + this + } + + /** Sets the splitting strategy. Gini and Entropy supported. + * + * *@param splitStrategy + * @return itself + */ + def setSplitStrategy(splitStrategy: String): DecisionTree = { + require(splitStrategy == "Gini" || splitStrategy == "Entropy", "Algorithm " + splitStrategy + " not supported") + parameters.add(SplitStrategy, splitStrategy) + this + } + + /** Sets the dimension of data. Will be cross checked with the data later + * + * *@param dimension + * @return itself + */ + def setDimension(dimension: Int): DecisionTree = { + require(dimension >= 1, "Dimension cannot be less than one") + parameters.add(Dimension, dimension) + this + } + + /** Sets which fields are to be considered categorical. Array of field indices + * + * *@param category + * @return itself + */ + def setCategory(category: Array[Int]): DecisionTree = { + parameters.add(Category, category) + this + } + + /** Sets how many classes there are in the data [will be cross checked with the data later] + * + * *@param numClasses + * @return itself + */ + def setClasses(numClasses: Int): DecisionTree = { + require(numClasses > 1, "There must be at least two classes in the data") + parameters.add(Classes, numClasses) + this + } + + + /** Trains a Decision Tree + * + * @param input Training data set + * @param fitParameters Parameter values + * @return Trained Decision Tree Model + */ + override def fit(input: DataSet[LabeledVector], fitParameters: ParameterMap): DecisionTreeModel = { + val resultingParameters = this.parameters ++ fitParameters + val depth = resultingParameters(Depth) + val minInstancePerNode = resultingParameters(MinInstancesPerNode) + val pruneTree = resultingParameters(Pruning) + val maxBins = resultingParameters(MaxBins) + val splitStrategy = resultingParameters(SplitStrategy) + val dimension = resultingParameters(Dimension) + val category = resultingParameters(Category) + val numClasses = resultingParameters(Classes) + + require(category.length == 0, "Only continuous fields supported right now") + require(splitStrategy == "Gini", "Only Gini algorithm implemented right now") + var tree = createInitialTree(input, new TreeConfiguration(maxBins, minInstancePerNode, depth, + pruneTree, splitStrategy, numClasses, dimension, category)) + + var treeCopy = tree.collect().toArray.apply(0) + + val numberVectors = tree.getExecutionEnvironment.fromElements(treeCopy.config.numTrainVector) + + tree = input.getExecutionEnvironment.fromElements(treeCopy) // remake the tree as a dataset + + // Group the input data into blocks in round robin fashion + val blockedInputNumberElements = FlinkTools.block(input, input.getParallelism, Some(ModuloKeyPartitioner)). + cross(numberVectors). + map { x => x } + + var any_unlabeled_left = true + while (any_unlabeled_left) { + + // next iteration will only happen if we happen to split an unlabeled leaf in this iteration + any_unlabeled_left = false + + // histograms from each node with key (tree_node, dimension, label) + val localHists = localHistUpdate( + tree, + blockedInputNumberElements + ) + + // merge histograms over key + val combinedHists = localHists.reduce( + (a, b) => { + b.iterator.foreach( + x => { + a.get(x._1) match { + case Some(hist) => a.put(x._1, a.get(x._1).get.merge(x._2, maxBins + 1)) + case None => a.put(x._1, x._2) + } + } + ) + a + } + ) + // now collect the tree. We're gonna need direct access to it. + treeCopy = tree.collect().toArray.apply(0) + + val finalHists = combinedHists.collect().toArray.apply(0) + val fieldStats = treeCopy.config.fieldStats + val labels = treeCopy.config.labels + val labelAddition = treeCopy.config.labelAddition + + // now, find splits across each dimension by merging (tree_node, dimension, label_1),(tree_node, dimension, label_2)... + + // keep a list of all the nodes we'll have to find splits for later + // we only get stuff for leafs which are unlabeled, so nodes only has those which are unlabeled leafs + + val nodeDimensionSplits = new mutable.HashMap[(Int, Int), Array[Double]] + val nodes = new mutable.HashMap[Int, Int] + + finalHists.keysIterator.foreach( + x => { + nodes.put(x._1, -1) + if (nodeDimensionSplits.get((x._1, x._2)).isEmpty) { + val min_value = fieldStats.apply(x._2).fieldMinValue + val max_value = fieldStats.apply(x._2).fieldMaxValue + var tmp_hist: Histogram = new OnlineHistogram(maxBins, 2 * min_value - max_value, 2 * max_value - min_value) + labels.iterator.foreach( + c => { + if (finalHists.get((x._1, x._2, c)).nonEmpty) { + tmp_hist = tmp_hist.merge(finalHists.get((x._1, x._2, c)).get, maxBins) + } + } + ) + // find maxBins quantiles, or if the size of histogram is less than that, just find those many + val actualBins = Math.min(maxBins, tmp_hist.bins - 1) + val quantileArray = new Array[Double](actualBins) + for (i <- 1 to actualBins) { + quantileArray(i - 1) = tmp_hist.quantile((i + 0.0) / (actualBins + 1)) + } + nodeDimensionSplits.put((x._1, x._2), quantileArray) + } + } + ) + // first, every unlabeled leaf node must have sent something. If not, its sibling will be stuck forever + + treeCopy.nodes.valuesIterator.foreach( + this_node => { + if (this_node.predict == -1 && this_node.split.isEmpty && nodes.get(this_node.id).isEmpty) { + // we're in trouble + var sibling_id = 0 + if (this_node.id % 2 == 0) + sibling_id = this_node.id + 1 + else + sibling_id = this_node.id - 1 + treeCopy.nodes.remove(this_node.id) // this node is pointless. Remove it from the tree + // we're not going to split the sibling anymore. + nodes.put(sibling_id, 1) // we'll check for this '1' later + } + } + ) + // now, for each node, for each dimension, evaluate splits based on the above histograms + nodes.keysIterator.foreach( + node_id => { + var node: Node = treeCopy.nodes.get(node_id).get + // since the count of classes across any dimension is same, pick 0 + // for calculating Gini index, we need count(c)^2 and \sum count(c)^2 + // also maintain which class occurred most frequently in case we need to mark this as a leaf node + var sumClassSquare, totalNumPointsHere, maxClassCountLabel, maxClassCount = 0.0 + labels.iterator.foreach( + x => { + val h = finalHists.get((node_id, 0, x)) + if (h.nonEmpty) { + val countOfClass = h.get.sum(h.get.upper) + totalNumPointsHere = totalNumPointsHere + countOfClass + sumClassSquare = sumClassSquare + countOfClass * countOfClass + if (countOfClass > maxClassCount) { + maxClassCount = countOfClass + maxClassCountLabel = x + } + } + } + ) + + // now see if this node has become pure or if it's depth is at a maximum + if (totalNumPointsHere * totalNumPointsHere == sumClassSquare || node.getDepth == depth) { + node.predict = maxClassCountLabel + labelAddition + // sanity check. If we're declaring something a leaf, it better not have any children + require(node.split.isEmpty, "An unexpected error occurred") + } else if (nodes.get(node_id).get == 1) { + // this node is meaningless. The parent didn't do anything at all. + // just remove this node from the tree and set the label of parent, which would be same as the label of this + node = treeCopy.nodes.get(node_id / 2).get + node.predict = maxClassCountLabel + labelAddition + node.split = None + treeCopy.nodes.remove(node_id) + } else { + val giniParent = 1 - sumClassSquare / Math.pow(totalNumPointsHere, 2) + var best_gini = -java.lang.Double.MAX_VALUE + var best_dimension = -1 + var best_split_value, best_left_total, best_right_total = 0.0 + // consider all splits across all dimensions and pick the best one + for (j <- 0 to dimension - 1) { + val splitsArray = nodeDimensionSplits.get((node_id, j)).get + val actualSplits = splitsArray.length + for (k <- 0 to actualSplits - 1) { + // maintain how many instances go left and right for Gini + var total_left, total_right, left_sum_sqr, right_sum_sqr = 0.0 + for (l <- 0 to numClasses - 1) { + val h = finalHists.get((node_id, j, labels.apply(l))) + if (h.nonEmpty) { + val left = h.get.sum(splitsArray.apply(k)) + val right = h.get.sum(h.get.upper) - left + total_left = total_left + left + total_right = total_right + right + left_sum_sqr = left_sum_sqr + left * left + right_sum_sqr = right_sum_sqr + right * right + } + } + // ensure that the split is allowed by user. We need at least this many instances everywhere + if (total_left >= minInstancePerNode && total_right >= minInstancePerNode) { + // use a balancing term to ensure the tree is balanced more on the top + // the exponential term in scaling makes sure that as we go deeper, Gini becomes more important in splits + // this makes sense. We want balanced tree on top and fine-grained splitting at deeper levels + val scaling = Math.pow(0.1, node.getDepth + 1) + val balancing = Math.abs(total_left - total_right) / totalNumPointsHere + val this_gini = (1 - scaling) * (giniParent - total_left * (1 - left_sum_sqr / Math.pow(total_left, 2)) / totalNumPointsHere + - total_right * (1 - right_sum_sqr / Math.pow(total_right, 2)) / totalNumPointsHere) + scaling * balancing + if (this_gini > best_gini) { + best_gini = this_gini + best_dimension = j + best_split_value = nodeDimensionSplits.get((node_id, j)).get.apply(k) + best_left_total = total_left + best_right_total = total_right + } + } + } + } + if (best_dimension != -1) { + node.split = Some(new SplitValue(best_dimension, true, best_split_value)) + treeCopy.nodes.put(node_id * 2, new Node(node_id * 2, treeCopy.treeID, None)) + treeCopy.nodes.put(node_id * 2 + 1, new Node(node_id * 2 + 1, treeCopy.treeID, None)) + any_unlabeled_left = true + } else { + node.predict = maxClassCountLabel + labelAddition + } + } + } + ) + tree = tree.getExecutionEnvironment.fromElements(treeCopy) + } + + DecisionTreeModel(tree) + } + + private def localHistUpdate( + tree: DataSet[Tree], + blockedInputNumberElements: DataSet[(Block[LabeledVector], Int)]) + : DataSet[mutable.HashMap[(Int, Int, Double), Histogram]] = { + + /** Rich mapper calculating histograms for each data block. We use a RichMapFunction here, + * because we broadcast the current value of the tree to all mappers. + * + */ + val localUpdate = new RichMapFunction[(Block[LabeledVector], Int), mutable.HashMap[(Int, Int, Double), Histogram]] { + + var tree: Tree = _ + var config: TreeConfiguration = _ + var histograms: mutable.HashMap[(Int, Int, Double), Histogram] = new mutable.HashMap[(Int, Int, Double), Histogram]() + + override def open(parameters: Configuration): Unit = { + // get the tree + tree = getRuntimeContext.getBroadcastVariable(DECISION_TREE).get(0) + config = tree.config + } + + override def map(blockNumberElements: (Block[LabeledVector], Int)) + : mutable.HashMap[(Int, Int, Double), Histogram] = { + // for all instances in the block + val (block, _) = blockNumberElements + val numLocalDataPoints = block.values.length + for (i <- 0 to numLocalDataPoints - 1) { + val LabeledVector(label, vector) = block.values(i) + // find where this instance goes + val (node, predict) = tree.filter(vector) + if (predict.round.toInt == -1) { + // we can be sure that this is an unlabeled leaf and not some internal node. That's how filter works + for (j <- 0 to tree.config.dimension - 1) { + val min_value = config.fieldStats.apply(j).fieldMinValue + val max_value = config.fieldStats.apply(j).fieldMaxValue + histograms.get((node, j, label)) match { + // if this histogram already exists, add a new entry to it + case Some(hist) => hist.add(vector.apply(j)) + case None => + // otherwise, create the histogram and put the entry in it + histograms.put((node, j, label), new OnlineHistogram( + config.MaxBins + 1, 2 * min_value - max_value, 2 * max_value - min_value)) + histograms.get((node, j, label)).get.add(vector.apply(j)) + } + } + } + } + histograms + } + } + // map using the RichMapFunction with the tree broadcasted + blockedInputNumberElements.map(localUpdate).withBroadcastSet(tree, DECISION_TREE) + } + + /** Creates the initial root + * + * @return initial decision tree, after making sanity checks on data and evaluating statistics + */ + private def createInitialTree(input: DataSet[LabeledVector], config: TreeConfiguration): + DataSet[Tree] = { + + val initTree = new RichMapFunction[LabeledVector, (Array[FieldStats], Int)] { + + var config: TreeConfiguration = _ + + override def open(parameters: Configuration): Unit = { + config = getRuntimeContext.getBroadcastVariable(DECISION_TREE_CONFIG).get(0) + } + + override def map(labeledVector: LabeledVector): (Array[FieldStats], Int) = { + require(labeledVector.vector.size == config.dimension, "Specify the dimension of data correctly") + val ret = new Array[FieldStats](config.dimension + 1) // we also include the label field for now + for (i <- 0 to config.dimension - 1) { + if (config.category.indexOf(i) >= 0) { + // if this is a categorical field + val h = new mutable.HashMap[Double, Int]() + h.put(labeledVector.vector.apply(i), 1) + ret(i) = new FieldStats(false, fieldCategories = h) + } else { + // if continuous field + ret(i) = new FieldStats(true, labeledVector.vector.apply(i), labeledVector.vector.apply(i)) + } + } + val h = new mutable.HashMap[Double, Int]() + h.put(labeledVector.label, 1) + ret(config.dimension) = new FieldStats(false, fieldCategories = h) // the label field is always categorical + (ret, 1) + } + + } + + // now reduce to merge all fieldStats and to get the total number of points in the data set + val combinedStats_tuple = input.map(initTree).withBroadcastSet( + input.getExecutionEnvironment.fromElements(config), DECISION_TREE_CONFIG + ).reduce( + (x, y) => { + val a = x._1 + val b = y._1 + val c = new Array[FieldStats](a.length) // to hold the merged fieldStats + for (i <- 0 to a.length - 1) { + if (a(i).fieldType) { + val min = Math.min(a(i).fieldMinValue, b(i).fieldMinValue) + val max = Math.max(a(i).fieldMaxValue, b(i).fieldMaxValue) + c(i) = new FieldStats(true, min, max) + } else { + // merge both hashmaps. We don't care about the values, just the keys + b(i).fieldCategories.keysIterator.foreach( + cat => a(i).fieldCategories.put(cat, 1) + ) + c(i) = new FieldStats(false, fieldCategories = a(i).fieldCategories) + } + } + (c, x._2 + y._2) // counting number of instances + } + ).collect().toArray.apply(0) + + val combinedStats = combinedStats_tuple._1 + config.setNumTrainVector(combinedStats_tuple._2) + + // cross check user-specified number of classes + require(combinedStats.apply(config.dimension).fieldCategories.size == config.numClasses, "Specify number of classes correctly") + + // now copy the label list to an array of double + // Find the minimum value, negative of which will be set to labelAddition + var min_label = java.lang.Double.MAX_VALUE + val labels_list = new util.ArrayList[Double](combinedStats.apply(config.dimension).fieldCategories.size) + combinedStats.apply(config.dimension).fieldCategories.keysIterator.foreach( + x => { + if (x < min_label) min_label = x + labels_list.add(x) + } + ) + val labels = new Array[Double](labels_list.size) + for (i <- 0 to labels.length - 1) { + labels(i) = labels_list.get(i) + } + + // the root node of the tree. Tree ID set to 1 by default + val h = new collection.mutable.HashMap[Int, Node] + h.put(1, new Node(1, 1, None, -1)) + + config.setFieldStats(combinedStats.slice(0, config.dimension)) + config.setLabelAddition(-min_label) + config.setLabels(labels) + + input.getExecutionEnvironment.fromElements(new Tree(1, h, config)) + } +} + +/** Resulting Tree model calculated by the Decision Tree algorithm. + * + * @param tree the final tree generated by fit which is to be used for predictions + */ +case class DecisionTreeModel(tree: DataSet[Tree]) + extends Transformer[Vector, LabeledVector] + with Serializable { + + import DecisionTree.DECISION_TREE + + /** Calculates the label for the input set using the tree + * + * @param input [[DataSet]] containing the vector for which to calculate the predictions + * @param parameters Parameter values for the algorithm + * @return [[DataSet]] containing the labeled vectors + */ + override def transform(input: DataSet[Vector], parameters: ParameterMap): + DataSet[LabeledVector] = { + input.map(new DecisionTreePredictionMapper).withBroadcastSet(tree, DECISION_TREE) + } + + /** Calculates the accuracy of classification on te test data set input + * + * @param input [[DataSet]] containing the vector for which to calculate the predictions + * @return Percentage accuracy, using a 0-1 loss function + */ + + def testAccuracy(input: DataSet[LabeledVector]): Double = { + + val accuracyMapper = new RichMapFunction[LabeledVector, (Int, Int)] { + var tree: Tree = _ + + override def open(parameters: Configuration): Unit = { + tree = getRuntimeContext.getBroadcastVariable(DECISION_TREE).get(0) + } + + override def map(labeledVector: LabeledVector): (Int, Int) = { + val label = tree.filter(labeledVector.vector)._2 - tree.config.labelAddition + if (label == labeledVector.label) (1, 1) + else (0, 1) + } + } + val result = input.map(accuracyMapper).withBroadcastSet(tree, DECISION_TREE). + reduce((a, b) => (a._1 + b._1, a._2 + b._2)). + collect().toArray.apply(0) + 100 * (result._1 + 0.0) / result._2 + } +} + +/** Mapper to calculate the value of the prediction function. This is a RichMapFunction, because + * we broadcast the tree to all mappers. + * + */ +class DecisionTreePredictionMapper + extends RichMapFunction[Vector, LabeledVector] { + + import DecisionTree.DECISION_TREE + + var tree: Tree = _ + var labelAddition: Double = 0 + + @throws(classOf[Exception]) + override def open(configuration: Configuration): Unit = { + // get the Tree + tree = getRuntimeContext.getBroadcastVariable[Tree](DECISION_TREE).get(0) + labelAddition = tree.config.labelAddition + } + + override def map(vector: Vector): LabeledVector = { + // calculate the predicted label + val label = tree.filter(vector)._2 + LabeledVector(label - labelAddition, vector) + } +} diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Histogram.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Histogram.scala index 360c1119023c5..40b7497a6b6bd 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Histogram.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Histogram.scala @@ -60,12 +60,6 @@ trait Histogram { */ def add(value: Double): Unit - /** Returns the estimated number of points in the interval `(-\infty,b]` - * - * @return Number of values in the interval `(-\infty,b]` - */ - def sum(b: Double): Int - /** Merges the histogram with h and returns a histogram with B bins * * @param h histogram to be merged @@ -73,10 +67,14 @@ trait Histogram { */ def merge(h: Histogram, B: Int): Histogram - /** Returns a list `u_1,u_2,\ldots,u_{B-1}` such that the number of points in - * `(-\infty,u_1],[u_1,u_2],\ldots,[u_{B-1},\infty)` is `\frac_{1}{B} \sum_{i=0}^{bins-1} m_i`. + /** Returns the qth quantile of the histogram * - * @param B number of intervals required */ - def uniform(B: Int): Array[Double] + def quantile(q: Double): Double + + /** Returns the estimated number of points in the interval `(-\infty,b]` + * + * @return Number of values in the interval `(-\infty,b]` + */ + def sum(b: Double): Int } diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala index eaae05e3d6fe5..e08a0cb70217e 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala @@ -16,6 +16,11 @@ * limitations under the License. */ +/** Implementation of an online histogram + * Adapted from Ben-Haim and Yom-Tov + * Refer http://www.jmlr.org/papers/volume11/ben-haim10a/ben-haim10a.pdf + * + */ package org.apache.flink.ml.math import org.apache.flink.api.java.tuple.Tuple2 @@ -23,10 +28,10 @@ import org.apache.flink.api.java.tuple.Tuple2 import java.util class OnlineHistogram( - capacity: Int, - min: Double = java.lang.Double.MIN_VALUE, - max: Double = java.lang.Double.MIN_VALUE, - data: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]] + val capacity: Int, + val min: Double = -java.lang.Double.MAX_VALUE, + val max: Double = java.lang.Double.MAX_VALUE, + val data: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]] ) extends Histogram with Serializable { require(checkSanity, "Invalid data provided") @@ -82,35 +87,7 @@ class OnlineHistogram( require(value > lower && value < upper, value + " not in (" + lower + "," + upper + ")") val search = find(value) data.add(search, new Tuple2[Double, Int](value, 1)) - if (bins > capacity) mergeElements() - } - - /** Returns the estimated number of points in the interval `(-\infty,b]` - * - * @return Number of values in the interval `(-\infty,b]` - */ - def sum(b: Double): Int = { - require(bins > 0, "Histogram is empty") - if (b < lower) return 0 - if (b > upper) return sum(upper) - val index = find(b) - 1 - var m_b, s: Double = 0 - if (index == -1) { - m_b = getCounter(index + 1) * (b - lower) / (getValue(index + 1) - lower) - s = m_b * (b - lower) / (2 * (getValue(index + 1) - lower)) - return s.toInt - } else if (index == bins - 1) { - m_b = getCounter(index) + (-getCounter(index)) * (b - getValue(index)) / (upper - getValue(index)) - s = (getCounter(index) + m_b) * (b - getValue(index)) / (2 * (upper - getValue(index))) - } else { - m_b = getCounter(index) + (getCounter(index + 1) - getCounter(index)) * (b - getValue(index)) / (getValue(index + 1) - getValue(index)) - s = (getCounter(index) + m_b) * (b - getValue(index)) / (2 * (getValue(index + 1) - getValue(index))) - } - for (i <- 0 to index - 1) { - s = s + getCounter(i) - } - s = s + getCounter(index) / 2 - s.toInt + mergeElements() // if the value we just added is already there, mergeElements will take care of this } /** Merges the histogram with h and returns a histogram with B bins @@ -143,43 +120,35 @@ class OnlineHistogram( new OnlineHistogram(B, Math.min(lower, h.lower), Math.max(upper, h.upper), tmp_list) } - /** Returns a list `u_1,u_2,\ldots,u_{B-1}` such that the number of points in - * `(-\infty,u_1],[u_1,u_2],\ldots,[u_{B-1},\infty)` is `\frac_{1}{B} \sum_{i=0}^{bins-1} m_i`. + /** Returns the qth quantile of the histogram * - * @param B number of intervals required */ - def uniform(B: Int): Array[Double] = { - require(bins > 0, "Histogram is empty") - require(B > 1, "Cannot equalize in less than two intervals") - val ret: Array[Double] = new Array[Double](B - 1) - val total: Int = sum(upper) - for (j <- 0 to B - 2) { - val s: Int = (j + 1) * total / B - val search: Tuple2[Int, Int] = searchSum(s) - val i: Int = search.getField(1).asInstanceOf[Int] - val d: Int = s - search.getField(0).asInstanceOf[Int] - var a, b, c: Double = 0 - if (i == -1) { - a = getCounter(i + 1) - b = 0 - c = -2 * d - val z: Double = (-b + Math.sqrt(b * b - 4 * a * c)) / (2 * a) - ret(j) = lower + (getValue(i + 1) - lower) * z - } else if (i == bins - 1) { - a = -getCounter(i) - b = 2 * getCounter(i) - c = -2 * d - val z: Double = (-b + Math.sqrt(b * b - 4 * a * c)) / (2 * a) - ret(j) = getValue(i) + (upper - getValue(i)) * z - } else { - a = getCounter(i + 1) - getCounter(i) - b = 2 * getCounter(i) - c = -2 * d - val z: Double = (-b + Math.sqrt(b * b - 4 * a * c)) / (2 * a) - ret(j) = getValue(i) + (getValue(i + 1) - getValue(i)) * z - } + def quantile(q: Double): Double = { + require(bins > 0, "Histogram is currently empty. Can't find a quantile value") + var total = 0 + for (i <- 0 to bins - 1) total = total + getCounter(i) + val wantedSum = (q * total).round.toInt + var currSum = sum(getValue(0)) + if (wantedSum < currSum) { + require(lower > -java.lang.Double.MAX_VALUE, "Set a lower bound before proceeding") + return Math.sqrt(2 * wantedSum * Math.pow(getValue(0) - lower, 2) / getCounter(0)) + lower } - ret + for (i <- 1 to bins - 1) { + val tmpSum = sum(getValue(i)) + if (currSum <= wantedSum && wantedSum < tmpSum) { + val neededSum = wantedSum - currSum + val a: Double = getCounter(i) - getCounter(i - 1) + val b: Double = 2 * getCounter(i - 1) + val c: Double = -2 * neededSum + if (a == 0) { + return getValue(i - 1) + (getValue(i) - getValue(i - 1)) * (-c / b) + } else return getValue(i - 1) + (getValue(i) - getValue(i - 1)) * (-b + Math.sqrt(b * b - 4 * a * c)) / (2 * a) + } else currSum = tmpSum + } + require(upper < java.lang.Double.MAX_VALUE, "Set an upper bound before proceeding") + // this means wantedSum > sum(getValue(bins-1)) + // this will likely fail to return a bounded value. Make sure you set some proper limits on min and max. + getValue(bins - 1) + Math.sqrt(Math.pow(upper - getValue(bins - 1), 2) * 2 * (wantedSum - currSum) / getCounter(bins - 1)) } /** Returns the string representation of the histogram. @@ -189,6 +158,40 @@ class OnlineHistogram( s"Size:" + bins + " " + data.toString } + /** Returns the estimated number of points in the interval `(-\infty,b]` + * + * @return Number of values in the interval `(-\infty,b]` + */ + def sum(b: Double): Int = { + require(bins > 0, "Histogram is empty") + if (b < lower) return 0 + if (b >= upper) { + var ret = 0 + for (i <- 0 to bins - 1) { + ret = ret + getCounter(i) + } + return ret + } + val index = find(b) - 1 + var m_b, s: Double = 0 + if (index == -1) { + m_b = getCounter(index + 1) * (b - lower) / (getValue(index + 1) - lower) + s = m_b * (b - lower) / (2 * (getValue(index + 1) - lower)) + return s.round.toInt + } else if (index == bins - 1) { + m_b = getCounter(index) + (-getCounter(index)) * (b - getValue(index)) / (upper - getValue(index)) + s = (getCounter(index) + m_b) * (b - getValue(index)) / (2 * (upper - getValue(index))) + } else { + m_b = getCounter(index) + (getCounter(index + 1) - getCounter(index)) * (b - getValue(index)) / (getValue(index + 1) - getValue(index)) + s = (getCounter(index) + m_b) * (b - getValue(index)) / (2 * (getValue(index + 1) - getValue(index))) + } + for (i <- 0 to index - 1) { + s = s + getCounter(i) + } + s = s + getCounter(index) / 2 + s.round.toInt + } + /** Updates the given bin with the provided value and counter. Sets `v_bin`=value and `m_bin`=counter * * @param bin bin to be updated @@ -201,25 +204,6 @@ class OnlineHistogram( data.set(bin, new Tuple2[Double, Int](value, counter)) } - /** Searches for an index i such that sum(v_i) < s < sum(v_{i+1}) - * - * *@param s - * @return a tuple of sum(v_i) and index i - */ - private def searchSum(s: Double): Tuple2[Int, Int] = { - val size: Int = bins - var curr_sum: Int = sum(getValue(0)) - for (i <- 0 to size - 1) { - var tmp_sum: Int = 0 - if (i + 1 < size) tmp_sum = sum(getValue(i + 1)) - if (s >= curr_sum && (i + 1 >= size || s < tmp_sum)) { - return new Tuple2[Int, Int](curr_sum, i) - } - curr_sum = tmp_sum - } - new Tuple2[Int, Int](0, -1) - } - /** Searches for value in the histogram * * @param p value to search for @@ -235,10 +219,11 @@ class OnlineHistogram( 0 } - /** Merges the closest two elements in the histogram + /** Merges the closest two elements in the histogram. Definitely merge if we're over capacity. + * Otherwise, merge only if two elements are really close * */ - private def mergeElements(): Unit = { + private def mergeElements(): Boolean = { var index: Int = -1 val size: Int = bins var diff: Double = java.lang.Double.MAX_VALUE @@ -249,9 +234,12 @@ class OnlineHistogram( index = i } } - val merged_tuple: Tuple2[Double, Int] = mergeBins(index) - set(index, merged_tuple.getField(0).asInstanceOf[Double] / merged_tuple.getField(1).asInstanceOf[Int], merged_tuple.getField(1)) - data.remove(index + 1) + if (bins > capacity || diff < 1e-9) { + val merged_tuple: Tuple2[Double, Int] = mergeBins(index) + set(index, merged_tuple.getField(0).asInstanceOf[Double] / merged_tuple.getField(1).asInstanceOf[Int], merged_tuple.getField(1)) + data.remove(index + 1) + true + } else false } /** Returns the merging of the bin b and its next bin @@ -276,12 +264,11 @@ class OnlineHistogram( if (lower >= getValue(0)) return false if (upper <= getValue(bins - 1)) return false for (i <- 0 to bins - 2) { - if (getValue(i + 1) <= getValue(i)) return false + if (getValue(i + 1) < getValue(i)) return false // equality will get merged later on if (getCounter(i) <= 0) return false } if (getCounter(bins - 1) <= 0) return false - while (bins > capacity) - mergeElements() + while (mergeElements()) {} true } -} +} \ No newline at end of file diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala new file mode 100644 index 0000000000000..671307c28663b --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.ml.tree + +/** Keeps useful statistics about a field + * fieldType is false for categorical fields, true for continuous fields + * For continuous field, minimum and maximum values. Usually, min-(max-min) and max+(max-min) from the data should suffice + * For categorical field, list of categories + * + */ + +class FieldStats( + val fieldType: Boolean, + val fieldMinValue: Double = -java.lang.Double.MAX_VALUE, + val fieldMaxValue: Double = java.lang.Double.MAX_VALUE, + val fieldCategories: collection.mutable.HashMap[Double, Int] = new collection.mutable.HashMap[Double, Int]) { + + override def toString: String = { + if (fieldType) + s"Continuous field: Range: ($fieldMinValue,$fieldMaxValue)" + else + s"Categorical field: Number of categories: $fieldCategories" + } +} \ No newline at end of file diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala new file mode 100644 index 0000000000000..ebe75cdbedcea --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.tree + +/** If the node has been trained, it will have: + * a. predict >=0, in this case, split should be empty. This is fully grown node and we can't go further down + * b. predict = -1, in this case if split is empty, we need to split, otherwise, this is an internal node + * + * ID starts from 1 for the root node. + * treeID is the tree to which this node belongs + * + */ + +class Node( + val id: Int, + val treeID: Int, + var split: Option[SplitValue], + var predict: Double = -1 + ) extends Serializable { + + override def toString: String = { + s"ID=$id, Tree ID=$treeID, predict=$predict, split=$split" + } + + def getDepth: Int = { + // taking log base 2 of the node id + // depth starts from one. A matter of convention really + java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(id)) + 1 + } +} \ No newline at end of file diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/SplitValue.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/SplitValue.scala new file mode 100644 index 0000000000000..86ae751b60567 --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/SplitValue.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.ml.tree + +import java.util + +import org.apache.flink.ml.math.Vector + +/** + * Implements a split value class which determines whether to send an instance to the + * left tree or the right tree + * + * Attribute is the index of the feature at which to split. Starts from 0 to {d-1} where d + * is the dimensionality + * splitType is true if the split is done using "<= splitValueDouble" for continuous fields + * splitType is false if the split is done using "in splitValueList" for categorical fields + * + * getSplitDirection returns true if the instance should go the left tree and false if it should + * go to the right tree + **/ + +class SplitValue( + val attribute: Int, + val splitType: Boolean, + val splitValueDouble: Double = 0, + val splitValueList: util.ArrayList[Double] = new util.ArrayList[Double]) { + + override def toString: String = { + if (splitType) + s"Attribute Index: $attribute, Split: Continuous Value at $splitValueDouble" + else + s"Attribute Index: $attribute, Split: Categorical at $splitValueList" + } + + def getSplitDirection(vector: Vector): Boolean = { + if (splitType) + vector.apply(attribute) <= splitValueDouble // go left if less than equal to + else + splitValueList.contains(vector.apply(attribute)) // go left is exists + } +} \ No newline at end of file diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala new file mode 100644 index 0000000000000..c44f28b9aa978 --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.tree + +import org.apache.flink.ml.math.Vector + +import scala.collection.mutable + +/** Tree structure. This is kind of maintained in an unconventional way. We provide direct access to all nodes + * The obvious assumption is that child of node i will be 2*i and 2*i+1, while parent of i will be i/2 + * + */ +class Tree( + val treeID: Int, + val nodes: mutable.HashMap[Int, Node], + val config: TreeConfiguration + ) extends Serializable { + + override def toString: String = { + var ret = s"Tree ID=$treeID\nConfiguration:\n$config \nTree Structure:\n" + for (i <- 1 to Math.pow(2, 20).toInt) { + if (nodes.get(i).nonEmpty) + ret = ret + nodes.get(i).get.toString + "\n" + } + ret + } + + /** Determines which node of the tree this vector will go to + * If predict at any node is -1 and it has a split, we'll go down recursively + * + */ + def filter(vector: Vector): (Int, Double) = { + var node: Node = nodes.get(1).get + while (node.predict.round.toInt == -1 && node.split.nonEmpty) { + if (node.split.get.getSplitDirection(vector)) + node = nodes.get(2 * node.id).get + else + node = nodes.get(2 * node.id + 1).get + } + (node.id, node.predict) + } +} \ No newline at end of file diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala new file mode 100644 index 0000000000000..88ea141f7c8c8 --- /dev/null +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.ml.tree + +/** Holds the configuration of the tree. User specified and automatically detected from data + * All the following fields are user specified unless otherwise mentioned + * @param MaxBins maximum splits to be considered while training + * @param MinInstancePerNode Minimum instances that should be present at every leaf node + * @param Depth Maximum allowed depth of the tree. (Maximum 30) + * @param Pruning Whether to prune the tree after training or not + * @param splitStrategy Which algorithm to use for splitting. Gini or Entropy + * @param numClasses Number of classes in data (cross checked with data) + * @param dimension Dimensionality of data(cross checked with data) + * @param category Which fields are to be considered as categorical. Array of field indexes + * @param fieldStats Field maximum and minimum values, list of categories [Automatically] + * @param labels Array of labels slash classes in data [Automatically] + * @param labelAddition Addition term to make all labels >=0 [Automatically][Only for internal use. Not visible to user] + * @param numTrainVector Number of training instances [Automatically] + */ +class TreeConfiguration( + val MaxBins: Int, + val MinInstancePerNode: Int, + val Depth: Int, + val Pruning: Boolean, + val splitStrategy: String, + val numClasses: Int, + val dimension: Int, + val category: Array[Int], + var fieldStats: Array[FieldStats] = Array.ofDim(0), + var labels: Array[Double] = Array.ofDim(0), + var labelAddition: Double = 0, + var numTrainVector: Int = 0 + ) { + + override def toString: String = { + var ret = s"Maximum Binning: $MaxBins, Minimum Instance per leaf node: $MinInstancePerNode, Maximum Depth: $Depth, Pruning:$Pruning" + + s"\nSplit Strategy: $splitStrategy, Number of classes: $numClasses, Dimension of data: $dimension, Number of training vectors: $numTrainVector\n" + + s"categorical fields: " + java.util.Arrays.toString(category) + "\nLabels in data: " + java.util.Arrays.toString(labels) + "\nField stats:" + fieldStats.iterator.foreach(x => ret = ret + x.toString) + ret + s"\nLabel Addition: $labelAddition" + } + + def setNumTrainVector(count: Int): Unit = { + numTrainVector = count + } + + def setFieldStats(stats: Array[FieldStats]): Unit = { + fieldStats = stats + } + + def setLabels(labels: Array[Double]): Unit = { + this.labels = labels + } + + def setLabelAddition(label_add: Double): Unit = { + labelAddition = label_add + } +} \ No newline at end of file diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala index c9dd00f489e0f..81892f8bde0dc 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala @@ -130,4 +130,162 @@ object Classification { ) val expectedWeightVector = DenseVector(-1.95, -3.45) + + // the IRIS data set, for testing the decision tree implementation + + val IrisTrainingData = Seq[LabeledVector]( + LabeledVector(1,DenseVector(5.1,3.5,1.4,0.2)), + LabeledVector(1,DenseVector(4.9,3.0,1.4,0.2)), + LabeledVector(1,DenseVector(4.7,3.2,1.3,0.2)), + LabeledVector(1,DenseVector(4.6,3.1,1.5,0.2)), + LabeledVector(1,DenseVector(5.0,3.6,1.4,0.2)), + LabeledVector(1,DenseVector(5.4,3.9,1.7,0.4)), + LabeledVector(1,DenseVector(4.6,3.4,1.4,0.3)), + LabeledVector(1,DenseVector(4.3,3.0,1.1,0.1)), + LabeledVector(1,DenseVector(5.8,4.0,1.2,0.2)), + LabeledVector(1,DenseVector(5.7,4.4,1.5,0.4)), + LabeledVector(1,DenseVector(5.4,3.9,1.3,0.4)), + LabeledVector(1,DenseVector(5.1,3.5,1.4,0.3)), + LabeledVector(1,DenseVector(5.1,3.3,1.7,0.5)), + LabeledVector(1,DenseVector(4.8,3.4,1.9,0.2)), + LabeledVector(1,DenseVector(5.0,3.0,1.6,0.2)), + LabeledVector(1,DenseVector(5.0,3.4,1.6,0.4)), + LabeledVector(1,DenseVector(5.2,3.5,1.5,0.2)), + LabeledVector(1,DenseVector(5.2,3.4,1.4,0.2)), + LabeledVector(1,DenseVector(4.7,3.2,1.6,0.2)), + LabeledVector(1,DenseVector(4.8,3.1,1.6,0.2)), + LabeledVector(1,DenseVector(5.4,3.4,1.5,0.4)), + LabeledVector(1,DenseVector(5.2,4.1,1.5,0.1)), + LabeledVector(1,DenseVector(5.5,4.2,1.4,0.2)), + LabeledVector(1,DenseVector(4.9,3.1,1.5,0.1)), + LabeledVector(1,DenseVector(5.0,3.2,1.2,0.2)), + LabeledVector(1,DenseVector(5.5,3.5,1.3,0.2)), + LabeledVector(1,DenseVector(4.9,3.1,1.5,0.1)), + LabeledVector(1,DenseVector(4.4,3.0,1.3,0.2)), + LabeledVector(1,DenseVector(5.1,3.4,1.5,0.2)), + LabeledVector(1,DenseVector(5.0,3.5,1.3,0.3)), + LabeledVector(1,DenseVector(4.5,2.3,1.3,0.3)), + LabeledVector(1,DenseVector(4.4,3.2,1.3,0.2)), + LabeledVector(1,DenseVector(5.0,3.5,1.6,0.6)), + LabeledVector(1,DenseVector(5.1,3.8,1.9,0.4)), + LabeledVector(1,DenseVector(4.8,3.0,1.4,0.3)), + LabeledVector(1,DenseVector(5.1,3.8,1.6,0.2)), + LabeledVector(1,DenseVector(4.6,3.2,1.4,0.2)), + LabeledVector(1,DenseVector(5.3,3.7,1.5,0.2)), + LabeledVector(1,DenseVector(5.0,3.3,1.4,0.2)), + LabeledVector(2,DenseVector(7.0,3.2,4.7,1.4)), + LabeledVector(2,DenseVector(6.4,3.2,4.5,1.5)), + LabeledVector(2,DenseVector(6.9,3.1,4.9,1.5)), + LabeledVector(2,DenseVector(5.5,2.3,4.0,1.3)), + LabeledVector(2,DenseVector(6.5,2.8,4.6,1.5)), + LabeledVector(2,DenseVector(5.7,2.8,4.5,1.3)), + LabeledVector(2,DenseVector(6.3,3.3,4.7,1.6)), + LabeledVector(2,DenseVector(4.9,2.4,3.3,1.0)), + LabeledVector(2,DenseVector(6.6,2.9,4.6,1.3)), + LabeledVector(2,DenseVector(5.2,2.7,3.9,1.4)), + LabeledVector(2,DenseVector(5.0,2.0,3.5,1.0)), + LabeledVector(2,DenseVector(5.9,3.0,4.2,1.5)), + LabeledVector(2,DenseVector(6.0,2.2,4.0,1.0)), + LabeledVector(2,DenseVector(6.1,2.9,4.7,1.4)), + LabeledVector(2,DenseVector(5.6,2.9,3.6,1.3)), + LabeledVector(2,DenseVector(6.7,3.1,4.4,1.4)), + LabeledVector(2,DenseVector(5.6,3.0,4.5,1.5)), + LabeledVector(2,DenseVector(5.8,2.7,4.1,1.0)), + LabeledVector(2,DenseVector(6.2,2.2,4.5,1.5)), + LabeledVector(2,DenseVector(5.6,2.5,3.9,1.1)), + LabeledVector(2,DenseVector(5.9,3.2,4.8,1.8)), + LabeledVector(2,DenseVector(6.1,2.8,4.0,1.3)), + LabeledVector(2,DenseVector(6.3,2.5,4.9,1.5)), + LabeledVector(2,DenseVector(6.7,3.0,5.0,1.7)), + LabeledVector(2,DenseVector(6.0,2.9,4.5,1.5)), + LabeledVector(2,DenseVector(5.7,2.6,3.5,1.0)), + LabeledVector(2,DenseVector(5.5,2.4,3.8,1.1)), + LabeledVector(2,DenseVector(5.5,2.4,3.7,1.0)), + LabeledVector(2,DenseVector(5.8,2.7,3.9,1.2)), + LabeledVector(2,DenseVector(6.0,2.7,5.1,1.6)), + LabeledVector(2,DenseVector(5.4,3.0,4.5,1.5)), + LabeledVector(2,DenseVector(6.0,3.4,4.5,1.6)), + LabeledVector(2,DenseVector(6.7,3.1,4.7,1.5)), + LabeledVector(2,DenseVector(6.1,3.0,4.6,1.4)), + LabeledVector(2,DenseVector(5.8,2.6,4.0,1.2)), + LabeledVector(2,DenseVector(5.0,2.3,3.3,1.0)), + LabeledVector(2,DenseVector(5.6,2.7,4.2,1.3)), + LabeledVector(2,DenseVector(5.7,3.0,4.2,1.2)), + LabeledVector(2,DenseVector(5.7,2.9,4.2,1.3)), + LabeledVector(2,DenseVector(6.2,2.9,4.3,1.3)), + LabeledVector(2,DenseVector(5.1,2.5,3.0,1.1)), + LabeledVector(2,DenseVector(5.7,2.8,4.1,1.3)), + LabeledVector(3,DenseVector(6.3,3.3,6.0,2.5)), + LabeledVector(3,DenseVector(5.8,2.7,5.1,1.9)), + LabeledVector(3,DenseVector(7.1,3.0,5.9,2.1)), + LabeledVector(3,DenseVector(6.3,2.9,5.6,1.8)), + LabeledVector(3,DenseVector(6.5,3.0,5.8,2.2)), + LabeledVector(3,DenseVector(7.6,3.0,6.6,2.1)), + LabeledVector(3,DenseVector(4.9,2.5,4.5,1.7)), + LabeledVector(3,DenseVector(7.3,2.9,6.3,1.8)), + LabeledVector(3,DenseVector(6.7,2.5,5.8,1.8)), + LabeledVector(3,DenseVector(7.2,3.6,6.1,2.5)), + LabeledVector(3,DenseVector(6.5,3.2,5.1,2.0)), + LabeledVector(3,DenseVector(6.4,2.7,5.3,1.9)), + LabeledVector(3,DenseVector(6.8,3.0,5.5,2.1)), + LabeledVector(3,DenseVector(5.7,2.5,5.0,2.0)), + LabeledVector(3,DenseVector(5.8,2.8,5.1,2.4)), + LabeledVector(3,DenseVector(6.4,3.2,5.3,2.3)), + LabeledVector(3,DenseVector(6.5,3.0,5.5,1.8)), + LabeledVector(3,DenseVector(7.7,3.8,6.7,2.2)), + LabeledVector(3,DenseVector(7.7,2.6,6.9,2.3)), + LabeledVector(3,DenseVector(6.0,2.2,5.0,1.5)), + LabeledVector(3,DenseVector(6.9,3.2,5.7,2.3)), + LabeledVector(3,DenseVector(6.7,3.3,5.7,2.1)), + LabeledVector(3,DenseVector(7.2,3.2,6.0,1.8)), + LabeledVector(3,DenseVector(6.2,2.8,4.8,1.8)), + LabeledVector(3,DenseVector(6.1,3.0,4.9,1.8)), + LabeledVector(3,DenseVector(6.4,2.8,5.6,2.1)), + LabeledVector(3,DenseVector(7.2,3.0,5.8,1.6)), + LabeledVector(3,DenseVector(7.4,2.8,6.1,1.9)), + LabeledVector(3,DenseVector(7.9,3.8,6.4,2.0)), + LabeledVector(3,DenseVector(6.4,2.8,5.6,2.2)), + LabeledVector(3,DenseVector(6.3,2.8,5.1,1.5)), + LabeledVector(3,DenseVector(6.1,2.6,5.6,1.4)), + LabeledVector(3,DenseVector(7.7,3.0,6.1,2.3)), + LabeledVector(3,DenseVector(6.3,3.4,5.6,2.4)), + LabeledVector(3,DenseVector(6.7,3.0,5.2,2.3)), + LabeledVector(3,DenseVector(6.3,2.5,5.0,1.9)), + LabeledVector(3,DenseVector(6.5,3.0,5.2,2.0)), + LabeledVector(3,DenseVector(6.2,3.4,5.4,2.3)), + LabeledVector(3,DenseVector(5.9,3.0,5.1,1.8)) + ) + + val IrisTestingData = Seq[LabeledVector]( + LabeledVector(1,DenseVector(5.0,3.4,1.5,0.2)), + LabeledVector(1,DenseVector(4.4,2.9,1.4,0.2)), + LabeledVector(1,DenseVector(4.9,3.1,1.5,0.1)), + LabeledVector(1,DenseVector(5.4,3.7,1.5,0.2)), + LabeledVector(1,DenseVector(4.8,3.4,1.6,0.2)), + LabeledVector(1,DenseVector(4.8,3.0,1.4,0.1)), + LabeledVector(1,DenseVector(5.7,3.8,1.7,0.3)), + LabeledVector(1,DenseVector(5.1,3.8,1.5,0.3)), + LabeledVector(1,DenseVector(5.4,3.4,1.7,0.2)), + LabeledVector(1,DenseVector(5.1,3.7,1.5,0.4)), + LabeledVector(1,DenseVector(4.6,3.6,1.0,0.2)), + LabeledVector(3,DenseVector(6.4,3.1,5.5,1.8)), + LabeledVector(3,DenseVector(6.0,3.0,4.8,1.8)), + LabeledVector(3,DenseVector(6.9,3.1,5.4,2.1)), + LabeledVector(3,DenseVector(6.7,3.1,5.6,2.4)), + LabeledVector(3,DenseVector(6.9,3.1,5.1,2.3)), + LabeledVector(3,DenseVector(5.8,2.7,5.1,1.9)), + LabeledVector(3,DenseVector(6.8,3.2,5.9,2.3)), + LabeledVector(3,DenseVector(6.7,3.3,5.7,2.5)), + LabeledVector(3,DenseVector(5.6,2.8,4.9,2.0)), + LabeledVector(2,DenseVector(6.1,2.8,4.7,1.2)), + LabeledVector(2,DenseVector(6.4,2.9,4.3,1.3)), + LabeledVector(2,DenseVector(6.6,3.0,4.4,1.4)), + LabeledVector(2,DenseVector(6.8,2.8,4.8,1.4)), + LabeledVector(3,DenseVector(7.7,2.8,6.7,2.0)), + LabeledVector(3,DenseVector(6.3,2.7,4.9,1.8)), + LabeledVector(2,DenseVector(6.3,2.3,4.4,1.3)), + LabeledVector(2,DenseVector(5.6,3.0,4.1,1.3)), + LabeledVector(2,DenseVector(5.5,2.5,4.0,1.3)), + LabeledVector(2,DenseVector(5.5,2.6,4.4,1.2)) + ) } diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/DecisionTreeSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/DecisionTreeSuite.scala new file mode 100644 index 0000000000000..b1d13cc94e4e6 --- /dev/null +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/DecisionTreeSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification + +import org.scalatest.{FlatSpec, Matchers} + +import org.apache.flink.api.scala._ +import org.apache.flink.test.util.FlinkTestBase + +class DecisionTreeSuite extends FlatSpec with Matchers with FlinkTestBase { + + behavior of "The Decision Tree implementation" + + it should "train a decision tree" in { + val env = ExecutionEnvironment.getExecutionEnvironment + + val learner = DecisionTree().setMaxBins(10).setDepth(20).setDimension(4).setClasses(3) + + val trainingDS = env.fromCollection(Classification.IrisTrainingData).setParallelism(4) + + val model = learner.fit(trainingDS) + + val predict = model.testAccuracy(env.fromCollection(Classification.IrisTestingData).setParallelism(4)) + println(s"Testing accuracy: $predict%") + } +} diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala index 924e88be6e933..128b6826c0990 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala @@ -110,20 +110,6 @@ class OnlineHistogramSuite extends FlatSpec with Matchers { } } - it should "succeed in computing sum values" in { - val l: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() - l.add(new Tuple2[Double, Int](1, 3)) - l.add(new Tuple2[Double, Int](3, 7)) - l.add(new Tuple2[Double, Int](7, 1)) - l.add(new Tuple2[Double, Int](10, 2)) - val h = new OnlineHistogram(4, 0, 11, l) - - h.sum(1) should equal(1) - h.sum(3) should equal(6) - h.sum(11) should equal(13) - h.sum(4) should equal(7) - } - it should "succeed in merging two histograms" in { val l: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() l.add(new Tuple2[Double, Int](1, 3)) @@ -142,15 +128,18 @@ class OnlineHistogramSuite extends FlatSpec with Matchers { h3.getCounter(2) should equal(2) } - it should "succeed in generating an equalization list" in { + it should "succeed in computing quantile" in { val l: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]]() - l.add(new Tuple2[Double, Int](1, 3)) - l.add(new Tuple2[Double, Int](3, 7)) - l.add(new Tuple2[Double, Int](7, 1)) + l.add(new Tuple2[Double, Int](1, 5)) + l.add(new Tuple2[Double, Int](3, 4)) + l.add(new Tuple2[Double, Int](7, 3)) l.add(new Tuple2[Double, Int](10, 2)) - val h = new OnlineHistogram(3, 0, 11, l) - val eqArr = h.uniform(3) - h.sum(eqArr(0)) should equal(4) - h.sum(eqArr(1)) should equal(8) + l.add(new Tuple2[Double, Int](11, 1)) + val h = new OnlineHistogram(5,0,12, data= l) + + h.sum(h.quantile(0.05)) should equal(1) + h.sum(h.quantile(0.4)) should equal(6) + h.sum(h.quantile(0.8)) should equal(12) + h.sum(h.quantile(0.95)) should equal(14) } -} +} \ No newline at end of file From f7d161513d3230794144a9d77ecfa60185077d90 Mon Sep 17 00:00:00 2001 From: Sachin Goel Date: Thu, 21 May 2015 21:13:50 +0530 Subject: [PATCH 3/5] Fixed scalastyle problems. Travis should pass now --- .../ml/classification/DecisionTree.scala | 380 +++++++++++------- .../flink/ml/math/OnlineHistogram.scala | 31 +- .../org/apache/flink/ml/tree/FieldStats.scala | 6 +- .../scala/org/apache/flink/ml/tree/Node.scala | 6 +- .../scala/org/apache/flink/ml/tree/Tree.scala | 8 +- .../flink/ml/tree/TreeConfiguration.scala | 12 +- 6 files changed, 274 insertions(+), 169 deletions(-) diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala index b6c459e76e053..bfdde73e13503 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala @@ -30,8 +30,8 @@ import org.apache.flink.ml.tree._ import scala.collection.mutable -/** Companion object of Decision Tree. Contains convenience functions and the parameter type definitions - * of the algorithm. +/** Companion object of Decision Tree. + * Contains convenience functions and the parameter type definitions of the algorithm. * */ object DecisionTree { @@ -98,7 +98,8 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial * @return itself */ def setMinInstancePerNode(minInstancesPerNode: Int): DecisionTree = { - require(minInstancesPerNode >= 1, "Every node must have at least one instance associated with it") + require(minInstancesPerNode >= 1, + "Every node must have at least one instance associated with it") parameters.add(MinInstancesPerNode, minInstancesPerNode) this } @@ -130,7 +131,8 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial * @return itself */ def setSplitStrategy(splitStrategy: String): DecisionTree = { - require(splitStrategy == "Gini" || splitStrategy == "Entropy", "Algorithm " + splitStrategy + " not supported") + require(splitStrategy == "Gini" || splitStrategy == "Entropy", + "Algorithm " + splitStrategy + " not supported") parameters.add(SplitStrategy, splitStrategy) this } @@ -167,14 +169,14 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial this } - /** Trains a Decision Tree * * @param input Training data set * @param fitParameters Parameter values * @return Trained Decision Tree Model */ - override def fit(input: DataSet[LabeledVector], fitParameters: ParameterMap): DecisionTreeModel = { + override def fit(input: DataSet[LabeledVector], fitParameters: ParameterMap): + DecisionTreeModel = { val resultingParameters = this.parameters ++ fitParameters val depth = resultingParameters(Depth) val minInstancePerNode = resultingParameters(MinInstancesPerNode) @@ -197,7 +199,8 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial tree = input.getExecutionEnvironment.fromElements(treeCopy) // remake the tree as a dataset // Group the input data into blocks in round robin fashion - val blockedInputNumberElements = FlinkTools.block(input, input.getParallelism, Some(ModuloKeyPartitioner)). + val blockedInputNumberElements = FlinkTools.block( + input, input.getParallelism, Some(ModuloKeyPartitioner)). cross(numberVectors). map { x => x } @@ -231,151 +234,222 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial treeCopy = tree.collect().toArray.apply(0) val finalHists = combinedHists.collect().toArray.apply(0) - val fieldStats = treeCopy.config.fieldStats - val labels = treeCopy.config.labels - val labelAddition = treeCopy.config.labelAddition - - // now, find splits across each dimension by merging (tree_node, dimension, label_1),(tree_node, dimension, label_2)... - - // keep a list of all the nodes we'll have to find splits for later - // we only get stuff for leafs which are unlabeled, so nodes only has those which are unlabeled leafs - - val nodeDimensionSplits = new mutable.HashMap[(Int, Int), Array[Double]] - val nodes = new mutable.HashMap[Int, Int] - - finalHists.keysIterator.foreach( - x => { - nodes.put(x._1, -1) - if (nodeDimensionSplits.get((x._1, x._2)).isEmpty) { - val min_value = fieldStats.apply(x._2).fieldMinValue - val max_value = fieldStats.apply(x._2).fieldMaxValue - var tmp_hist: Histogram = new OnlineHistogram(maxBins, 2 * min_value - max_value, 2 * max_value - min_value) - labels.iterator.foreach( - c => { - if (finalHists.get((x._1, x._2, c)).nonEmpty) { - tmp_hist = tmp_hist.merge(finalHists.get((x._1, x._2, c)).get, maxBins) - } + + any_unlabeled_left = evaluateSplits(treeCopy, finalHists) + tree = tree.getExecutionEnvironment.fromElements(treeCopy) + } + + DecisionTreeModel(tree) + } + + private def evaluateSplits( + tree: Tree, + finalHists: mutable.HashMap[(Int, Int, Double), Histogram]) + : Boolean = { + val fieldStats = tree.config.fieldStats + val labels = tree.config.labels + val maxBins = tree.config.MaxBins + + var any_split_done = false + + val (nodeDimensionSplits, nodes) = calculateSplits(finalHists, fieldStats, maxBins, labels) + + // first, every unlabeled leaf node must have sent something. + // If not, its sibling will be stuck forever + + tree.nodes.valuesIterator.foreach( + this_node => { + if (this_node.predict == -1 && this_node.split.isEmpty && + nodes.get(this_node.id).isEmpty) { + // we're in trouble + var sibling_id = 0 + if (this_node.id % 2 == 0) + sibling_id = this_node.id + 1 + else + sibling_id = this_node.id - 1 + // this node is pointless. Remove it from the tree + tree.nodes.remove(this_node.id) + // we're not going to split the sibling anymore. + nodes.put(sibling_id, 1) // we'll check for this '1' later + } + } + ) + // now, for each node, for each dimension, evaluate splits based on the above histograms + any_split_done = evaluateNodes(nodes, tree, nodeDimensionSplits, finalHists) + return any_split_done + } + + /** Merge all histograms (node_id, dim, _) in finalHists and return an array of splits + * that need to be considered at node node_id on dimension dim. + * Also, return a hashmap containing entries for whichever nodes have some instances allocated + * to them + * + */ + private def calculateSplits( + finalHists: mutable.HashMap[(Int, Int, Double), Histogram], + fieldStats: Array[FieldStats], + maxBins: Int, + labels: Array[Double]): + (mutable.HashMap[(Int, Int), Array[Double]], mutable.HashMap[Int, Int]) = { + + // keep a list of all the nodes we'll have to find splits for later + // we only get stuff for leafs which are unlabeled, so nodes only has those which are + // unlabeled leafs + + val nodeDimensionSplits = new mutable.HashMap[(Int, Int), Array[Double]] + val nodes = new mutable.HashMap[Int, Int] + + finalHists.keysIterator.foreach( + x => { + nodes.put(x._1, -1) + if (nodeDimensionSplits.get((x._1, x._2)).isEmpty) { + val min_value = fieldStats.apply(x._2).fieldMinValue + val max_value = fieldStats.apply(x._2).fieldMaxValue + var tmp_hist: Histogram = + new OnlineHistogram(maxBins, 2 * min_value - max_value, 2 * max_value - min_value) + labels.iterator.foreach( + c => { + if (finalHists.get((x._1, x._2, c)).nonEmpty) { + tmp_hist = tmp_hist.merge(finalHists.get((x._1, x._2, c)).get, maxBins) } - ) - // find maxBins quantiles, or if the size of histogram is less than that, just find those many - val actualBins = Math.min(maxBins, tmp_hist.bins - 1) - val quantileArray = new Array[Double](actualBins) - for (i <- 1 to actualBins) { - quantileArray(i - 1) = tmp_hist.quantile((i + 0.0) / (actualBins + 1)) } - nodeDimensionSplits.put((x._1, x._2), quantileArray) + ) + // find maxBins quantiles, or if the size of histogram is less than that, just those many + val actualBins = Math.min(maxBins, tmp_hist.bins - 1) + val quantileArray = new Array[Double](actualBins) + for (i <- 1 to actualBins) { + quantileArray(i - 1) = tmp_hist.quantile((i + 0.0) / (actualBins + 1)) } + nodeDimensionSplits.put((x._1, x._2), quantileArray) } - ) - // first, every unlabeled leaf node must have sent something. If not, its sibling will be stuck forever - - treeCopy.nodes.valuesIterator.foreach( - this_node => { - if (this_node.predict == -1 && this_node.split.isEmpty && nodes.get(this_node.id).isEmpty) { - // we're in trouble - var sibling_id = 0 - if (this_node.id % 2 == 0) - sibling_id = this_node.id + 1 - else - sibling_id = this_node.id - 1 - treeCopy.nodes.remove(this_node.id) // this node is pointless. Remove it from the tree - // we're not going to split the sibling anymore. - nodes.put(sibling_id, 1) // we'll check for this '1' later + } + ) + (nodeDimensionSplits, nodes) + } + + private def findGini( + node_id: Int, + finalHists: mutable.HashMap[(Int, Int, Double), Histogram], + labels: Array[Double]): + (Double, Double, Double, Double) = { + var sumClassSquare, totalNumPointsHere, maxClassCountLabel, maxClassCount = 0.0 + // since the count of classes across any dimension is same, pick 0 + // for calculating Gini index, we need count(c)^2 and \sum count(c)^2 + // also maintain which class occurred most frequently in case we need to mark this as a leaf + // node + labels.iterator.foreach( + x => { + val h = finalHists.get((node_id, 0, x)) + if (h.nonEmpty) { + val countOfClass = h.get.sum(h.get.upper) + totalNumPointsHere = totalNumPointsHere + countOfClass + sumClassSquare = sumClassSquare + countOfClass * countOfClass + if (countOfClass > maxClassCount) { + maxClassCount = countOfClass + maxClassCountLabel = x } } - ) - // now, for each node, for each dimension, evaluate splits based on the above histograms - nodes.keysIterator.foreach( - node_id => { - var node: Node = treeCopy.nodes.get(node_id).get - // since the count of classes across any dimension is same, pick 0 - // for calculating Gini index, we need count(c)^2 and \sum count(c)^2 - // also maintain which class occurred most frequently in case we need to mark this as a leaf node - var sumClassSquare, totalNumPointsHere, maxClassCountLabel, maxClassCount = 0.0 - labels.iterator.foreach( - x => { - val h = finalHists.get((node_id, 0, x)) - if (h.nonEmpty) { - val countOfClass = h.get.sum(h.get.upper) - totalNumPointsHere = totalNumPointsHere + countOfClass - sumClassSquare = sumClassSquare + countOfClass * countOfClass - if (countOfClass > maxClassCount) { - maxClassCount = countOfClass - maxClassCountLabel = x - } - } - } - ) + } + ) + (sumClassSquare, totalNumPointsHere, maxClassCountLabel, maxClassCount) + } - // now see if this node has become pure or if it's depth is at a maximum - if (totalNumPointsHere * totalNumPointsHere == sumClassSquare || node.getDepth == depth) { - node.predict = maxClassCountLabel + labelAddition - // sanity check. If we're declaring something a leaf, it better not have any children - require(node.split.isEmpty, "An unexpected error occurred") - } else if (nodes.get(node_id).get == 1) { - // this node is meaningless. The parent didn't do anything at all. - // just remove this node from the tree and set the label of parent, which would be same as the label of this - node = treeCopy.nodes.get(node_id / 2).get - node.predict = maxClassCountLabel + labelAddition - node.split = None - treeCopy.nodes.remove(node_id) - } else { - val giniParent = 1 - sumClassSquare / Math.pow(totalNumPointsHere, 2) - var best_gini = -java.lang.Double.MAX_VALUE - var best_dimension = -1 - var best_split_value, best_left_total, best_right_total = 0.0 - // consider all splits across all dimensions and pick the best one - for (j <- 0 to dimension - 1) { - val splitsArray = nodeDimensionSplits.get((node_id, j)).get - val actualSplits = splitsArray.length - for (k <- 0 to actualSplits - 1) { - // maintain how many instances go left and right for Gini - var total_left, total_right, left_sum_sqr, right_sum_sqr = 0.0 - for (l <- 0 to numClasses - 1) { - val h = finalHists.get((node_id, j, labels.apply(l))) - if (h.nonEmpty) { - val left = h.get.sum(splitsArray.apply(k)) - val right = h.get.sum(h.get.upper) - left - total_left = total_left + left - total_right = total_right + right - left_sum_sqr = left_sum_sqr + left * left - right_sum_sqr = right_sum_sqr + right * right - } + private def evaluateNodes( + nodes: mutable.HashMap[Int, Int], + tree: Tree, + nodeDimensionSplits: mutable.HashMap[(Int, Int), Array[Double]], + finalHists: mutable.HashMap[(Int, Int, Double), Histogram]): + Boolean = { + val depth = tree.config.Depth + val labelAddition = tree.config.labelAddition + val dimension = tree.config.dimension + val numClasses = tree.config.numClasses + val minInstancePerNode = tree.config.MinInstancePerNode + val labels = tree.config.labels + var any_split_done = false + + nodes.keysIterator.foreach( + node_id => { + + var node: Node = tree.nodes.get(node_id).get + + // find the gini index of this node + val (sumClassSquare, totalNumPointsHere, maxClassCountLabel, maxClassCount) = + findGini(node_id, finalHists, labels) + + // now see if this node has become pure or if it's depth is at a maximum + if (totalNumPointsHere * totalNumPointsHere == sumClassSquare || node.getDepth == depth) { + node.predict = maxClassCountLabel + labelAddition + // sanity check. If we're declaring something a leaf, it better not have any children + require(node.split.isEmpty, "An unexpected error occurred") + } + else if (nodes.get(node_id).get == 1) { + // this node is meaningless. The parent didn't do anything at all. + // just remove this node from the tree and set the label of parent, + // which would be same as the label of this + node = tree.nodes.get(node_id / 2).get + node.predict = maxClassCountLabel + labelAddition + node.split = None + tree.nodes.remove(node_id) + } + else { + val giniParent = 1 - sumClassSquare / Math.pow(totalNumPointsHere, 2) + var best_gini = -java.lang.Double.MAX_VALUE + var best_dimension = -1 + var best_split_value, best_left_total, best_right_total = 0.0 + // consider all splits across all dimensions and pick the best one + for (j <- 0 to dimension - 1) { + val splitsArray = nodeDimensionSplits.get((node_id, j)).get + val actualSplits = splitsArray.length + for (k <- 0 to actualSplits - 1) { + // maintain how many instances go left and right for Gini + var total_left, total_right, left_sum_sqr, right_sum_sqr = 0.0 + for (l <- 0 to numClasses - 1) { + val h = finalHists.get((node_id, j, labels.apply(l))) + if (h.nonEmpty) { + val left = h.get.sum(splitsArray.apply(k)) + val right = h.get.sum(h.get.upper) - left + total_left = total_left + left + total_right = total_right + right + left_sum_sqr = left_sum_sqr + left * left + right_sum_sqr = right_sum_sqr + right * right } - // ensure that the split is allowed by user. We need at least this many instances everywhere - if (total_left >= minInstancePerNode && total_right >= minInstancePerNode) { - // use a balancing term to ensure the tree is balanced more on the top - // the exponential term in scaling makes sure that as we go deeper, Gini becomes more important in splits - // this makes sense. We want balanced tree on top and fine-grained splitting at deeper levels - val scaling = Math.pow(0.1, node.getDepth + 1) - val balancing = Math.abs(total_left - total_right) / totalNumPointsHere - val this_gini = (1 - scaling) * (giniParent - total_left * (1 - left_sum_sqr / Math.pow(total_left, 2)) / totalNumPointsHere - - total_right * (1 - right_sum_sqr / Math.pow(total_right, 2)) / totalNumPointsHere) + scaling * balancing - if (this_gini > best_gini) { - best_gini = this_gini - best_dimension = j - best_split_value = nodeDimensionSplits.get((node_id, j)).get.apply(k) - best_left_total = total_left - best_right_total = total_right - } + } + // ensure that the split is allowed by user. We need at least this many instances + if (total_left >= minInstancePerNode && total_right >= minInstancePerNode) { + // use a balancing term to ensure the tree is balanced more on the top + // the exponential term in scaling makes sure that as we go deeper, + // Gini becomes more important in splits + // this makes sense. We want balanced tree on top and fine-grained splitting at + // deeper levels + val scaling = Math.pow(0.1, node.getDepth + 1) + val balancing = Math.abs(total_left - total_right) / totalNumPointsHere + val this_gini = (1 - scaling) * (giniParent - + total_left * (1 - left_sum_sqr / Math.pow(total_left, 2)) / totalNumPointsHere - + total_right * (1 - right_sum_sqr / Math.pow(total_right, 2)) / totalNumPointsHere + ) + scaling * balancing + if (this_gini > best_gini) { + best_gini = this_gini + best_dimension = j + best_split_value = nodeDimensionSplits.get((node_id, j)).get.apply(k) + best_left_total = total_left + best_right_total = total_right } } } - if (best_dimension != -1) { - node.split = Some(new SplitValue(best_dimension, true, best_split_value)) - treeCopy.nodes.put(node_id * 2, new Node(node_id * 2, treeCopy.treeID, None)) - treeCopy.nodes.put(node_id * 2 + 1, new Node(node_id * 2 + 1, treeCopy.treeID, None)) - any_unlabeled_left = true - } else { - node.predict = maxClassCountLabel + labelAddition - } + } + if (best_dimension != -1) { + node.split = Some(new SplitValue(best_dimension, true, best_split_value)) + tree.nodes.put(node_id * 2, new Node(node_id * 2, tree.treeID, None)) + tree.nodes.put(node_id * 2 + 1, new Node(node_id * 2 + 1, tree.treeID, None)) + any_split_done = true + } else { + node.predict = maxClassCountLabel + labelAddition } } - ) - tree = tree.getExecutionEnvironment.fromElements(treeCopy) - } - - DecisionTreeModel(tree) + } + ) + return any_split_done } private def localHistUpdate( @@ -387,11 +461,13 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial * because we broadcast the current value of the tree to all mappers. * */ - val localUpdate = new RichMapFunction[(Block[LabeledVector], Int), mutable.HashMap[(Int, Int, Double), Histogram]] { + val localUpdate = new RichMapFunction[(Block[LabeledVector], Int), + mutable.HashMap[(Int, Int, Double), Histogram]] { var tree: Tree = _ var config: TreeConfiguration = _ - var histograms: mutable.HashMap[(Int, Int, Double), Histogram] = new mutable.HashMap[(Int, Int, Double), Histogram]() + var histograms: mutable.HashMap[(Int, Int, Double), Histogram] = + new mutable.HashMap[(Int, Int, Double), Histogram]() override def open(parameters: Configuration): Unit = { // get the tree @@ -409,7 +485,8 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial // find where this instance goes val (node, predict) = tree.filter(vector) if (predict.round.toInt == -1) { - // we can be sure that this is an unlabeled leaf and not some internal node. That's how filter works + // we can be sure that this is an unlabeled leaf and not some internal node. + // That's how filter works for (j <- 0 to tree.config.dimension - 1) { val min_value = config.fieldStats.apply(j).fieldMinValue val max_value = config.fieldStats.apply(j).fieldMaxValue @@ -448,8 +525,10 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial } override def map(labeledVector: LabeledVector): (Array[FieldStats], Int) = { - require(labeledVector.vector.size == config.dimension, "Specify the dimension of data correctly") - val ret = new Array[FieldStats](config.dimension + 1) // we also include the label field for now + require(labeledVector.vector.size == config.dimension, + "Specify the dimension of data correctly") + // we also include the label field for now + val ret = new Array[FieldStats](config.dimension + 1) for (i <- 0 to config.dimension - 1) { if (config.category.indexOf(i) >= 0) { // if this is a categorical field @@ -458,12 +537,14 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial ret(i) = new FieldStats(false, fieldCategories = h) } else { // if continuous field - ret(i) = new FieldStats(true, labeledVector.vector.apply(i), labeledVector.vector.apply(i)) + ret(i) = new FieldStats( + true, labeledVector.vector.apply(i), labeledVector.vector.apply(i)) } } val h = new mutable.HashMap[Double, Int]() h.put(labeledVector.label, 1) - ret(config.dimension) = new FieldStats(false, fieldCategories = h) // the label field is always categorical + // the label field is always categorical + ret(config.dimension) = new FieldStats(false, fieldCategories = h) (ret, 1) } @@ -498,12 +579,15 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial config.setNumTrainVector(combinedStats_tuple._2) // cross check user-specified number of classes - require(combinedStats.apply(config.dimension).fieldCategories.size == config.numClasses, "Specify number of classes correctly") + require(combinedStats.apply(config.dimension).fieldCategories.size == config.numClasses, + "Specify number of classes correctly") // now copy the label list to an array of double // Find the minimum value, negative of which will be set to labelAddition var min_label = java.lang.Double.MAX_VALUE - val labels_list = new util.ArrayList[Double](combinedStats.apply(config.dimension).fieldCategories.size) + val labels_list = new util.ArrayList[Double]( + combinedStats.apply(config.dimension).fieldCategories.size) + combinedStats.apply(config.dimension).fieldCategories.keysIterator.foreach( x => { if (x < min_label) min_label = x diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala index e08a0cb70217e..3f72c230e95e9 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala @@ -31,7 +31,8 @@ class OnlineHistogram( val capacity: Int, val min: Double = -java.lang.Double.MAX_VALUE, val max: Double = java.lang.Double.MAX_VALUE, - val data: util.ArrayList[Tuple2[Double, Int]] = new util.ArrayList[Tuple2[Double, Int]] + val data: util.ArrayList[Tuple2[Double, Int]] = + new util.ArrayList[Tuple2[Double, Int]] ) extends Histogram with Serializable { require(checkSanity, "Invalid data provided") @@ -87,7 +88,8 @@ class OnlineHistogram( require(value > lower && value < upper, value + " not in (" + lower + "," + upper + ")") val search = find(value) data.add(search, new Tuple2[Double, Int](value, 1)) - mergeElements() // if the value we just added is already there, mergeElements will take care of this + // if the value we just added is already there, mergeElements will take care of this + mergeElements() } /** Merges the histogram with h and returns a histogram with B bins @@ -142,13 +144,16 @@ class OnlineHistogram( val c: Double = -2 * neededSum if (a == 0) { return getValue(i - 1) + (getValue(i) - getValue(i - 1)) * (-c / b) - } else return getValue(i - 1) + (getValue(i) - getValue(i - 1)) * (-b + Math.sqrt(b * b - 4 * a * c)) / (2 * a) + } else return getValue(i - 1) + + (getValue(i) - getValue(i - 1)) * (-b + Math.sqrt(b * b - 4 * a * c)) / (2 * a) } else currSum = tmpSum } require(upper < java.lang.Double.MAX_VALUE, "Set an upper bound before proceeding") // this means wantedSum > sum(getValue(bins-1)) - // this will likely fail to return a bounded value. Make sure you set some proper limits on min and max. - getValue(bins - 1) + Math.sqrt(Math.pow(upper - getValue(bins - 1), 2) * 2 * (wantedSum - currSum) / getCounter(bins - 1)) + // this will likely fail to return a bounded value. + // Make sure you set some proper limits on min and max. + getValue(bins - 1) + Math.sqrt( + Math.pow(upper - getValue(bins - 1), 2) * 2 * (wantedSum - currSum) / getCounter(bins - 1)) } /** Returns the string representation of the histogram. @@ -179,11 +184,15 @@ class OnlineHistogram( s = m_b * (b - lower) / (2 * (getValue(index + 1) - lower)) return s.round.toInt } else if (index == bins - 1) { - m_b = getCounter(index) + (-getCounter(index)) * (b - getValue(index)) / (upper - getValue(index)) - s = (getCounter(index) + m_b) * (b - getValue(index)) / (2 * (upper - getValue(index))) + m_b = getCounter(index) + + (-getCounter(index)) * (b - getValue(index)) / (upper - getValue(index)) + s = (getCounter(index) + m_b) * + (b - getValue(index)) / (2 * (upper - getValue(index))) } else { - m_b = getCounter(index) + (getCounter(index + 1) - getCounter(index)) * (b - getValue(index)) / (getValue(index + 1) - getValue(index)) - s = (getCounter(index) + m_b) * (b - getValue(index)) / (2 * (getValue(index + 1) - getValue(index))) + m_b = getCounter(index) + (getCounter(index + 1) - getCounter(index)) * + (b - getValue(index)) / (getValue(index + 1) - getValue(index)) + s = (getCounter(index) + m_b) * + (b - getValue(index)) / (2 * (getValue(index + 1) - getValue(index))) } for (i <- 0 to index - 1) { s = s + getCounter(i) @@ -236,7 +245,9 @@ class OnlineHistogram( } if (bins > capacity || diff < 1e-9) { val merged_tuple: Tuple2[Double, Int] = mergeBins(index) - set(index, merged_tuple.getField(0).asInstanceOf[Double] / merged_tuple.getField(1).asInstanceOf[Int], merged_tuple.getField(1)) + set(index, + merged_tuple.getField(0).asInstanceOf[Double] / merged_tuple.getField(1).asInstanceOf[Int], + merged_tuple.getField(1)) data.remove(index + 1) true } else false diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala index 671307c28663b..58415920e6e7c 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala @@ -21,7 +21,8 @@ package org.apache.flink.ml.tree /** Keeps useful statistics about a field * fieldType is false for categorical fields, true for continuous fields - * For continuous field, minimum and maximum values. Usually, min-(max-min) and max+(max-min) from the data should suffice + * For continuous field, minimum and maximum values. + * Usually, min-(max-min) and max+(max-min) from the data should suffice * For categorical field, list of categories * */ @@ -30,7 +31,8 @@ class FieldStats( val fieldType: Boolean, val fieldMinValue: Double = -java.lang.Double.MAX_VALUE, val fieldMaxValue: Double = java.lang.Double.MAX_VALUE, - val fieldCategories: collection.mutable.HashMap[Double, Int] = new collection.mutable.HashMap[Double, Int]) { + val fieldCategories: collection.mutable.HashMap[Double, Int] = + new collection.mutable.HashMap[Double, Int]) { override def toString: String = { if (fieldType) diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala index ebe75cdbedcea..b9466bb25c782 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala @@ -19,8 +19,10 @@ package org.apache.flink.ml.tree /** If the node has been trained, it will have: - * a. predict >=0, in this case, split should be empty. This is fully grown node and we can't go further down - * b. predict = -1, in this case if split is empty, we need to split, otherwise, this is an internal node + * a. predict >=0, in this case, split should be empty. This is fully grown node and we can't go + * further down + * b. predict = -1, in this case if split is empty, we need to split, otherwise, this is an + * internal node * * ID starts from 1 for the root node. * treeID is the tree to which this node belongs diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala index c44f28b9aa978..2c9b13645b41d 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala @@ -18,12 +18,13 @@ package org.apache.flink.ml.tree -import org.apache.flink.ml.math.Vector +import org.apache.flink.ml.math.{Histogram, Vector} import scala.collection.mutable -/** Tree structure. This is kind of maintained in an unconventional way. We provide direct access to all nodes - * The obvious assumption is that child of node i will be 2*i and 2*i+1, while parent of i will be i/2 +/** Tree structure. This is kind of maintained in an unconventional way. + * We provide direct access to all nodes + * The obvious assumption is that child of i are 2*i and 2*i+1, while parent of i is i/2 * */ class Tree( @@ -32,6 +33,7 @@ class Tree( val config: TreeConfiguration ) extends Serializable { + override def toString: String = { var ret = s"Tree ID=$treeID\nConfiguration:\n$config \nTree Structure:\n" for (i <- 1 to Math.pow(2, 20).toInt) { diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala index 88ea141f7c8c8..36f5c4135a06e 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala @@ -31,7 +31,8 @@ package org.apache.flink.ml.tree * @param category Which fields are to be considered as categorical. Array of field indexes * @param fieldStats Field maximum and minimum values, list of categories [Automatically] * @param labels Array of labels slash classes in data [Automatically] - * @param labelAddition Addition term to make all labels >=0 [Automatically][Only for internal use. Not visible to user] + * @param labelAddition Addition term to make all labels >=0 [Automatically] + * [Only for internal use. Not visible to user] * @param numTrainVector Number of training instances [Automatically] */ class TreeConfiguration( @@ -50,9 +51,12 @@ class TreeConfiguration( ) { override def toString: String = { - var ret = s"Maximum Binning: $MaxBins, Minimum Instance per leaf node: $MinInstancePerNode, Maximum Depth: $Depth, Pruning:$Pruning" + - s"\nSplit Strategy: $splitStrategy, Number of classes: $numClasses, Dimension of data: $dimension, Number of training vectors: $numTrainVector\n" + - s"categorical fields: " + java.util.Arrays.toString(category) + "\nLabels in data: " + java.util.Arrays.toString(labels) + "\nField stats:" + var ret = s"Maximum Binning: $MaxBins, Minimum Instance per leaf node: $MinInstancePerNode, " + + s"Maximum Depth: $Depth, Pruning:$Pruning\nSplit strategy: $splitStrategy, Number of " + + s"classes: $numClasses, Dimension of data: $dimension, Number of training vectors:"+ + s"$numTrainVector\n"+s"categorical fields: " + java.util.Arrays.toString(category) + + "\nLabels in data: " + java.util.Arrays.toString(labels) + "\nField stats:" + fieldStats.iterator.foreach(x => ret = ret + x.toString) ret + s"\nLabel Addition: $labelAddition" } From 52a644f45ad8837b84d9ab17499f0001b4ce3b21 Mon Sep 17 00:00:00 2001 From: Sachin Goel Date: Thu, 21 May 2015 21:21:51 +0530 Subject: [PATCH 4/5] Fixed two more scalastyle errors --- .../main/scala/org/apache/flink/ml/math/OnlineHistogram.scala | 3 ++- .../org/apache/flink/ml/classification/DecisionTreeSuite.scala | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala index 3f72c230e95e9..a34d415e96a81 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala @@ -201,7 +201,8 @@ class OnlineHistogram( s.round.toInt } - /** Updates the given bin with the provided value and counter. Sets `v_bin`=value and `m_bin`=counter + /** Updates the given bin with the provided value and counter. + * Sets `v_bin`=value and `m_bin`=counter * * @param bin bin to be updated * @param value value to be set at bin diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/DecisionTreeSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/DecisionTreeSuite.scala index b1d13cc94e4e6..94ee30b42ded9 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/DecisionTreeSuite.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/DecisionTreeSuite.scala @@ -36,7 +36,8 @@ class DecisionTreeSuite extends FlatSpec with Matchers with FlinkTestBase { val model = learner.fit(trainingDS) - val predict = model.testAccuracy(env.fromCollection(Classification.IrisTestingData).setParallelism(4)) + val predict = model.testAccuracy(env.fromCollection( + Classification.IrisTestingData).setParallelism(4)) println(s"Testing accuracy: $predict%") } } From 169f8e911e113d23f20ef64f3611ea6b294f08d6 Mon Sep 17 00:00:00 2001 From: Sachin Goel Date: Thu, 21 May 2015 22:12:38 +0530 Subject: [PATCH 5/5] More styling errors fixed. I'm the weilder of the flame of Anor and ye shall pass --- .../ml/classification/DecisionTree.scala | 60 ++++++++++--------- .../flink/ml/math/OnlineHistogram.scala | 2 +- .../org/apache/flink/ml/tree/FieldStats.scala | 8 ++- .../scala/org/apache/flink/ml/tree/Node.scala | 2 +- .../org/apache/flink/ml/tree/SplitValue.scala | 14 +++-- .../scala/org/apache/flink/ml/tree/Tree.scala | 11 ++-- .../flink/ml/tree/TreeConfiguration.scala | 10 ++-- .../flink/ml/math/OnlineHistogramSuite.scala | 2 +- 8 files changed, 60 insertions(+), 49 deletions(-) diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala index bfdde73e13503..88377055ee4ae 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/DecisionTree.scala @@ -263,10 +263,12 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial nodes.get(this_node.id).isEmpty) { // we're in trouble var sibling_id = 0 - if (this_node.id % 2 == 0) + if (this_node.id % 2 == 0) { sibling_id = this_node.id + 1 - else + } + else { sibling_id = this_node.id - 1 + } // this node is pointless. Remove it from the tree tree.nodes.remove(this_node.id) // we're not going to split the sibling anymore. @@ -327,33 +329,6 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial (nodeDimensionSplits, nodes) } - private def findGini( - node_id: Int, - finalHists: mutable.HashMap[(Int, Int, Double), Histogram], - labels: Array[Double]): - (Double, Double, Double, Double) = { - var sumClassSquare, totalNumPointsHere, maxClassCountLabel, maxClassCount = 0.0 - // since the count of classes across any dimension is same, pick 0 - // for calculating Gini index, we need count(c)^2 and \sum count(c)^2 - // also maintain which class occurred most frequently in case we need to mark this as a leaf - // node - labels.iterator.foreach( - x => { - val h = finalHists.get((node_id, 0, x)) - if (h.nonEmpty) { - val countOfClass = h.get.sum(h.get.upper) - totalNumPointsHere = totalNumPointsHere + countOfClass - sumClassSquare = sumClassSquare + countOfClass * countOfClass - if (countOfClass > maxClassCount) { - maxClassCount = countOfClass - maxClassCountLabel = x - } - } - } - ) - (sumClassSquare, totalNumPointsHere, maxClassCountLabel, maxClassCount) - } - private def evaluateNodes( nodes: mutable.HashMap[Int, Int], tree: Tree, @@ -452,6 +427,33 @@ class DecisionTree extends Learner[LabeledVector, DecisionTreeModel] with Serial return any_split_done } + private def findGini( + node_id: Int, + finalHists: mutable.HashMap[(Int, Int, Double), Histogram], + labels: Array[Double]): + (Double, Double, Double, Double) = { + var sumClassSquare, totalNumPointsHere, maxClassCountLabel, maxClassCount = 0.0 + // since the count of classes across any dimension is same, pick 0 + // for calculating Gini index, we need count(c)^2 and \sum count(c)^2 + // also maintain which class occurred most frequently in case we need to mark this as a leaf + // node + labels.iterator.foreach( + x => { + val h = finalHists.get((node_id, 0, x)) + if (h.nonEmpty) { + val countOfClass = h.get.sum(h.get.upper) + totalNumPointsHere = totalNumPointsHere + countOfClass + sumClassSquare = sumClassSquare + countOfClass * countOfClass + if (countOfClass > maxClassCount) { + maxClassCount = countOfClass + maxClassCountLabel = x + } + } + } + ) + (sumClassSquare, totalNumPointsHere, maxClassCountLabel, maxClassCount) + } + private def localHistUpdate( tree: DataSet[Tree], blockedInputNumberElements: DataSet[(Block[LabeledVector], Int)]) diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala index a34d415e96a81..5e745ff710410 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/OnlineHistogram.scala @@ -283,4 +283,4 @@ class OnlineHistogram( while (mergeElements()) {} true } -} \ No newline at end of file +} diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala index 58415920e6e7c..dee7713916215 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/FieldStats.scala @@ -35,9 +35,11 @@ class FieldStats( new collection.mutable.HashMap[Double, Int]) { override def toString: String = { - if (fieldType) + if (fieldType) { s"Continuous field: Range: ($fieldMinValue,$fieldMaxValue)" - else + } + else { s"Categorical field: Number of categories: $fieldCategories" + } } -} \ No newline at end of file +} diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala index b9466bb25c782..a064359beb092 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Node.scala @@ -45,4 +45,4 @@ class Node( // depth starts from one. A matter of convention really java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(id)) + 1 } -} \ No newline at end of file +} diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/SplitValue.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/SplitValue.scala index 86ae751b60567..4894cb49a309e 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/SplitValue.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/SplitValue.scala @@ -43,16 +43,20 @@ class SplitValue( val splitValueList: util.ArrayList[Double] = new util.ArrayList[Double]) { override def toString: String = { - if (splitType) + if (splitType){ s"Attribute Index: $attribute, Split: Continuous Value at $splitValueDouble" - else + } + else{ s"Attribute Index: $attribute, Split: Categorical at $splitValueList" + } } def getSplitDirection(vector: Vector): Boolean = { - if (splitType) + if (splitType){ vector.apply(attribute) <= splitValueDouble // go left if less than equal to - else + } + else{ splitValueList.contains(vector.apply(attribute)) // go left is exists + } } -} \ No newline at end of file +} diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala index 2c9b13645b41d..1f6c1e55f1e54 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/Tree.scala @@ -37,8 +37,9 @@ class Tree( override def toString: String = { var ret = s"Tree ID=$treeID\nConfiguration:\n$config \nTree Structure:\n" for (i <- 1 to Math.pow(2, 20).toInt) { - if (nodes.get(i).nonEmpty) + if (nodes.get(i).nonEmpty){ ret = ret + nodes.get(i).get.toString + "\n" + } } ret } @@ -50,11 +51,13 @@ class Tree( def filter(vector: Vector): (Int, Double) = { var node: Node = nodes.get(1).get while (node.predict.round.toInt == -1 && node.split.nonEmpty) { - if (node.split.get.getSplitDirection(vector)) + if (node.split.get.getSplitDirection(vector)){ node = nodes.get(2 * node.id).get - else + } + else{ node = nodes.get(2 * node.id + 1).get + } } (node.id, node.predict) } -} \ No newline at end of file +} diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala index 36f5c4135a06e..a38ca890b89eb 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/tree/TreeConfiguration.scala @@ -51,10 +51,10 @@ class TreeConfiguration( ) { override def toString: String = { - var ret = s"Maximum Binning: $MaxBins, Minimum Instance per leaf node: $MinInstancePerNode, " + - s"Maximum Depth: $Depth, Pruning:$Pruning\nSplit strategy: $splitStrategy, Number of " + - s"classes: $numClasses, Dimension of data: $dimension, Number of training vectors:"+ - s"$numTrainVector\n"+s"categorical fields: " + java.util.Arrays.toString(category) + + var ret = s"Maximum Binning: $MaxBins, Minimum Instance per leaf node: $MinInstancePerNode, " + + s"Maximum Depth: $Depth, Pruning:$Pruning\nSplit strategy: $splitStrategy, Number of " + + s"classes: $numClasses, Dimension of data: $dimension, Number of training vectors:" + + s"$numTrainVector\n Categorical fields: " + java.util.Arrays.toString(category) + "\nLabels in data: " + java.util.Arrays.toString(labels) + "\nField stats:" fieldStats.iterator.foreach(x => ret = ret + x.toString) @@ -76,4 +76,4 @@ class TreeConfiguration( def setLabelAddition(label_add: Double): Unit = { labelAddition = label_add } -} \ No newline at end of file +} diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala index 128b6826c0990..099ea6ac41779 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/OnlineHistogramSuite.scala @@ -142,4 +142,4 @@ class OnlineHistogramSuite extends FlatSpec with Matchers { h.sum(h.quantile(0.8)) should equal(12) h.sum(h.quantile(0.95)) should equal(14) } -} \ No newline at end of file +}