Browse files

Merge branch 'mesos'

  • Loading branch information...
2 parents df9ae8a + 548856a commit 97e242067b9d75abd88543c759d8fc0aebd9eb8c @haitaoyao haitaoyao committed Jan 24, 2013
Showing with 14,401 additions and 2,297 deletions.
  1. +2 −0 .gitignore
  2. +11 −0 bagel/pom.xml
  3. +2 −2 bagel/src/test/resources/log4j.properties
  4. +16 −1 core/pom.xml
  5. +32 −9 core/src/main/scala/spark/Accumulators.scala
  6. +0 −118 core/src/main/scala/spark/BoundedMemoryCache.scala
  7. +65 −0 core/src/main/scala/spark/CacheManager.scala
  8. +0 −238 core/src/main/scala/spark/CacheTracker.scala
  9. +0 −18 core/src/main/scala/spark/DaemonThreadFactory.scala
  10. +3 −5 core/src/main/scala/spark/HttpFileServer.scala
  11. +8 −1 core/src/main/scala/spark/HttpServer.scala
  12. +72 −138 core/src/main/scala/spark/KryoSerializer.scala
  13. +1 −2 core/src/main/scala/spark/Logging.scala
  14. +42 −18 core/src/main/scala/spark/MapOutputTracker.scala
  15. +55 −31 core/src/main/scala/spark/PairRDDFunctions.scala
  16. +15 −9 core/src/main/scala/spark/ParallelCollection.scala
  17. +4 −0 core/src/main/scala/spark/Partitioner.scala
  18. +159 −38 core/src/main/scala/spark/RDD.scala
  19. +105 −0 core/src/main/scala/spark/RDDCheckpointData.scala
  20. +7 −1 core/src/main/scala/spark/SequenceFileRDDFunctions.scala
  21. +10 −3 core/src/main/scala/spark/SizeEstimator.scala
  22. +100 −57 core/src/main/scala/spark/SparkContext.scala
  23. +23 −16 core/src/main/scala/spark/SparkEnv.scala
  24. +25 −0 core/src/main/scala/spark/SparkFiles.java
  25. +1 −2 core/src/main/scala/spark/TaskContext.scala
  26. +66 −59 core/src/main/scala/spark/Utils.scala
  27. +10 −0 core/src/main/scala/spark/api/java/JavaPairRDD.scala
  28. +33 −0 core/src/main/scala/spark/api/java/JavaRDDLike.scala
  29. +97 −8 core/src/main/scala/spark/api/java/JavaSparkContext.scala
  30. +11 −0 core/src/main/scala/spark/api/java/StorageLevels.java
  31. +48 −0 core/src/main/scala/spark/api/python/PythonPartitioner.scala
  32. +293 −0 core/src/main/scala/spark/api/python/PythonRDD.scala
  33. +1 −1 core/src/main/scala/spark/broadcast/Broadcast.scala
  34. +25 −1 core/src/main/scala/spark/broadcast/HttpBroadcast.scala
  35. +2 −2 core/src/main/scala/spark/deploy/DeployMessage.scala
  36. +2 −1 core/src/main/scala/spark/deploy/JobDescription.scala
  37. +78 −0 core/src/main/scala/spark/deploy/JsonProtocol.scala
  38. +1 −1 core/src/main/scala/spark/deploy/client/TestClient.scala
  39. +9 −6 core/src/main/scala/spark/deploy/master/Master.scala
  40. +42 −16 core/src/main/scala/spark/deploy/master/MasterWebUI.scala
  41. +5 −1 core/src/main/scala/spark/deploy/master/WorkerInfo.scala
  42. +7 −0 core/src/main/scala/spark/deploy/master/WorkerState.scala
  43. +0 −5 core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
  44. +2 −2 core/src/main/scala/spark/deploy/worker/Worker.scala
  45. +19 −3 core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
  46. +15 −4 core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
  47. +18 −16 core/src/main/scala/spark/executor/Executor.scala
  48. +0 −3 core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
  49. +5 −2 core/src/main/scala/spark/network/Connection.scala
  50. +8 −9 core/src/main/scala/spark/network/ConnectionManager.scala
  51. +16 −8 core/src/main/scala/spark/network/ConnectionManagerTest.scala
  52. +9 −11 core/src/main/scala/spark/rdd/BlockRDD.scala
  53. +36 −11 core/src/main/scala/spark/rdd/CartesianRDD.scala
  54. +128 −0 core/src/main/scala/spark/rdd/CheckpointRDD.scala
  55. +48 −22 core/src/main/scala/spark/rdd/CoGroupedRDD.scala
  56. +35 −12 core/src/main/scala/spark/rdd/CoalescedRDD.scala
  57. +11 −6 core/src/main/scala/spark/rdd/FilteredRDD.scala
  58. +5 −5 core/src/main/scala/spark/rdd/FlatMappedRDD.scala
  59. +7 −7 core/src/main/scala/spark/rdd/GlommedRDD.scala
  60. +8 −7 core/src/main/scala/spark/rdd/HadoopRDD.scala
  61. +8 −6 core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
  62. +8 −6 core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
  63. +6 −5 core/src/main/scala/spark/rdd/MappedRDD.scala
  64. +8 −11 core/src/main/scala/spark/rdd/NewHadoopRDD.scala
  65. +9 −9 core/src/main/scala/spark/rdd/PipedRDD.scala
  66. +15 −14 core/src/main/scala/spark/rdd/SampledRDD.scala
  67. +14 −14 core/src/main/scala/spark/rdd/ShuffledRDD.scala
  68. +28 −17 core/src/main/scala/spark/rdd/UnionRDD.scala
  69. +36 −24 core/src/main/scala/spark/rdd/ZippedRDD.scala
  70. +67 −23 core/src/main/scala/spark/scheduler/DAGScheduler.scala
  71. +1 −1 core/src/main/scala/spark/scheduler/MapStatus.scala
  72. +94 −8 core/src/main/scala/spark/scheduler/ResultTask.scala
  73. +15 −9 core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
  74. +18 −13 core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
  75. +2 −1 core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
  76. +5 −1 core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
  77. +25 −19 core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
  78. +6 −10 core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
  79. +3 −7 core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
  80. +117 −114 core/src/main/scala/spark/storage/BlockManager.scala
  81. +70 −0 core/src/main/scala/spark/storage/BlockManagerId.scala
  82. +109 −618 core/src/main/scala/spark/storage/BlockManagerMaster.scala
  83. +401 −0 core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
  84. +100 −0 core/src/main/scala/spark/storage/BlockManagerMessages.scala
  85. +16 −0 core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala
  86. +1 −1 core/src/main/scala/spark/storage/BlockMessage.scala
  87. +6 −1 core/src/main/scala/spark/storage/BlockStore.scala
  88. +4 −1 core/src/main/scala/spark/storage/DiskStore.scala
  89. +3 −3 core/src/main/scala/spark/storage/MemoryStore.scala
  90. +60 −20 core/src/main/scala/spark/storage/StorageLevel.scala
  91. +9 −4 core/src/main/scala/spark/storage/ThreadingTest.scala
  92. +1 −0 core/src/main/scala/spark/util/AkkaUtils.scala
  93. +14 −0 core/src/main/scala/spark/util/IdGenerator.scala
  94. +44 −0 core/src/main/scala/spark/util/MetadataCleaner.scala
  95. +62 −0 core/src/main/scala/spark/util/RateLimitedOutputStream.scala
  96. +93 −0 core/src/main/scala/spark/util/TimeStampedHashMap.scala
  97. +69 −0 core/src/main/scala/spark/util/TimeStampedHashSet.scala
  98. +1 −0 core/src/main/twirl/spark/deploy/master/worker_row.scala.html
  99. +1 −0 core/src/main/twirl/spark/deploy/master/worker_table.scala.html
  100. +2 −2 core/src/test/resources/log4j.properties
  101. +0 −58 core/src/test/scala/spark/BoundedMemoryCacheSuite.scala
  102. +0 −131 core/src/test/scala/spark/CacheTrackerSuite.scala
  103. +357 −0 core/src/test/scala/spark/CheckpointSuite.scala
  104. +2 −0 core/src/test/scala/spark/ClosureCleanerSuite.scala
  105. +69 −0 core/src/test/scala/spark/DistributedSuite.scala
  106. +31 −0 core/src/test/scala/spark/DriverSuite.scala
  107. +8 −5 core/src/test/scala/spark/FileServerSuite.scala
  108. +98 −0 core/src/test/scala/spark/JavaAPISuite.java
  109. +46 −10 core/src/test/scala/spark/MapOutputTrackerSuite.scala
  110. +26 −0 core/src/test/scala/spark/PartitioningSuite.scala
  111. +51 −9 core/src/test/scala/spark/RDDSuite.scala
  112. +7 −0 core/src/test/scala/spark/ShuffleSuite.scala
  113. +26 −22 core/src/test/scala/spark/SizeEstimatorSuite.scala
  114. +42 −0 core/src/test/scala/spark/scheduler/TaskContextSuite.scala
  115. +123 −45 core/src/test/scala/spark/storage/BlockManagerSuite.scala
  116. +23 −0 core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala
  117. +5 −3 docs/README.md
  118. +10 −1 docs/_layouts/global.html
  119. +19 −2 docs/_plugins/copy_api_dirs.rb
  120. +4 −2 docs/api.md
  121. +27 −0 docs/configuration.md
  122. +3 −1 docs/ec2-scripts.md
  123. +12 −5 docs/index.md
  124. +2 −1 docs/java-programming-guide.md
  125. +110 −0 docs/python-programming-guide.md
  126. +49 −1 docs/quick-start.md
  127. +2 −1 docs/scala-programming-guide.md
  128. +21 −22 docs/spark-standalone.md
  129. +313 −0 docs/streaming-programming-guide.md
  130. +16 −14 docs/tuning.md
  131. +28 −0 examples/pom.xml
  132. +1 −1 examples/src/main/scala/spark/examples/LocalLR.scala
  133. +20 −39 examples/src/main/scala/spark/examples/SparkALS.scala
  134. +43 −0 examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala
  135. +36 −0 examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala
  136. +50 −0 examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java
  137. +62 −0 examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java
  138. +62 −0 examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java
  139. +69 −0 examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
  140. +36 −0 examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala
  141. +39 −0 examples/src/main/scala/spark/streaming/examples/QueueStream.scala
  142. +46 −0 examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala
  143. +85 −0 examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala
  144. +84 −0 examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala
  145. +60 −0 examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala
  146. +71 −0 examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala
  147. +43 −2 pom.xml
  148. +24 −8 project/SparkBuild.scala
  149. +39 −0 pyspark
  150. +2 −0 python/.gitignore
  151. +19 −0 python/epydoc.conf
  152. +71 −0 python/examples/als.py
  153. +54 −0 python/examples/kmeans.py
  154. +57 −0 python/examples/logistic_regression.py
  155. +21 −0 python/examples/pi.py
  156. +50 −0 python/examples/transitive_closure.py
  157. +19 −0 python/examples/wordcount.py
  158. +27 −0 python/lib/PY4J_LICENSE.txt
  159. +1 −0 python/lib/PY4J_VERSION.txt
  160. BIN python/lib/py4j0.7.egg
  161. BIN python/lib/py4j0.7.jar
  162. +27 −0 python/pyspark/__init__.py
  163. +187 −0 python/pyspark/accumulators.py
  164. +48 −0 python/pyspark/broadcast.py
  165. +974 −0 python/pyspark/cloudpickle.py
  166. +258 −0 python/pyspark/context.py
  167. +38 −0 python/pyspark/files.py
  168. +38 −0 python/pyspark/java_gateway.py
  169. +92 −0 python/pyspark/join.py
  170. +761 −0 python/pyspark/rdd.py
  171. +83 −0 python/pyspark/serializers.py
  172. +17 −0 python/pyspark/shell.py
  173. +112 −0 python/pyspark/tests.py
  174. +52 −0 python/pyspark/worker.py
  175. +35 −0 python/run-tests
  176. +1 −0 python/test_support/hello.txt
  177. +7 −0 python/test_support/userlibrary.py
  178. +14 −2 repl-bin/pom.xml
  179. +2 −2 repl-bin/src/deb/control/control
  180. +35 −0 repl/pom.xml
  181. +2 −2 repl/src/test/resources/log4j.properties
  182. +18 −9 run
  183. +3 −1 run2.cmd
  184. BIN streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar
  185. +1 −0 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5
  186. +1 −0 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1
  187. +9 −0 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom
  188. +1 −0 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5
  189. +1 −0 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1
  190. +12 −0 streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml
  191. +1 −0 streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5
  192. +1 −0 streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1
  193. +155 −0 streaming/pom.xml
  194. +118 −0 streaming/src/main/scala/spark/streaming/Checkpoint.scala
  195. +657 −0 streaming/src/main/scala/spark/streaming/DStream.scala
  196. +134 −0 streaming/src/main/scala/spark/streaming/DStreamGraph.scala
  197. +62 −0 streaming/src/main/scala/spark/streaming/Duration.scala
  198. +41 −0 streaming/src/main/scala/spark/streaming/Interval.scala
  199. +24 −0 streaming/src/main/scala/spark/streaming/Job.scala
  200. +33 −0 streaming/src/main/scala/spark/streaming/JobManager.scala
  201. +151 −0 streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
  202. +562 −0 streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
  203. +77 −0 streaming/src/main/scala/spark/streaming/Scheduler.scala
  204. +411 −0 streaming/src/main/scala/spark/streaming/StreamingContext.scala
  205. +42 −0 streaming/src/main/scala/spark/streaming/Time.scala
  206. +91 −0 streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala
  207. +183 −0 streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala
  208. +638 −0 streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala
  209. +346 −0 streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
  210. +40 −0 streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala
  211. +19 −0 streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala
  212. +102 −0 streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala
  213. +21 −0 streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala
  214. +20 −0 streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala
  215. +20 −0 streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala
  216. +137 −0 streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala
  217. +28 −0 streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala
  218. +17 −0 streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala
  219. +19 −0 streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala
  220. +200 −0 streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
  221. +21 −0 streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala
  222. +21 −0 streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala
  223. +20 −0 streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala
  224. +254 −0 streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
  225. +41 −0 streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala
  226. +91 −0 streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala
  227. +149 −0 streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala
  228. +27 −0 streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala
  229. +103 −0 streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala
  230. +84 −0 streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala
  231. +19 −0 streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala
  232. +40 −0 streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala
  233. +40 −0 streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala
  234. +84 −0 streaming/src/main/scala/spark/streaming/util/Clock.scala
  235. +98 −0 streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala
Sorry, we could not display the entire diff because it was too big.
View
2 .gitignore
@@ -12,6 +12,7 @@ third_party/libmesos.so
third_party/libmesos.dylib
conf/java-opts
conf/spark-env.sh
+conf/streaming-env.sh
conf/log4j.properties
docs/_site
docs/api
@@ -31,6 +32,7 @@ project/plugins/src_managed/
logs/
log/
spark-tests.log
+streaming-tests.log
dependency-reduced-pom.xml
.ensime
.ensime_lucene
View
11 bagel/pom.xml
@@ -45,6 +45,11 @@
<profiles>
<profile>
<id>hadoop1</id>
+ <activation>
+ <property>
+ <name>!hadoopVersion</name>
+ </property>
+ </activation>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
@@ -72,6 +77,12 @@
</profile>
<profile>
<id>hadoop2</id>
+ <activation>
+ <property>
+ <name>hadoopVersion</name>
+ <value>2</value>
+ </property>
+ </activation>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
View
4 bagel/src/test/resources/log4j.properties
@@ -1,8 +1,8 @@
-# Set everything to be logged to the console
+# Set everything to be logged to the file bagel/target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
-log4j.appender.file.file=spark-tests.log
+log4j.appender.file.file=bagel/target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
View
17 core/pom.xml
@@ -72,6 +72,10 @@
<artifactId>spray-server</artifactId>
</dependency>
<dependency>
+ <groupId>cc.spray</groupId>
+ <artifactId>spray-json_${scala.version}</artifactId>
+ </dependency>
+ <dependency>
<groupId>org.tomdz.twirl</groupId>
<artifactId>twirl-api</artifactId>
</dependency>
@@ -159,6 +163,11 @@
<profiles>
<profile>
<id>hadoop1</id>
+ <activation>
+ <property>
+ <name>!hadoopVersion</name>
+ </property>
+ </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
@@ -211,6 +220,12 @@
</profile>
<profile>
<id>hadoop2</id>
+ <activation>
+ <property>
+ <name>hadoopVersion</name>
+ <value>2</value>
+ </property>
+ </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
@@ -267,4 +282,4 @@
</build>
</profile>
</profiles>
-</project>
+</project>
View
41 core/src/main/scala/spark/Accumulators.scala
@@ -25,8 +25,7 @@ class Accumulable[R, T] (
extends Serializable {
val id = Accumulators.newId
- @transient
- private var value_ = initialValue // Current value on master
+ @transient private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
@@ -39,19 +38,36 @@ class Accumulable[R, T] (
def += (term: T) { value_ = param.addAccumulator(value_, term) }
/**
+ * Add more data to this accumulator / accumulable
+ * @param term the data to add
+ */
+ def add(term: T) { value_ = param.addAccumulator(value_, term) }
+
+ /**
* Merge two accumulable objects together
- *
+ *
* Normally, a user will not want to use this version, but will instead call `+=`.
- * @param term the other Accumulable that will get merged with this
+ * @param term the other `R` that will get merged with this
*/
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
/**
+ * Merge two accumulable objects together
+ *
+ * Normally, a user will not want to use this version, but will instead call `add`.
+ * @param term the other `R` that will get merged with this
+ */
+ def merge(term: R) { value_ = param.addInPlace(value_, term)}
+
+ /**
* Access the accumulator's current value; only allowed on master.
*/
- def value = {
- if (!deserialized) value_
- else throw new UnsupportedOperationException("Can't read accumulator value in task")
+ def value: R = {
+ if (!deserialized) {
+ value_
+ } else {
+ throw new UnsupportedOperationException("Can't read accumulator value in task")
+ }
}
/**
@@ -68,10 +84,17 @@ class Accumulable[R, T] (
/**
* Set the accumulator's value; only allowed on master.
*/
- def value_= (r: R) {
- if (!deserialized) value_ = r
+ def value_= (newValue: R) {
+ if (!deserialized) value_ = newValue
else throw new UnsupportedOperationException("Can't assign accumulator value in task")
}
+
+ /**
+ * Set the accumulator's value; only allowed on master
+ */
+ def setValue(newValue: R) {
+ this.value = newValue
+ }
// Called by Java when deserializing an object
private def readObject(in: ObjectInputStream) {
View
118 core/src/main/scala/spark/BoundedMemoryCache.scala
@@ -1,118 +0,0 @@
-package spark
-
-import java.util.LinkedHashMap
-
-/**
- * An implementation of Cache that estimates the sizes of its entries and attempts to limit its
- * total memory usage to a fraction of the JVM heap. Objects' sizes are estimated using
- * SizeEstimator, which has limitations; most notably, we will overestimate total memory used if
- * some cache entries have pointers to a shared object. Nonetheless, this Cache should work well
- * when most of the space is used by arrays of primitives or of simple classes.
- */
-private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
- logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
-
- def this() {
- this(BoundedMemoryCache.getMaxBytes)
- }
-
- private var currentBytes = 0L
- private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true)
-
- override def get(datasetId: Any, partition: Int): Any = {
- synchronized {
- val entry = map.get((datasetId, partition))
- if (entry != null) {
- entry.value
- } else {
- null
- }
- }
- }
-
- override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
- val key = (datasetId, partition)
- logInfo("Asked to add key " + key)
- val size = estimateValueSize(key, value)
- synchronized {
- if (size > getCapacity) {
- return CachePutFailure()
- } else if (ensureFreeSpace(datasetId, size)) {
- logInfo("Adding key " + key)
- map.put(key, new Entry(value, size))
- currentBytes += size
- logInfo("Number of entries is now " + map.size)
- return CachePutSuccess(size)
- } else {
- logInfo("Didn't add key " + key + " because we would have evicted part of same dataset")
- return CachePutFailure()
- }
- }
- }
-
- override def getCapacity: Long = maxBytes
-
- /**
- * Estimate sizeOf 'value'
- */
- private def estimateValueSize(key: (Any, Int), value: Any) = {
- val startTime = System.currentTimeMillis
- val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef])
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Estimated size for key %s is %d".format(key, size))
- logInfo("Size estimation for key %s took %d ms".format(key, timeTaken))
- size
- }
-
- /**
- * Remove least recently used entries from the map until at least space bytes are free, in order
- * to make space for a partition from the given dataset ID. If this cannot be done without
- * evicting other data from the same dataset, returns false; otherwise, returns true. Assumes
- * that a lock is held on the BoundedMemoryCache.
- */
- private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = {
- logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format(
- datasetId, space, currentBytes, maxBytes))
- val iter = map.entrySet.iterator // Will give entries in LRU order
- while (maxBytes - currentBytes < space && iter.hasNext) {
- val mapEntry = iter.next()
- val (entryDatasetId, entryPartition) = mapEntry.getKey
- if (entryDatasetId == datasetId) {
- // Cannot make space without removing part of the same dataset, or a more recently used one
- return false
- }
- reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue)
- currentBytes -= mapEntry.getValue.size
- iter.remove()
- }
- return true
- }
-
- protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
- logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
- // TODO: remove BoundedMemoryCache
-
- val (keySpaceId, innerDatasetId) = datasetId.asInstanceOf[(Any, Any)]
- innerDatasetId match {
- case rddId: Int =>
- SparkEnv.get.cacheTracker.dropEntry(rddId, partition)
- case broadcastUUID: java.util.UUID =>
- // TODO: Maybe something should be done if the broadcasted variable falls out of cache
- case _ =>
- }
- }
-}
-
-// An entry in our map; stores a cached object and its size in bytes
-private[spark] case class Entry(value: Any, size: Long)
-
-private[spark] object BoundedMemoryCache {
- /**
- * Get maximum cache capacity from system configuration
- */
- def getMaxBytes: Long = {
- val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble
- (Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong
- }
-}
-
View
65 core/src/main/scala/spark/CacheManager.scala
@@ -0,0 +1,65 @@
+package spark
+
+import scala.collection.mutable.{ArrayBuffer, HashSet}
+import spark.storage.{BlockManager, StorageLevel}
+
+
+/** Spark class responsible for passing RDDs split contents to the BlockManager and making
+ sure a node doesn't load two copies of an RDD at once.
+ */
+private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
+ private val loading = new HashSet[String]
+
+ /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */
+ def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
+ : Iterator[T] = {
+ val key = "rdd_%d_%d".format(rdd.id, split.index)
+ logInfo("Cache key is " + key)
+ blockManager.get(key) match {
+ case Some(cachedValues) =>
+ // Split is in cache, so just return its values
+ logInfo("Found partition in cache!")
+ return cachedValues.asInstanceOf[Iterator[T]]
+
+ case None =>
+ // Mark the split as loading (unless someone else marks it first)
+ loading.synchronized {
+ if (loading.contains(key)) {
+ logInfo("Loading contains " + key + ", waiting...")
+ while (loading.contains(key)) {
+ try {loading.wait()} catch {case _ =>}
+ }
+ logInfo("Loading no longer contains " + key + ", so returning cached result")
+ // See whether someone else has successfully loaded it. The main way this would fail
+ // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+ // partition but we didn't want to make space for it. However, that case is unlikely
+ // because it's unlikely that two threads would work on the same RDD partition. One
+ // downside of the current code is that threads wait serially if this does happen.
+ blockManager.get(key) match {
+ case Some(values) =>
+ return values.asInstanceOf[Iterator[T]]
+ case None =>
+ logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
+ loading.add(key)
+ }
+ } else {
+ loading.add(key)
+ }
+ }
+ try {
+ // If we got here, we have to load the split
+ val elements = new ArrayBuffer[Any]
+ logInfo("Computing partition " + split)
+ elements ++= rdd.compute(split, context)
+ // Try to put this block in the blockManager
+ blockManager.put(key, elements, storageLevel, true)
+ return elements.iterator.asInstanceOf[Iterator[T]]
+ } finally {
+ loading.synchronized {
+ loading.remove(key)
+ loading.notifyAll()
+ }
+ }
+ }
+ }
+}
View
238 core/src/main/scala/spark/CacheTracker.scala
@@ -1,238 +0,0 @@
-package spark
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-
-import akka.actor._
-import akka.dispatch._
-import akka.pattern.ask
-import akka.remote._
-import akka.util.Duration
-import akka.util.Timeout
-import akka.util.duration._
-
-import spark.storage.BlockManager
-import spark.storage.StorageLevel
-
-private[spark] sealed trait CacheTrackerMessage
-
-private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
- extends CacheTrackerMessage
-private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
- extends CacheTrackerMessage
-private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
-private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
-private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
-private[spark] case object GetCacheStatus extends CacheTrackerMessage
-private[spark] case object GetCacheLocations extends CacheTrackerMessage
-private[spark] case object StopCacheTracker extends CacheTrackerMessage
-
-private[spark] class CacheTrackerActor extends Actor with Logging {
- // TODO: Should probably store (String, CacheType) tuples
- private val locs = new HashMap[Int, Array[List[String]]]
-
- /**
- * A map from the slave's host name to its cache size.
- */
- private val slaveCapacity = new HashMap[String, Long]
- private val slaveUsage = new HashMap[String, Long]
-
- private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
- private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
- private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
-
- def receive = {
- case SlaveCacheStarted(host: String, size: Long) =>
- slaveCapacity.put(host, size)
- slaveUsage.put(host, 0)
- sender ! true
-
- case RegisterRDD(rddId: Int, numPartitions: Int) =>
- logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
- locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
- sender ! true
-
- case AddedToCache(rddId, partition, host, size) =>
- slaveUsage.put(host, getCacheUsage(host) + size)
- locs(rddId)(partition) = host :: locs(rddId)(partition)
- sender ! true
-
- case DroppedFromCache(rddId, partition, host, size) =>
- slaveUsage.put(host, getCacheUsage(host) - size)
- // Do a sanity check to make sure usage is greater than 0.
- locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
- sender ! true
-
- case MemoryCacheLost(host) =>
- logInfo("Memory cache lost on " + host)
- for ((id, locations) <- locs) {
- for (i <- 0 until locations.length) {
- locations(i) = locations(i).filterNot(_ == host)
- }
- }
- sender ! true
-
- case GetCacheLocations =>
- logInfo("Asked for current cache locations")
- sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())}
-
- case GetCacheStatus =>
- val status = slaveCapacity.map { case (host, capacity) =>
- (host, capacity, getCacheUsage(host))
- }.toSeq
- sender ! status
-
- case StopCacheTracker =>
- logInfo("Stopping CacheTrackerActor")
- sender ! true
- context.stop(self)
- }
-}
-
-private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
- extends Logging {
-
- // Tracker actor on the master, or remote reference to it on workers
- val ip: String = System.getProperty("spark.master.host", "localhost")
- val port: Int = System.getProperty("spark.master.port", "7077").toInt
- val actorName: String = "CacheTracker"
-
- val timeout = 10.seconds
-
- var trackerActor: ActorRef = if (isMaster) {
- val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
- logInfo("Registered CacheTrackerActor actor")
- actor
- } else {
- val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
- actorSystem.actorFor(url)
- }
-
- val registeredRddIds = new HashSet[Int]
-
- // Remembers which splits are currently being loaded (on worker nodes)
- val loading = new HashSet[String]
-
- // Send a message to the trackerActor and get its result within a default timeout, or
- // throw a SparkException if this fails.
- def askTracker(message: Any): Any = {
- try {
- val future = trackerActor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with CacheTracker", e)
- }
- }
-
- // Send a one-way message to the trackerActor, to which we expect it to reply with true.
- def communicate(message: Any) {
- if (askTracker(message) != true) {
- throw new SparkException("Error reply received from CacheTracker")
- }
- }
-
- // Registers an RDD (on master only)
- def registerRDD(rddId: Int, numPartitions: Int) {
- registeredRddIds.synchronized {
- if (!registeredRddIds.contains(rddId)) {
- logInfo("Registering RDD ID " + rddId + " with cache")
- registeredRddIds += rddId
- communicate(RegisterRDD(rddId, numPartitions))
- }
- }
- }
-
- // For BlockManager.scala only
- def cacheLost(host: String) {
- communicate(MemoryCacheLost(host))
- logInfo("CacheTracker successfully removed entries on " + host)
- }
-
- // Get the usage status of slave caches. Each tuple in the returned sequence
- // is in the form of (host name, capacity, usage).
- def getCacheStatus(): Seq[(String, Long, Long)] = {
- askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]]
- }
-
- // For BlockManager.scala only
- def notifyFromBlockManager(t: AddedToCache) {
- communicate(t)
- }
-
- // Get a snapshot of the currently known locations
- def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
- askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
- }
-
- // Gets or computes an RDD split
- def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
- : Iterator[T] = {
- val key = "rdd_%d_%d".format(rdd.id, split.index)
- logInfo("Cache key is " + key)
- blockManager.get(key) match {
- case Some(cachedValues) =>
- // Split is in cache, so just return its values
- logInfo("Found partition in cache!")
- return cachedValues.asInstanceOf[Iterator[T]]
-
- case None =>
- // Mark the split as loading (unless someone else marks it first)
- loading.synchronized {
- if (loading.contains(key)) {
- logInfo("Loading contains " + key + ", waiting...")
- while (loading.contains(key)) {
- try {loading.wait()} catch {case _ =>}
- }
- logInfo("Loading no longer contains " + key + ", so returning cached result")
- // See whether someone else has successfully loaded it. The main way this would fail
- // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
- // partition but we didn't want to make space for it. However, that case is unlikely
- // because it's unlikely that two threads would work on the same RDD partition. One
- // downside of the current code is that threads wait serially if this does happen.
- blockManager.get(key) match {
- case Some(values) =>
- return values.asInstanceOf[Iterator[T]]
- case None =>
- logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
- loading.add(key)
- }
- } else {
- loading.add(key)
- }
- }
- // If we got here, we have to load the split
- // Tell the master that we're doing so
- //val host = System.getProperty("spark.hostname", Utils.localHostName)
- //val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
- // TODO: fetch any remote copy of the split that may be available
- // TODO: also register a listener for when it unloads
- logInfo("Computing partition " + split)
- val elements = new ArrayBuffer[Any]
- elements ++= rdd.compute(split, context)
- try {
- // Try to put this block in the blockManager
- blockManager.put(key, elements, storageLevel, true)
- //future.apply() // Wait for the reply from the cache tracker
- } finally {
- loading.synchronized {
- loading.remove(key)
- loading.notifyAll()
- }
- }
- return elements.iterator.asInstanceOf[Iterator[T]]
- }
- }
-
- // Called by the Cache to report that an entry has been dropped from it
- def dropEntry(rddId: Int, partition: Int) {
- communicate(DroppedFromCache(rddId, partition, Utils.localHostName()))
- }
-
- def stop() {
- communicate(StopCacheTracker)
- registeredRddIds.clear()
- trackerActor = null
- }
-}
View
18 core/src/main/scala/spark/DaemonThreadFactory.scala
@@ -1,18 +0,0 @@
-package spark
-
-import java.util.concurrent.ThreadFactory
-
-/**
- * A ThreadFactory that creates daemon threads
- */
-private object DaemonThreadFactory extends ThreadFactory {
- override def newThread(r: Runnable): Thread = new DaemonThread(r)
-}
-
-private class DaemonThread(r: Runnable = null) extends Thread {
- override def run() {
- if (r != null) {
- r.run()
- }
- }
-}
View
8 core/src/main/scala/spark/HttpFileServer.scala
@@ -1,9 +1,7 @@
package spark
-import java.io.{File, PrintWriter}
-import java.net.URL
-import scala.collection.mutable.HashMap
-import org.apache.hadoop.fs.FileUtil
+import java.io.{File}
+import com.google.common.io.Files
private[spark] class HttpFileServer extends Logging {
@@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging {
}
def addFileToDir(file: File, dir: File) : String = {
- Utils.copyFile(file, new File(dir, file.getName))
+ Files.copy(file, new File(dir, file.getName))
return dir + "/" + file.getName
}
View
9 core/src/main/scala/spark/HttpServer.scala
@@ -4,6 +4,7 @@ import java.io.File
import java.net.InetAddress
import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.server.bio.SocketConnector
import org.eclipse.jetty.server.handler.DefaultHandler
import org.eclipse.jetty.server.handler.HandlerList
import org.eclipse.jetty.server.handler.ResourceHandler
@@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
if (server != null) {
throw new ServerStateException("Server is already started")
} else {
- server = new Server(0)
+ server = new Server()
+ val connector = new SocketConnector
+ connector.setMaxIdleTime(60*1000)
+ connector.setSoLingerTime(-1)
+ connector.setPort(0)
+ server.addConnector(connector)
+
val threadPool = new QueuedThreadPool
threadPool.setDaemon(true)
server.setThreadPool(threadPool)
View
210 core/src/main/scala/spark/KryoSerializer.scala
@@ -9,153 +9,80 @@ import scala.collection.mutable
import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer}
-import com.esotericsoftware.kryo.serialize.ClassSerializer
-import com.esotericsoftware.kryo.serialize.SerializableSerializer
+import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
+import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
import serializer.{SerializerInstance, DeserializationStream, SerializationStream}
import spark.broadcast._
import spark.storage._
-/**
- * Zig-zag encoder used to write object sizes to serialization streams.
- * Based on Kryo's integer encoder.
- */
-private[spark] object ZigZag {
- def writeInt(n: Int, out: OutputStream) {
- var value = n
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- out.write(value)
- }
+private[spark]
+class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
- def readInt(in: InputStream): Int = {
- var offset = 0
- var result = 0
- while (offset < 32) {
- val b = in.read()
- if (b == -1) {
- throw new EOFException("End of stream")
- }
- result |= ((b & 0x7F) << offset)
- if ((b & 0x80) == 0) {
- return result
- }
- offset += 7
- }
- throw new SparkException("Malformed zigzag-encoded integer")
- }
-}
-
-private[spark]
-class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
-extends SerializationStream {
- val channel = Channels.newChannel(out)
+ val output = new KryoOutput(outStream)
def writeObject[T](t: T): SerializationStream = {
- kryo.writeClassAndObject(threadBuffer, t)
- ZigZag.writeInt(threadBuffer.position(), out)
- threadBuffer.flip()
- channel.write(threadBuffer)
- threadBuffer.clear()
+ kryo.writeClassAndObject(output, t)
this
}
- def flush() { out.flush() }
- def close() { out.close() }
+ def flush() { output.flush() }
+ def close() { output.close() }
}
-private[spark]
-class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
-extends DeserializationStream {
+private[spark]
+class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
+
+ val input = new KryoInput(inStream)
+
def readObject[T](): T = {
- val len = ZigZag.readInt(in)
- objectBuffer.readClassAndObject(in, len).asInstanceOf[T]
+ try {
+ kryo.readClassAndObject(input).asInstanceOf[T]
+ } catch {
+ // DeserializationStream uses the EOF exception to indicate stopping condition.
+ case e: com.esotericsoftware.kryo.KryoException => throw new java.io.EOFException
+ }
}
- def close() { in.close() }
+ def close() {
+ // Kryo's Input automatically closes the input stream it is using.
+ input.close()
+ }
}
private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
- val kryo = ks.kryo
- val threadBuffer = ks.threadBuffer.get()
- val objectBuffer = ks.objectBuffer.get()
+
+ val kryo = ks.kryo.get()
+ val output = ks.output.get()
+ val input = ks.input.get()
def serialize[T](t: T): ByteBuffer = {
- // Write it to our thread-local scratch buffer first to figure out the size, then return a new
- // ByteBuffer of the appropriate size
- threadBuffer.clear()
- kryo.writeClassAndObject(threadBuffer, t)
- val newBuf = ByteBuffer.allocate(threadBuffer.position)
- threadBuffer.flip()
- newBuf.put(threadBuffer)
- newBuf.flip()
- newBuf
+ output.clear()
+ kryo.writeClassAndObject(output, t)
+ ByteBuffer.wrap(output.toBytes)
}
def deserialize[T](bytes: ByteBuffer): T = {
- kryo.readClassAndObject(bytes).asInstanceOf[T]
+ input.setBuffer(bytes.array)
+ kryo.readClassAndObject(input).asInstanceOf[T]
}
def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
val oldClassLoader = kryo.getClassLoader
kryo.setClassLoader(loader)
- val obj = kryo.readClassAndObject(bytes).asInstanceOf[T]
+ input.setBuffer(bytes.array)
+ val obj = kryo.readClassAndObject(input).asInstanceOf[T]
kryo.setClassLoader(oldClassLoader)
obj
}
def serializeStream(s: OutputStream): SerializationStream = {
- threadBuffer.clear()
- new KryoSerializationStream(kryo, threadBuffer, s)
+ new KryoSerializationStream(kryo, s)
}
def deserializeStream(s: InputStream): DeserializationStream = {
- new KryoDeserializationStream(objectBuffer, s)
- }
-
- override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
- threadBuffer.clear()
- while (iterator.hasNext) {
- val element = iterator.next()
- // TODO: Do we also want to write the object's size? Doesn't seem necessary.
- kryo.writeClassAndObject(threadBuffer, element)
- }
- val newBuf = ByteBuffer.allocate(threadBuffer.position)
- threadBuffer.flip()
- newBuf.put(threadBuffer)
- newBuf.flip()
- newBuf
- }
-
- override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
- buffer.rewind()
- new Iterator[Any] {
- override def hasNext: Boolean = buffer.remaining > 0
- override def next(): Any = kryo.readClassAndObject(buffer)
- }
+ new KryoDeserializationStream(kryo, s)
}
}
@@ -171,18 +98,19 @@ trait KryoRegistrator {
* A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
*/
class KryoSerializer extends spark.serializer.Serializer with Logging {
- // Make this lazy so that it only gets called once we receive our first task on each executor,
- // so we can pull out any custom Kryo registrator from the user's JARs.
- lazy val kryo = createKryo()
- val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
+ val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
- val objectBuffer = new ThreadLocal[ObjectBuffer] {
- override def initialValue = new ObjectBuffer(kryo, bufferSize)
+ val kryo = new ThreadLocal[Kryo] {
+ override def initialValue = createKryo()
}
- val threadBuffer = new ThreadLocal[ByteBuffer] {
- override def initialValue = ByteBuffer.allocate(bufferSize)
+ val output = new ThreadLocal[KryoOutput] {
+ override def initialValue = new KryoOutput(bufferSize)
+ }
+
+ val input = new ThreadLocal[KryoInput] {
+ override def initialValue = new KryoInput(bufferSize)
}
def createKryo(): Kryo = {
@@ -213,41 +141,44 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
kryo.register(obj.getClass)
}
- // Register the following classes for passing closures.
- kryo.register(classOf[Class[_]], new ClassSerializer(kryo))
- kryo.setRegistrationOptional(true)
-
// Allow sending SerializableWritable
- kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer())
- kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer())
+ kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
+ kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
// Register some commonly used Scala singleton objects. Because these
// are singletons, we must return the exact same local object when we
// deserialize rather than returning a clone as FieldSerializer would.
- class SingletonSerializer(obj: AnyRef) extends KSerializer {
- override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {}
- override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = obj.asInstanceOf[T]
+ class SingletonSerializer[T](obj: T) extends KSerializer[T] {
+ override def write(kryo: Kryo, output: KryoOutput, obj: T) {}
+ override def read(kryo: Kryo, input: KryoInput, cls: java.lang.Class[T]): T = obj
}
- kryo.register(None.getClass, new SingletonSerializer(None))
- kryo.register(Nil.getClass, new SingletonSerializer(Nil))
+ kryo.register(None.getClass, new SingletonSerializer[AnyRef](None))
+ kryo.register(Nil.getClass, new SingletonSerializer[AnyRef](Nil))
// Register maps with a special serializer since they have complex internal structure
class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any])
- extends KSerializer {
- override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {
+ extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] {
+ override def write(
+ kryo: Kryo,
+ output: KryoOutput,
+ obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) {
val map = obj.asInstanceOf[scala.collection.Map[Any, Any]]
- kryo.writeObject(buf, map.size.asInstanceOf[java.lang.Integer])
+ kryo.writeObject(output, map.size.asInstanceOf[java.lang.Integer])
for ((k, v) <- map) {
- kryo.writeClassAndObject(buf, k)
- kryo.writeClassAndObject(buf, v)
+ kryo.writeClassAndObject(output, k)
+ kryo.writeClassAndObject(output, v)
}
}
- override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = {
- val size = kryo.readObject(buf, classOf[java.lang.Integer]).intValue
+ override def read (
+ kryo: Kryo,
+ input: KryoInput,
+ cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]])
+ : Array[(Any, Any)] => scala.collection.Map[Any, Any] = {
+ val size = kryo.readObject(input, classOf[java.lang.Integer]).intValue
val elems = new Array[(Any, Any)](size)
for (i <- 0 until size)
- elems(i) = (kryo.readClassAndObject(buf), kryo.readClassAndObject(buf))
- buildMap(elems).asInstanceOf[T]
+ elems(i) = (kryo.readClassAndObject(input), kryo.readClassAndObject(input))
+ buildMap(elems).asInstanceOf[Array[(Any, Any)] => scala.collection.Map[Any, Any]]
}
}
kryo.register(mutable.HashMap().getClass, new ScalaMapSerializer(mutable.HashMap() ++ _))
@@ -275,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
kryo
}
- def newInstance(): SerializerInstance = new KryoSerializerInstance(this)
+ def newInstance(): SerializerInstance = {
+ this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader)
+ new KryoSerializerInstance(this)
+ }
}
View
3 core/src/main/scala/spark/Logging.scala
@@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory
trait Logging {
// Make the log field transient so that objects with Logging can
// be serialized and used on another machine
- @transient
- private var log_ : Logger = null
+ @transient private var log_ : Logger = null
// Method to get or create the logger for this object
protected def log: Logger = {
View
60 core/src/main/scala/spark/MapOutputTracker.scala
@@ -17,6 +17,7 @@ import akka.util.duration._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
+import spark.util.{MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait MapOutputTrackerMessage
@@ -44,7 +45,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
val timeout = 10.seconds
- var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
+ var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
@@ -53,7 +54,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
- val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]
+ val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
@@ -64,6 +65,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
actorSystem.actorFor(url)
}
+ val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
+
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
@@ -84,14 +87,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.get(shuffleId) != null) {
+ if (mapStatuses.get(shuffleId) != None) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
- var array = mapStatuses.get(shuffleId)
+ var array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
@@ -108,7 +111,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = mapStatuses.get(shuffleId)
+ var array = mapStatuses(shuffleId)
if (array != null) {
array.synchronized {
if (array(mapId) != null && array(mapId).address == bmAddress) {
@@ -126,7 +129,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
- val statuses = mapStatuses.get(shuffleId)
+ val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
fetching.synchronized {
@@ -139,8 +142,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case e: InterruptedException =>
}
}
- return mapStatuses.get(shuffleId).map(status =>
- (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
} else {
fetching += shuffleId
}
@@ -156,27 +158,27 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
- if (fetchedStatuses.contains(null)) {
- throw new FetchFailedException(null, shuffleId, -1, reduceId,
- new Exception("Missing an output location for shuffle " + shuffleId))
- }
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
- return fetchedStatuses.map(s =>
- (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
} else {
- return statuses.map(s =>
- (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
}
}
+ def cleanup(cleanupTime: Long) {
+ mapStatuses.clearOldValues(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
+ }
+
def stop() {
communicate(StopMapOutputTracker)
mapStatuses.clear()
+ metadataCleaner.cancel()
trackerActor = null
}
@@ -202,7 +204,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
+ mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
generation = newGen
}
}
@@ -220,7 +222,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case Some(bytes) =>
return bytes
case None =>
- statuses = mapStatuses.get(shuffleId)
+ statuses = mapStatuses(shuffleId)
generationGotten = generation
}
}
@@ -258,6 +260,28 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
+ // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
+ // any of the statuses is null (indicating a missing location due to a failed mapper),
+ // throw a FetchFailedException.
+ def convertMapStatuses(
+ shuffleId: Int,
+ reduceId: Int,
+ statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
+ if (statuses == null) {
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing all output locations for shuffle " + shuffleId))
+ }
+ statuses.map {
+ status =>
+ if (status == null) {
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing an output location for shuffle " + shuffleId))
+ } else {
+ (status.address, decompressSize(status.compressedSizes(reduceId)))
+ }
+ }
+ }
+
/**
* Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
* We do this by encoding the log base 1.1 of the size as an integer, which can support
View
86 core/src/main/scala/spark/PairRDDFunctions.scala
@@ -52,6 +52,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
mapSideCombine: Boolean = true): RDD[(K, C)] = {
+ if (getKeyClass().isArray) {
+ if (mapSideCombine) {
+ throw new SparkException("Cannot use map-side combining with array keys.")
+ }
+ if (partitioner.isInstanceOf[HashPartitioner]) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ }
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (mapSideCombine) {
@@ -92,6 +100,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
+
+ if (getKeyClass().isArray) {
+ throw new SparkException("reduceByKeyLocally() does not support array keys")
+ }
+
def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
val map = new JHashMap[K, V]
for ((k, v) <- iter) {
@@ -165,6 +178,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* be set to true.
*/
def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = {
+ if (getKeyClass().isArray) {
+ if (mapSideCombine) {
+ throw new SparkException("Cannot use map-side combining with array keys.")
+ }
+ if (partitioner.isInstanceOf[HashPartitioner]) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ }
if (mapSideCombine) {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
@@ -178,9 +199,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
- * Merge the values for each key using an associative reduce function. This will also perform
- * the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce.
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD.
*/
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
this.cogroup(other, partitioner).flatMapValues {
@@ -336,6 +357,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]),
partitioner)
@@ -352,6 +376,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
*/
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]],
other1.asInstanceOf[RDD[(_, _)]],
@@ -438,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
val res = self.context.runJob(self, process _, Array(index), false)
res(0)
case None =>
- throw new UnsupportedOperationException("lookup() called on an RDD without a partitioner")
+ self.filter(_._1 == key).map(_._2).collect
}
}
@@ -466,20 +493,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
path: String,
keyClass: Class[_],
valueClass: Class[_],
- outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) {
- saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration)
- }
-
- /**
- * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
- * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
- */
- def saveAsNewAPIHadoopFile(
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
- conf: Configuration) {
+ conf: Configuration = self.context.hadoopConfiguration) {
val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)
@@ -530,7 +545,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
- conf: JobConf = new JobConf) {
+ conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
@@ -588,6 +603,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
writer.cleanup()
}
+ /**
+ * Return an RDD with the keys of each tuple.
+ */
+ def keys: RDD[K] = self.map(_._1)
+
+ /**
+ * Return an RDD with the values of each tuple.
+ */
+ def values: RDD[V] = self.map(_._2)
+
private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure
private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure
@@ -624,24 +649,23 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
private[spark]
-class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override val partitioner = prev.partitioner
- override def compute(split: Split, taskContext: TaskContext) =
- prev.iterator(split, taskContext).map{case (k, v) => (k, f(v))}
+class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U)
+ extends RDD[(K, U)](prev) {
+
+ override def getSplits = firstParent[(K, V)].splits
+ override val partitioner = firstParent[(K, V)].partitioner
+ override def compute(split: Split, context: TaskContext) =
+ firstParent[(K, V)].iterator(split, context).map{ case (k, v) => (k, f(v)) }
}
private[spark]
class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U])
- extends RDD[(K, U)](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override val partitioner = prev.partitioner
+ extends RDD[(K, U)](prev) {
- override def compute(split: Split, taskContext: TaskContext) = {
- prev.iterator(split, taskContext).flatMap { case (k, v) => f(v).map(x => (k, x)) }
+ override def getSplits = firstParent[(K, V)].splits
+ override val partitioner = firstParent[(K, V)].partitioner
+ override def compute(split: Split, context: TaskContext) = {
+ firstParent[(K, V)].iterator(split, context).flatMap { case (k, v) => f(v).map(x => (k, x)) }
}
}
View
24 core/src/main/scala/spark/ParallelCollection.scala
@@ -2,6 +2,7 @@ package spark
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
+import scala.collection.Map
private[spark] class ParallelCollectionSplit[T: ClassManifest](
val rddId: Long,
@@ -22,28 +23,33 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
}
private[spark] class ParallelCollection[T: ClassManifest](
- sc: SparkContext,
+ @transient sc: SparkContext,
@transient data: Seq[T],
- numSlices: Int)
- extends RDD[T](sc) {
+ numSlices: Int,
+ locationPrefs: Map[Int,Seq[String]])
+ extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
// instead.
+ // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
- @transient
- val splits_ = {
+ @transient var splits_ : Array[Split] = {
val slices = ParallelCollection.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
}
- override def splits = splits_.asInstanceOf[Array[Split]]
+ override def getSplits = splits_
- override def compute(s: Split, taskContext: TaskContext) =
+ override def compute(s: Split, context: TaskContext) =
s.asInstanceOf[ParallelCollectionSplit[T]].iterator
- override def preferredLocations(s: Split): Seq[String] = Nil
+ override def getPreferredLocations(s: Split): Seq[String] = {
+ locationPrefs.getOrElse(s.index, Nil)
+ }
- override val dependencies: List[Dependency[_]] = Nil
+ override def clearDependencies() {
+ splits_ = null
+ }
}
private object ParallelCollection {
View
4 core/src/main/scala/spark/Partitioner.scala
@@ -11,6 +11,10 @@ abstract class Partitioner extends Serializable {
/**
* A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
+ *
+ * Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
+ * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
+ * produce an unexpected or incorrect result.
*/
class HashPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
View
197 core/src/main/scala/spark/RDD.scala
<
@@ -1,10 +1,8 @@
package spark
-import java.io.EOFException
-import java.io.ObjectInputStream
+import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream}
import java.net.URL
-import java.util.Random
-import java.util.Date
+import java.util.{Date, Random}
import java.util.{HashMap => JHashMap}
import java.util.concurrent.atomic.AtomicLong
@@ -13,6 +11,7 @@ import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
@@ -73,41 +72,42 @@ import SparkContext._
* [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
* on RDD internals.
*/
-abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable {
+abstract class RDD[T: ClassManifest](
+ @transient var sc: SparkContext,
+ var dependencies_ : List[Dependency[_]]
+ ) extends Serializable with Logging {
- // Methods that must be implemented by subclasses:
- /** Set of partitions in this RDD. */
- def splits: Array[Split]
+ def this(@transient oneParent: RDD[_]) =
+ this(oneParent.context , List(new OneToOneDependency(oneParent)))
+
+ // =======================================================================
+ // Methods that should be implemented by subclasses of RDD
+ // =======================================================================
/** Function for computing a given partition. */
def compute(split: Split, context: TaskContext): Iterator[T]
- /** How this RDD depends on any parent RDDs. */
- @transient val dependencies: List[Dependency[_]]
+ /** Set of partitions in this RDD. */
+ protected def getSplits(): Array[Split]
- // Methods available on all RDDs:
+ /** How this RDD depends on any parent RDDs. */
+ protected def getDependencies(): List[Dependency[_]] = dependencies_
- /** Record user function generating this RDD. */
- private[spark] val origin = Utils.getSparkCallSite
+ /** Optionally overridden by subclasses to specify placement preferences. */
+ protected def getPreferredLocations(split: Split): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
- /** Optionally overridden by subclasses to specify placement preferences. */
- def preferredLocations(split: Split): Seq[String] = Nil
-
- /** The [[spark.SparkContext]] that this RDD was created on. */
- def context = sc
- private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+ // =======================================================================
+ // Methods and fields available on all RDDs
+ // =======================================================================
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
- // Variables relating to persistence
- private var storageLevel: StorageLevel = StorageLevel.NONE
-
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
@@ -131,22 +131,39 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
- private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
- if (!level.useDisk && level.replication < 2) {
- throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
+ /**
+ * Get the preferred location of a split, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def preferredLocations(split: Split): Seq[String] = {
+ if (isCheckpointed) {
+ checkpointData.get.getPreferredLocations(split)
+ } else {
+ getPreferredLocations(split)
}
+ }
- // This is a hack. Ideally this should re-use the code used by the CacheTracker
- // to generate the key.
- def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
-
- persist(level)
- sc.runJob(this, (iter: Iterator[T]) => {} )
-
- val p = this.partitioner
+ /**
+ * Get the array of splits of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def splits: Array[Split] = {
+ if (isCheckpointed) {
+ checkpointData.get.getSplits
+ } else {
+ getSplits
+ }
+ }
- new BlockRDD[T](sc, splits.map(getSplitKey).toArray) {
- override val partitioner = p
+ /**
+ * Get the list of dependencies of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def dependencies: List[Dependency[_]] = {
+ if (isCheckpointed) {
+ dependencies_
+ } else {
+ getDependencies
}
}
@@ -156,8 +173,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* subclasses of RDD.
*/
final def iterator(split: Split, context: TaskContext): Iterator[T] = {
- if (storageLevel != StorageLevel.NONE) {
- SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
+ if (isCheckpointed) {
+ checkpointData.get.iterator(split, context)
+ } else if (storageLevel != StorageLevel.NONE) {
+ SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
compute(split, context)
}
@@ -185,9 +204,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int = splits.size): RDD[T] =
+ def distinct(numSplits: Int): RDD[T] =
map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1)
+ def distinct(): RDD[T] = distinct(splits.size)
+
/**
* Return a sampled subset of this RDD.
*/
@@ -328,6 +349,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def toArray(): Array[T] = collect()
/**
+ * Return an RDD that contains all matching values by applying `f`.
+ */
+ def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = {
+ filter(f.isDefinedAt).map(f)
+ }
+
+ /**
* Reduces the elements of this RDD using the specified associative binary operator.
*/
def reduce(f: (T, T) => T): T = {
@@ -415,6 +443,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): Map[T, Long] = {
+ if (elementClassManifest.erasure.isArray) {
+ throw new SparkException("countByValue() does not support arrays")
+ }
// TODO: This should perhaps be distributed by default.
def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
val map = new OLMap[T]
@@ -443,6 +474,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
timeout: Long,
confidence: Double = 0.95
): PartialResult[Map[T, BoundedDouble]] = {
+ if (elementClassManifest.erasure.isArray) {
+ throw new SparkException("countByValueApprox() does not support arrays")
+ }
val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
val map = new OLMap[T]
while (iter.hasNext) {
@@ -502,8 +536,95 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
.saveAsSequenceFile(path)
}
+ /**
+ * Creates tuples of the elements in this RDD by applying `f`.
+ */
+ def keyBy[K](f: T => K): RDD[(K, T)] = {
+ map(x => (f(x), x))
+ }
+
/** A private method for tests, to look at the contents of each partition */
private[spark] def collectPartitions(): Array[Array[T]] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
}
+
+ /**
+ * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+ * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * RDDs will be removed. This function must be called before any job has been
+ * executed on this RDD. It is strongly recommended that this RDD is persisted in
+ * memory, otherwise saving it on a file will require recomputation.
+ */
+ def checkpoint() {
+ if (context.checkpointDir.isEmpty) {
+ throw new Exception("Checkpoint directory has not been set in the SparkContext")
+ } else if (checkpointData.isEmpty) {
+ checkpointData = Some(new RDDCheckpointData(this))
+ checkpointData.get.markForCheckpoint()
+ }
+ }