Skip to content

Commit

Permalink
Updates based on Marcelo's review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Oct 14, 2014
1 parent 7a1417f commit e8e2867
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 50 deletions.
11 changes: 6 additions & 5 deletions core/src/main/scala/org/apache/spark/FutureAction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import java.util.Collections
import java.util.concurrent.TimeUnit

import org.apache.spark.api.java.JavaFutureAction
Expand Down Expand Up @@ -285,12 +286,12 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S

override def isDone: Boolean = {
// According to java.util.Future's Javadoc, this returns True if the task was completed,
// whether that completion was due to succesful execution, an exception, or a cancellation.
// whether that completion was due to successful execution, an exception, or a cancellation.
futureAction.isCancelled || futureAction.isCompleted
}

override def jobIds(): java.util.List[java.lang.Integer] = {
new java.util.ArrayList(futureAction.jobIds.map(x => new Integer(x)).asJava)
Collections.unmodifiableList(futureAction.jobIds.map(Integer.valueOf).asJava)
}

private def getImpl(timeout: Duration): T = {
Expand All @@ -300,10 +301,10 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S
case scala.util.Success(value) => converter(value)
case Failure(exception) =>
if (isCancelled) {
throw new CancellationException("Job cancelled: ${exception.message}");
throw new CancellationException("Job cancelled").initCause(exception)
} else {
// java.util.Future.get() wraps exceptions in ExecutionException
throw new ExecutionException("Exception thrown by job: ", exception)
throw new ExecutionException("Exception thrown by job", exception)
}
}
}
Expand All @@ -313,7 +314,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S
override def get(timeout: Long, unit: TimeUnit): T =
getImpl(Duration.fromNanos(unit.toNanos(timeout)))

override def cancel(mayInterruptIfRunning: Boolean): Boolean = {
override def cancel(mayInterruptIfRunning: Boolean): Boolean = synchronized {
if (isDone) {
// According to java.util.Future's Javadoc, this should return false if the task is completed.
false
Expand Down
13 changes: 5 additions & 8 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import java.util.{Comparator, List => JList, Iterator => JIterator}
import java.lang.{Iterable => JIterable, Long => JLong}

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec

import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
Expand Down Expand Up @@ -578,34 +580,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* future for counting the number of elements in this RDD.
*/
def countAsync(): JavaFutureAction[JLong] = {
import org.apache.spark.SparkContext._
new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), x => new JLong(x))
new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), JLong.valueOf)
}

/**
* The asynchronous version of `collect`, which returns a future for
* retrieving an array containing all of the elements in this RDD.
*/
def collectAsync(): JavaFutureAction[JList[T]] = {
import org.apache.spark.SparkContext._
new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => new java.util.ArrayList(x))
new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => x.asJava)
}

/**
* The asynchronous version of the `take` action, which returns a
* future for retrieving the first `num` elements of this RDD.
*/
def takeAsync(num: Int): JavaFutureAction[JList[T]] = {
import org.apache.spark.SparkContext._
new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => new java.util.ArrayList(x))
new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => x.asJava)
}

/**
* The asynchronous version of the `foreach` action, which
* applies a function f to all the elements of this RDD.
*/
def foreachAsync(f: VoidFunction[T]): JavaFutureAction[Void] = {
import org.apache.spark.SparkContext._
new JavaFutureActionWrapper[Unit, Void](rdd.foreachAsync(x => f.call(x)),
{ x => null.asInstanceOf[Void] })
}
Expand All @@ -615,7 +613,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* applies a function f to each partition of this RDD.
*/
def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = {
import org.apache.spark.SparkContext._
new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)),
{ x => null.asInstanceOf[Void] })
}
Expand Down
56 changes: 19 additions & 37 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.base.Throwables;
import com.google.common.base.Optional;
import com.google.common.base.Charsets;
import com.google.common.io.Files;
Expand Down Expand Up @@ -1306,21 +1307,6 @@ public void collectUnderlyingScalaRDD() {
Assert.assertEquals(data.size(), collected.length);
}

private static final class IdentityWithDelay<T> implements Function<T, T> {

final int delayMillis;

IdentityWithDelay(int delayMillis) {
this.delayMillis = delayMillis;
}

@Override
public T call(T x) throws Exception {
Thread.sleep(delayMillis);
return x;
}
}

private static final class BuggyMapFunction<T> implements Function<T, T> {

@Override
Expand All @@ -1333,62 +1319,59 @@ public T call(T x) throws Exception {
public void collectAsync() throws Exception {
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
JavaRDD<Integer> rdd = sc.parallelize(data, 1);
JavaFutureAction<List<Integer>> future =
rdd.map(new IdentityWithDelay<Integer>(200)).collectAsync();
Assert.assertFalse(future.isCancelled());
Assert.assertFalse(future.isDone());
JavaFutureAction<List<Integer>> future = rdd.collectAsync();
List<Integer> result = future.get();
Assert.assertEquals(result, data);
Assert.assertEquals(data, result);
Assert.assertFalse(future.isCancelled());
Assert.assertTrue(future.isDone());
Assert.assertEquals(future.jobIds().size(), 1);
Assert.assertEquals(1, future.jobIds().size());
}

@Test
public void foreachAsync() throws Exception {
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
JavaRDD<Integer> rdd = sc.parallelize(data, 1);
JavaFutureAction<Void> future = rdd.map(new IdentityWithDelay<Integer>(200)).foreachAsync(
JavaFutureAction<Void> future = rdd.foreachAsync(
new VoidFunction<Integer>() {
@Override
public void call(Integer integer) throws Exception {
// intentionally left blank.
}
}
);
Assert.assertFalse(future.isCancelled());
Assert.assertFalse(future.isDone());
future.get();
Assert.assertFalse(future.isCancelled());
Assert.assertTrue(future.isDone());
Assert.assertEquals(future.jobIds().size(), 1);
Assert.assertEquals(1, future.jobIds().size());
}

@Test
public void countAsync() throws Exception {
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
JavaRDD<Integer> rdd = sc.parallelize(data, 1);
JavaFutureAction<Long> future = rdd.map(new IdentityWithDelay<Integer>(200)).countAsync();
Assert.assertFalse(future.isCancelled());
Assert.assertFalse(future.isDone());
JavaFutureAction<Long> future = rdd.countAsync();
long count = future.get();
Assert.assertEquals(count, data.size());
Assert.assertEquals(data.size(), count);
Assert.assertFalse(future.isCancelled());
Assert.assertTrue(future.isDone());
Assert.assertEquals(future.jobIds().size(), 1);
Assert.assertEquals(1, future.jobIds().size());
}

@Test
public void testAsyncActionCancellation() throws Exception {
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
JavaRDD<Integer> rdd = sc.parallelize(data, 1);
JavaFutureAction<Long> future = rdd.map(new IdentityWithDelay<Integer>(200)).countAsync();
Thread.sleep(200);
JavaFutureAction<Void> future = rdd.foreachAsync(new VoidFunction<Integer>() {
@Override
public void call(Integer integer) throws Exception {
Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled.
}
});
future.cancel(true);
Assert.assertTrue(future.isCancelled());
Assert.assertTrue(future.isDone());
try {
long count = future.get(2000, TimeUnit.MILLISECONDS);
future.get(2000, TimeUnit.MILLISECONDS);
Assert.fail("Expected future.get() for cancelled job to throw CancellationException");
} catch (CancellationException ignored) {
// pass
Expand All @@ -1400,12 +1383,11 @@ public void testAsyncActionErrorWrapping() throws Exception {
List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
JavaRDD<Integer> rdd = sc.parallelize(data, 1);
JavaFutureAction<Long> future = rdd.map(new BuggyMapFunction<Integer>()).countAsync();
Thread.sleep(200);
try {
long count = future.get(2000, TimeUnit.MILLISECONDS);
long count = future.get(2, TimeUnit.SECONDS);
Assert.fail("Expected future.get() for failed job to throw ExcecutionException");
} catch (ExecutionException ignored) {
// pass
} catch (ExecutionException ee) {
Assert.assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!"));
}
Assert.assertTrue(future.isDone());
}
Expand Down

0 comments on commit e8e2867

Please sign in to comment.