@@ -0,0 +1,19 @@
import spark._
object SleepJob {
def main(args: Array[String]) {
if (args.length != 3) {
System.err.println("Usage: SleepJob <master> <tasks> <task_duration>");
System.exit(1)
}
val sc = new SparkContext(args(0), "Sleep job")
val tasks = args(1).toInt
val duration = args(2).toInt
def task {
val start = System.currentTimeMillis
while (System.currentTimeMillis - start < duration * 1000L)
Thread.sleep(200)
}
sc.runTasks(Array.make(tasks, () => task))
}
}
@@ -0,0 +1,138 @@
import java.io.Serializable
import java.util.Random
import cern.jet.math._
import cern.colt.matrix._
import cern.colt.matrix.linalg._
import spark._
object SparkALS {
// Parameters set through command line arguments
var M = 0 // Number of movies
var U = 0 // Number of users
var F = 0 // Number of features
var ITERATIONS = 0
val LAMBDA = 0.01 // Regularization coefficient
// Some COLT objects
val factory2D = DoubleFactory2D.dense
val factory1D = DoubleFactory1D.dense
val algebra = Algebra.DEFAULT
val blas = SeqBlas.seqBlas
def generateR(): DoubleMatrix2D = {
val mh = factory2D.random(M, F)
val uh = factory2D.random(U, F)
return algebra.mult(mh, algebra.transpose(uh))
}
def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D],
us: Array[DoubleMatrix1D]): Double =
{
val r = factory2D.make(M, U)
for (i <- 0 until M; j <- 0 until U) {
r.set(i, j, blas.ddot(ms(i), us(j)))
}
//println("R: " + r)
blas.daxpy(-1, targetR, r)
val sumSqs = r.aggregate(Functions.plus, Functions.square)
return Math.sqrt(sumSqs / (M * U))
}
def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
R: DoubleMatrix2D) : DoubleMatrix1D =
{
val U = us.size
val F = us(0).size
val XtX = factory2D.make(F, F)
val Xty = factory1D.make(F)
// For each user that rated the movie
for (j <- 0 until U) {
val u = us(j)
// Add u * u^t to XtX
blas.dger(1, u, u, XtX)
// Add u * rating to Xty
blas.daxpy(R.get(i, j), u, Xty)
}
// Add regularization coefs to diagonal terms
for (d <- 0 until F) {
XtX.set(d, d, XtX.get(d, d) + LAMBDA * U)
}
// Solve it with Cholesky
val ch = new CholeskyDecomposition(XtX)
val Xty2D = factory2D.make(Xty.toArray, F)
val solved2D = ch.solve(Xty2D)
return solved2D.viewColumn(0)
}
def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
R: DoubleMatrix2D) : DoubleMatrix1D =
{
val M = ms.size
val F = ms(0).size
val XtX = factory2D.make(F, F)
val Xty = factory1D.make(F)
// For each movie that the user rated
for (i <- 0 until M) {
val m = ms(i)
// Add m * m^t to XtX
blas.dger(1, m, m, XtX)
// Add m * rating to Xty
blas.daxpy(R.get(i, j), m, Xty)
}
// Add regularization coefs to diagonal terms
for (d <- 0 until F) {
XtX.set(d, d, XtX.get(d, d) + LAMBDA * M)
}
// Solve it with Cholesky
val ch = new CholeskyDecomposition(XtX)
val Xty2D = factory2D.make(Xty.toArray, F)
val solved2D = ch.solve(Xty2D)
return solved2D.viewColumn(0)
}
def main(args: Array[String]) {
var host = ""
var slices = 0
args match {
case Array(m, u, f, iters, slices_, host_) => {
M = m.toInt
U = u.toInt
F = f.toInt
ITERATIONS = iters.toInt
slices = slices_.toInt
host = host_
}
case _ => {
System.err.println("Usage: SparkALS <M> <U> <F> <iters> <slices> <host>")
System.exit(1)
}
}
printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS);
val spark = new SparkContext(host, "SparkALS")
val R = generateR()
// Initialize m and u randomly
var ms = Array.fromFunction(_ => factory1D.random(F))(M)
var us = Array.fromFunction(_ => factory1D.random(F))(U)
// Iteratively update movies then users
val Rc = spark.broadcast(R)
var msb = spark.broadcast(ms)
var usb = spark.broadcast(us)
for (iter <- 1 to ITERATIONS) {
println("Iteration " + iter + ":")
ms = spark.parallelize(0 until M, slices)
.map(i => updateMovie(i, msb.value(i), usb.value, Rc.value))
.toArray
msb = spark.broadcast(ms) // Re-broadcast ms because it was updated
us = spark.parallelize(0 until U, slices)
.map(i => updateUser(i, usb.value(i), msb.value, Rc.value))
.toArray
usb = spark.broadcast(us) // Re-broadcast us because it was updated
println("RMSE = " + rmse(R, ms, us))
println()
}
}
}
@@ -0,0 +1,50 @@
import java.util.Random
import Vector._
import spark._
object SparkHdfsLR {
val D = 10 // Numer of dimensions
val rand = new Random(42)
case class DataPoint(x: Vector, y: Double)
def parsePoint(line: String): DataPoint = {
//val nums = line.split(' ').map(_.toDouble)
//return DataPoint(new Vector(nums.subArray(1, D+1)), nums(0))
val tok = new java.util.StringTokenizer(line, " ")
var y = tok.nextToken.toDouble
var x = new Array[Double](D)
var i = 0
while (i < D) {
x(i) = tok.nextToken.toDouble; i += 1
}
return DataPoint(new Vector(x), y)
}
def main(args: Array[String]) {
if (args.length < 3) {
System.err.println("Usage: SparkHdfsLR <master> <file> <iters>")
System.exit(1)
}
val sc = new SparkContext(args(0), "SparkHdfsLR")
val lines = sc.textFile(args(1))
val points = lines.map(parsePoint _).cache()
val ITERATIONS = args(2).toInt
// Initialize w to a random value
var w = Vector(D, _ => 2 * rand.nextDouble - 1)
println("Initial w: " + w)
for (i <- 1 to ITERATIONS) {
println("On iteration " + i)
val gradient = sc.accumulator(Vector.zeros(D))
for (p <- points) {
val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y
gradient += scale * p.x
}
w -= gradient.value
}
println("Final w: " + w)
}
}
@@ -0,0 +1,48 @@
import java.util.Random
import Vector._
import spark._
object SparkLR {
val N = 10000 // Number of data points
val D = 10 // Numer of dimensions
val R = 0.7 // Scaling factor
val ITERATIONS = 5
val rand = new Random(42)
case class DataPoint(x: Vector, y: Double)
def generateData = {
def generatePoint(i: Int) = {
val y = if(i % 2 == 0) -1 else 1
val x = Vector(D, _ => rand.nextGaussian + y * R)
DataPoint(x, y)
}
Array.fromFunction(generatePoint _)(N)
}
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: SparkLR <host> [<slices>]")
System.exit(1)
}
val sc = new SparkContext(args(0), "SparkLR")
val numSlices = if (args.length > 1) args(1).toInt else 2
val data = generateData
// Initialize w to a random value
var w = Vector(D, _ => 2 * rand.nextDouble - 1)
println("Initial w: " + w)
for (i <- 1 to ITERATIONS) {
println("On iteration " + i)
val gradient = sc.accumulator(Vector.zeros(D))
for (p <- sc.parallelize(data, numSlices)) {
val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y
gradient += scale * p.x
}
w -= gradient.value
}
println("Final w: " + w)
}
}
@@ -0,0 +1,20 @@
import spark._
import SparkContext._
object SparkPi {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: SparkLR <host> [<slices>]")
System.exit(1)
}
val spark = new SparkContext(args(0), "SparkPi")
val slices = if (args.length > 1) args(1).toInt else 2
var count = spark.accumulator(0)
for (i <- spark.parallelize(1 to 100000, slices)) {
val x = Math.random * 2 - 1
val y = Math.random * 2 - 1
if (x*x + y*y < 1) count += 1
}
println("Pi is roughly " + 4 * count.value / 100000.0)
}
}
@@ -0,0 +1,63 @@
@serializable class Vector(val elements: Array[Double]) {
def length = elements.length
def apply(index: Int) = elements(index)
def + (other: Vector): Vector = {
if (length != other.length)
throw new IllegalArgumentException("Vectors of different length")
return Vector(length, i => this(i) + other(i))
}
def - (other: Vector): Vector = {
if (length != other.length)
throw new IllegalArgumentException("Vectors of different length")
return Vector(length, i => this(i) - other(i))
}
def dot(other: Vector): Double = {
if (length != other.length)
throw new IllegalArgumentException("Vectors of different length")
var ans = 0.0
for (i <- 0 until length)
ans += this(i) * other(i)
return ans
}
def * ( scale: Double): Vector = Vector(length, i => this(i) * scale)
def unary_- = this * -1
def sum = elements.reduceLeft(_ + _)
override def toString = elements.mkString("(", ", ", ")")
}
object Vector {
def apply(elements: Array[Double]) = new Vector(elements)
def apply(elements: Double*) = new Vector(elements.toArray)
def apply(length: Int, initializer: Int => Double): Vector = {
val elements = new Array[Double](length)
for (i <- 0 until length)
elements(i) = initializer(i)
return new Vector(elements)
}
def zeros(length: Int) = new Vector(new Array[Double](length))
def ones(length: Int) = Vector(length, _ => 1)
class Multiplier(num: Double) {
def * (vec: Vector) = vec * num
}
implicit def doubleToMultiplier(num: Double) = new Multiplier(num)
implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] {
def add(t1: Vector, t2: Vector) = t1 + t2
def zero(initialValue: Vector) = Vector.zeros(initialValue.length)
}
}
@@ -0,0 +1,27 @@
package spark.compress.lzf;
public class LZF {
private static boolean loaded;
static {
try {
System.loadLibrary("spark_native");
loaded = true;
} catch(Throwable t) {
System.out.println("Failed to load native LZF library: " + t.toString());
loaded = false;
}
}
public static boolean isLoaded() {
return loaded;
}
public static native int compress(
byte[] in, int inOff, int inLen,
byte[] out, int outOff, int outLen);
public static native int decompress(
byte[] in, int inOff, int inLen,
byte[] out, int outOff, int outLen);
}
@@ -0,0 +1,180 @@
package spark.compress.lzf;
import java.io.EOFException;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
public class LZFInputStream extends FilterInputStream {
private static final int MAX_BLOCKSIZE = 1024 * 64 - 1;
private static final int MAX_HDR_SIZE = 7;
private byte[] inBuf; // Holds data to decompress (including header)
private byte[] outBuf; // Holds decompressed data to output
private int outPos; // Current position in outBuf
private int outSize; // Total amount of data in outBuf
private boolean closed;
private boolean reachedEof;
private byte[] singleByte = new byte[1];
public LZFInputStream(InputStream in) {
super(in);
if (in == null)
throw new NullPointerException();
inBuf = new byte[MAX_BLOCKSIZE + MAX_HDR_SIZE];
outBuf = new byte[MAX_BLOCKSIZE + MAX_HDR_SIZE];
outPos = 0;
outSize = 0;
}
private void ensureOpen() throws IOException {
if (closed) throw new IOException("Stream closed");
}
@Override
public int read() throws IOException {
ensureOpen();
int count = read(singleByte, 0, 1);
return (count == -1 ? -1 : singleByte[0] & 0xFF);
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
ensureOpen();
if ((off | len | (off + len) | (b.length - (off + len))) < 0)
throw new IndexOutOfBoundsException();
int totalRead = 0;
// Start with the current block in outBuf, and read and decompress any
// further blocks necessary. Instead of trying to decompress directly to b
// when b is large, we always use outBuf as an intermediate holding space
// in case GetPrimitiveArrayCritical decides to copy arrays instead of
// pinning them, which would cause b to be copied repeatedly into C-land.
while (len > 0) {
if (outPos == outSize) {
readNextBlock();
if (reachedEof)
return totalRead == 0 ? -1 : totalRead;
}
int amtToCopy = Math.min(outSize - outPos, len);
System.arraycopy(outBuf, outPos, b, off, amtToCopy);
off += amtToCopy;
len -= amtToCopy;
outPos += amtToCopy;
totalRead += amtToCopy;
}
return totalRead;
}
// Read len bytes from this.in to a buffer, stopping only if EOF is reached
private int readFully(byte[] b, int off, int len) throws IOException {
int totalRead = 0;
while (len > 0) {
int amt = in.read(b, off, len);
if (amt == -1)
break;
off += amt;
len -= amt;
totalRead += amt;
}
return totalRead;
}
// Read the next block from the underlying InputStream into outBuf,
// setting outPos and outSize, or set reachedEof if the stream ends.
private void readNextBlock() throws IOException {
// Read first 5 bytes of header
int count = readFully(inBuf, 0, 5);
if (count == 0) {
reachedEof = true;
return;
} else if (count < 5) {
throw new EOFException("Truncated LZF block header");
}
// Check magic bytes
if (inBuf[0] != 'Z' || inBuf[1] != 'V')
throw new IOException("Wrong magic bytes in LZF block header");
// Read the block
if (inBuf[2] == 0) {
// Uncompressed block - read directly to outBuf
int size = ((inBuf[3] & 0xFF) << 8) | (inBuf[4] & 0xFF);
if (readFully(outBuf, 0, size) != size)
throw new EOFException("EOF inside LZF block");
outPos = 0;
outSize = size;
} else if (inBuf[2] == 1) {
// Compressed block - read to inBuf and decompress
if (readFully(inBuf, 5, 2) != 2)
throw new EOFException("Truncated LZF block header");
int csize = ((inBuf[3] & 0xFF) << 8) | (inBuf[4] & 0xFF);
int usize = ((inBuf[5] & 0xFF) << 8) | (inBuf[6] & 0xFF);
if (readFully(inBuf, 7, csize) != csize)
throw new EOFException("Truncated LZF block");
if (LZF.decompress(inBuf, 7, csize, outBuf, 0, usize) != usize)
throw new IOException("Corrupt LZF data stream");
outPos = 0;
outSize = usize;
} else {
throw new IOException("Unknown block type in LZF block header");
}
}
/**
* Returns 0 after EOF has been reached, otherwise always return 1.
*
* Programs should not count on this method to return the actual number
* of bytes that could be read without blocking.
*/
@Override
public int available() throws IOException {
ensureOpen();
return reachedEof ? 0 : 1;
}
// TODO: Skip complete chunks without decompressing them?
@Override
public long skip(long n) throws IOException {
ensureOpen();
if (n < 0)
throw new IllegalArgumentException("negative skip length");
byte[] buf = new byte[512];
long skipped = 0;
while (skipped < n) {
int len = (int) Math.min(n - skipped, buf.length);
len = read(buf, 0, len);
if (len == -1) {
reachedEof = true;
break;
}
skipped += len;
}
return skipped;
}
@Override
public void close() throws IOException {
if (!closed) {
in.close();
closed = true;
}
}
@Override
public boolean markSupported() {
return false;
}
@Override
public void mark(int readLimit) {}
@Override
public void reset() throws IOException {
throw new IOException("mark/reset not supported");
}
}
@@ -0,0 +1,85 @@
package spark.compress.lzf;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.OutputStream;
public class LZFOutputStream extends FilterOutputStream {
private static final int BLOCKSIZE = 1024 * 64 - 1;
private static final int MAX_HDR_SIZE = 7;
private byte[] inBuf; // Holds input data to be compressed
private byte[] outBuf; // Holds compressed data to be written
private int inPos; // Current position in inBuf
public LZFOutputStream(OutputStream out) {
super(out);
inBuf = new byte[BLOCKSIZE + MAX_HDR_SIZE];
outBuf = new byte[BLOCKSIZE + MAX_HDR_SIZE];
inPos = MAX_HDR_SIZE;
}
@Override
public void write(int b) throws IOException {
inBuf[inPos++] = (byte) b;
if (inPos == inBuf.length)
compressAndSendBlock();
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
if ((off | len | (off + len) | (b.length - (off + len))) < 0)
throw new IndexOutOfBoundsException();
// If we're given a large array, copy it piece by piece into inBuf and
// write one BLOCKSIZE at a time. This is done to prevent the JNI code
// from copying the whole array repeatedly if GetPrimitiveArrayCritical
// decides to copy instead of pinning.
while (inPos + len >= inBuf.length) {
int amtToCopy = inBuf.length - inPos;
System.arraycopy(b, off, inBuf, inPos, amtToCopy);
inPos += amtToCopy;
compressAndSendBlock();
off += amtToCopy;
len -= amtToCopy;
}
// Copy the remaining (incomplete) block into inBuf
System.arraycopy(b, off, inBuf, inPos, len);
inPos += len;
}
@Override
public void flush() throws IOException {
if (inPos > MAX_HDR_SIZE)
compressAndSendBlock();
out.flush();
}
// Send the data in inBuf, and reset inPos to start writing a new block.
private void compressAndSendBlock() throws IOException {
int us = inPos - MAX_HDR_SIZE;
int maxcs = us > 4 ? us - 4 : us;
int cs = LZF.compress(inBuf, MAX_HDR_SIZE, us, outBuf, MAX_HDR_SIZE, maxcs);
if (cs != 0) {
// Compression made the data smaller; use type 1 header
outBuf[0] = 'Z';
outBuf[1] = 'V';
outBuf[2] = 1;
outBuf[3] = (byte) (cs >> 8);
outBuf[4] = (byte) (cs & 0xFF);
outBuf[5] = (byte) (us >> 8);
outBuf[6] = (byte) (us & 0xFF);
out.write(outBuf, 0, 7 + cs);
} else {
// Compression didn't help; use type 0 header and uncompressed data
inBuf[2] = 'Z';
inBuf[3] = 'V';
inBuf[4] = 0;
inBuf[5] = (byte) (us >> 8);
inBuf[6] = (byte) (us & 0xFF);
out.write(inBuf, 2, 5 + us);
}
inPos = MAX_HDR_SIZE;
}
}
@@ -0,0 +1,30 @@
CC = gcc
#JAVA_HOME = /usr/lib/jvm/java-6-sun
OS_NAME = linux
CFLAGS = -fPIC -O3 -funroll-all-loops
SPARK = ../..
LZF = $(SPARK)/third_party/liblzf-3.5
LIB = libspark_native.so
all: $(LIB)
spark_compress_lzf_LZF.h: $(SPARK)/classes/spark/compress/lzf/LZF.class
ifeq ($(JAVA_HOME),)
$(error JAVA_HOME is not set)
else
$(JAVA_HOME)/bin/javah -classpath $(SPARK)/classes spark.compress.lzf.LZF
endif
$(LIB): spark_compress_lzf_LZF.h spark_compress_lzf_LZF.c
$(CC) $(CFLAGS) -shared -o $@ spark_compress_lzf_LZF.c \
-I $(JAVA_HOME)/include -I $(JAVA_HOME)/include/$(OS_NAME) \
-I $(LZF) $(LZF)/lzf_c.c $(LZF)/lzf_d.c
clean:
rm -f spark_compress_lzf_LZF.h $(LIB)
.PHONY: all clean
@@ -0,0 +1,90 @@
#include "spark_compress_lzf_LZF.h"
#include <lzf.h>
/* Helper function to throw an exception */
static void throwException(JNIEnv *env, const char* className) {
jclass cls = (*env)->FindClass(env, className);
if (cls != 0) /* If cls is null, an exception was already thrown */
(*env)->ThrowNew(env, cls, "");
}
/*
* Since LZF.compress() and LZF.decompress() have the same signatures
* and differ only in which lzf_ function they call, implement both in a
* single function and pass it a pointer to the correct lzf_ function.
*/
static jint callCompressionFunction
(unsigned int (*func)(const void *const, unsigned int, void *, unsigned int),
JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen,
jbyteArray outArray, jint outOff, jint outLen)
{
jint inCap;
jint outCap;
jbyte *inData = 0;
jbyte *outData = 0;
jint ret;
jint s;
if (!inArray || !outArray) {
throwException(env, "java/lang/NullPointerException");
goto cleanup;
}
inCap = (*env)->GetArrayLength(env, inArray);
outCap = (*env)->GetArrayLength(env, outArray);
// Check if any of the offset/length pairs is invalid; we do this by OR'ing
// things we don't want to be negative and seeing if the result is negative
s = inOff | inLen | (inOff + inLen) | (inCap - (inOff + inLen)) |
outOff | outLen | (outOff + outLen) | (outCap - (outOff + outLen));
if (s < 0) {
throwException(env, "java/lang/IndexOutOfBoundsException");
goto cleanup;
}
inData = (*env)->GetPrimitiveArrayCritical(env, inArray, 0);
outData = (*env)->GetPrimitiveArrayCritical(env, outArray, 0);
if (!inData || !outData) {
// Out of memory - JVM will throw OutOfMemoryError
goto cleanup;
}
ret = func(inData + inOff, inLen, outData + outOff, outLen);
cleanup:
if (inData)
(*env)->ReleasePrimitiveArrayCritical(env, inArray, inData, 0);
if (outData)
(*env)->ReleasePrimitiveArrayCritical(env, outArray, outData, 0);
return ret;
}
/*
* Class: spark_compress_lzf_LZF
* Method: compress
* Signature: ([B[B)I
*/
JNIEXPORT jint JNICALL Java_spark_compress_lzf_LZF_compress
(JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen,
jbyteArray outArray, jint outOff, jint outLen)
{
return callCompressionFunction(lzf_compress, env, cls,
inArray, inOff, inLen, outArray,outOff, outLen);
}
/*
* Class: spark_compress_lzf_LZF
* Method: decompress
* Signature: ([B[B)I
*/
JNIEXPORT jint JNICALL Java_spark_compress_lzf_LZF_decompress
(JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen,
jbyteArray outArray, jint outOff, jint outLen)
{
return callCompressionFunction(lzf_decompress, env, cls,
inArray, inOff, inLen, outArray,outOff, outLen);
}
@@ -0,0 +1,71 @@
package spark
import java.io._
import scala.collection.mutable.Map
@serializable class Accumulator[T](initialValue: T, param: AccumulatorParam[T])
{
val id = Accumulators.newId
@transient var value_ = initialValue
var deserialized = false
Accumulators.register(this)
def += (term: T) { value_ = param.add(value_, term) }
def value = this.value_
def value_= (t: T) {
if (!deserialized) value_ = t
else throw new UnsupportedOperationException("Can't use value_= in task")
}
// Called by Java when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject
value_ = param.zero(initialValue)
deserialized = true
Accumulators.register(this)
}
override def toString = value_.toString
}
@serializable trait AccumulatorParam[T] {
def add(t1: T, t2: T): T
def zero(initialValue: T): T
}
// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private object Accumulators
{
// TODO: Use soft references? => need to make readObject work properly then
val accums = Map[(Thread, Long), Accumulator[_]]()
var lastId: Long = 0
def newId: Long = synchronized { lastId += 1; return lastId }
def register(a: Accumulator[_]): Unit = synchronized {
accums((currentThread, a.id)) = a
}
def clear: Unit = synchronized {
accums.retain((key, accum) => key._1 != currentThread)
}
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for(((thread, id), accum) <- accums if thread == currentThread)
ret(id) = accum.value
return ret
}
def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
if (accums.contains((thread, id))) {
val accum = accums((thread, id))
accum.asInstanceOf[Accumulator[Any]] += value
}
}
}
}
@@ -0,0 +1,110 @@
package spark
import java.io._
import java.net.URI
import java.util.UUID
import com.google.common.collect.MapMaker
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
import spark.compress.lzf.{LZFInputStream, LZFOutputStream}
@serializable class Cached[T](@transient var value_ : T, local: Boolean) {
val uuid = UUID.randomUUID()
def value = value_
Cache.synchronized { Cache.values.put(uuid, value_) }
if (!local) writeCacheFile()
private def writeCacheFile() {
val out = new ObjectOutputStream(Cache.openFileForWriting(uuid))
out.writeObject(value_)
out.close()
}
// Called by Java when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject
Cache.synchronized {
val cachedVal = Cache.values.get(uuid)
if (cachedVal != null) {
value_ = cachedVal.asInstanceOf[T]
} else {
val start = System.nanoTime
val fileIn = new ObjectInputStream(Cache.openFileForReading(uuid))
value_ = fileIn.readObject().asInstanceOf[T]
Cache.values.put(uuid, value_)
fileIn.close()
val time = (System.nanoTime - start) / 1e9
println("Reading cached variable " + uuid + " took " + time + " s")
}
}
}
override def toString = "spark.Cached(" + uuid + ")"
}
private object Cache {
val values = new MapMaker().softValues().makeMap[UUID, Any]()
private var initialized = false
private var fileSystem: FileSystem = null
private var workDir: String = null
private var compress: Boolean = false
private var bufferSize: Int = 65536
// Will be called by SparkContext or Executor before using cache
def initialize() {
synchronized {
if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val dfs = System.getProperty("spark.dfs", "file:///")
if (!dfs.startsWith("file://")) {
val conf = new Configuration()
conf.setInt("io.file.buffer.size", bufferSize)
val rep = System.getProperty("spark.dfs.replication", "3").toInt
conf.setInt("dfs.replication", rep)
fileSystem = FileSystem.get(new URI(dfs), conf)
}
workDir = System.getProperty("spark.dfs.workdir", "/tmp")
compress = System.getProperty("spark.compress", "false").toBoolean
initialized = true
}
}
}
private def getPath(uuid: UUID) = new Path(workDir + "/cache-" + uuid)
def openFileForReading(uuid: UUID): InputStream = {
val fileStream = if (fileSystem != null) {
fileSystem.open(getPath(uuid))
} else {
// Local filesystem
new FileInputStream(getPath(uuid).toString)
}
if (compress)
new LZFInputStream(fileStream) // LZF stream does its own buffering
else if (fileSystem == null)
new BufferedInputStream(fileStream, bufferSize)
else
fileStream // Hadoop streams do their own buffering
}
def openFileForWriting(uuid: UUID): OutputStream = {
val fileStream = if (fileSystem != null) {
fileSystem.create(getPath(uuid))
} else {
// Local filesystem
new FileOutputStream(getPath(uuid).toString)
}
if (compress)
new LZFOutputStream(fileStream) // LZF stream does its own buffering
else if (fileSystem == null)
new BufferedOutputStream(fileStream, bufferSize)
else
fileStream // Hadoop streams do their own buffering
}
}
@@ -0,0 +1,157 @@
package spark
import scala.collection.mutable.Map
import scala.collection.mutable.Set
import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._
object ClosureCleaner {
private def getClassReader(cls: Class[_]): ClassReader = new ClassReader(
cls.getResourceAsStream(cls.getName.replaceFirst("^.*\\.", "") + ".class"))
private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
f.setAccessible(true)
return f.getType :: getOuterClasses(f.get(obj))
}
return Nil
}
private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
f.setAccessible(true)
return f.get(obj) :: getOuterObjects(f.get(obj))
}
return Nil
}
private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
val seen = Set[Class[_]](obj.getClass)
var stack = List[Class[_]](obj.getClass)
while (!stack.isEmpty) {
val cr = getClassReader(stack.head)
stack = stack.tail
val set = Set[Class[_]]()
cr.accept(new InnerClosureFinder(set), 0)
for (cls <- set -- seen) {
seen += cls
stack = cls :: stack
}
}
return (seen - obj.getClass).toList
}
private def createNullValue(cls: Class[_]): AnyRef = {
if (cls.isPrimitive)
new java.lang.Byte(0: Byte) // Should be convertible to any primitive type
else
null
}
def clean(func: AnyRef): Unit = {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
val outerObjects = getOuterObjects(func)
val accessedFields = Map[Class[_], Set[String]]()
for (cls <- outerClasses)
accessedFields(cls) = Set[String]()
for (cls <- func.getClass :: innerClasses)
getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
var outer: AnyRef = null
for ((cls, obj) <- (outerClasses zip outerObjects).reverse) {
outer = instantiateClass(cls, outer);
for (fieldName <- accessedFields(cls)) {
val field = cls.getDeclaredField(fieldName)
field.setAccessible(true)
val value = field.get(obj)
//println("1: Setting " + fieldName + " on " + cls + " to " + value);
field.set(outer, value)
}
}
if (outer != null) {
//println("2: Setting $outer on " + func.getClass + " to " + outer);
val field = func.getClass.getDeclaredField("$outer")
field.setAccessible(true)
field.set(func, outer)
}
}
private def instantiateClass(cls: Class[_], outer: AnyRef): AnyRef = {
if (spark.repl.Main.interp == null) {
// This is a bona fide closure class, whose constructor has no effects
// other than to set its fields, so use its constructor
val cons = cls.getConstructors()(0)
val params = cons.getParameterTypes.map(createNullValue).toArray
if (outer != null)
params(0) = outer // First param is always outer object
return cons.newInstance(params: _*).asInstanceOf[AnyRef]
} else {
// Use reflection to instantiate object without calling constructor
val rf = sun.reflect.ReflectionFactory.getReflectionFactory();
val parentCtor = classOf[java.lang.Object].getDeclaredConstructor();
val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
val obj = newCtor.newInstance().asInstanceOf[AnyRef];
if (outer != null) {
//println("3: Setting $outer on " + cls + " to " + outer);
val field = cls.getDeclaredField("$outer")
field.setAccessible(true)
field.set(obj, outer)
}
return obj
}
}
}
class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
return new EmptyVisitor {
override def visitFieldInsn(op: Int, owner: String, name: String,
desc: String) {
if (op == GETFIELD)
for (cl <- output.keys if cl.getName == owner.replace('/', '.'))
output(cl) += name
}
override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer"))
for (cl <- output.keys if cl.getName == owner.replace('/', '.'))
output(cl) += name
}
}
}
}
class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
var myName: String = null
override def visit(version: Int, access: Int, name: String, sig: String,
superName: String, interfaces: Array[String]) {
myName = name
}
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
return new EmptyVisitor {
override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
val argTypes = Type.getArgumentTypes(desc)
if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0
&& argTypes(0).toString.startsWith("L") // is it an object?
&& argTypes(0).getInternalName == myName)
output += Class.forName(owner.replace('/', '.'), false,
Thread.currentThread.getContextClassLoader)
}
}
}
}
@@ -0,0 +1,70 @@
package spark
import java.util.concurrent.{Executors, ExecutorService}
import nexus.{ExecutorArgs, TaskDescription, TaskState, TaskStatus}
object Executor {
def main(args: Array[String]) {
System.loadLibrary("nexus")
val exec = new nexus.Executor() {
var classLoader: ClassLoader = null
var threadPool: ExecutorService = null
override def init(args: ExecutorArgs) {
// Read spark.* system properties
val props = Utils.deserialize[Array[(String, String)]](args.getData)
for ((key, value) <- props)
System.setProperty(key, value)
// Initialize cache (uses some properties read above)
Cache.initialize()
// If the REPL is in use, create a ClassLoader that will be able to
// read new classes defined by the REPL as the user types code
classLoader = this.getClass.getClassLoader
val classDir = System.getProperty("spark.repl.classdir")
if (classDir != null) {
println("Using REPL classdir: " + classDir)
classLoader = new repl.ExecutorClassLoader(classDir, classLoader)
}
Thread.currentThread.setContextClassLoader(classLoader)
// Start worker thread pool (they will inherit our context ClassLoader)
threadPool = Executors.newCachedThreadPool()
}
override def startTask(desc: TaskDescription) {
// Pull taskId and arg out of TaskDescription because it won't be a
// valid pointer after this method call (TODO: fix this in C++/SWIG)
val taskId = desc.getTaskId
val arg = desc.getArg
threadPool.execute(new Runnable() {
def run() = {
println("Running task ID " + taskId)
try {
Accumulators.clear
val task = Utils.deserialize[Task[Any]](arg, classLoader)
val value = task.run
val accumUpdates = Accumulators.values
val result = new TaskResult(value, accumUpdates)
sendStatusUpdate(new TaskStatus(
taskId, TaskState.TASK_FINISHED, Utils.serialize(result)))
println("Finished task ID " + taskId)
} catch {
case e: Exception => {
// TODO: Handle errors in tasks less dramatically
System.err.println("Exception in task ID " + taskId + ":")
e.printStackTrace
System.exit(1)
}
}
}
})
}
}
exec.run()
}
}
@@ -0,0 +1,277 @@
package spark
import java.io._
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.ConcurrentHashMap
import java.util.HashSet
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.Map
import nexus._
import com.google.common.collect.MapMaker
import org.apache.hadoop.io.ObjectWritable
import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileInputFormat
import org.apache.hadoop.mapred.InputSplit
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
@serializable
abstract class DistributedFile[T, Split](@transient sc: SparkContext) {
def splits: Array[Split]
def iterator(split: Split): Iterator[T]
def prefers(split: Split, slot: SlaveOffer): Boolean
def taskStarted(split: Split, slot: SlaveOffer) {}
def sparkContext = sc
def foreach(f: T => Unit) {
val cleanF = sc.clean(f)
val tasks = splits.map(s => new ForeachTask(this, s, cleanF)).toArray
sc.runTaskObjects(tasks)
}
def toArray: Array[T] = {
val tasks = splits.map(s => new GetTask(this, s))
val results = sc.runTaskObjects(tasks)
Array.concat(results: _*)
}
def reduce(f: (T, T) => T): T = {
val cleanF = sc.clean(f)
val tasks = splits.map(s => new ReduceTask(this, s, f))
val results = new ArrayBuffer[T]
for (option <- sc.runTaskObjects(tasks); elem <- option)
results += elem
if (results.size == 0)
throw new UnsupportedOperationException("empty collection")
else
return results.reduceLeft(f)
}
def take(num: Int): Array[T] = {
if (num == 0)
return new Array[T](0)
val buf = new ArrayBuffer[T]
for (split <- splits; elem <- iterator(split)) {
buf += elem
if (buf.length == num)
return buf.toArray
}
return buf.toArray
}
def first: T = take(1) match {
case Array(t) => t
case _ => throw new UnsupportedOperationException("empty collection")
}
def map[U](f: T => U) = new MappedFile(this, sc.clean(f))
def filter(f: T => Boolean) = new FilteredFile(this, sc.clean(f))
def cache() = new CachedFile(this)
def count(): Long =
try { map(x => 1L).reduce(_+_) }
catch { case e: UnsupportedOperationException => 0L }
}
@serializable
abstract class FileTask[U, T, Split](val file: DistributedFile[T, Split],
val split: Split)
extends Task[U] {
override def prefers(slot: SlaveOffer) = file.prefers(split, slot)
override def markStarted(slot: SlaveOffer) { file.taskStarted(split, slot) }
}
class ForeachTask[T, Split](file: DistributedFile[T, Split],
split: Split, func: T => Unit)
extends FileTask[Unit, T, Split](file, split) {
override def run() {
println("Processing " + split)
file.iterator(split).foreach(func)
}
}
class GetTask[T, Split](file: DistributedFile[T, Split], split: Split)
extends FileTask[Array[T], T, Split](file, split) {
override def run(): Array[T] = {
println("Processing " + split)
file.iterator(split).collect.toArray
}
}
class ReduceTask[T, Split](file: DistributedFile[T, Split],
split: Split, f: (T, T) => T)
extends FileTask[Option[T], T, Split](file, split) {
override def run(): Option[T] = {
println("Processing " + split)
val iter = file.iterator(split)
if (iter.hasNext)
Some(iter.reduceLeft(f))
else
None
}
}
class MappedFile[U, T, Split](prev: DistributedFile[T, Split], f: T => U)
extends DistributedFile[U, Split](prev.sparkContext) {
override def splits = prev.splits
override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot)
override def iterator(split: Split) = prev.iterator(split).map(f)
override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
}
class FilteredFile[T, Split](prev: DistributedFile[T, Split], f: T => Boolean)
extends DistributedFile[T, Split](prev.sparkContext) {
override def splits = prev.splits
override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot)
override def iterator(split: Split) = prev.iterator(split).filter(f)
override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
}
class CachedFile[T, Split](prev: DistributedFile[T, Split])
extends DistributedFile[T, Split](prev.sparkContext) {
val id = CachedFile.newId()
@transient val cacheLocs = Map[Split, List[Int]]()
override def splits = prev.splits
override def prefers(split: Split, slot: SlaveOffer): Boolean = {
if (cacheLocs.contains(split))
cacheLocs(split).contains(slot.getSlaveId)
else
prev.prefers(split, slot)
}
override def iterator(split: Split): Iterator[T] = {
val key = id + "::" + split.toString
val cache = CachedFile.cache
val loading = CachedFile.loading
val cachedVal = cache.get(key)
if (cachedVal != null) {
// Split is in cache, so just return its values
return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]])
} else {
// Mark the split as loading (unless someone else marks it first)
loading.synchronized {
if (loading.contains(key)) {
while (loading.contains(key)) {
try {loading.wait()} catch {case _ =>}
}
return Iterator.fromArray(cache.get(key).asInstanceOf[Array[T]])
} else {
loading.add(key)
}
}
// If we got here, we have to load the split
println("Loading and caching " + split)
val array = prev.iterator(split).collect.toArray
cache.put(key, array)
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
return Iterator.fromArray(array)
}
}
override def taskStarted(split: Split, slot: SlaveOffer) {
val oldList = cacheLocs.getOrElse(split, Nil)
val slaveId = slot.getSlaveId
if (!oldList.contains(slaveId))
cacheLocs(split) = slaveId :: oldList
}
}
private object CachedFile {
val nextId = new AtomicLong(0) // Generates IDs for mapped files (on master)
def newId() = nextId.getAndIncrement()
// Stores map results for various splits locally (on workers)
val cache = new MapMaker().softValues().makeMap[String, AnyRef]()
// Remembers which splits are currently being loaded (on workers)
val loading = new HashSet[String]
}
class HdfsSplit(@transient s: InputSplit)
extends SerializableWritable[InputSplit](s)
class HdfsTextFile(sc: SparkContext, path: String)
extends DistributedFile[String, HdfsSplit](sc) {
@transient val conf = new JobConf()
@transient val inputFormat = new TextInputFormat()
FileInputFormat.setInputPaths(conf, path)
ConfigureLock.synchronized { inputFormat.configure(conf) }
@transient val splits_ =
inputFormat.getSplits(conf, 2).map(new HdfsSplit(_)).toArray
override def splits = splits_
override def iterator(split: HdfsSplit) = new Iterator[String] {
var reader: RecordReader[LongWritable, Text] = null
ConfigureLock.synchronized {
val conf = new JobConf()
conf.set("io.file.buffer.size",
System.getProperty("spark.buffer.size", "65536"))
val tif = new TextInputFormat()
tif.configure(conf)
reader = tif.getRecordReader(split.value, conf, Reporter.NULL)
}
val lineNum = new LongWritable()
val text = new Text()
var gotNext = false
var finished = false
override def hasNext: Boolean = {
if (!gotNext) {
finished = !reader.next(lineNum, text)
gotNext = true
}
!finished
}
override def next: String = {
if (!gotNext)
finished = !reader.next(lineNum, text)
if (finished)
throw new java.util.NoSuchElementException("end of stream")
gotNext = false
text.toString
}
}
override def prefers(split: HdfsSplit, slot: SlaveOffer) =
split.value.getLocations().contains(slot.getHost)
}
object ConfigureLock {}
@serializable
class SerializableWritable[T <: Writable](@transient var t: T) {
def value = t
override def toString = t.toString
private def writeObject(out: ObjectOutputStream) {
out.defaultWriteObject()
new ObjectWritable(t).write(out)
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
val ow = new ObjectWritable()
ow.setConf(new JobConf())
ow.readFields(in)
t = ow.get().asInstanceOf[T]
}
}
@@ -0,0 +1,65 @@
package spark
import java.util.concurrent._
import scala.collection.mutable.Map
// A simple Scheduler implementation that runs tasks locally in a thread pool.
private class LocalScheduler(threads: Int) extends Scheduler {
var threadPool: ExecutorService =
Executors.newFixedThreadPool(threads, DaemonThreadFactory)
override def start() {}
override def waitForRegister() {}
override def runTasks[T](tasks: Array[Task[T]]): Array[T] = {
val futures = new Array[Future[TaskResult[T]]](tasks.length)
for (i <- 0 until tasks.length) {
futures(i) = threadPool.submit(new Callable[TaskResult[T]]() {
def call(): TaskResult[T] = {
println("Running task " + i)
try {
// Serialize and deserialize the task so that accumulators are
// changed to thread-local ones; this adds a bit of unnecessary
// overhead but matches how the Nexus Executor works
Accumulators.clear
val bytes = Utils.serialize(tasks(i))
println("Size of task " + i + " is " + bytes.size + " bytes")
val task = Utils.deserialize[Task[T]](
bytes, currentThread.getContextClassLoader)
val value = task.run
val accumUpdates = Accumulators.values
println("Finished task " + i)
new TaskResult[T](value, accumUpdates)
} catch {
case e: Exception => {
// TODO: Do something nicer here
System.err.println("Exception in task " + i + ":")
e.printStackTrace()
System.exit(1)
null
}
}
}
})
}
val taskResults = futures.map(_.get)
for (result <- taskResults)
Accumulators.add(currentThread, result.accumUpdates)
return taskResults.map(_.value).toArray
}
override def stop() {}
}
// A ThreadFactory that creates daemon threads
private object DaemonThreadFactory extends ThreadFactory {
override def newThread(r: Runnable): Thread = {
val t = new Thread(r);
t.setDaemon(true)
return t
}
}
@@ -0,0 +1,258 @@
package spark
import java.io.File
import java.util.concurrent.Semaphore
import nexus.{ExecutorInfo, TaskDescription, TaskState, TaskStatus}
import nexus.{SlaveOffer, SchedulerDriver, NexusSchedulerDriver}
import nexus.{SlaveOfferVector, TaskDescriptionVector, StringMap}
// The main Scheduler implementation, which talks to Nexus. Clients are expected
// to first call start(), then submit tasks through the runTasks method.
//
// This implementation is currently a little quick and dirty. The following
// improvements need to be made to it:
// 1) Fault tolerance should be added - if a task fails, just re-run it anywhere.
// 2) Right now, the scheduler uses a linear scan through the tasks to find a
// local one for a given node. It would be faster to have a separate list of
// pending tasks for each node.
// 3) The Callbacks way of organizing things didn't work out too well, so the
// way the scheduler keeps track of the currently active runTasks operation
// can be made cleaner.
private class NexusScheduler(
master: String, frameworkName: String, execArg: Array[Byte])
extends nexus.Scheduler with spark.Scheduler
{
// Semaphore used by runTasks to ensure only one thread can be in it
val semaphore = new Semaphore(1)
// Lock used to wait for scheduler to be registered
var isRegistered = false
val registeredLock = new Object()
// Trait representing a set of scheduler callbacks
trait Callbacks {
def slotOffer(s: SlaveOffer): Option[TaskDescription]
def taskFinished(t: TaskStatus): Unit
def error(code: Int, message: String): Unit
}
// Current callback object (may be null)
var callbacks: Callbacks = null
// Incrementing task ID
var nextTaskId = 0
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong
// Driver for talking to Nexus
var driver: SchedulerDriver = null
override def start() {
new Thread("Spark scheduler") {
setDaemon(true)
override def run {
val ns = NexusScheduler.this
ns.driver = new NexusSchedulerDriver(ns, master)
ns.driver.run()
}
}.start
}
override def getFrameworkName(d: SchedulerDriver): String = frameworkName
override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo =
new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg)
override def runTasks[T](tasks: Array[Task[T]]): Array[T] = {
val results = new Array[T](tasks.length)
if (tasks.length == 0)
return results
val launched = new Array[Boolean](tasks.length)
val callingThread = currentThread
var errorHappened = false
var errorCode = 0
var errorMessage = ""
// Wait for scheduler to be registered with Nexus
waitForRegister()
try {
// Acquire the runTasks semaphore
semaphore.acquire()
val myCallbacks = new Callbacks {
val firstTaskId = nextTaskId
var tasksLaunched = 0
var tasksFinished = 0
var lastPreferredLaunchTime = System.currentTimeMillis
def slotOffer(slot: SlaveOffer): Option[TaskDescription] = {
try {
if (tasksLaunched < tasks.length) {
// TODO: Add a short wait if no task with location pref is found
// TODO: Figure out why a function is needed around this to
// avoid scala.runtime.NonLocalReturnException
def findTask: Option[TaskDescription] = {
var checkPrefVals: Array[Boolean] = Array(true)
val time = System.currentTimeMillis
if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
checkPrefVals = Array(true, false) // Allow non-preferred tasks
// TODO: Make desiredCpus and desiredMem configurable
val desiredCpus = 1
val desiredMem = 750L * 1024L * 1024L
if (slot.getParams.get("cpus").toInt < desiredCpus ||
slot.getParams.get("mem").toLong < desiredMem)
return None
for (checkPref <- checkPrefVals;
i <- 0 until tasks.length;
if !launched(i) && (!checkPref || tasks(i).prefers(slot)))
{
val taskId = nextTaskId
nextTaskId += 1
printf("Starting task %d as TID %d on slave %d: %s (%s)\n",
i, taskId, slot.getSlaveId, slot.getHost,
if(checkPref) "preferred" else "non-preferred")
tasks(i).markStarted(slot)
launched(i) = true
tasksLaunched += 1
if (checkPref)
lastPreferredLaunchTime = time
val params = new StringMap
params.set("cpus", "" + desiredCpus)
params.set("mem", "" + desiredMem)
val serializedTask = Utils.serialize(tasks(i))
return Some(new TaskDescription(taskId, slot.getSlaveId,
"task_" + taskId, params, serializedTask))
}
return None
}
return findTask
} else {
return None
}
} catch {
case e: Exception => {
e.printStackTrace
System.exit(1)
return None
}
}
}
def taskFinished(status: TaskStatus) {
println("Finished TID " + status.getTaskId)
// Deserialize task result
val result = Utils.deserialize[TaskResult[T]](status.getData)
results(status.getTaskId - firstTaskId) = result.value
// Update accumulators
Accumulators.add(callingThread, result.accumUpdates)
// Stop if we've finished all the tasks
tasksFinished += 1
if (tasksFinished == tasks.length) {
NexusScheduler.this.callbacks = null
NexusScheduler.this.notifyAll()
}
}
def error(code: Int, message: String) {
// Save the error message
errorHappened = true
errorCode = code
errorMessage = message
// Indicate to caller thread that we're done
NexusScheduler.this.callbacks = null
NexusScheduler.this.notifyAll()
}
}
this.synchronized {
this.callbacks = myCallbacks
}
driver.reviveOffers();
this.synchronized {
while (this.callbacks != null) this.wait()
}
} finally {
semaphore.release()
}
if (errorHappened)
throw new SparkException(errorMessage, errorCode)
else
return results
}
override def registered(d: SchedulerDriver, frameworkId: Int) {
println("Registered as framework ID " + frameworkId)
registeredLock.synchronized {
isRegistered = true
registeredLock.notifyAll()
}
}
override def waitForRegister() {
registeredLock.synchronized {
while (!isRegistered) registeredLock.wait()
}
}
override def resourceOffer(
d: SchedulerDriver, oid: Long, slots: SlaveOfferVector) {
synchronized {
val tasks = new TaskDescriptionVector
if (callbacks != null) {
try {
for (i <- 0 until slots.size.toInt) {
callbacks.slotOffer(slots.get(i)) match {
case Some(task) => tasks.add(task)
case None => {}
}
}
} catch {
case e: Exception => e.printStackTrace
}
}
val params = new StringMap
params.set("timeout", "1")
d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout
}
}
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
synchronized {
if (callbacks != null && status.getState == TaskState.TASK_FINISHED) {
try {
callbacks.taskFinished(status)
} catch {
case e: Exception => e.printStackTrace
}
}
}
}
override def error(d: SchedulerDriver, code: Int, message: String) {
synchronized {
if (callbacks != null) {
try {
callbacks.error(code, message)
} catch {
case e: Exception => e.printStackTrace
}
} else {
val msg = "Nexus error: %s (error code: %d)".format(message, code)
System.err.println(msg)
System.exit(1)
}
}
}
override def stop() {
if (driver != null)
driver.stop()
}
}
@@ -0,0 +1,97 @@
package spark
abstract class ParallelArray[T](sc: SparkContext) {
def filter(f: T => Boolean): ParallelArray[T] = {
val cleanF = sc.clean(f)
new FilteredParallelArray[T](sc, this, cleanF)
}
def foreach(f: T => Unit): Unit
def map[U](f: T => U): Array[U]
}
private object ParallelArray {
def slice[T](seq: Seq[T], numSlices: Int): Array[Seq[T]] = {
if (numSlices < 1)
throw new IllegalArgumentException("Positive number of slices required")
seq match {
case r: Range.Inclusive => {
val sign = if (r.step < 0) -1 else 1
slice(new Range(r.start, r.end + sign, r.step).asInstanceOf[Seq[T]],
numSlices)
}
case r: Range => {
(0 until numSlices).map(i => {
val start = ((i * r.length.toLong) / numSlices).toInt
val end = (((i+1) * r.length.toLong) / numSlices).toInt
new SerializableRange(
r.start + start * r.step, r.start + end * r.step, r.step)
}).asInstanceOf[Seq[Seq[T]]].toArray
}
case _ => {
val array = seq.toArray // To prevent O(n^2) operations for List etc
(0 until numSlices).map(i => {
val start = ((i * array.length.toLong) / numSlices).toInt
val end = (((i+1) * array.length.toLong) / numSlices).toInt
array.slice(start, end).toArray
}).toArray
}
}
}
}
private class SimpleParallelArray[T](
sc: SparkContext, data: Seq[T], numSlices: Int)
extends ParallelArray[T](sc) {
val slices = ParallelArray.slice(data, numSlices)
def foreach(f: T => Unit) {
val cleanF = sc.clean(f)
var tasks = for (i <- 0 until numSlices) yield
new ForeachRunner(i, slices(i), cleanF)
sc.runTasks[Unit](tasks.toArray)
}
def map[U](f: T => U): Array[U] = {
val cleanF = sc.clean(f)
var tasks = for (i <- 0 until numSlices) yield
new MapRunner(i, slices(i), cleanF)
return Array.concat(sc.runTasks[Array[U]](tasks.toArray): _*)
}
}
@serializable
private class ForeachRunner[T](sliceNum: Int, data: Seq[T], f: T => Unit)
extends Function0[Unit] {
def apply() = {
printf("Running slice %d of parallel foreach\n", sliceNum)
data.foreach(f)
}
}
@serializable
private class MapRunner[T, U](sliceNum: Int, data: Seq[T], f: T => U)
extends Function0[Array[U]] {
def apply(): Array[U] = {
printf("Running slice %d of parallel map\n", sliceNum)
return data.map(f).toArray
}
}
private class FilteredParallelArray[T](
sc: SparkContext, array: ParallelArray[T], predicate: T => Boolean)
extends ParallelArray[T](sc) {
val cleanPred = sc.clean(predicate)
def foreach(f: T => Unit) {
val cleanF = sc.clean(f)
array.foreach(t => if (cleanPred(t)) cleanF(t))
}
def map[U](f: T => U): Array[U] = {
val cleanF = sc.clean(f)
throw new UnsupportedOperationException(
"Map is not yet supported on FilteredParallelArray")
}
}
@@ -0,0 +1,9 @@
package spark
// Scheduler trait, implemented by both NexusScheduler and LocalScheduler.
private trait Scheduler {
def start()
def waitForRegister()
def runTasks[T](tasks: Array[Task[T]]): Array[T]
def stop()
}
@@ -0,0 +1,75 @@
// This is a copy of Scala 2.7.7's Range class, (c) 2006-2009, LAMP/EPFL.
// The only change here is to make it Serializable, because Ranges aren't.
// This won't be needed in Scala 2.8, where Scala's Range becomes Serializable.
package spark
@serializable
private class SerializableRange(val start: Int, val end: Int, val step: Int)
extends RandomAccessSeq.Projection[Int] {
if (step == 0) throw new Predef.IllegalArgumentException
/** Create a new range with the start and end values of this range and
* a new <code>step</code>.
*/
def by(step: Int): Range = new Range(start, end, step)
override def foreach(f: Int => Unit) {
if (step > 0) {
var i = this.start
val until = if (inInterval(end)) end + 1 else end
while (i < until) {
f(i)
i += step
}
} else {
var i = this.start
val until = if (inInterval(end)) end - 1 else end
while (i > until) {
f(i)
i += step
}
}
}
lazy val length: Int = {
if (start < end && this.step < 0) 0
else if (start > end && this.step > 0) 0
else {
val base = if (start < end) end - start
else start - end
assert(base >= 0)
val step = if (this.step < 0) -this.step else this.step
assert(step >= 0)
base / step + last(base, step)
}
}
protected def last(base: Int, step: Int): Int =
if (base % step != 0) 1 else 0
def apply(idx: Int): Int = {
if (idx < 0 || idx >= length) throw new Predef.IndexOutOfBoundsException
start + (step * idx)
}
/** a <code>Seq.contains</code>, not a <code>Iterator.contains</code>! */
def contains(x: Int): Boolean = {
inInterval(x) && (((x - start) % step) == 0)
}
/** Is the argument inside the interval defined by `start' and `end'?
* Returns true if `x' is inside [start, end).
*/
protected def inInterval(x: Int): Boolean =
if (step > 0)
(x >= start && x < end)
else
(x <= start && x > end)
//def inclusive = new Range.Inclusive(start,end,step)
override def toString = "SerializableRange(%d, %d, %d)".format(start, end, step)
}
@@ -0,0 +1,89 @@
package spark
import java.io._
import java.util.UUID
import scala.collection.mutable.ArrayBuffer
class SparkContext(master: String, frameworkName: String) {
Cache.initialize()
def parallelize[T](seq: Seq[T], numSlices: Int): ParallelArray[T] =
new SimpleParallelArray[T](this, seq, numSlices)
def parallelize[T](seq: Seq[T]): ParallelArray[T] = parallelize(seq, 2)
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
// TODO: Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = new Cached(value, local)
def textFile(path: String) = new HdfsTextFile(this, path)
val LOCAL_REGEX = """local\[([0-9]+)\]""".r
private var scheduler: Scheduler = master match {
case "local" => new LocalScheduler(1)
case LOCAL_REGEX(threads) => new LocalScheduler(threads.toInt)
case _ => { System.loadLibrary("nexus");
new NexusScheduler(master, frameworkName, createExecArg()) }
}
private val local = scheduler.isInstanceOf[LocalScheduler]
scheduler.start()
private def createExecArg(): Array[Byte] = {
// Our executor arg is an array containing all the spark.* system properties
val props = new ArrayBuffer[(String, String)]
val iter = System.getProperties.entrySet.iterator
while (iter.hasNext) {
val entry = iter.next
val (key, value) = (entry.getKey.toString, entry.getValue.toString)
if (key.startsWith("spark."))
props += (key, value)
}
return Utils.serialize(props.toArray)
}
def runTasks[T](tasks: Array[() => T]): Array[T] = {
runTaskObjects(tasks.map(f => new FunctionTask(f)))
}
private[spark] def runTaskObjects[T](tasks: Seq[Task[T]]): Array[T] = {
println("Running " + tasks.length + " tasks in parallel")
val start = System.nanoTime
val result = scheduler.runTasks(tasks.toArray)
println("Tasks finished in " + (System.nanoTime - start) / 1e9 + " s")
return result
}
def stop() {
scheduler.stop()
scheduler = null
}
def waitForRegister() {
scheduler.waitForRegister()
}
// Clean a closure to make it ready to serialized and send to tasks
// (removes unreferenced variables in $outer's, updates REPL variables)
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
return f
}
}
object SparkContext {
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def add(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
}
implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
def add(t1: Int, t2: Int): Int = t1 + t2
def zero(initialValue: Int) = 0
}
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
}
@@ -0,0 +1,7 @@
package spark
class SparkException(message: String) extends Exception(message) {
def this(message: String, errorCode: Int) {
this("%s (error code: %d)".format(message, errorCode))
}
}
@@ -0,0 +1,16 @@
package spark
import nexus._
@serializable
trait Task[T] {
def run: T
def prefers(slot: SlaveOffer): Boolean = true
def markStarted(slot: SlaveOffer) {}
}
@serializable
class FunctionTask[T](body: () => T)
extends Task[T] {
def run: T = body()
}
@@ -0,0 +1,9 @@
package spark
import scala.collection.mutable.Map
// Task result. Also contains updates to accumulator variables.
// TODO: Use of distributed cache to return result is a hack to get around
// what seems to be a bug with messages over 60KB in libprocess; fix it
@serializable
private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any])
@@ -0,0 +1,28 @@
package spark
import java.io._
private object Utils {
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream
val oos = new ObjectOutputStream(bos)
oos.writeObject(o)
oos.close
return bos.toByteArray
}
def deserialize[T](bytes: Array[Byte]): T = {
val bis = new ByteArrayInputStream(bytes)
val ois = new ObjectInputStream(bis)
return ois.readObject.asInstanceOf[T]
}
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
val bis = new ByteArrayInputStream(bytes)
val ois = new ObjectInputStream(bis) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
return ois.readObject.asInstanceOf[T]
}
}
@@ -0,0 +1,86 @@
package spark.repl
import java.io.{ByteArrayOutputStream, InputStream}
import java.net.{URI, URL, URLClassLoader}
import java.util.concurrent.{Executors, ExecutorService}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.objectweb.asm._
import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._
// A ClassLoader that reads classes from a Hadoop FileSystem URL, used to load
// classes defined by the interpreter when the REPL is in use
class ExecutorClassLoader(classDir: String, parent: ClassLoader)
extends ClassLoader(parent) {
val fileSystem = FileSystem.get(new URI(classDir), new Configuration())
val directory = new URI(classDir).getPath
override def findClass(name: String): Class[_] = {
try {
//println("repl.ExecutorClassLoader resolving " + name)
val path = new Path(directory, name.replace('.', '/') + ".class")
val bytes = readAndTransformClass(name, fileSystem.open(path))
return defineClass(name, bytes, 0, bytes.length)
} catch {
case e: Exception => throw new ClassNotFoundException(name, e)
}
}
def readAndTransformClass(name: String, in: InputStream): Array[Byte] = {
if (name.startsWith("line") && name.endsWith("$iw$")) {
// Class seems to be an interpreter "wrapper" object storing a val or var.
// Replace its constructor with a dummy one that does not run the
// initialization code placed there by the REPL. The val or var will
// be initialized later through reflection when it is used in a task.
val cr = new ClassReader(in)
val cw = new ClassWriter(
ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS)
val cleaner = new ConstructorCleaner(name, cw)
cr.accept(cleaner, 0)
return cw.toByteArray
} else {
// Pass the class through unmodified
val bos = new ByteArrayOutputStream
val bytes = new Array[Byte](4096)
var done = false
while (!done) {
val num = in.read(bytes)
if (num >= 0)
bos.write(bytes, 0, num)
else
done = true
}
return bos.toByteArray
}
}
}
class ConstructorCleaner(className: String, cv: ClassVisitor)
extends ClassAdapter(cv) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
val mv = cv.visitMethod(access, name, desc, sig, exceptions)
if (name == "<init>" && (access & ACC_STATIC) == 0) {
// This is the constructor, time to clean it; just output some new
// instructions to mv that create the object and set the static MODULE$
// field in the class to point to it, but do nothing otherwise.
//println("Cleaning constructor of " + className)
mv.visitCode()
mv.visitVarInsn(ALOAD, 0) // load this
mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V")
mv.visitVarInsn(ALOAD, 0) // load this
//val classType = className.replace('.', '/')
//mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";")
mv.visitInsn(RETURN)
mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed
mv.visitEnd()
return null
} else {
return mv
}
}
}
@@ -0,0 +1,16 @@
package spark.repl
import scala.collection.mutable.Set
object Main {
private var _interp: SparkInterpreterLoop = null
def interp = _interp
private[repl] def interp_=(i: SparkInterpreterLoop) { _interp = i }
def main(args: Array[String]) {
_interp = new SparkInterpreterLoop
_interp.main(args)
}
}

Large diffs are not rendered by default.

Oops, something went wrong.
@@ -0,0 +1,366 @@
/* NSC -- new Scala compiler
* Copyright 2005-2009 LAMP/EPFL
* @author Alexander Spoon
*/
// $Id: InterpreterLoop.scala 16881 2009-01-09 16:28:11Z cunei $
package spark.repl
import scala.tools.nsc
import scala.tools.nsc._
import java.io.{BufferedReader, File, FileReader, PrintWriter}
import java.io.IOException
import java.lang.{ClassLoader, System}
import scala.tools.nsc.{InterpreterResults => IR}
import scala.tools.nsc.interpreter._
import spark.SparkContext
/** The
* <a href="http://scala-lang.org/" target="_top">Scala</a>
* interactive shell. It provides a read-eval-print loop around
* the Interpreter class.
* After instantiation, clients should call the <code>main()</code> method.
*
* <p>If no in0 is specified, then input will come from the console, and
* the class will attempt to provide input editing feature such as
* input history.
*
* @author Moez A. Abdel-Gawad
* @author Lex Spoon
* @version 1.2
*/
class SparkInterpreterLoop(in0: Option[BufferedReader], val out: PrintWriter,
master: Option[String])
{
def this(in0: BufferedReader, out: PrintWriter, master: String) =
this(Some(in0), out, Some(master))
def this(in0: BufferedReader, out: PrintWriter) =
this(Some(in0), out, None)
def this() = this(None, new PrintWriter(Console.out), None)
/** The input stream from which interpreter commands come */
var in: InteractiveReader = _ //set by main()
/** The context class loader at the time this object was created */
protected val originalClassLoader =
Thread.currentThread.getContextClassLoader
var settings: Settings = _ // set by main()
var interpreter: SparkInterpreter = null // set by createInterpreter()
def isettings = interpreter.isettings
/** A reverse list of commands to replay if the user
* requests a :replay */
var replayCommandsRev: List[String] = Nil
/** A list of commands to replay if the user requests a :replay */
def replayCommands = replayCommandsRev.reverse
/** Record a command for replay should the user requset a :replay */
def addReplay(cmd: String) =
replayCommandsRev = cmd :: replayCommandsRev
/** Close the interpreter, if there is one, and set
* interpreter to <code>null</code>. */
def closeInterpreter() {
if (interpreter ne null) {
interpreter.close
interpreter = null
Thread.currentThread.setContextClassLoader(originalClassLoader)
}
}
/** Create a new interpreter. Close the old one, if there
* is one. */
def createInterpreter() {
//closeInterpreter()
interpreter = new SparkInterpreter(settings, out) {
override protected def parentClassLoader =
classOf[SparkInterpreterLoop].getClassLoader
}
interpreter.setContextClassLoader()
}
/** Bind the settings so that evaluated code can modiy them */
def bindSettings() {
interpreter.beQuietDuring {
interpreter.compileString(InterpreterSettings.sourceCodeForClass)
interpreter.bind(
"settings",
"scala.tools.nsc.InterpreterSettings",
isettings)
}
}
/** print a friendly help message */
def printHelp {
//printWelcome
out.println("This is Scala " + Properties.versionString + " (" +
System.getProperty("java.vm.name") + ", Java " + System.getProperty("java.version") + ")." )
out.println("Type in expressions to have them evaluated.")
out.println("Type :load followed by a filename to load a Scala file.")
//out.println("Type :replay to reset execution and replay all previous commands.")
out.println("Type :quit to exit the interpreter.")
}
/** Print a welcome message */
def printWelcome() {
out.println("""Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/___/ .__/\_,_/_/ /_/\_\ version 0.0
/_/
""")
out.println("Using Scala " + Properties.versionString + " (" +
System.getProperty("java.vm.name") + ", Java " +
System.getProperty("java.version") + ")." )
out.flush()
}
def createSparkContext(): SparkContext = {
val master = this.master match {
case Some(m) => m
case None => {
val prop = System.getenv("MASTER")
if (prop != null) prop else "local"
}
}
new SparkContext(master, "Spark REPL")
}
/** Prompt to print when awaiting input */
val prompt = Properties.shellPromptString
/** The main read-eval-print loop for the interpreter. It calls
* <code>command()</code> for each line of input, and stops when
* <code>command()</code> returns <code>false</code>.
*/
def repl() {
out.println("Intializing...")
out.flush()
interpreter.beQuietDuring {
command("""
spark.repl.Main.interp.out.println("Registering with Nexus...");
@transient val sc = spark.repl.Main.interp.createSparkContext();
sc.waitForRegister();
spark.repl.Main.interp.out.println("Spark context available as sc.")
""")
command("import spark.SparkContext._");
}
out.println("Type in expressions to have them evaluated.")
out.println("Type :help for more information.")
out.flush()
var first = true
while (true) {
out.flush()
val line =
if (first) {
/* For some reason, the first interpreted command always takes
* a second or two. So, wait until the welcome message
* has been printed before calling bindSettings. That way,
* the user can read the welcome message while this
* command executes.
*/
val futLine = scala.concurrent.ops.future(in.readLine(prompt))
bindSettings()
first = false
futLine()
} else {
in.readLine(prompt)
}
if (line eq null)
return () // assumes null means EOF
val (keepGoing, finalLineMaybe) = command(line)
if (!keepGoing)
return
finalLineMaybe match {
case Some(finalLine) => addReplay(finalLine)
case None => ()
}
}
}
/** interpret all lines from a specified file */
def interpretAllFrom(filename: String) {
val fileIn = try {
new FileReader(filename)
} catch {
case _:IOException =>
out.println("Error opening file: " + filename)
return
}
val oldIn = in
val oldReplay = replayCommandsRev
try {
val inFile = new BufferedReader(fileIn)
in = new SimpleReader(inFile, out, false)
out.println("Loading " + filename + "...")
out.flush
repl
} finally {
in = oldIn
replayCommandsRev = oldReplay
fileIn.close
}
}
/** create a new interpreter and replay all commands so far */
def replay() {
closeInterpreter()
createInterpreter()
for (cmd <- replayCommands) {
out.println("Replaying: " + cmd)
out.flush() // because maybe cmd will have its own output
command(cmd)
out.println
}
}
/** Run one command submitted by the user. Three values are returned:
* (1) whether to keep running, (2) the line to record for replay,
* if any. */
def command(line: String): (Boolean, Option[String]) = {
def withFile(command: String)(action: String => Unit) {
val spaceIdx = command.indexOf(' ')
if (spaceIdx <= 0) {
out.println("That command requires a filename to be specified.")
return ()
}
val filename = command.substring(spaceIdx).trim
if (! new File(filename).exists) {
out.println("That file does not exist")
return ()
}
action(filename)
}
val helpRegexp = ":h(e(l(p)?)?)?"
val quitRegexp = ":q(u(i(t)?)?)?"
val loadRegexp = ":l(o(a(d)?)?)?.*"
//val replayRegexp = ":r(e(p(l(a(y)?)?)?)?)?.*"
var shouldReplay: Option[String] = None
if (line.matches(helpRegexp))
printHelp
else if (line.matches(quitRegexp))
return (false, None)
else if (line.matches(loadRegexp)) {
withFile(line)(f => {
interpretAllFrom(f)
shouldReplay = Some(line)
})
}
//else if (line matches replayRegexp)
// replay
else if (line startsWith ":")
out.println("Unknown command. Type :help for help.")
else
shouldReplay = interpretStartingWith(line)
(true, shouldReplay)
}
/** Interpret expressions starting with the first line.
* Read lines until a complete compilation unit is available
* or until a syntax error has been seen. If a full unit is
* read, go ahead and interpret it. Return the full string
* to be recorded for replay, if any.
*/
def interpretStartingWith(code: String): Option[String] = {
interpreter.interpret(code) match {
case IR.Success => Some(code)
case IR.Error => None
case IR.Incomplete =>
if (in.interactive && code.endsWith("\n\n")) {
out.println("You typed two blank lines. Starting a new command.")
None
} else {
val nextLine = in.readLine(" | ")
if (nextLine == null)
None // end of file
else
interpretStartingWith(code + "\n" + nextLine)
}
}
}
def loadFiles(settings: Settings) {
settings match {
case settings: GenericRunnerSettings =>
for (filename <- settings.loadfiles.value) {
val cmd = ":load " + filename
command(cmd)
replayCommandsRev = cmd :: replayCommandsRev
out.println()
}
case _ =>
}
}
def main(settings: Settings) {
this.settings = settings
in =
in0 match {
case Some(in0) =>
new SimpleReader(in0, out, true)
case None =>
val emacsShell = System.getProperty("env.emacs", "") != ""
//println("emacsShell="+emacsShell) //debug
if (settings.Xnojline.value || emacsShell)
new SimpleReader()
else
InteractiveReader.createDefault()
}
createInterpreter()
loadFiles(settings)
try {
if (interpreter.reporter.hasErrors) {
return // it is broken on startup; go ahead and exit
}
printWelcome()
repl()
} finally {
closeInterpreter()
}
}
/** process command-line arguments and do as they request */
def main(args: Array[String]) {
def error1(msg: String) { out.println("scala: " + msg) }
val command = new InterpreterCommand(List.fromArray(args), error1)
if (!command.ok || command.settings.help.value || command.settings.Xhelp.value) {
// either the command line is wrong, or the user
// explicitly requested a help listing
if (command.settings.help.value) out.println(command.usageMsg)
if (command.settings.Xhelp.value) out.println(command.xusageMsg)
out.flush
}
else
main(command.settings)
}
}
@@ -0,0 +1,21 @@
package ubiquifs
import java.io.{DataInputStream, DataOutputStream}
object RequestType {
val READ = 0
val WRITE = 1
}
class Header(val requestType: Int, val path: String) {
def write(out: DataOutputStream) {
out.write(requestType)
out.writeUTF(path)
}
}
object Header {
def read(in: DataInputStream): Header = {
new Header(in.read(), in.readUTF())
}
}
@@ -0,0 +1,49 @@
package ubiquifs
import scala.actors.Actor
import scala.actors.Actor._
import scala.actors.remote.RemoteActor
import scala.actors.remote.RemoteActor._
import scala.actors.remote.Node
import scala.collection.mutable.{ArrayBuffer, Map, Set}
class Master(port: Int) extends Actor {
case class SlaveInfo(host: String, port: Int)
val files = Set[String]()
val slaves = new ArrayBuffer[SlaveInfo]()
def act() {
alive(port)
register('UbiquiFS, self)
println("Created UbiquiFS Master on port " + port)
loop {
react {
case RegisterSlave(host, port) =>
slaves += SlaveInfo(host, port)
sender ! RegisterSucceeded()
case Create(path) =>
if (files.contains(path)) {
sender ! CreateFailed("File already exists")
} else if (slaves.isEmpty) {
sender ! CreateFailed("No slaves registered")
} else {
files += path
sender ! CreateSucceeded(slaves(0).host, slaves(0).port)
}
case m: Any =>
println("Unknown message: " + m)
}
}
}
}
object MasterMain {
def main(args: Array[String]) {
val port = args(0).toInt
new Master(port).start()
}
}
@@ -0,0 +1,14 @@
package ubiquifs
sealed case class Message()
case class RegisterSlave(host: String, port: Int) extends Message
case class RegisterSucceeded() extends Message
case class Create(path: String) extends Message
case class CreateSucceeded(host: String, port: Int) extends Message
case class CreateFailed(message: String) extends Message
case class Read(path: String) extends Message
case class ReadSucceeded(host: String, port: Int) extends Message
case class ReadFailed(message: String) extends Message
@@ -0,0 +1,141 @@
package ubiquifs
import java.io.{DataInputStream, DataOutputStream, IOException}
import java.net.{InetAddress, Socket, ServerSocket}
import java.util.concurrent.locks.ReentrantLock
import scala.actors.Actor
import scala.actors.Actor._
import scala.actors.remote.RemoteActor
import scala.actors.remote.RemoteActor._
import scala.actors.remote.Node
import scala.collection.mutable.{ArrayBuffer, Map, Set}
class Slave(myPort: Int, master: String) extends Thread("UbiquiFS slave") {
val CHUNK_SIZE = 1024 * 1024
val buffers = Map[String, Buffer]()
override def run() {
// Create server socket
val socket = new ServerSocket(myPort)
// Register with master
val (masterHost, masterPort) = Utils.parseHostPort(master)
val masterActor = select(Node(masterHost, masterPort), 'UbiquiFS)
val myHost = InetAddress.getLocalHost.getHostName
val reply = masterActor !? RegisterSlave(myHost, myPort)
println("Registered with master, reply = " + reply)
while (true) {
val conn = socket.accept()
new ConnectionHandler(conn).start()
}
}
class ConnectionHandler(conn: Socket) extends Thread("ConnectionHandler") {
try {
val in = new DataInputStream(conn.getInputStream)
val out = new DataOutputStream(conn.getOutputStream)
val header = Header.read(in)
header.requestType match {
case RequestType.READ =>
performRead(header.path, out)
case RequestType.WRITE =>
performWrite(header.path, in)
case other =>
throw new IOException("Invalid header type " + other)
}
println("hi")
} catch {
case e: Exception => e.printStackTrace()
} finally {
conn.close()
}
}
def performWrite(path: String, in: DataInputStream) {
var buffer = new Buffer()
synchronized {
if (buffers.contains(path))
throw new IllegalArgumentException("Path " + path + " already exists")
buffers(path) = buffer
}
var chunk = new Array[Byte](CHUNK_SIZE)
var pos = 0
while (true) {
var numRead = in.read(chunk, pos, chunk.size - pos)
if (numRead == -1) {
buffer.addChunk(chunk.subArray(0, pos), true)
return
} else {
pos += numRead
if (pos == chunk.size) {
buffer.addChunk(chunk, false)
chunk = new Array[Byte](CHUNK_SIZE)
pos = 0
}
}
}
// TODO: launch a thread to write the data to disk, and when this finishes,
// remove the hard reference to buffer
}
def performRead(path: String, out: DataOutputStream) {
var buffer: Buffer = null
synchronized {
if (!buffers.contains(path))
throw new IllegalArgumentException("Path " + path + " doesn't exist")
buffer = buffers(path)
}
for (chunk <- buffer.iterator) {
out.write(chunk, 0, chunk.size)
}
}
class Buffer {
val chunks = new ArrayBuffer[Array[Byte]]
var finished = false
val mutex = new ReentrantLock
val chunksAvailable = mutex.newCondition()
def addChunk(chunk: Array[Byte], finish: Boolean) {
mutex.lock()
chunks += chunk
finished = finish
chunksAvailable.signalAll()
mutex.unlock()
}
def iterator = new Iterator[Array[Byte]] {
var index = 0
def hasNext: Boolean = {
mutex.lock()
while (index >= chunks.size && !finished)
chunksAvailable.await()
val ret = (index < chunks.size)
mutex.unlock()
return ret
}
def next: Array[Byte] = {
mutex.lock()
if (!hasNext)
throw new NoSuchElementException("End of file")
val ret = chunks(index) // hasNext ensures we advance past index
index += 1
mutex.unlock()
return ret
}
}
}
}
object SlaveMain {
def main(args: Array[String]) {
val port = args(0).toInt
val master = args(1)
new Slave(port, master).start()
}
}
@@ -0,0 +1,11 @@
package ubiquifs
import java.io.{InputStream, OutputStream}
class UbiquiFS(master: String) {
private val (masterHost, masterPort) = Utils.parseHostPort(master)
def create(path: String): OutputStream = null
def open(path: String): InputStream = null
}
@@ -0,0 +1,12 @@
package ubiquifs
private[ubiquifs] object Utils {
private val HOST_PORT_RE = "([a-zA-Z0-9.-]+):([0-9]+)".r
def parseHostPort(string: String): (String, Int) = {
string match {
case HOST_PORT_RE(host, port) => (host, port.toInt)
case _ => throw new IllegalArgumentException(string)
}
}
}
Oops, something went wrong.