Skip to content

Commit

Permalink
Add LimitNode and UnionNode
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Aug 26, 2015
1 parent 99433c1 commit 4e101ee
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute


case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode {

override def output: Seq[Attribute] = child.output

override def execute(): OpenCloseRowIterator = new OpenCloseRowIterator {

private var count = 0

private val childIter = child.execute()

override def open(): Unit = childIter.open()

override def close(): Unit = childIter.close()

override def getRow: InternalRow = childIter.getRow

override def advanceNext(): Boolean = {
if (count < limit) {
count += 1
childIter.advanceNext()
} else {
false
}
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute

case class UnionNode(children: Seq[LocalNode]) extends LocalNode {

override def output: Seq[Attribute] = children.head.output

override def execute(): OpenCloseRowIterator = new OpenCloseRowIterator {

private var currentIter: OpenCloseRowIterator = _

private var nextChildIndex: Int = _

override def open(): Unit = {
currentIter = children.head.execute()
currentIter.open()
nextChildIndex = 1
}

private def advanceToNextChild(): Boolean = {
var found = false
var exit = false
while (!exit && !found) {
if (currentIter != null) {
currentIter.close()
}
if (nextChildIndex >= children.size) {
found = false
exit = true
} else {
currentIter = children(nextChildIndex).execute()
nextChildIndex += 1
currentIter.open()
found = currentIter.advanceNext()
}
}
found
}

override def close(): Unit = {
if (currentIter != null) {
currentIter.close()
}
}

override def getRow: InternalRow = currentIter.getRow

override def advanceNext(): Boolean = {
if (currentIter.advanceNext()) {
true
} else {
advanceToNextChild()
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.test.SharedSQLContext

class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {

test("basic") {
checkAnswer(testData,
node => LimitNode(10, node),
testData.limit(10).collect()
)
}

test("empty") {
checkAnswer(emptyTestData,
node => LimitNode(10, node),
emptyTestData.limit(10).collect()
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.local

import org.apache.spark.sql.test.SharedSQLContext

class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {

test("basic") {
checkAnswer2(
testData,
testData,
(node1, node2) => UnionNode(Seq(node1, node2)),
testData.unionAll(testData).collect()
)
}

test("empty") {
checkAnswer2(
emptyTestData,
emptyTestData,
(node1, node2) => UnionNode(Seq(node1, node2)),
emptyTestData.unionAll(emptyTestData).collect()
)
}

test("complicated union") {
val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData,
emptyTestData, emptyTestData, testData, emptyTestData)
doCheckAnswer(
dfs,
nodes => UnionNode(nodes),
dfs.reduce(_.unionAll(_)).collect()
)
}

}

0 comments on commit 4e101ee

Please sign in to comment.