Skip to content

Commit

Permalink
[SPARK-28153][PYTHON] Use AtomicReference at InputFileBlockHolder (to…
Browse files Browse the repository at this point in the history
… support input_file_name with Python UDF)

## What changes were proposed in this pull request?

This PR proposes to use `AtomicReference` so that parent and child threads can access to the same file block holder.

Python UDF expressions are turned to a plan and then it launches a separate thread to consume the input iterator. In the separate child thread, the iterator sets `InputFileBlockHolder.set` before the parent does which the parent thread is unable to read later.

1. In this separate child thread, if it happens to call `InputFileBlockHolder.set` first without initialization of the parent's thread local (which is done when the `ThreadLocal.get()` is first called), the child thread seems calling its own `initialValue` to initialize.

2. After that, the parent calls its own `initialValue` to initializes at the first call of `ThreadLocal.get()`.

3. Both now have two different references. Updating at child isn't reflected to parent.

This PR fixes it via initializing parent's thread local with `AtomicReference` for file status so that they can be used in each task, and children thread's update is reflected.

I also tried to explain this a bit more at #24958 (comment).

## How was this patch tested?

Manually tested and unittest was added.

Closes #24958 from HyukjinKwon/SPARK-28153.

Authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
HyukjinKwon authored and cloud-fan committed Jul 31, 2019
1 parent d03ec65 commit b8e13b0
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.rdd

import java.util.concurrent.atomic.AtomicReference

import org.apache.spark.unsafe.types.UTF8String

/**
Expand All @@ -40,26 +42,33 @@ private[spark] object InputFileBlockHolder {
/**
* The thread variable for the name of the current file being read. This is used by
* the InputFileName function in Spark SQL.
*
* @note `inputBlock` works somewhat complicatedly. It guarantees that `initialValue`
* is called at the start of a task. Therefore, one atomic reference is created in the task
* thread. After that, read and write happen to the same atomic reference across the parent and
* children threads. This is in order to support a case where write happens in a child thread
* but read happens at its parent thread, for instance, Python UDF execution. See SPARK-28153.
*/
private[this] val inputBlock: InheritableThreadLocal[FileBlock] =
new InheritableThreadLocal[FileBlock] {
override protected def initialValue(): FileBlock = new FileBlock
private[this] val inputBlock: InheritableThreadLocal[AtomicReference[FileBlock]] =
new InheritableThreadLocal[AtomicReference[FileBlock]] {
override protected def initialValue(): AtomicReference[FileBlock] =
new AtomicReference(new FileBlock)
}

/**
* Returns the holding file name or empty string if it is unknown.
*/
def getInputFilePath: UTF8String = inputBlock.get().filePath
def getInputFilePath: UTF8String = inputBlock.get().get().filePath

/**
* Returns the starting offset of the block currently being read, or -1 if it is unknown.
*/
def getStartOffset: Long = inputBlock.get().startOffset
def getStartOffset: Long = inputBlock.get().get().startOffset

/**
* Returns the length of the block being read, or -1 if it is unknown.
*/
def getLength: Long = inputBlock.get().length
def getLength: Long = inputBlock.get().get().length

/**
* Sets the thread-local input block.
Expand All @@ -68,11 +77,17 @@ private[spark] object InputFileBlockHolder {
require(filePath != null, "filePath cannot be null")
require(startOffset >= 0, s"startOffset ($startOffset) cannot be negative")
require(length >= 0, s"length ($length) cannot be negative")
inputBlock.set(new FileBlock(UTF8String.fromString(filePath), startOffset, length))
inputBlock.get().set(new FileBlock(UTF8String.fromString(filePath), startOffset, length))
}

/**
* Clears the input file block to default value.
*/
def unset(): Unit = inputBlock.remove()

/**
* Initializes thread local by explicitly getting the value. It triggers ThreadLocal's
* initialValue in the parent thread.
*/
def initialize(): Unit = inputBlock.get()
}
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ private[spark] abstract class Task[T](
taskContext
}

InputFileBlockHolder.initialize()
TaskContext.setTaskContext(context)
taskThread = Thread.currentThread()

Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ def test_array_repeat(self):
df.select(array_repeat("id", lit(3))).toDF("val").collect(),
)

def test_input_file_name_udf(self):
df = self.spark.read.text('python/test_support/hello/hello.txt')
df = df.select(udf(lambda x: x)("value"), input_file_name().alias('file'))
file_name = df.collect()[0].file
self.assertTrue("python/test_support/hello/hello.txt" in file_name)


if __name__ == "__main__":
import unittest
Expand Down

0 comments on commit b8e13b0

Please sign in to comment.