Skip to content

Commit

Permalink
Merge branch 'master' into gbt_fit_validation
Browse files Browse the repository at this point in the history
  • Loading branch information
WeichenXu123 committed May 8, 2018
2 parents fd99f51 + 76ecd09 commit 54f73af
Show file tree
Hide file tree
Showing 295 changed files with 9,994 additions and 2,556 deletions.
5 changes: 5 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ exportMethods("%<=>%",
"approxCountDistinct",
"approxQuantile",
"array_contains",
"array_max",
"array_min",
"array_position",
"asc",
"ascii",
"asin",
Expand Down Expand Up @@ -245,6 +248,7 @@ exportMethods("%<=>%",
"decode",
"dense_rank",
"desc",
"element_at",
"encode",
"endsWith",
"exp",
Expand All @@ -254,6 +258,7 @@ exportMethods("%<=>%",
"expr",
"factorial",
"first",
"flatten",
"floor",
"format_number",
"format_string",
Expand Down
83 changes: 81 additions & 2 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ NULL
#' the map or array of maps.
#' \item \code{from_json}: it is the column containing the JSON string.
#' }
#' @param value A value to compute on.
#' \itemize{
#' \item \code{array_contains}: a value to be checked if contained in the column.
#' \item \code{array_position}: a value to locate in the given array.
#' }
#' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains
#' additional named properties to control how it is converted, accepts the same
#' options as the JSON data source.
Expand All @@ -201,14 +206,18 @@ NULL
#' df <- createDataFrame(cbind(model = rownames(mtcars), mtcars))
#' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp))
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1)))
#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1)))
#' head(select(tmp, array_position(tmp$v1, 21)))
#' head(select(tmp, flatten(tmp$v1)))
#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1))
#' head(tmp2)
#' head(select(tmp, posexplode(tmp$v1)))
#' head(select(tmp, sort_array(tmp$v1)))
#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))
#' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl))
#' head(select(tmp3, map_keys(tmp3$v3)))
#' head(select(tmp3, map_values(tmp3$v3)))}
#' head(select(tmp3, map_values(tmp3$v3)))
#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))}
NULL

#' Window functions for Column operations
Expand Down Expand Up @@ -2975,7 +2984,6 @@ setMethod("row_number",
#' \code{array_contains}: Returns null if the array is null, true if the array contains
#' the value, and false otherwise.
#'
#' @param value a value to be checked if contained in the column
#' @rdname column_collection_functions
#' @aliases array_contains array_contains,Column-method
#' @note array_contains since 1.6.0
Expand All @@ -2986,6 +2994,61 @@ setMethod("array_contains",
column(jc)
})

#' @details
#' \code{array_max}: Returns the maximum value of the array.
#'
#' @rdname column_collection_functions
#' @aliases array_max array_max,Column-method
#' @note array_max since 2.4.0
setMethod("array_max",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "array_max", x@jc)
column(jc)
})

#' @details
#' \code{array_min}: Returns the minimum value of the array.
#'
#' @rdname column_collection_functions
#' @aliases array_min array_min,Column-method
#' @note array_min since 2.4.0
setMethod("array_min",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "array_min", x@jc)
column(jc)
})

#' @details
#' \code{array_position}: Locates the position of the first occurrence of the given value
#' in the given array. Returns NA if either of the arguments are NA.
#' Note: The position is not zero based, but 1 based index. Returns 0 if the given
#' value could not be found in the array.
#'
#' @rdname column_collection_functions
#' @aliases array_position array_position,Column-method
#' @note array_position since 2.4.0
setMethod("array_position",
signature(x = "Column", value = "ANY"),
function(x, value) {
jc <- callJStatic("org.apache.spark.sql.functions", "array_position", x@jc, value)
column(jc)
})

#' @details
#' \code{flatten}: Transforms an array of arrays into a single array.
#'
#' @rdname column_collection_functions
#' @aliases flatten flatten,Column-method
#' @note flatten since 2.4.0
setMethod("flatten",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "flatten", x@jc)
column(jc)
})

#' @details
#' \code{map_keys}: Returns an unordered array containing the keys of the map.
#'
Expand All @@ -3012,6 +3075,22 @@ setMethod("map_values",
column(jc)
})

#' @details
#' \code{element_at}: Returns element of array at given index in \code{extraction} if
#' \code{x} is array. Returns value for the given key in \code{extraction} if \code{x} is map.
#' Note: The position is not zero based, but 1 based index.
#'
#' @param extraction index to check for in array or key to check for in map
#' @rdname column_collection_functions
#' @aliases element_at element_at,Column-method
#' @note element_at since 2.4.0
setMethod("element_at",
signature(x = "Column", extraction = "ANY"),
function(x, extraction) {
jc <- callJStatic("org.apache.spark.sql.functions", "element_at", x@jc, extraction)
column(jc)
})

#' @details
#' \code{explode}: Creates a new row for each element in the given array or map column.
#'
Expand Down
20 changes: 20 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,18 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun
#' @name NULL
setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_max", function(x) { standardGeneric("array_max") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_min", function(x) { standardGeneric("array_min") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_position", function(x, value) { standardGeneric("array_position") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("ascii", function(x) { standardGeneric("ascii") })
Expand Down Expand Up @@ -886,6 +898,10 @@ setGeneric("decode", function(x, charset) { standardGeneric("decode") })
#' @name NULL
setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("element_at", function(x, extraction) { standardGeneric("element_at") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("encode", function(x, charset) { standardGeneric("encode") })
Expand All @@ -902,6 +918,10 @@ setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") })
#' @name NULL
setGeneric("expr", function(x) { standardGeneric("expr") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("flatten", function(x) { standardGeneric("flatten") })

#' @rdname column_datetime_diff_functions
#' @name NULL
setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") })
Expand Down
26 changes: 24 additions & 2 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1479,24 +1479,46 @@ test_that("column functions", {
df5 <- createDataFrame(list(list(a = "010101")))
expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15")

# Test array_contains() and sort_array()
# Test array_contains(), array_max(), array_min(), array_position(), element_at()
# and sort_array()
df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L))))
result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]]
expect_equal(result, c(TRUE, FALSE))

result <- collect(select(df, array_max(df[[1]])))[[1]]
expect_equal(result, c(3, 6))

result <- collect(select(df, array_min(df[[1]])))[[1]]
expect_equal(result, c(1, 4))

result <- collect(select(df, array_position(df[[1]], 1L)))[[1]]
expect_equal(result, c(1, 0))

result <- collect(select(df, element_at(df[[1]], 1L)))[[1]]
expect_equal(result, c(1, 6))

result <- collect(select(df, sort_array(df[[1]], FALSE)))[[1]]
expect_equal(result, list(list(3L, 2L, 1L), list(6L, 5L, 4L)))
result <- collect(select(df, sort_array(df[[1]])))[[1]]
expect_equal(result, list(list(1L, 2L, 3L), list(4L, 5L, 6L)))

# Test map_keys() and map_values()
# Test flattern
df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))),
list(list(list(5L, 6L), list(7L, 8L)))))
result <- collect(select(df, flatten(df[[1]])))[[1]]
expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L)))

# Test map_keys(), map_values() and element_at()
df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2)))))
result <- collect(select(df, map_keys(df$map)))[[1]]
expect_equal(result, list(list("x", "y")))

result <- collect(select(df, map_values(df$map)))[[1]]
expect_equal(result, list(list(1, 2)))

result <- collect(select(df, element_at(df$map, "y")))[[1]]
expect_equal(result, 2)

# Test that stats::lag is working
expect_equal(length(lag(ldeaths, 12)), 72)

Expand Down
8 changes: 8 additions & 0 deletions assembly/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,14 @@
<artifactId>spark-hadoop-cloud_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<!--
Redeclare this dependency to force it into the distribution.
-->
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-util</artifactId>
<scope>${hadoop.deps.scope}</scope>
</dependency>
</dependencies>
</profile>
</profiles>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import org.apache.commons.lang3.SystemUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -98,6 +99,7 @@ private void init(String hostToBind, int portToBind) {
.group(bossGroup, workerGroup)
.channel(NettyUtils.getServerChannelClass(ioMode))
.option(ChannelOption.ALLOCATOR, allocator)
.option(ChannelOption.SO_REUSEADDR, !SystemUtils.IS_OS_WINDOWS)
.childOption(ChannelOption.ALLOCATOR, allocator);

this.metrics = new NettyMemoryMetrics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ public static long nextPowerOf2(long num) {
}

public static int roundNumberOfBytesToNearestWord(int numBytes) {
int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
return (int)roundNumberOfBytesToNearestWord((long)numBytes);
}

public static long roundNumberOfBytesToNearestWord(long numBytes) {
long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
if (remainder == 0) {
return numBytes;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.unsafe.types;

import org.apache.spark.unsafe.Platform;

import java.util.Arrays;

import com.google.common.primitives.Ints;

import org.apache.spark.unsafe.Platform;

public final class ByteArray {

public static final byte[] EMPTY_BYTE = new byte[0];
Expand Down Expand Up @@ -77,17 +79,17 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) {

public static byte[] concat(byte[]... inputs) {
// Compute the total length of the result
int totalLength = 0;
long totalLength = 0;
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
totalLength += inputs[i].length;
totalLength += (long)inputs[i].length;
} else {
return null;
}
}

// Allocate a new byte array, and copy the inputs one by one into it
final byte[] result = new byte[totalLength];
final byte[] result = new byte[Ints.checkedCast(totalLength)];
int offset = 0;
for (int i = 0; i < inputs.length; i++) {
int len = inputs[i].length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;

import com.google.common.primitives.Ints;

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
Expand Down Expand Up @@ -877,17 +877,17 @@ public UTF8String lpad(int len, UTF8String pad) {
*/
public static UTF8String concat(UTF8String... inputs) {
// Compute the total length of the result.
int totalLength = 0;
long totalLength = 0;
for (int i = 0; i < inputs.length; i++) {
if (inputs[i] != null) {
totalLength += inputs[i].numBytes;
totalLength += (long)inputs[i].numBytes;
} else {
return null;
}
}

// Allocate a new byte array, and copy the inputs one by one into it.
final byte[] result = new byte[totalLength];
final byte[] result = new byte[Ints.checkedCast(totalLength)];
int offset = 0;
for (int i = 0; i < inputs.length; i++) {
int len = inputs[i].numBytes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ private void check(MemoryBlock memory, Object obj, long offset, int length) {
} catch (Exception expected) {
Assert.assertThat(expected.getMessage(), containsString("should not be larger than"));
}

memory.setPageNumber(MemoryBlock.NO_PAGE_NUMBER);
}

@Test
Expand Down Expand Up @@ -165,11 +167,13 @@ public void testOffHeapArrayMemoryBlock() {
int length = 56;

check(memory, obj, offset, length);
memoryAllocator.free(memory);

long address = Platform.allocateMemory(112);
memory = new OffHeapMemoryBlock(address, length);
obj = memory.getBaseObject();
offset = memory.getBaseOffset();
check(memory, obj, offset, length);
Platform.freeMemory(address);
}
}
6 changes: 6 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@
<groupId>org.apache.curator</groupId>
<artifactId>curator-recipes</artifactId>
</dependency>
<!-- With curator 2.12 SBT/Ivy doesn't get ZK on the build classpath.
Explicitly declaring it as a dependency fixes this. -->
<dependency>
<groupId>org.apache.zookeeper</groupId>
<artifactId>zookeeper</artifactId>
</dependency>

<!-- Jetty dependencies promoted to compile here so they are shaded
and inlined into spark-core jar -->
Expand Down
Loading

0 comments on commit 54f73af

Please sign in to comment.