Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into refreshInsertInto…
Browse files Browse the repository at this point in the history
…HiveTable
  • Loading branch information
gatorsmile committed Jan 14, 2017
2 parents 203e36c + ad0dada commit d2d751b
Show file tree
Hide file tree
Showing 14 changed files with 558 additions and 43 deletions.
20 changes: 14 additions & 6 deletions R/pkg/R/SQLContext.R
Expand Up @@ -184,8 +184,11 @@ getDefaultSqlSource <- function() {
#'
#' Converts R data.frame or list into SparkDataFrame.
#'
#' @param data an RDD or list or data.frame.
#' @param data a list or data.frame.
#' @param schema a list of column names or named list (StructType), optional.
#' @param samplingRatio Currently not used.
#' @param numPartitions the number of partitions of the SparkDataFrame. Defaults to 1, this is
#' limited by length of the list or number of rows of the data.frame
#' @return A SparkDataFrame.
#' @rdname createDataFrame
#' @export
Expand All @@ -195,12 +198,14 @@ getDefaultSqlSource <- function() {
#' df1 <- as.DataFrame(iris)
#' df2 <- as.DataFrame(list(3,4,5,6))
#' df3 <- createDataFrame(iris)
#' df4 <- createDataFrame(cars, numPartitions = 2)
#' }
#' @name createDataFrame
#' @method createDataFrame default
#' @note createDataFrame since 1.4.0
# TODO(davies): support sampling and infer type from NA
createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0,
numPartitions = NULL) {
sparkSession <- getSparkSession()

if (is.data.frame(data)) {
Expand Down Expand Up @@ -233,7 +238,11 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {

if (is.list(data)) {
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
rdd <- parallelize(sc, data)
if (!is.null(numPartitions)) {
rdd <- parallelize(sc, data, numSlices = numToInt(numPartitions))
} else {
rdd <- parallelize(sc, data, numSlices = 1)
}
} else if (inherits(data, "RDD")) {
rdd <- data
} else {
Expand Down Expand Up @@ -283,14 +292,13 @@ createDataFrame <- function(x, ...) {
dispatchFunc("createDataFrame(data, schema = NULL)", x, ...)
}

#' @param samplingRatio Currently not used.
#' @rdname createDataFrame
#' @aliases createDataFrame
#' @export
#' @method as.DataFrame default
#' @note as.DataFrame since 1.6.0
as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
createDataFrame(data, schema)
as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) {
createDataFrame(data, schema, samplingRatio, numPartitions)
}

#' @param ... additional argument(s).
Expand Down
39 changes: 34 additions & 5 deletions R/pkg/R/context.R
Expand Up @@ -91,6 +91,16 @@ objectFile <- function(sc, path, minPartitions = NULL) {
#' will write it to disk and send the file name to JVM. Also to make sure each slice is not
#' larger than that limit, number of slices may be increased.
#'
#' In 2.2.0 we are changing how the numSlices are used/computed to handle
#' 1 < (length(coll) / numSlices) << length(coll) better, and to get the exact number of slices.
#' This change affects both createDataFrame and spark.lapply.
#' In the specific one case that it is used to convert R native object into SparkDataFrame, it has
#' always been kept at the default of 1. In the case the object is large, we are explicitly setting
#' the parallism to numSlices (which is still 1).
#'
#' Specifically, we are changing to split positions to match the calculation in positions() of
#' ParallelCollectionRDD in Spark.
#'
#' @param sc SparkContext to use
#' @param coll collection to parallelize
#' @param numSlices number of partitions to create in the RDD
Expand All @@ -107,6 +117,8 @@ parallelize <- function(sc, coll, numSlices = 1) {
# TODO: bound/safeguard numSlices
# TODO: unit tests for if the split works for all primitives
# TODO: support matrix, data frame, etc

# Note, for data.frame, createDataFrame turns it into a list before it calls here.
# nolint start
# suppress lintr warning: Place a space before left parenthesis, except in a function call.
if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) {
Expand All @@ -128,12 +140,29 @@ parallelize <- function(sc, coll, numSlices = 1) {
objectSize <- object.size(coll)

# For large objects we make sure the size of each slice is also smaller than sizeLimit
numSlices <- max(numSlices, ceiling(objectSize / sizeLimit))
if (numSlices > length(coll))
numSlices <- length(coll)
numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit))
if (numSerializedSlices > length(coll))
numSerializedSlices <- length(coll)

# Generate the slice ids to put each row
# For instance, for numSerializedSlices of 22, length of 50
# [1] 0 0 2 2 4 4 6 6 6 9 9 11 11 13 13 15 15 15 18 18 20 20 22 22 22
# [26] 25 25 27 27 29 29 31 31 31 34 34 36 36 38 38 40 40 40 43 43 45 45 47 47 47
# Notice the slice group with 3 slices (ie. 6, 15, 22) are roughly evenly spaced.
# We are trying to reimplement the calculation in the positions method in ParallelCollectionRDD
splits <- if (numSerializedSlices > 0) {
unlist(lapply(0: (numSerializedSlices - 1), function(x) {
# nolint start
start <- trunc((x * length(coll)) / numSerializedSlices)
end <- trunc(((x + 1) * length(coll)) / numSerializedSlices)
# nolint end
rep(start, end - start)
}))
} else {
1
}

sliceLen <- ceiling(length(coll) / numSlices)
slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)])
slices <- split(coll, splits)

# Serialize each slice: obtain a list of raws, or a list of lists (slices) of
# 2-tuples of raws
Expand Down
4 changes: 2 additions & 2 deletions R/pkg/inst/tests/testthat/test_rdd.R
Expand Up @@ -381,8 +381,8 @@ test_that("aggregateRDD() on RDDs", {
test_that("zipWithUniqueId() on RDDs", {
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
actual <- collectRDD(zipWithUniqueId(rdd))
expected <- list(list("a", 0), list("b", 3), list("c", 1),
list("d", 4), list("e", 2))
expected <- list(list("a", 0), list("b", 1), list("c", 4),
list("d", 2), list("e", 5))
expect_equal(actual, expected)

rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L)
Expand Down
23 changes: 22 additions & 1 deletion R/pkg/inst/tests/testthat/test_sparkSQL.R
Expand Up @@ -196,6 +196,26 @@ test_that("create DataFrame from RDD", {
expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float")))
expect_equal(as.list(collect(where(df, df$name == "John"))),
list(name = "John", age = 19L, height = 176.5))
expect_equal(getNumPartitions(toRDD(df)), 1)

df <- as.DataFrame(cars, numPartitions = 2)
expect_equal(getNumPartitions(toRDD(df)), 2)
df <- createDataFrame(cars, numPartitions = 3)
expect_equal(getNumPartitions(toRDD(df)), 3)
# validate limit by num of rows
df <- createDataFrame(cars, numPartitions = 60)
expect_equal(getNumPartitions(toRDD(df)), 50)
# validate when 1 < (length(coll) / numSlices) << length(coll)
df <- createDataFrame(cars, numPartitions = 20)
expect_equal(getNumPartitions(toRDD(df)), 20)

df <- as.DataFrame(data.frame(0))
expect_is(df, "SparkDataFrame")
df <- createDataFrame(list(list(1)))
expect_is(df, "SparkDataFrame")
df <- as.DataFrame(data.frame(0), numPartitions = 2)
# no data to partition, goes to 1
expect_equal(getNumPartitions(toRDD(df)), 1)

setHiveContext(sc)
sql("CREATE TABLE people (name string, age double, height float)")
Expand All @@ -213,7 +233,8 @@ test_that("createDataFrame uses files for large objects", {
# To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value
conf <- callJMethod(sparkSession, "conf")
callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100")
df <- suppressWarnings(createDataFrame(iris))
df <- suppressWarnings(createDataFrame(iris, numPartitions = 3))
expect_equal(getNumPartitions(toRDD(df)), 3)

# Resetting the conf back to default value
callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10))
Expand Down
Expand Up @@ -835,6 +835,190 @@ public UTF8String translate(Map<Character, Character> dict) {
return fromString(sb.toString());
}

private int getDigit(byte b) {
if (b >= '0' && b <= '9') {
return b - '0';
}
throw new NumberFormatException(toString());
}

/**
* Parses this UTF8String to long.
*
* Note that, in this method we accumulate the result in negative format, and convert it to
* positive format at the end, if this string is not started with '-'. This is because min value
* is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
* Integer.MIN_VALUE is '-2147483648'.
*
* This code is mostly copied from LazyLong.parseLong in Hive.
*/
public long toLong() {
if (numBytes == 0) {
throw new NumberFormatException("Empty string");
}

byte b = getByte(0);
final boolean negative = b == '-';
int offset = 0;
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
throw new NumberFormatException(toString());
}
}

final byte separator = '.';
final int radix = 10;
final long stopValue = Long.MIN_VALUE / radix;
long result = 0;

while (offset < numBytes) {
b = getByte(offset);
offset++;
if (b == separator) {
// We allow decimals and will return a truncated integral in that case.
// Therefore we won't throw an exception here (checking the fractional
// part happens below.)
break;
}

int digit = getDigit(b);
// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
if (result < stopValue) {
throw new NumberFormatException(toString());
}

result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we
// can just use `result > 0` to check overflow. If result overflows, we should stop and throw
// exception.
if (result > 0) {
throw new NumberFormatException(toString());
}
}

// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
if (getDigit(getByte(offset)) == -1) {
throw new NumberFormatException(toString());
}
offset++;
}

if (!negative) {
result = -result;
if (result < 0) {
throw new NumberFormatException(toString());
}
}

return result;
}

/**
* Parses this UTF8String to int.
*
* Note that, in this method we accumulate the result in negative format, and convert it to
* positive format at the end, if this string is not started with '-'. This is because min value
* is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
* Integer.MIN_VALUE is '-2147483648'.
*
* This code is mostly copied from LazyInt.parseInt in Hive.
*
* Note that, this method is almost same as `toLong`, but we leave it duplicated for performance
* reasons, like Hive does.
*/
public int toInt() {
if (numBytes == 0) {
throw new NumberFormatException("Empty string");
}

byte b = getByte(0);
final boolean negative = b == '-';
int offset = 0;
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
throw new NumberFormatException(toString());
}
}

final byte separator = '.';
final int radix = 10;
final int stopValue = Integer.MIN_VALUE / radix;
int result = 0;

while (offset < numBytes) {
b = getByte(offset);
offset++;
if (b == separator) {
// We allow decimals and will return a truncated integral in that case.
// Therefore we won't throw an exception here (checking the fractional
// part happens below.)
break;
}

int digit = getDigit(b);
// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
if (result < stopValue) {
throw new NumberFormatException(toString());
}

result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
// we can just use `result > 0` to check overflow. If result overflows, we should stop and
// throw exception.
if (result > 0) {
throw new NumberFormatException(toString());
}
}

// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
if (getDigit(getByte(offset)) == -1) {
throw new NumberFormatException(toString());
}
offset++;
}

if (!negative) {
result = -result;
if (result < 0) {
throw new NumberFormatException(toString());
}
}

return result;
}

public short toShort() {
int intValue = toInt();
short result = (short) intValue;
if (result != intValue) {
throw new NumberFormatException(toString());
}

return result;
}

public byte toByte() {
int intValue = toInt();
byte result = (byte) intValue;
if (result != intValue) {
throw new NumberFormatException(toString());
}

return result;
}

@Override
public String toString() {
return new String(getBytes(), StandardCharsets.UTF_8);
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/context.py
Expand Up @@ -73,7 +73,7 @@ def __init__(self, sparkContext, sparkSession=None, jsqlContext=None):
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
if sparkSession is None:
sparkSession = SparkSession(sparkContext)
sparkSession = SparkSession.builder.getOrCreate()
if jsqlContext is None:
jsqlContext = sparkSession._jwrapped
self.sparkSession = sparkSession
Expand Down
7 changes: 6 additions & 1 deletion python/pyspark/sql/tests.py
Expand Up @@ -47,7 +47,7 @@
import unittest

from pyspark import SparkContext
from pyspark.sql import SparkSession, HiveContext, Column, Row
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests
Expand Down Expand Up @@ -206,6 +206,11 @@ def tearDownClass(cls):
cls.spark.stop()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)

def test_sqlcontext_reuses_sparksession(self):
sqlContext1 = SQLContext(self.sc)
sqlContext2 = SQLContext(self.sc)
self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)

def test_row_should_be_read_only(self):
row = Row(a=1, b=2)
self.assertEqual(1, row.a)
Expand Down

0 comments on commit d2d751b

Please sign in to comment.