Skip to content

Commit

Permalink
Merge branch 'master' of github.com:apache/spark into fix-input-metri…
Browse files Browse the repository at this point in the history
…cs-coalesce
  • Loading branch information
Andrew Or committed Jan 29, 2016
2 parents c31a410 + b9dfdcc commit c5a97fc
Show file tree
Hide file tree
Showing 19 changed files with 590 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,45 @@ int getVersionNumber() {
public abstract long totalCount();

/**
* Adds 1 to {@code item}.
* Increments {@code item}'s count by one.
*/
public abstract void add(Object item);

/**
* Adds {@code count} to {@code item}.
* Increments {@code item}'s count by {@code count}.
*/
public abstract void add(Object item, long count);

/**
* Increments {@code item}'s count by one.
*/
public abstract void addLong(long item);

/**
* Increments {@code item}'s count by {@code count}.
*/
public abstract void addLong(long item, long count);

/**
* Increments {@code item}'s count by one.
*/
public abstract void addString(String item);

/**
* Increments {@code item}'s count by {@code count}.
*/
public abstract void addString(String item, long count);

/**
* Increments {@code item}'s count by one.
*/
public abstract void addBinary(byte[] item);

/**
* Increments {@code item}'s count by {@code count}.
*/
public abstract void addBinary(byte[] item, long count);

/**
* Returns the estimated frequency of {@code item}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Random;

Expand Down Expand Up @@ -146,27 +145,49 @@ public void add(Object item, long count) {
}
}

private void addString(String item, long count) {
@Override
public void addString(String item) {
addString(item, 1);
}

@Override
public void addString(String item, long count) {
addBinary(Utils.getBytesFromUTF8String(item), count);
}

@Override
public void addLong(long item) {
addLong(item, 1);
}

@Override
public void addLong(long item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}

int[] buckets = getHashBuckets(item, depth, width);

for (int i = 0; i < depth; ++i) {
table[i][buckets[i]] += count;
table[i][hash(item, i)] += count;
}

totalCount += count;
}

private void addLong(long item, long count) {
@Override
public void addBinary(byte[] item) {
addBinary(item, 1);
}

@Override
public void addBinary(byte[] item, long count) {
if (count < 0) {
throw new IllegalArgumentException("Negative increments not implemented");
}

int[] buckets = getHashBuckets(item, depth, width);

for (int i = 0; i < depth; ++i) {
table[i][hash(item, i)] += count;
table[i][buckets[i]] += count;
}

totalCount += count;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.Connection
import java.util.Properties

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.tags.DockerTest

@DockerTest
Expand All @@ -39,20 +39,21 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
override def dataPreparation(conn: Connection): Unit = {
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
conn.setCatalog("foo")
conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate()
conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
+ "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
+ "c10 integer[], c11 text[], c12 real[])").executeUpdate()
+ "c10 integer[], c11 text[], c12 real[], c13 enum_type)").executeUpdate()
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
+ """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}')""").executeUpdate()
+ """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', 'd1')""").executeUpdate()
}

test("Type mapping for various types") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass)
assert(types.length == 13)
assert(types.length == 14)
assert(classOf[String].isAssignableFrom(types(0)))
assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
Expand All @@ -66,22 +67,24 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(classOf[Seq[Int]].isAssignableFrom(types(10)))
assert(classOf[Seq[String]].isAssignableFrom(types(11)))
assert(classOf[Seq[Double]].isAssignableFrom(types(12)))
assert(classOf[String].isAssignableFrom(types(13)))
assert(rows(0).getString(0).equals("hello"))
assert(rows(0).getInt(1) == 42)
assert(rows(0).getDouble(2) == 1.25)
assert(rows(0).getLong(3) == 123456789012345L)
assert(rows(0).getBoolean(4) == false)
assert(!rows(0).getBoolean(4))
// BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's...
assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5),
Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49)))
assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6),
Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte)))
assert(rows(0).getBoolean(7) == true)
assert(rows(0).getBoolean(7))
assert(rows(0).getString(8) == "172.16.0.42")
assert(rows(0).getString(9) == "192.168.0.0/16")
assert(rows(0).getSeq(10) == Seq(1, 2))
assert(rows(0).getSeq(11) == Seq("a", null, "b"))
assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f))
assert(rows(0).getString(13) == "d1")
}

test("Basic write test") {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.optim

import org.apache.spark.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD

/**
* Model fitted by [[IterativelyReweightedLeastSquares]].
* @param coefficients model coefficients
* @param intercept model intercept
*/
private[ml] class IterativelyReweightedLeastSquaresModel(
val coefficients: DenseVector,
val intercept: Double) extends Serializable

/**
* Implements the method of iteratively reweighted least squares (IRLS) which is used to solve
* certain optimization problems by an iterative method. In each step of the iterations, it
* involves solving a weighted lease squares (WLS) problem by [[WeightedLeastSquares]].
* It can be used to find maximum likelihood estimates of a generalized linear model (GLM),
* find M-estimator in robust regression and other optimization problems.
*
* @param initialModel the initial guess model.
* @param reweightFunc the reweight function which is used to update offsets and weights
* at each iteration.
* @param fitIntercept whether to fit intercept.
* @param regParam L2 regularization parameter used by WLS.
* @param maxIter maximum number of iterations.
* @param tol the convergence tolerance.
*
* @see [[http://www.jstor.org/stable/2345503 P. J. Green, Iteratively Reweighted Least Squares
* for Maximum Likelihood Estimation, and some Robust and Resistant Alternatives,
* Journal of the Royal Statistical Society. Series B, 1984.]]
*/
private[ml] class IterativelyReweightedLeastSquares(
val initialModel: WeightedLeastSquaresModel,
val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double),
val fitIntercept: Boolean,
val regParam: Double,
val maxIter: Int,
val tol: Double) extends Logging with Serializable {

def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = {

var converged = false
var iter = 0

var model: WeightedLeastSquaresModel = initialModel
var oldModel: WeightedLeastSquaresModel = null

while (iter < maxIter && !converged) {

oldModel = model

// Update offsets and weights using reweightFunc
val newInstances = instances.map { instance =>
val (newOffset, newWeight) = reweightFunc(instance, oldModel)
Instance(newOffset, newWeight, instance.features)
}

// Estimate new model
model = new WeightedLeastSquares(fitIntercept, regParam, standardizeFeatures = false,
standardizeLabel = false).fit(newInstances)

// Check convergence
val oldCoefficients = oldModel.coefficients
val coefficients = model.coefficients
BLAS.axpy(-1.0, coefficients, oldCoefficients)
val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) =>
math.max(math.abs(x), math.abs(y))
}
val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept))

if (maxTol < tol) {
converged = true
logInfo(s"IRLS converged in $iter iterations.")
}

logInfo(s"Iteration $iter : relative tolerance = $maxTol")
iter = iter + 1

if (iter == maxIter) {
logInfo(s"IRLS reached the max number of iterations: $maxIter.")
}

}

new IterativelyReweightedLeastSquaresModel(model.coefficients, model.intercept)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ import org.apache.spark.rdd.RDD
private[ml] class WeightedLeastSquaresModel(
val coefficients: DenseVector,
val intercept: Double,
val diagInvAtWA: DenseVector) extends Serializable
val diagInvAtWA: DenseVector) extends Serializable {

def predict(features: Vector): Double = {
BLAS.dot(coefficients, features) + intercept
}
}

/**
* Weighted least squares solver via normal equation.
Expand Down
Loading

0 comments on commit c5a97fc

Please sign in to comment.