Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into sql-external-sort
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 6, 2015
2 parents 5822e6f + 132e7fc commit 9969c14
Show file tree
Hide file tree
Showing 17 changed files with 672 additions and 175 deletions.
2 changes: 1 addition & 1 deletion R/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R

#### Build Spark

Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run
Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run
```
build/mvn -DskipTests -Psparkr package
```
Expand Down
107 changes: 107 additions & 0 deletions examples/src/main/r/data-manipulation.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# For this example, we shall use the "flights" dataset
# The dataset consists of every flight departing Houston in 2011.
# The data set is made up of 227,496 rows x 14 columns.

# To run this example use
# ./bin/sparkR --packages com.databricks:spark-csv_2.10:1.0.3
# examples/src/main/r/data-manipulation.R <path_to_csv>

# Load SparkR library into your R session
library(SparkR)

args <- commandArgs(trailing = TRUE)

if (length(args) != 1) {
print("Usage: data-manipulation.R <path-to-flights.csv")
print("The data can be downloaded from: http://s3-us-west-2.amazonaws.com/sparkr-data/flights.csv ")
q("no")
}

## Initialize SparkContext
sc <- sparkR.init(appName = "SparkR-data-manipulation-example")

## Initialize SQLContext
sqlContext <- sparkRSQL.init(sc)

flightsCsvPath <- args[[1]]

# Create a local R dataframe
flights_df <- read.csv(flightsCsvPath, header = TRUE)
flights_df$date <- as.Date(flights_df$date)

## Filter flights whose destination is San Francisco and write to a local data frame
SFO_df <- flights_df[flights_df$dest == "SFO", ]

# Convert the local data frame into a SparkR DataFrame
SFO_DF <- createDataFrame(sqlContext, SFO_df)

# Directly create a SparkR DataFrame from the source data
flightsDF <- read.df(sqlContext, flightsCsvPath, source = "com.databricks.spark.csv", header = "true")

# Print the schema of this Spark DataFrame
printSchema(flightsDF)

# Cache the DataFrame
cache(flightsDF)

# Print the first 6 rows of the DataFrame
showDF(flightsDF, numRows = 6) ## Or
head(flightsDF)

# Show the column names in the DataFrame
columns(flightsDF)

# Show the number of rows in the DataFrame
count(flightsDF)

# Select specific columns
destDF <- select(flightsDF, "dest", "cancelled")

# Using SQL to select columns of data
# First, register the flights DataFrame as a table
registerTempTable(flightsDF, "flightsTable")
destDF <- sql(sqlContext, "SELECT dest, cancelled FROM flightsTable")

# Use collect to create a local R data frame
local_df <- collect(destDF)

# Print the newly created local data frame
head(local_df)

# Filter flights whose destination is JFK
jfkDF <- filter(flightsDF, "dest = \"JFK\"") ##OR
jfkDF <- filter(flightsDF, flightsDF$dest == "JFK")

# If the magrittr library is available, we can use it to
# chain data frame operations
if("magrittr" %in% rownames(installed.packages())) {
library(magrittr)

# Group the flights by date and then find the average daily delay
# Write the result into a DataFrame
groupBy(flightsDF, flightsDF$date) %>%
summarize(avg(flightsDF$dep_delay), avg(flightsDF$arr_delay)) -> dailyDelayDF

# Print the computed data frame
head(dailyDelayDF)
}

# Stop the SparkContext now
sparkR.stop()
28 changes: 28 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,34 @@ def randn(seed=None):
return Column(jc)


@ignore_unicode_prefix
@since(1.5)
def hex(col):
"""Computes hex value of the given column, which could be StringType,
BinaryType, IntegerType or LongType.
>>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect()
[Row(hex(a)=u'414243', hex(b)=u'3')]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.hex(_to_java_column(col))
return Column(jc)


@ignore_unicode_prefix
@since(1.5)
def unhex(col):
"""Inverse of hex. Interprets each pair of characters as a hexadecimal number
and converts to the byte representation of number.
>>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect()
[Row(unhex(a)=bytearray(b'ABC'))]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.unhex(_to_java_column(col))
return Column(jc)


@ignore_unicode_prefix
@since(1.5)
def sha1(col):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,15 +287,18 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
throw new AnalysisException(s"invalid function approximate($floatLit) $udfName")
}
}
| CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~
(ELSE ~> expression).? <~ END ^^ {
case casePart ~ altPart ~ elsePart =>
val branches = altPart.flatMap { case whenExpr ~ thenExpr =>
Seq(whenExpr, thenExpr)
} ++ elsePart
casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches))
}
)
| CASE ~> whenThenElse ^^ CaseWhen
| CASE ~> expression ~ whenThenElse ^^
{ case keyPart ~ branches => CaseKeyWhen(keyPart, branches) }
)

protected lazy val whenThenElse: Parser[List[Expression]] =
rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ {
case altPart ~ elsePart =>
altPart.flatMap { case whenExpr ~ thenExpr =>
Seq(whenExpr, thenExpr)
} ++ elsePart
}

protected lazy val cast: Parser[Expression] =
CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ {
Expand Down Expand Up @@ -354,6 +357,11 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val signedPrimary: Parser[Expression] =
sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e}

protected lazy val attributeName: Parser[String] = acceptMatch("attribute name", {
case lexical.Identifier(str) => str
case lexical.Keyword(str) if !lexical.delimiters.contains(str) => str
})

protected lazy val primary: PackratParser[Expression] =
( literal
| expression ~ ("[" ~> expression <~ "]") ^^
Expand All @@ -364,9 +372,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
| "(" ~> expression <~ ")"
| function
| dotExpressionHeader
| ident ^^ {case i => UnresolvedAttribute.quoted(i)}
| signedPrimary
| "~" ~> expression ^^ BitwiseNot
| attributeName ^^ UnresolvedAttribute.quoted
)

protected lazy val dotExpressionHeader: Parser[Expression] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ object FunctionRegistry {
expression[Substring]("substring"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
expression[UnHex]("unhex"),
expression[Unhex]("unhex"),
expression[Upper]("upper"),

// datetime functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,21 @@ case class Bin(child: Expression)
}
}

object Hex {
val hexDigits = Array[Char](
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
).map(_.toByte)

// lookup table to translate '0' -> 0 ... 'F'/'f' -> 15
val unhexDigits = {
val array = Array.fill[Byte](128)(-1)
(0 to 9).foreach(i => array('0' + i) = i.toByte)
(0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
(0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
array
}
}

/**
* If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format.
* Otherwise if the number is a STRING, it converts each character into its hex representation
Expand All @@ -307,7 +322,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes
// TODO: Create code-gen version.

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, StringType, BinaryType))
Seq(TypeCollection(LongType, BinaryType, StringType))

override def dataType: DataType = StringType

Expand All @@ -319,30 +334,18 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes
child.dataType match {
case LongType => hex(num.asInstanceOf[Long])
case BinaryType => hex(num.asInstanceOf[Array[Byte]])
case StringType => hex(num.asInstanceOf[UTF8String])
case StringType => hex(num.asInstanceOf[UTF8String].getBytes)
}
}
}

/**
* Converts every character in s to two hex digits.
*/
private def hex(str: UTF8String): UTF8String = {
hex(str.getBytes)
}

private def hex(bytes: Array[Byte]): UTF8String = {
doHex(bytes, bytes.length)
}

private def doHex(bytes: Array[Byte], length: Int): UTF8String = {
private[this] def hex(bytes: Array[Byte]): UTF8String = {
val length = bytes.length
val value = new Array[Byte](length * 2)
var i = 0
while (i < length) {
value(i * 2) = Character.toUpperCase(Character.forDigit(
(bytes(i) & 0xF0) >>> 4, 16)).toByte
value(i * 2 + 1) = Character.toUpperCase(Character.forDigit(
bytes(i) & 0x0F, 16)).toByte
value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4)
value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F)
i += 1
}
UTF8String.fromBytes(value)
Expand All @@ -355,24 +358,23 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes
var len = 0
do {
len += 1
value(value.length - len) =
Character.toUpperCase(Character.forDigit((numBuf & 0xF).toInt, 16)).toByte
value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt)
numBuf >>>= 4
} while (numBuf != 0)
UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length))
}
}


/**
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
*/
case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
// TODO: Create code-gen version.

override def inputTypes: Seq[AbstractDataType] = Seq(StringType)

override def nullable: Boolean = true
override def dataType: DataType = BinaryType

override def eval(input: InternalRow): Any = {
Expand All @@ -384,26 +386,31 @@ case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTyp
}
}

private val unhexDigits = {
val array = Array.fill[Byte](128)(-1)
(0 to 9).foreach(i => array('0' + i) = i.toByte)
(0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
(0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
array
}

private def unhex(inputBytes: Array[Byte]): Array[Byte] = {
var bytes = inputBytes
private[this] def unhex(bytes: Array[Byte]): Array[Byte] = {
val out = new Array[Byte]((bytes.length + 1) >> 1)
var i = 0
if ((bytes.length & 0x01) != 0) {
bytes = '0'.toByte +: bytes
// padding with '0'
if (bytes(0) < 0) {
return null
}
val v = Hex.unhexDigits(bytes(0))
if (v == -1) {
return null
}
out(0) = v
i += 1
}
val out = new Array[Byte](bytes.length >> 1)
// two characters form the hex value.
var i = 0
while (i < bytes.length) {
val first = unhexDigits(bytes(i))
val second = unhexDigits(bytes(i + 1))
if (first == -1 || second == -1) { return null}
if (bytes(i) < 0 || bytes(i + 1) < 0) {
return null
}
val first = Hex.unhexDigits(bytes(i))
val second = Hex.unhexDigits(bytes(i + 1))
if (first == -1 || second == -1) {
return null
}
out(i / 2) = (((first << 4) | second) & 0xFF).toByte
i += 2
}
Expand Down
Loading

0 comments on commit 9969c14

Please sign in to comment.