| @@ -0,0 +1,19 @@ | ||
| import spark._ | ||
| object SleepJob { | ||
| def main(args: Array[String]) { | ||
| if (args.length != 3) { | ||
| System.err.println("Usage: SleepJob <master> <tasks> <task_duration>"); | ||
| System.exit(1) | ||
| } | ||
| val sc = new SparkContext(args(0), "Sleep job") | ||
| val tasks = args(1).toInt | ||
| val duration = args(2).toInt | ||
| def task { | ||
| val start = System.currentTimeMillis | ||
| while (System.currentTimeMillis - start < duration * 1000L) | ||
| Thread.sleep(200) | ||
| } | ||
| sc.runTasks(Array.make(tasks, () => task)) | ||
| } | ||
| } |
| @@ -0,0 +1,138 @@ | ||
| import java.io.Serializable | ||
| import java.util.Random | ||
| import cern.jet.math._ | ||
| import cern.colt.matrix._ | ||
| import cern.colt.matrix.linalg._ | ||
| import spark._ | ||
| object SparkALS { | ||
| // Parameters set through command line arguments | ||
| var M = 0 // Number of movies | ||
| var U = 0 // Number of users | ||
| var F = 0 // Number of features | ||
| var ITERATIONS = 0 | ||
| val LAMBDA = 0.01 // Regularization coefficient | ||
| // Some COLT objects | ||
| val factory2D = DoubleFactory2D.dense | ||
| val factory1D = DoubleFactory1D.dense | ||
| val algebra = Algebra.DEFAULT | ||
| val blas = SeqBlas.seqBlas | ||
| def generateR(): DoubleMatrix2D = { | ||
| val mh = factory2D.random(M, F) | ||
| val uh = factory2D.random(U, F) | ||
| return algebra.mult(mh, algebra.transpose(uh)) | ||
| } | ||
| def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], | ||
| us: Array[DoubleMatrix1D]): Double = | ||
| { | ||
| val r = factory2D.make(M, U) | ||
| for (i <- 0 until M; j <- 0 until U) { | ||
| r.set(i, j, blas.ddot(ms(i), us(j))) | ||
| } | ||
| //println("R: " + r) | ||
| blas.daxpy(-1, targetR, r) | ||
| val sumSqs = r.aggregate(Functions.plus, Functions.square) | ||
| return Math.sqrt(sumSqs / (M * U)) | ||
| } | ||
| def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], | ||
| R: DoubleMatrix2D) : DoubleMatrix1D = | ||
| { | ||
| val U = us.size | ||
| val F = us(0).size | ||
| val XtX = factory2D.make(F, F) | ||
| val Xty = factory1D.make(F) | ||
| // For each user that rated the movie | ||
| for (j <- 0 until U) { | ||
| val u = us(j) | ||
| // Add u * u^t to XtX | ||
| blas.dger(1, u, u, XtX) | ||
| // Add u * rating to Xty | ||
| blas.daxpy(R.get(i, j), u, Xty) | ||
| } | ||
| // Add regularization coefs to diagonal terms | ||
| for (d <- 0 until F) { | ||
| XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) | ||
| } | ||
| // Solve it with Cholesky | ||
| val ch = new CholeskyDecomposition(XtX) | ||
| val Xty2D = factory2D.make(Xty.toArray, F) | ||
| val solved2D = ch.solve(Xty2D) | ||
| return solved2D.viewColumn(0) | ||
| } | ||
| def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], | ||
| R: DoubleMatrix2D) : DoubleMatrix1D = | ||
| { | ||
| val M = ms.size | ||
| val F = ms(0).size | ||
| val XtX = factory2D.make(F, F) | ||
| val Xty = factory1D.make(F) | ||
| // For each movie that the user rated | ||
| for (i <- 0 until M) { | ||
| val m = ms(i) | ||
| // Add m * m^t to XtX | ||
| blas.dger(1, m, m, XtX) | ||
| // Add m * rating to Xty | ||
| blas.daxpy(R.get(i, j), m, Xty) | ||
| } | ||
| // Add regularization coefs to diagonal terms | ||
| for (d <- 0 until F) { | ||
| XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) | ||
| } | ||
| // Solve it with Cholesky | ||
| val ch = new CholeskyDecomposition(XtX) | ||
| val Xty2D = factory2D.make(Xty.toArray, F) | ||
| val solved2D = ch.solve(Xty2D) | ||
| return solved2D.viewColumn(0) | ||
| } | ||
| def main(args: Array[String]) { | ||
| var host = "" | ||
| var slices = 0 | ||
| args match { | ||
| case Array(m, u, f, iters, slices_, host_) => { | ||
| M = m.toInt | ||
| U = u.toInt | ||
| F = f.toInt | ||
| ITERATIONS = iters.toInt | ||
| slices = slices_.toInt | ||
| host = host_ | ||
| } | ||
| case _ => { | ||
| System.err.println("Usage: SparkALS <M> <U> <F> <iters> <slices> <host>") | ||
| System.exit(1) | ||
| } | ||
| } | ||
| printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); | ||
| val spark = new SparkContext(host, "SparkALS") | ||
| val R = generateR() | ||
| // Initialize m and u randomly | ||
| var ms = Array.fromFunction(_ => factory1D.random(F))(M) | ||
| var us = Array.fromFunction(_ => factory1D.random(F))(U) | ||
| // Iteratively update movies then users | ||
| val Rc = spark.broadcast(R) | ||
| var msb = spark.broadcast(ms) | ||
| var usb = spark.broadcast(us) | ||
| for (iter <- 1 to ITERATIONS) { | ||
| println("Iteration " + iter + ":") | ||
| ms = spark.parallelize(0 until M, slices) | ||
| .map(i => updateMovie(i, msb.value(i), usb.value, Rc.value)) | ||
| .toArray | ||
| msb = spark.broadcast(ms) // Re-broadcast ms because it was updated | ||
| us = spark.parallelize(0 until U, slices) | ||
| .map(i => updateUser(i, usb.value(i), msb.value, Rc.value)) | ||
| .toArray | ||
| usb = spark.broadcast(us) // Re-broadcast us because it was updated | ||
| println("RMSE = " + rmse(R, ms, us)) | ||
| println() | ||
| } | ||
| } | ||
| } |
| @@ -0,0 +1,50 @@ | ||
| import java.util.Random | ||
| import Vector._ | ||
| import spark._ | ||
| object SparkHdfsLR { | ||
| val D = 10 // Numer of dimensions | ||
| val rand = new Random(42) | ||
| case class DataPoint(x: Vector, y: Double) | ||
| def parsePoint(line: String): DataPoint = { | ||
| //val nums = line.split(' ').map(_.toDouble) | ||
| //return DataPoint(new Vector(nums.subArray(1, D+1)), nums(0)) | ||
| val tok = new java.util.StringTokenizer(line, " ") | ||
| var y = tok.nextToken.toDouble | ||
| var x = new Array[Double](D) | ||
| var i = 0 | ||
| while (i < D) { | ||
| x(i) = tok.nextToken.toDouble; i += 1 | ||
| } | ||
| return DataPoint(new Vector(x), y) | ||
| } | ||
| def main(args: Array[String]) { | ||
| if (args.length < 3) { | ||
| System.err.println("Usage: SparkHdfsLR <master> <file> <iters>") | ||
| System.exit(1) | ||
| } | ||
| val sc = new SparkContext(args(0), "SparkHdfsLR") | ||
| val lines = sc.textFile(args(1)) | ||
| val points = lines.map(parsePoint _).cache() | ||
| val ITERATIONS = args(2).toInt | ||
| // Initialize w to a random value | ||
| var w = Vector(D, _ => 2 * rand.nextDouble - 1) | ||
| println("Initial w: " + w) | ||
| for (i <- 1 to ITERATIONS) { | ||
| println("On iteration " + i) | ||
| val gradient = sc.accumulator(Vector.zeros(D)) | ||
| for (p <- points) { | ||
| val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y | ||
| gradient += scale * p.x | ||
| } | ||
| w -= gradient.value | ||
| } | ||
| println("Final w: " + w) | ||
| } | ||
| } |
| @@ -0,0 +1,48 @@ | ||
| import java.util.Random | ||
| import Vector._ | ||
| import spark._ | ||
| object SparkLR { | ||
| val N = 10000 // Number of data points | ||
| val D = 10 // Numer of dimensions | ||
| val R = 0.7 // Scaling factor | ||
| val ITERATIONS = 5 | ||
| val rand = new Random(42) | ||
| case class DataPoint(x: Vector, y: Double) | ||
| def generateData = { | ||
| def generatePoint(i: Int) = { | ||
| val y = if(i % 2 == 0) -1 else 1 | ||
| val x = Vector(D, _ => rand.nextGaussian + y * R) | ||
| DataPoint(x, y) | ||
| } | ||
| Array.fromFunction(generatePoint _)(N) | ||
| } | ||
| def main(args: Array[String]) { | ||
| if (args.length == 0) { | ||
| System.err.println("Usage: SparkLR <host> [<slices>]") | ||
| System.exit(1) | ||
| } | ||
| val sc = new SparkContext(args(0), "SparkLR") | ||
| val numSlices = if (args.length > 1) args(1).toInt else 2 | ||
| val data = generateData | ||
| // Initialize w to a random value | ||
| var w = Vector(D, _ => 2 * rand.nextDouble - 1) | ||
| println("Initial w: " + w) | ||
| for (i <- 1 to ITERATIONS) { | ||
| println("On iteration " + i) | ||
| val gradient = sc.accumulator(Vector.zeros(D)) | ||
| for (p <- sc.parallelize(data, numSlices)) { | ||
| val scale = (1 / (1 + Math.exp(-p.y * (w dot p.x))) - 1) * p.y | ||
| gradient += scale * p.x | ||
| } | ||
| w -= gradient.value | ||
| } | ||
| println("Final w: " + w) | ||
| } | ||
| } |
| @@ -0,0 +1,20 @@ | ||
| import spark._ | ||
| import SparkContext._ | ||
| object SparkPi { | ||
| def main(args: Array[String]) { | ||
| if (args.length == 0) { | ||
| System.err.println("Usage: SparkLR <host> [<slices>]") | ||
| System.exit(1) | ||
| } | ||
| val spark = new SparkContext(args(0), "SparkPi") | ||
| val slices = if (args.length > 1) args(1).toInt else 2 | ||
| var count = spark.accumulator(0) | ||
| for (i <- spark.parallelize(1 to 100000, slices)) { | ||
| val x = Math.random * 2 - 1 | ||
| val y = Math.random * 2 - 1 | ||
| if (x*x + y*y < 1) count += 1 | ||
| } | ||
| println("Pi is roughly " + 4 * count.value / 100000.0) | ||
| } | ||
| } |
| @@ -0,0 +1,63 @@ | ||
| @serializable class Vector(val elements: Array[Double]) { | ||
| def length = elements.length | ||
| def apply(index: Int) = elements(index) | ||
| def + (other: Vector): Vector = { | ||
| if (length != other.length) | ||
| throw new IllegalArgumentException("Vectors of different length") | ||
| return Vector(length, i => this(i) + other(i)) | ||
| } | ||
| def - (other: Vector): Vector = { | ||
| if (length != other.length) | ||
| throw new IllegalArgumentException("Vectors of different length") | ||
| return Vector(length, i => this(i) - other(i)) | ||
| } | ||
| def dot(other: Vector): Double = { | ||
| if (length != other.length) | ||
| throw new IllegalArgumentException("Vectors of different length") | ||
| var ans = 0.0 | ||
| for (i <- 0 until length) | ||
| ans += this(i) * other(i) | ||
| return ans | ||
| } | ||
| def * ( scale: Double): Vector = Vector(length, i => this(i) * scale) | ||
| def unary_- = this * -1 | ||
| def sum = elements.reduceLeft(_ + _) | ||
| override def toString = elements.mkString("(", ", ", ")") | ||
| } | ||
| object Vector { | ||
| def apply(elements: Array[Double]) = new Vector(elements) | ||
| def apply(elements: Double*) = new Vector(elements.toArray) | ||
| def apply(length: Int, initializer: Int => Double): Vector = { | ||
| val elements = new Array[Double](length) | ||
| for (i <- 0 until length) | ||
| elements(i) = initializer(i) | ||
| return new Vector(elements) | ||
| } | ||
| def zeros(length: Int) = new Vector(new Array[Double](length)) | ||
| def ones(length: Int) = Vector(length, _ => 1) | ||
| class Multiplier(num: Double) { | ||
| def * (vec: Vector) = vec * num | ||
| } | ||
| implicit def doubleToMultiplier(num: Double) = new Multiplier(num) | ||
| implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] { | ||
| def add(t1: Vector, t2: Vector) = t1 + t2 | ||
| def zero(initialValue: Vector) = Vector.zeros(initialValue.length) | ||
| } | ||
| } |
| @@ -0,0 +1,27 @@ | ||
| package spark.compress.lzf; | ||
| public class LZF { | ||
| private static boolean loaded; | ||
| static { | ||
| try { | ||
| System.loadLibrary("spark_native"); | ||
| loaded = true; | ||
| } catch(Throwable t) { | ||
| System.out.println("Failed to load native LZF library: " + t.toString()); | ||
| loaded = false; | ||
| } | ||
| } | ||
| public static boolean isLoaded() { | ||
| return loaded; | ||
| } | ||
| public static native int compress( | ||
| byte[] in, int inOff, int inLen, | ||
| byte[] out, int outOff, int outLen); | ||
| public static native int decompress( | ||
| byte[] in, int inOff, int inLen, | ||
| byte[] out, int outOff, int outLen); | ||
| } |
| @@ -0,0 +1,180 @@ | ||
| package spark.compress.lzf; | ||
| import java.io.EOFException; | ||
| import java.io.FilterInputStream; | ||
| import java.io.IOException; | ||
| import java.io.InputStream; | ||
| public class LZFInputStream extends FilterInputStream { | ||
| private static final int MAX_BLOCKSIZE = 1024 * 64 - 1; | ||
| private static final int MAX_HDR_SIZE = 7; | ||
| private byte[] inBuf; // Holds data to decompress (including header) | ||
| private byte[] outBuf; // Holds decompressed data to output | ||
| private int outPos; // Current position in outBuf | ||
| private int outSize; // Total amount of data in outBuf | ||
| private boolean closed; | ||
| private boolean reachedEof; | ||
| private byte[] singleByte = new byte[1]; | ||
| public LZFInputStream(InputStream in) { | ||
| super(in); | ||
| if (in == null) | ||
| throw new NullPointerException(); | ||
| inBuf = new byte[MAX_BLOCKSIZE + MAX_HDR_SIZE]; | ||
| outBuf = new byte[MAX_BLOCKSIZE + MAX_HDR_SIZE]; | ||
| outPos = 0; | ||
| outSize = 0; | ||
| } | ||
| private void ensureOpen() throws IOException { | ||
| if (closed) throw new IOException("Stream closed"); | ||
| } | ||
| @Override | ||
| public int read() throws IOException { | ||
| ensureOpen(); | ||
| int count = read(singleByte, 0, 1); | ||
| return (count == -1 ? -1 : singleByte[0] & 0xFF); | ||
| } | ||
| @Override | ||
| public int read(byte[] b, int off, int len) throws IOException { | ||
| ensureOpen(); | ||
| if ((off | len | (off + len) | (b.length - (off + len))) < 0) | ||
| throw new IndexOutOfBoundsException(); | ||
| int totalRead = 0; | ||
| // Start with the current block in outBuf, and read and decompress any | ||
| // further blocks necessary. Instead of trying to decompress directly to b | ||
| // when b is large, we always use outBuf as an intermediate holding space | ||
| // in case GetPrimitiveArrayCritical decides to copy arrays instead of | ||
| // pinning them, which would cause b to be copied repeatedly into C-land. | ||
| while (len > 0) { | ||
| if (outPos == outSize) { | ||
| readNextBlock(); | ||
| if (reachedEof) | ||
| return totalRead == 0 ? -1 : totalRead; | ||
| } | ||
| int amtToCopy = Math.min(outSize - outPos, len); | ||
| System.arraycopy(outBuf, outPos, b, off, amtToCopy); | ||
| off += amtToCopy; | ||
| len -= amtToCopy; | ||
| outPos += amtToCopy; | ||
| totalRead += amtToCopy; | ||
| } | ||
| return totalRead; | ||
| } | ||
| // Read len bytes from this.in to a buffer, stopping only if EOF is reached | ||
| private int readFully(byte[] b, int off, int len) throws IOException { | ||
| int totalRead = 0; | ||
| while (len > 0) { | ||
| int amt = in.read(b, off, len); | ||
| if (amt == -1) | ||
| break; | ||
| off += amt; | ||
| len -= amt; | ||
| totalRead += amt; | ||
| } | ||
| return totalRead; | ||
| } | ||
| // Read the next block from the underlying InputStream into outBuf, | ||
| // setting outPos and outSize, or set reachedEof if the stream ends. | ||
| private void readNextBlock() throws IOException { | ||
| // Read first 5 bytes of header | ||
| int count = readFully(inBuf, 0, 5); | ||
| if (count == 0) { | ||
| reachedEof = true; | ||
| return; | ||
| } else if (count < 5) { | ||
| throw new EOFException("Truncated LZF block header"); | ||
| } | ||
| // Check magic bytes | ||
| if (inBuf[0] != 'Z' || inBuf[1] != 'V') | ||
| throw new IOException("Wrong magic bytes in LZF block header"); | ||
| // Read the block | ||
| if (inBuf[2] == 0) { | ||
| // Uncompressed block - read directly to outBuf | ||
| int size = ((inBuf[3] & 0xFF) << 8) | (inBuf[4] & 0xFF); | ||
| if (readFully(outBuf, 0, size) != size) | ||
| throw new EOFException("EOF inside LZF block"); | ||
| outPos = 0; | ||
| outSize = size; | ||
| } else if (inBuf[2] == 1) { | ||
| // Compressed block - read to inBuf and decompress | ||
| if (readFully(inBuf, 5, 2) != 2) | ||
| throw new EOFException("Truncated LZF block header"); | ||
| int csize = ((inBuf[3] & 0xFF) << 8) | (inBuf[4] & 0xFF); | ||
| int usize = ((inBuf[5] & 0xFF) << 8) | (inBuf[6] & 0xFF); | ||
| if (readFully(inBuf, 7, csize) != csize) | ||
| throw new EOFException("Truncated LZF block"); | ||
| if (LZF.decompress(inBuf, 7, csize, outBuf, 0, usize) != usize) | ||
| throw new IOException("Corrupt LZF data stream"); | ||
| outPos = 0; | ||
| outSize = usize; | ||
| } else { | ||
| throw new IOException("Unknown block type in LZF block header"); | ||
| } | ||
| } | ||
| /** | ||
| * Returns 0 after EOF has been reached, otherwise always return 1. | ||
| * | ||
| * Programs should not count on this method to return the actual number | ||
| * of bytes that could be read without blocking. | ||
| */ | ||
| @Override | ||
| public int available() throws IOException { | ||
| ensureOpen(); | ||
| return reachedEof ? 0 : 1; | ||
| } | ||
| // TODO: Skip complete chunks without decompressing them? | ||
| @Override | ||
| public long skip(long n) throws IOException { | ||
| ensureOpen(); | ||
| if (n < 0) | ||
| throw new IllegalArgumentException("negative skip length"); | ||
| byte[] buf = new byte[512]; | ||
| long skipped = 0; | ||
| while (skipped < n) { | ||
| int len = (int) Math.min(n - skipped, buf.length); | ||
| len = read(buf, 0, len); | ||
| if (len == -1) { | ||
| reachedEof = true; | ||
| break; | ||
| } | ||
| skipped += len; | ||
| } | ||
| return skipped; | ||
| } | ||
| @Override | ||
| public void close() throws IOException { | ||
| if (!closed) { | ||
| in.close(); | ||
| closed = true; | ||
| } | ||
| } | ||
| @Override | ||
| public boolean markSupported() { | ||
| return false; | ||
| } | ||
| @Override | ||
| public void mark(int readLimit) {} | ||
| @Override | ||
| public void reset() throws IOException { | ||
| throw new IOException("mark/reset not supported"); | ||
| } | ||
| } |
| @@ -0,0 +1,85 @@ | ||
| package spark.compress.lzf; | ||
| import java.io.FilterOutputStream; | ||
| import java.io.IOException; | ||
| import java.io.OutputStream; | ||
| public class LZFOutputStream extends FilterOutputStream { | ||
| private static final int BLOCKSIZE = 1024 * 64 - 1; | ||
| private static final int MAX_HDR_SIZE = 7; | ||
| private byte[] inBuf; // Holds input data to be compressed | ||
| private byte[] outBuf; // Holds compressed data to be written | ||
| private int inPos; // Current position in inBuf | ||
| public LZFOutputStream(OutputStream out) { | ||
| super(out); | ||
| inBuf = new byte[BLOCKSIZE + MAX_HDR_SIZE]; | ||
| outBuf = new byte[BLOCKSIZE + MAX_HDR_SIZE]; | ||
| inPos = MAX_HDR_SIZE; | ||
| } | ||
| @Override | ||
| public void write(int b) throws IOException { | ||
| inBuf[inPos++] = (byte) b; | ||
| if (inPos == inBuf.length) | ||
| compressAndSendBlock(); | ||
| } | ||
| @Override | ||
| public void write(byte[] b, int off, int len) throws IOException { | ||
| if ((off | len | (off + len) | (b.length - (off + len))) < 0) | ||
| throw new IndexOutOfBoundsException(); | ||
| // If we're given a large array, copy it piece by piece into inBuf and | ||
| // write one BLOCKSIZE at a time. This is done to prevent the JNI code | ||
| // from copying the whole array repeatedly if GetPrimitiveArrayCritical | ||
| // decides to copy instead of pinning. | ||
| while (inPos + len >= inBuf.length) { | ||
| int amtToCopy = inBuf.length - inPos; | ||
| System.arraycopy(b, off, inBuf, inPos, amtToCopy); | ||
| inPos += amtToCopy; | ||
| compressAndSendBlock(); | ||
| off += amtToCopy; | ||
| len -= amtToCopy; | ||
| } | ||
| // Copy the remaining (incomplete) block into inBuf | ||
| System.arraycopy(b, off, inBuf, inPos, len); | ||
| inPos += len; | ||
| } | ||
| @Override | ||
| public void flush() throws IOException { | ||
| if (inPos > MAX_HDR_SIZE) | ||
| compressAndSendBlock(); | ||
| out.flush(); | ||
| } | ||
| // Send the data in inBuf, and reset inPos to start writing a new block. | ||
| private void compressAndSendBlock() throws IOException { | ||
| int us = inPos - MAX_HDR_SIZE; | ||
| int maxcs = us > 4 ? us - 4 : us; | ||
| int cs = LZF.compress(inBuf, MAX_HDR_SIZE, us, outBuf, MAX_HDR_SIZE, maxcs); | ||
| if (cs != 0) { | ||
| // Compression made the data smaller; use type 1 header | ||
| outBuf[0] = 'Z'; | ||
| outBuf[1] = 'V'; | ||
| outBuf[2] = 1; | ||
| outBuf[3] = (byte) (cs >> 8); | ||
| outBuf[4] = (byte) (cs & 0xFF); | ||
| outBuf[5] = (byte) (us >> 8); | ||
| outBuf[6] = (byte) (us & 0xFF); | ||
| out.write(outBuf, 0, 7 + cs); | ||
| } else { | ||
| // Compression didn't help; use type 0 header and uncompressed data | ||
| inBuf[2] = 'Z'; | ||
| inBuf[3] = 'V'; | ||
| inBuf[4] = 0; | ||
| inBuf[5] = (byte) (us >> 8); | ||
| inBuf[6] = (byte) (us & 0xFF); | ||
| out.write(inBuf, 2, 5 + us); | ||
| } | ||
| inPos = MAX_HDR_SIZE; | ||
| } | ||
| } |
| @@ -0,0 +1,30 @@ | ||
| CC = gcc | ||
| #JAVA_HOME = /usr/lib/jvm/java-6-sun | ||
| OS_NAME = linux | ||
| CFLAGS = -fPIC -O3 -funroll-all-loops | ||
| SPARK = ../.. | ||
| LZF = $(SPARK)/third_party/liblzf-3.5 | ||
| LIB = libspark_native.so | ||
| all: $(LIB) | ||
| spark_compress_lzf_LZF.h: $(SPARK)/classes/spark/compress/lzf/LZF.class | ||
| ifeq ($(JAVA_HOME),) | ||
| $(error JAVA_HOME is not set) | ||
| else | ||
| $(JAVA_HOME)/bin/javah -classpath $(SPARK)/classes spark.compress.lzf.LZF | ||
| endif | ||
| $(LIB): spark_compress_lzf_LZF.h spark_compress_lzf_LZF.c | ||
| $(CC) $(CFLAGS) -shared -o $@ spark_compress_lzf_LZF.c \ | ||
| -I $(JAVA_HOME)/include -I $(JAVA_HOME)/include/$(OS_NAME) \ | ||
| -I $(LZF) $(LZF)/lzf_c.c $(LZF)/lzf_d.c | ||
| clean: | ||
| rm -f spark_compress_lzf_LZF.h $(LIB) | ||
| .PHONY: all clean |
| @@ -0,0 +1,90 @@ | ||
| #include "spark_compress_lzf_LZF.h" | ||
| #include <lzf.h> | ||
| /* Helper function to throw an exception */ | ||
| static void throwException(JNIEnv *env, const char* className) { | ||
| jclass cls = (*env)->FindClass(env, className); | ||
| if (cls != 0) /* If cls is null, an exception was already thrown */ | ||
| (*env)->ThrowNew(env, cls, ""); | ||
| } | ||
| /* | ||
| * Since LZF.compress() and LZF.decompress() have the same signatures | ||
| * and differ only in which lzf_ function they call, implement both in a | ||
| * single function and pass it a pointer to the correct lzf_ function. | ||
| */ | ||
| static jint callCompressionFunction | ||
| (unsigned int (*func)(const void *const, unsigned int, void *, unsigned int), | ||
| JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen, | ||
| jbyteArray outArray, jint outOff, jint outLen) | ||
| { | ||
| jint inCap; | ||
| jint outCap; | ||
| jbyte *inData = 0; | ||
| jbyte *outData = 0; | ||
| jint ret; | ||
| jint s; | ||
| if (!inArray || !outArray) { | ||
| throwException(env, "java/lang/NullPointerException"); | ||
| goto cleanup; | ||
| } | ||
| inCap = (*env)->GetArrayLength(env, inArray); | ||
| outCap = (*env)->GetArrayLength(env, outArray); | ||
| // Check if any of the offset/length pairs is invalid; we do this by OR'ing | ||
| // things we don't want to be negative and seeing if the result is negative | ||
| s = inOff | inLen | (inOff + inLen) | (inCap - (inOff + inLen)) | | ||
| outOff | outLen | (outOff + outLen) | (outCap - (outOff + outLen)); | ||
| if (s < 0) { | ||
| throwException(env, "java/lang/IndexOutOfBoundsException"); | ||
| goto cleanup; | ||
| } | ||
| inData = (*env)->GetPrimitiveArrayCritical(env, inArray, 0); | ||
| outData = (*env)->GetPrimitiveArrayCritical(env, outArray, 0); | ||
| if (!inData || !outData) { | ||
| // Out of memory - JVM will throw OutOfMemoryError | ||
| goto cleanup; | ||
| } | ||
| ret = func(inData + inOff, inLen, outData + outOff, outLen); | ||
| cleanup: | ||
| if (inData) | ||
| (*env)->ReleasePrimitiveArrayCritical(env, inArray, inData, 0); | ||
| if (outData) | ||
| (*env)->ReleasePrimitiveArrayCritical(env, outArray, outData, 0); | ||
| return ret; | ||
| } | ||
| /* | ||
| * Class: spark_compress_lzf_LZF | ||
| * Method: compress | ||
| * Signature: ([B[B)I | ||
| */ | ||
| JNIEXPORT jint JNICALL Java_spark_compress_lzf_LZF_compress | ||
| (JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen, | ||
| jbyteArray outArray, jint outOff, jint outLen) | ||
| { | ||
| return callCompressionFunction(lzf_compress, env, cls, | ||
| inArray, inOff, inLen, outArray,outOff, outLen); | ||
| } | ||
| /* | ||
| * Class: spark_compress_lzf_LZF | ||
| * Method: decompress | ||
| * Signature: ([B[B)I | ||
| */ | ||
| JNIEXPORT jint JNICALL Java_spark_compress_lzf_LZF_decompress | ||
| (JNIEnv *env, jclass cls, jbyteArray inArray, jint inOff, jint inLen, | ||
| jbyteArray outArray, jint outOff, jint outLen) | ||
| { | ||
| return callCompressionFunction(lzf_decompress, env, cls, | ||
| inArray, inOff, inLen, outArray,outOff, outLen); | ||
| } |
| @@ -0,0 +1,71 @@ | ||
| package spark | ||
| import java.io._ | ||
| import scala.collection.mutable.Map | ||
| @serializable class Accumulator[T](initialValue: T, param: AccumulatorParam[T]) | ||
| { | ||
| val id = Accumulators.newId | ||
| @transient var value_ = initialValue | ||
| var deserialized = false | ||
| Accumulators.register(this) | ||
| def += (term: T) { value_ = param.add(value_, term) } | ||
| def value = this.value_ | ||
| def value_= (t: T) { | ||
| if (!deserialized) value_ = t | ||
| else throw new UnsupportedOperationException("Can't use value_= in task") | ||
| } | ||
| // Called by Java when deserializing an object | ||
| private def readObject(in: ObjectInputStream) { | ||
| in.defaultReadObject | ||
| value_ = param.zero(initialValue) | ||
| deserialized = true | ||
| Accumulators.register(this) | ||
| } | ||
| override def toString = value_.toString | ||
| } | ||
| @serializable trait AccumulatorParam[T] { | ||
| def add(t1: T, t2: T): T | ||
| def zero(initialValue: T): T | ||
| } | ||
| // TODO: The multi-thread support in accumulators is kind of lame; check | ||
| // if there's a more intuitive way of doing it right | ||
| private object Accumulators | ||
| { | ||
| // TODO: Use soft references? => need to make readObject work properly then | ||
| val accums = Map[(Thread, Long), Accumulator[_]]() | ||
| var lastId: Long = 0 | ||
| def newId: Long = synchronized { lastId += 1; return lastId } | ||
| def register(a: Accumulator[_]): Unit = synchronized { | ||
| accums((currentThread, a.id)) = a | ||
| } | ||
| def clear: Unit = synchronized { | ||
| accums.retain((key, accum) => key._1 != currentThread) | ||
| } | ||
| def values: Map[Long, Any] = synchronized { | ||
| val ret = Map[Long, Any]() | ||
| for(((thread, id), accum) <- accums if thread == currentThread) | ||
| ret(id) = accum.value | ||
| return ret | ||
| } | ||
| def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized { | ||
| for ((id, value) <- values) { | ||
| if (accums.contains((thread, id))) { | ||
| val accum = accums((thread, id)) | ||
| accum.asInstanceOf[Accumulator[Any]] += value | ||
| } | ||
| } | ||
| } | ||
| } |
| @@ -0,0 +1,110 @@ | ||
| package spark | ||
| import java.io._ | ||
| import java.net.URI | ||
| import java.util.UUID | ||
| import com.google.common.collect.MapMaker | ||
| import org.apache.hadoop.conf.Configuration | ||
| import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} | ||
| import spark.compress.lzf.{LZFInputStream, LZFOutputStream} | ||
| @serializable class Cached[T](@transient var value_ : T, local: Boolean) { | ||
| val uuid = UUID.randomUUID() | ||
| def value = value_ | ||
| Cache.synchronized { Cache.values.put(uuid, value_) } | ||
| if (!local) writeCacheFile() | ||
| private def writeCacheFile() { | ||
| val out = new ObjectOutputStream(Cache.openFileForWriting(uuid)) | ||
| out.writeObject(value_) | ||
| out.close() | ||
| } | ||
| // Called by Java when deserializing an object | ||
| private def readObject(in: ObjectInputStream) { | ||
| in.defaultReadObject | ||
| Cache.synchronized { | ||
| val cachedVal = Cache.values.get(uuid) | ||
| if (cachedVal != null) { | ||
| value_ = cachedVal.asInstanceOf[T] | ||
| } else { | ||
| val start = System.nanoTime | ||
| val fileIn = new ObjectInputStream(Cache.openFileForReading(uuid)) | ||
| value_ = fileIn.readObject().asInstanceOf[T] | ||
| Cache.values.put(uuid, value_) | ||
| fileIn.close() | ||
| val time = (System.nanoTime - start) / 1e9 | ||
| println("Reading cached variable " + uuid + " took " + time + " s") | ||
| } | ||
| } | ||
| } | ||
| override def toString = "spark.Cached(" + uuid + ")" | ||
| } | ||
| private object Cache { | ||
| val values = new MapMaker().softValues().makeMap[UUID, Any]() | ||
| private var initialized = false | ||
| private var fileSystem: FileSystem = null | ||
| private var workDir: String = null | ||
| private var compress: Boolean = false | ||
| private var bufferSize: Int = 65536 | ||
| // Will be called by SparkContext or Executor before using cache | ||
| def initialize() { | ||
| synchronized { | ||
| if (!initialized) { | ||
| bufferSize = System.getProperty("spark.buffer.size", "65536").toInt | ||
| val dfs = System.getProperty("spark.dfs", "file:///") | ||
| if (!dfs.startsWith("file://")) { | ||
| val conf = new Configuration() | ||
| conf.setInt("io.file.buffer.size", bufferSize) | ||
| val rep = System.getProperty("spark.dfs.replication", "3").toInt | ||
| conf.setInt("dfs.replication", rep) | ||
| fileSystem = FileSystem.get(new URI(dfs), conf) | ||
| } | ||
| workDir = System.getProperty("spark.dfs.workdir", "/tmp") | ||
| compress = System.getProperty("spark.compress", "false").toBoolean | ||
| initialized = true | ||
| } | ||
| } | ||
| } | ||
| private def getPath(uuid: UUID) = new Path(workDir + "/cache-" + uuid) | ||
| def openFileForReading(uuid: UUID): InputStream = { | ||
| val fileStream = if (fileSystem != null) { | ||
| fileSystem.open(getPath(uuid)) | ||
| } else { | ||
| // Local filesystem | ||
| new FileInputStream(getPath(uuid).toString) | ||
| } | ||
| if (compress) | ||
| new LZFInputStream(fileStream) // LZF stream does its own buffering | ||
| else if (fileSystem == null) | ||
| new BufferedInputStream(fileStream, bufferSize) | ||
| else | ||
| fileStream // Hadoop streams do their own buffering | ||
| } | ||
| def openFileForWriting(uuid: UUID): OutputStream = { | ||
| val fileStream = if (fileSystem != null) { | ||
| fileSystem.create(getPath(uuid)) | ||
| } else { | ||
| // Local filesystem | ||
| new FileOutputStream(getPath(uuid).toString) | ||
| } | ||
| if (compress) | ||
| new LZFOutputStream(fileStream) // LZF stream does its own buffering | ||
| else if (fileSystem == null) | ||
| new BufferedOutputStream(fileStream, bufferSize) | ||
| else | ||
| fileStream // Hadoop streams do their own buffering | ||
| } | ||
| } |
| @@ -0,0 +1,157 @@ | ||
| package spark | ||
| import scala.collection.mutable.Map | ||
| import scala.collection.mutable.Set | ||
| import org.objectweb.asm.{ClassReader, MethodVisitor, Type} | ||
| import org.objectweb.asm.commons.EmptyVisitor | ||
| import org.objectweb.asm.Opcodes._ | ||
| object ClosureCleaner { | ||
| private def getClassReader(cls: Class[_]): ClassReader = new ClassReader( | ||
| cls.getResourceAsStream(cls.getName.replaceFirst("^.*\\.", "") + ".class")) | ||
| private def getOuterClasses(obj: AnyRef): List[Class[_]] = { | ||
| for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { | ||
| f.setAccessible(true) | ||
| return f.getType :: getOuterClasses(f.get(obj)) | ||
| } | ||
| return Nil | ||
| } | ||
| private def getOuterObjects(obj: AnyRef): List[AnyRef] = { | ||
| for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { | ||
| f.setAccessible(true) | ||
| return f.get(obj) :: getOuterObjects(f.get(obj)) | ||
| } | ||
| return Nil | ||
| } | ||
| private def getInnerClasses(obj: AnyRef): List[Class[_]] = { | ||
| val seen = Set[Class[_]](obj.getClass) | ||
| var stack = List[Class[_]](obj.getClass) | ||
| while (!stack.isEmpty) { | ||
| val cr = getClassReader(stack.head) | ||
| stack = stack.tail | ||
| val set = Set[Class[_]]() | ||
| cr.accept(new InnerClosureFinder(set), 0) | ||
| for (cls <- set -- seen) { | ||
| seen += cls | ||
| stack = cls :: stack | ||
| } | ||
| } | ||
| return (seen - obj.getClass).toList | ||
| } | ||
| private def createNullValue(cls: Class[_]): AnyRef = { | ||
| if (cls.isPrimitive) | ||
| new java.lang.Byte(0: Byte) // Should be convertible to any primitive type | ||
| else | ||
| null | ||
| } | ||
| def clean(func: AnyRef): Unit = { | ||
| // TODO: cache outerClasses / innerClasses / accessedFields | ||
| val outerClasses = getOuterClasses(func) | ||
| val innerClasses = getInnerClasses(func) | ||
| val outerObjects = getOuterObjects(func) | ||
| val accessedFields = Map[Class[_], Set[String]]() | ||
| for (cls <- outerClasses) | ||
| accessedFields(cls) = Set[String]() | ||
| for (cls <- func.getClass :: innerClasses) | ||
| getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) | ||
| var outer: AnyRef = null | ||
| for ((cls, obj) <- (outerClasses zip outerObjects).reverse) { | ||
| outer = instantiateClass(cls, outer); | ||
| for (fieldName <- accessedFields(cls)) { | ||
| val field = cls.getDeclaredField(fieldName) | ||
| field.setAccessible(true) | ||
| val value = field.get(obj) | ||
| //println("1: Setting " + fieldName + " on " + cls + " to " + value); | ||
| field.set(outer, value) | ||
| } | ||
| } | ||
| if (outer != null) { | ||
| //println("2: Setting $outer on " + func.getClass + " to " + outer); | ||
| val field = func.getClass.getDeclaredField("$outer") | ||
| field.setAccessible(true) | ||
| field.set(func, outer) | ||
| } | ||
| } | ||
| private def instantiateClass(cls: Class[_], outer: AnyRef): AnyRef = { | ||
| if (spark.repl.Main.interp == null) { | ||
| // This is a bona fide closure class, whose constructor has no effects | ||
| // other than to set its fields, so use its constructor | ||
| val cons = cls.getConstructors()(0) | ||
| val params = cons.getParameterTypes.map(createNullValue).toArray | ||
| if (outer != null) | ||
| params(0) = outer // First param is always outer object | ||
| return cons.newInstance(params: _*).asInstanceOf[AnyRef] | ||
| } else { | ||
| // Use reflection to instantiate object without calling constructor | ||
| val rf = sun.reflect.ReflectionFactory.getReflectionFactory(); | ||
| val parentCtor = classOf[java.lang.Object].getDeclaredConstructor(); | ||
| val newCtor = rf.newConstructorForSerialization(cls, parentCtor) | ||
| val obj = newCtor.newInstance().asInstanceOf[AnyRef]; | ||
| if (outer != null) { | ||
| //println("3: Setting $outer on " + cls + " to " + outer); | ||
| val field = cls.getDeclaredField("$outer") | ||
| field.setAccessible(true) | ||
| field.set(obj, outer) | ||
| } | ||
| return obj | ||
| } | ||
| } | ||
| } | ||
| class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor { | ||
| override def visitMethod(access: Int, name: String, desc: String, | ||
| sig: String, exceptions: Array[String]): MethodVisitor = { | ||
| return new EmptyVisitor { | ||
| override def visitFieldInsn(op: Int, owner: String, name: String, | ||
| desc: String) { | ||
| if (op == GETFIELD) | ||
| for (cl <- output.keys if cl.getName == owner.replace('/', '.')) | ||
| output(cl) += name | ||
| } | ||
| override def visitMethodInsn(op: Int, owner: String, name: String, | ||
| desc: String) { | ||
| if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) | ||
| for (cl <- output.keys if cl.getName == owner.replace('/', '.')) | ||
| output(cl) += name | ||
| } | ||
| } | ||
| } | ||
| } | ||
| class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor { | ||
| var myName: String = null | ||
| override def visit(version: Int, access: Int, name: String, sig: String, | ||
| superName: String, interfaces: Array[String]) { | ||
| myName = name | ||
| } | ||
| override def visitMethod(access: Int, name: String, desc: String, | ||
| sig: String, exceptions: Array[String]): MethodVisitor = { | ||
| return new EmptyVisitor { | ||
| override def visitMethodInsn(op: Int, owner: String, name: String, | ||
| desc: String) { | ||
| val argTypes = Type.getArgumentTypes(desc) | ||
| if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0 | ||
| && argTypes(0).toString.startsWith("L") // is it an object? | ||
| && argTypes(0).getInternalName == myName) | ||
| output += Class.forName(owner.replace('/', '.'), false, | ||
| Thread.currentThread.getContextClassLoader) | ||
| } | ||
| } | ||
| } | ||
| } |
| @@ -0,0 +1,70 @@ | ||
| package spark | ||
| import java.util.concurrent.{Executors, ExecutorService} | ||
| import nexus.{ExecutorArgs, TaskDescription, TaskState, TaskStatus} | ||
| object Executor { | ||
| def main(args: Array[String]) { | ||
| System.loadLibrary("nexus") | ||
| val exec = new nexus.Executor() { | ||
| var classLoader: ClassLoader = null | ||
| var threadPool: ExecutorService = null | ||
| override def init(args: ExecutorArgs) { | ||
| // Read spark.* system properties | ||
| val props = Utils.deserialize[Array[(String, String)]](args.getData) | ||
| for ((key, value) <- props) | ||
| System.setProperty(key, value) | ||
| // Initialize cache (uses some properties read above) | ||
| Cache.initialize() | ||
| // If the REPL is in use, create a ClassLoader that will be able to | ||
| // read new classes defined by the REPL as the user types code | ||
| classLoader = this.getClass.getClassLoader | ||
| val classDir = System.getProperty("spark.repl.classdir") | ||
| if (classDir != null) { | ||
| println("Using REPL classdir: " + classDir) | ||
| classLoader = new repl.ExecutorClassLoader(classDir, classLoader) | ||
| } | ||
| Thread.currentThread.setContextClassLoader(classLoader) | ||
| // Start worker thread pool (they will inherit our context ClassLoader) | ||
| threadPool = Executors.newCachedThreadPool() | ||
| } | ||
| override def startTask(desc: TaskDescription) { | ||
| // Pull taskId and arg out of TaskDescription because it won't be a | ||
| // valid pointer after this method call (TODO: fix this in C++/SWIG) | ||
| val taskId = desc.getTaskId | ||
| val arg = desc.getArg | ||
| threadPool.execute(new Runnable() { | ||
| def run() = { | ||
| println("Running task ID " + taskId) | ||
| try { | ||
| Accumulators.clear | ||
| val task = Utils.deserialize[Task[Any]](arg, classLoader) | ||
| val value = task.run | ||
| val accumUpdates = Accumulators.values | ||
| val result = new TaskResult(value, accumUpdates) | ||
| sendStatusUpdate(new TaskStatus( | ||
| taskId, TaskState.TASK_FINISHED, Utils.serialize(result))) | ||
| println("Finished task ID " + taskId) | ||
| } catch { | ||
| case e: Exception => { | ||
| // TODO: Handle errors in tasks less dramatically | ||
| System.err.println("Exception in task ID " + taskId + ":") | ||
| e.printStackTrace | ||
| System.exit(1) | ||
| } | ||
| } | ||
| } | ||
| }) | ||
| } | ||
| } | ||
| exec.run() | ||
| } | ||
| } |
| @@ -0,0 +1,277 @@ | ||
| package spark | ||
| import java.io._ | ||
| import java.util.concurrent.atomic.AtomicLong | ||
| import java.util.concurrent.ConcurrentHashMap | ||
| import java.util.HashSet | ||
| import scala.collection.mutable.ArrayBuffer | ||
| import scala.collection.mutable.Map | ||
| import nexus._ | ||
| import com.google.common.collect.MapMaker | ||
| import org.apache.hadoop.io.ObjectWritable | ||
| import org.apache.hadoop.io.LongWritable | ||
| import org.apache.hadoop.io.Text | ||
| import org.apache.hadoop.io.Writable | ||
| import org.apache.hadoop.mapred.FileInputFormat | ||
| import org.apache.hadoop.mapred.InputSplit | ||
| import org.apache.hadoop.mapred.JobConf | ||
| import org.apache.hadoop.mapred.TextInputFormat | ||
| import org.apache.hadoop.mapred.RecordReader | ||
| import org.apache.hadoop.mapred.Reporter | ||
| @serializable | ||
| abstract class DistributedFile[T, Split](@transient sc: SparkContext) { | ||
| def splits: Array[Split] | ||
| def iterator(split: Split): Iterator[T] | ||
| def prefers(split: Split, slot: SlaveOffer): Boolean | ||
| def taskStarted(split: Split, slot: SlaveOffer) {} | ||
| def sparkContext = sc | ||
| def foreach(f: T => Unit) { | ||
| val cleanF = sc.clean(f) | ||
| val tasks = splits.map(s => new ForeachTask(this, s, cleanF)).toArray | ||
| sc.runTaskObjects(tasks) | ||
| } | ||
| def toArray: Array[T] = { | ||
| val tasks = splits.map(s => new GetTask(this, s)) | ||
| val results = sc.runTaskObjects(tasks) | ||
| Array.concat(results: _*) | ||
| } | ||
| def reduce(f: (T, T) => T): T = { | ||
| val cleanF = sc.clean(f) | ||
| val tasks = splits.map(s => new ReduceTask(this, s, f)) | ||
| val results = new ArrayBuffer[T] | ||
| for (option <- sc.runTaskObjects(tasks); elem <- option) | ||
| results += elem | ||
| if (results.size == 0) | ||
| throw new UnsupportedOperationException("empty collection") | ||
| else | ||
| return results.reduceLeft(f) | ||
| } | ||
| def take(num: Int): Array[T] = { | ||
| if (num == 0) | ||
| return new Array[T](0) | ||
| val buf = new ArrayBuffer[T] | ||
| for (split <- splits; elem <- iterator(split)) { | ||
| buf += elem | ||
| if (buf.length == num) | ||
| return buf.toArray | ||
| } | ||
| return buf.toArray | ||
| } | ||
| def first: T = take(1) match { | ||
| case Array(t) => t | ||
| case _ => throw new UnsupportedOperationException("empty collection") | ||
| } | ||
| def map[U](f: T => U) = new MappedFile(this, sc.clean(f)) | ||
| def filter(f: T => Boolean) = new FilteredFile(this, sc.clean(f)) | ||
| def cache() = new CachedFile(this) | ||
| def count(): Long = | ||
| try { map(x => 1L).reduce(_+_) } | ||
| catch { case e: UnsupportedOperationException => 0L } | ||
| } | ||
| @serializable | ||
| abstract class FileTask[U, T, Split](val file: DistributedFile[T, Split], | ||
| val split: Split) | ||
| extends Task[U] { | ||
| override def prefers(slot: SlaveOffer) = file.prefers(split, slot) | ||
| override def markStarted(slot: SlaveOffer) { file.taskStarted(split, slot) } | ||
| } | ||
| class ForeachTask[T, Split](file: DistributedFile[T, Split], | ||
| split: Split, func: T => Unit) | ||
| extends FileTask[Unit, T, Split](file, split) { | ||
| override def run() { | ||
| println("Processing " + split) | ||
| file.iterator(split).foreach(func) | ||
| } | ||
| } | ||
| class GetTask[T, Split](file: DistributedFile[T, Split], split: Split) | ||
| extends FileTask[Array[T], T, Split](file, split) { | ||
| override def run(): Array[T] = { | ||
| println("Processing " + split) | ||
| file.iterator(split).collect.toArray | ||
| } | ||
| } | ||
| class ReduceTask[T, Split](file: DistributedFile[T, Split], | ||
| split: Split, f: (T, T) => T) | ||
| extends FileTask[Option[T], T, Split](file, split) { | ||
| override def run(): Option[T] = { | ||
| println("Processing " + split) | ||
| val iter = file.iterator(split) | ||
| if (iter.hasNext) | ||
| Some(iter.reduceLeft(f)) | ||
| else | ||
| None | ||
| } | ||
| } | ||
| class MappedFile[U, T, Split](prev: DistributedFile[T, Split], f: T => U) | ||
| extends DistributedFile[U, Split](prev.sparkContext) { | ||
| override def splits = prev.splits | ||
| override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot) | ||
| override def iterator(split: Split) = prev.iterator(split).map(f) | ||
| override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) | ||
| } | ||
| class FilteredFile[T, Split](prev: DistributedFile[T, Split], f: T => Boolean) | ||
| extends DistributedFile[T, Split](prev.sparkContext) { | ||
| override def splits = prev.splits | ||
| override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot) | ||
| override def iterator(split: Split) = prev.iterator(split).filter(f) | ||
| override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) | ||
| } | ||
| class CachedFile[T, Split](prev: DistributedFile[T, Split]) | ||
| extends DistributedFile[T, Split](prev.sparkContext) { | ||
| val id = CachedFile.newId() | ||
| @transient val cacheLocs = Map[Split, List[Int]]() | ||
| override def splits = prev.splits | ||
| override def prefers(split: Split, slot: SlaveOffer): Boolean = { | ||
| if (cacheLocs.contains(split)) | ||
| cacheLocs(split).contains(slot.getSlaveId) | ||
| else | ||
| prev.prefers(split, slot) | ||
| } | ||
| override def iterator(split: Split): Iterator[T] = { | ||
| val key = id + "::" + split.toString | ||
| val cache = CachedFile.cache | ||
| val loading = CachedFile.loading | ||
| val cachedVal = cache.get(key) | ||
| if (cachedVal != null) { | ||
| // Split is in cache, so just return its values | ||
| return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]]) | ||
| } else { | ||
| // Mark the split as loading (unless someone else marks it first) | ||
| loading.synchronized { | ||
| if (loading.contains(key)) { | ||
| while (loading.contains(key)) { | ||
| try {loading.wait()} catch {case _ =>} | ||
| } | ||
| return Iterator.fromArray(cache.get(key).asInstanceOf[Array[T]]) | ||
| } else { | ||
| loading.add(key) | ||
| } | ||
| } | ||
| // If we got here, we have to load the split | ||
| println("Loading and caching " + split) | ||
| val array = prev.iterator(split).collect.toArray | ||
| cache.put(key, array) | ||
| loading.synchronized { | ||
| loading.remove(key) | ||
| loading.notifyAll() | ||
| } | ||
| return Iterator.fromArray(array) | ||
| } | ||
| } | ||
| override def taskStarted(split: Split, slot: SlaveOffer) { | ||
| val oldList = cacheLocs.getOrElse(split, Nil) | ||
| val slaveId = slot.getSlaveId | ||
| if (!oldList.contains(slaveId)) | ||
| cacheLocs(split) = slaveId :: oldList | ||
| } | ||
| } | ||
| private object CachedFile { | ||
| val nextId = new AtomicLong(0) // Generates IDs for mapped files (on master) | ||
| def newId() = nextId.getAndIncrement() | ||
| // Stores map results for various splits locally (on workers) | ||
| val cache = new MapMaker().softValues().makeMap[String, AnyRef]() | ||
| // Remembers which splits are currently being loaded (on workers) | ||
| val loading = new HashSet[String] | ||
| } | ||
| class HdfsSplit(@transient s: InputSplit) | ||
| extends SerializableWritable[InputSplit](s) | ||
| class HdfsTextFile(sc: SparkContext, path: String) | ||
| extends DistributedFile[String, HdfsSplit](sc) { | ||
| @transient val conf = new JobConf() | ||
| @transient val inputFormat = new TextInputFormat() | ||
| FileInputFormat.setInputPaths(conf, path) | ||
| ConfigureLock.synchronized { inputFormat.configure(conf) } | ||
| @transient val splits_ = | ||
| inputFormat.getSplits(conf, 2).map(new HdfsSplit(_)).toArray | ||
| override def splits = splits_ | ||
| override def iterator(split: HdfsSplit) = new Iterator[String] { | ||
| var reader: RecordReader[LongWritable, Text] = null | ||
| ConfigureLock.synchronized { | ||
| val conf = new JobConf() | ||
| conf.set("io.file.buffer.size", | ||
| System.getProperty("spark.buffer.size", "65536")) | ||
| val tif = new TextInputFormat() | ||
| tif.configure(conf) | ||
| reader = tif.getRecordReader(split.value, conf, Reporter.NULL) | ||
| } | ||
| val lineNum = new LongWritable() | ||
| val text = new Text() | ||
| var gotNext = false | ||
| var finished = false | ||
| override def hasNext: Boolean = { | ||
| if (!gotNext) { | ||
| finished = !reader.next(lineNum, text) | ||
| gotNext = true | ||
| } | ||
| !finished | ||
| } | ||
| override def next: String = { | ||
| if (!gotNext) | ||
| finished = !reader.next(lineNum, text) | ||
| if (finished) | ||
| throw new java.util.NoSuchElementException("end of stream") | ||
| gotNext = false | ||
| text.toString | ||
| } | ||
| } | ||
| override def prefers(split: HdfsSplit, slot: SlaveOffer) = | ||
| split.value.getLocations().contains(slot.getHost) | ||
| } | ||
| object ConfigureLock {} | ||
| @serializable | ||
| class SerializableWritable[T <: Writable](@transient var t: T) { | ||
| def value = t | ||
| override def toString = t.toString | ||
| private def writeObject(out: ObjectOutputStream) { | ||
| out.defaultWriteObject() | ||
| new ObjectWritable(t).write(out) | ||
| } | ||
| private def readObject(in: ObjectInputStream) { | ||
| in.defaultReadObject() | ||
| val ow = new ObjectWritable() | ||
| ow.setConf(new JobConf()) | ||
| ow.readFields(in) | ||
| t = ow.get().asInstanceOf[T] | ||
| } | ||
| } |
| @@ -0,0 +1,65 @@ | ||
| package spark | ||
| import java.util.concurrent._ | ||
| import scala.collection.mutable.Map | ||
| // A simple Scheduler implementation that runs tasks locally in a thread pool. | ||
| private class LocalScheduler(threads: Int) extends Scheduler { | ||
| var threadPool: ExecutorService = | ||
| Executors.newFixedThreadPool(threads, DaemonThreadFactory) | ||
| override def start() {} | ||
| override def waitForRegister() {} | ||
| override def runTasks[T](tasks: Array[Task[T]]): Array[T] = { | ||
| val futures = new Array[Future[TaskResult[T]]](tasks.length) | ||
| for (i <- 0 until tasks.length) { | ||
| futures(i) = threadPool.submit(new Callable[TaskResult[T]]() { | ||
| def call(): TaskResult[T] = { | ||
| println("Running task " + i) | ||
| try { | ||
| // Serialize and deserialize the task so that accumulators are | ||
| // changed to thread-local ones; this adds a bit of unnecessary | ||
| // overhead but matches how the Nexus Executor works | ||
| Accumulators.clear | ||
| val bytes = Utils.serialize(tasks(i)) | ||
| println("Size of task " + i + " is " + bytes.size + " bytes") | ||
| val task = Utils.deserialize[Task[T]]( | ||
| bytes, currentThread.getContextClassLoader) | ||
| val value = task.run | ||
| val accumUpdates = Accumulators.values | ||
| println("Finished task " + i) | ||
| new TaskResult[T](value, accumUpdates) | ||
| } catch { | ||
| case e: Exception => { | ||
| // TODO: Do something nicer here | ||
| System.err.println("Exception in task " + i + ":") | ||
| e.printStackTrace() | ||
| System.exit(1) | ||
| null | ||
| } | ||
| } | ||
| } | ||
| }) | ||
| } | ||
| val taskResults = futures.map(_.get) | ||
| for (result <- taskResults) | ||
| Accumulators.add(currentThread, result.accumUpdates) | ||
| return taskResults.map(_.value).toArray | ||
| } | ||
| override def stop() {} | ||
| } | ||
| // A ThreadFactory that creates daemon threads | ||
| private object DaemonThreadFactory extends ThreadFactory { | ||
| override def newThread(r: Runnable): Thread = { | ||
| val t = new Thread(r); | ||
| t.setDaemon(true) | ||
| return t | ||
| } | ||
| } |
| @@ -0,0 +1,258 @@ | ||
| package spark | ||
| import java.io.File | ||
| import java.util.concurrent.Semaphore | ||
| import nexus.{ExecutorInfo, TaskDescription, TaskState, TaskStatus} | ||
| import nexus.{SlaveOffer, SchedulerDriver, NexusSchedulerDriver} | ||
| import nexus.{SlaveOfferVector, TaskDescriptionVector, StringMap} | ||
| // The main Scheduler implementation, which talks to Nexus. Clients are expected | ||
| // to first call start(), then submit tasks through the runTasks method. | ||
| // | ||
| // This implementation is currently a little quick and dirty. The following | ||
| // improvements need to be made to it: | ||
| // 1) Fault tolerance should be added - if a task fails, just re-run it anywhere. | ||
| // 2) Right now, the scheduler uses a linear scan through the tasks to find a | ||
| // local one for a given node. It would be faster to have a separate list of | ||
| // pending tasks for each node. | ||
| // 3) The Callbacks way of organizing things didn't work out too well, so the | ||
| // way the scheduler keeps track of the currently active runTasks operation | ||
| // can be made cleaner. | ||
| private class NexusScheduler( | ||
| master: String, frameworkName: String, execArg: Array[Byte]) | ||
| extends nexus.Scheduler with spark.Scheduler | ||
| { | ||
| // Semaphore used by runTasks to ensure only one thread can be in it | ||
| val semaphore = new Semaphore(1) | ||
| // Lock used to wait for scheduler to be registered | ||
| var isRegistered = false | ||
| val registeredLock = new Object() | ||
| // Trait representing a set of scheduler callbacks | ||
| trait Callbacks { | ||
| def slotOffer(s: SlaveOffer): Option[TaskDescription] | ||
| def taskFinished(t: TaskStatus): Unit | ||
| def error(code: Int, message: String): Unit | ||
| } | ||
| // Current callback object (may be null) | ||
| var callbacks: Callbacks = null | ||
| // Incrementing task ID | ||
| var nextTaskId = 0 | ||
| // Maximum time to wait to run a task in a preferred location (in ms) | ||
| val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong | ||
| // Driver for talking to Nexus | ||
| var driver: SchedulerDriver = null | ||
| override def start() { | ||
| new Thread("Spark scheduler") { | ||
| setDaemon(true) | ||
| override def run { | ||
| val ns = NexusScheduler.this | ||
| ns.driver = new NexusSchedulerDriver(ns, master) | ||
| ns.driver.run() | ||
| } | ||
| }.start | ||
| } | ||
| override def getFrameworkName(d: SchedulerDriver): String = frameworkName | ||
| override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo = | ||
| new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg) | ||
| override def runTasks[T](tasks: Array[Task[T]]): Array[T] = { | ||
| val results = new Array[T](tasks.length) | ||
| if (tasks.length == 0) | ||
| return results | ||
| val launched = new Array[Boolean](tasks.length) | ||
| val callingThread = currentThread | ||
| var errorHappened = false | ||
| var errorCode = 0 | ||
| var errorMessage = "" | ||
| // Wait for scheduler to be registered with Nexus | ||
| waitForRegister() | ||
| try { | ||
| // Acquire the runTasks semaphore | ||
| semaphore.acquire() | ||
| val myCallbacks = new Callbacks { | ||
| val firstTaskId = nextTaskId | ||
| var tasksLaunched = 0 | ||
| var tasksFinished = 0 | ||
| var lastPreferredLaunchTime = System.currentTimeMillis | ||
| def slotOffer(slot: SlaveOffer): Option[TaskDescription] = { | ||
| try { | ||
| if (tasksLaunched < tasks.length) { | ||
| // TODO: Add a short wait if no task with location pref is found | ||
| // TODO: Figure out why a function is needed around this to | ||
| // avoid scala.runtime.NonLocalReturnException | ||
| def findTask: Option[TaskDescription] = { | ||
| var checkPrefVals: Array[Boolean] = Array(true) | ||
| val time = System.currentTimeMillis | ||
| if (time - lastPreferredLaunchTime > LOCALITY_WAIT) | ||
| checkPrefVals = Array(true, false) // Allow non-preferred tasks | ||
| // TODO: Make desiredCpus and desiredMem configurable | ||
| val desiredCpus = 1 | ||
| val desiredMem = 750L * 1024L * 1024L | ||
| if (slot.getParams.get("cpus").toInt < desiredCpus || | ||
| slot.getParams.get("mem").toLong < desiredMem) | ||
| return None | ||
| for (checkPref <- checkPrefVals; | ||
| i <- 0 until tasks.length; | ||
| if !launched(i) && (!checkPref || tasks(i).prefers(slot))) | ||
| { | ||
| val taskId = nextTaskId | ||
| nextTaskId += 1 | ||
| printf("Starting task %d as TID %d on slave %d: %s (%s)\n", | ||
| i, taskId, slot.getSlaveId, slot.getHost, | ||
| if(checkPref) "preferred" else "non-preferred") | ||
| tasks(i).markStarted(slot) | ||
| launched(i) = true | ||
| tasksLaunched += 1 | ||
| if (checkPref) | ||
| lastPreferredLaunchTime = time | ||
| val params = new StringMap | ||
| params.set("cpus", "" + desiredCpus) | ||
| params.set("mem", "" + desiredMem) | ||
| val serializedTask = Utils.serialize(tasks(i)) | ||
| return Some(new TaskDescription(taskId, slot.getSlaveId, | ||
| "task_" + taskId, params, serializedTask)) | ||
| } | ||
| return None | ||
| } | ||
| return findTask | ||
| } else { | ||
| return None | ||
| } | ||
| } catch { | ||
| case e: Exception => { | ||
| e.printStackTrace | ||
| System.exit(1) | ||
| return None | ||
| } | ||
| } | ||
| } | ||
| def taskFinished(status: TaskStatus) { | ||
| println("Finished TID " + status.getTaskId) | ||
| // Deserialize task result | ||
| val result = Utils.deserialize[TaskResult[T]](status.getData) | ||
| results(status.getTaskId - firstTaskId) = result.value | ||
| // Update accumulators | ||
| Accumulators.add(callingThread, result.accumUpdates) | ||
| // Stop if we've finished all the tasks | ||
| tasksFinished += 1 | ||
| if (tasksFinished == tasks.length) { | ||
| NexusScheduler.this.callbacks = null | ||
| NexusScheduler.this.notifyAll() | ||
| } | ||
| } | ||
| def error(code: Int, message: String) { | ||
| // Save the error message | ||
| errorHappened = true | ||
| errorCode = code | ||
| errorMessage = message | ||
| // Indicate to caller thread that we're done | ||
| NexusScheduler.this.callbacks = null | ||
| NexusScheduler.this.notifyAll() | ||
| } | ||
| } | ||
| this.synchronized { | ||
| this.callbacks = myCallbacks | ||
| } | ||
| driver.reviveOffers(); | ||
| this.synchronized { | ||
| while (this.callbacks != null) this.wait() | ||
| } | ||
| } finally { | ||
| semaphore.release() | ||
| } | ||
| if (errorHappened) | ||
| throw new SparkException(errorMessage, errorCode) | ||
| else | ||
| return results | ||
| } | ||
| override def registered(d: SchedulerDriver, frameworkId: Int) { | ||
| println("Registered as framework ID " + frameworkId) | ||
| registeredLock.synchronized { | ||
| isRegistered = true | ||
| registeredLock.notifyAll() | ||
| } | ||
| } | ||
| override def waitForRegister() { | ||
| registeredLock.synchronized { | ||
| while (!isRegistered) registeredLock.wait() | ||
| } | ||
| } | ||
| override def resourceOffer( | ||
| d: SchedulerDriver, oid: Long, slots: SlaveOfferVector) { | ||
| synchronized { | ||
| val tasks = new TaskDescriptionVector | ||
| if (callbacks != null) { | ||
| try { | ||
| for (i <- 0 until slots.size.toInt) { | ||
| callbacks.slotOffer(slots.get(i)) match { | ||
| case Some(task) => tasks.add(task) | ||
| case None => {} | ||
| } | ||
| } | ||
| } catch { | ||
| case e: Exception => e.printStackTrace | ||
| } | ||
| } | ||
| val params = new StringMap | ||
| params.set("timeout", "1") | ||
| d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout | ||
| } | ||
| } | ||
| override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { | ||
| synchronized { | ||
| if (callbacks != null && status.getState == TaskState.TASK_FINISHED) { | ||
| try { | ||
| callbacks.taskFinished(status) | ||
| } catch { | ||
| case e: Exception => e.printStackTrace | ||
| } | ||
| } | ||
| } | ||
| } | ||
| override def error(d: SchedulerDriver, code: Int, message: String) { | ||
| synchronized { | ||
| if (callbacks != null) { | ||
| try { | ||
| callbacks.error(code, message) | ||
| } catch { | ||
| case e: Exception => e.printStackTrace | ||
| } | ||
| } else { | ||
| val msg = "Nexus error: %s (error code: %d)".format(message, code) | ||
| System.err.println(msg) | ||
| System.exit(1) | ||
| } | ||
| } | ||
| } | ||
| override def stop() { | ||
| if (driver != null) | ||
| driver.stop() | ||
| } | ||
| } |
| @@ -0,0 +1,97 @@ | ||
| package spark | ||
| abstract class ParallelArray[T](sc: SparkContext) { | ||
| def filter(f: T => Boolean): ParallelArray[T] = { | ||
| val cleanF = sc.clean(f) | ||
| new FilteredParallelArray[T](sc, this, cleanF) | ||
| } | ||
| def foreach(f: T => Unit): Unit | ||
| def map[U](f: T => U): Array[U] | ||
| } | ||
| private object ParallelArray { | ||
| def slice[T](seq: Seq[T], numSlices: Int): Array[Seq[T]] = { | ||
| if (numSlices < 1) | ||
| throw new IllegalArgumentException("Positive number of slices required") | ||
| seq match { | ||
| case r: Range.Inclusive => { | ||
| val sign = if (r.step < 0) -1 else 1 | ||
| slice(new Range(r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], | ||
| numSlices) | ||
| } | ||
| case r: Range => { | ||
| (0 until numSlices).map(i => { | ||
| val start = ((i * r.length.toLong) / numSlices).toInt | ||
| val end = (((i+1) * r.length.toLong) / numSlices).toInt | ||
| new SerializableRange( | ||
| r.start + start * r.step, r.start + end * r.step, r.step) | ||
| }).asInstanceOf[Seq[Seq[T]]].toArray | ||
| } | ||
| case _ => { | ||
| val array = seq.toArray // To prevent O(n^2) operations for List etc | ||
| (0 until numSlices).map(i => { | ||
| val start = ((i * array.length.toLong) / numSlices).toInt | ||
| val end = (((i+1) * array.length.toLong) / numSlices).toInt | ||
| array.slice(start, end).toArray | ||
| }).toArray | ||
| } | ||
| } | ||
| } | ||
| } | ||
| private class SimpleParallelArray[T]( | ||
| sc: SparkContext, data: Seq[T], numSlices: Int) | ||
| extends ParallelArray[T](sc) { | ||
| val slices = ParallelArray.slice(data, numSlices) | ||
| def foreach(f: T => Unit) { | ||
| val cleanF = sc.clean(f) | ||
| var tasks = for (i <- 0 until numSlices) yield | ||
| new ForeachRunner(i, slices(i), cleanF) | ||
| sc.runTasks[Unit](tasks.toArray) | ||
| } | ||
| def map[U](f: T => U): Array[U] = { | ||
| val cleanF = sc.clean(f) | ||
| var tasks = for (i <- 0 until numSlices) yield | ||
| new MapRunner(i, slices(i), cleanF) | ||
| return Array.concat(sc.runTasks[Array[U]](tasks.toArray): _*) | ||
| } | ||
| } | ||
| @serializable | ||
| private class ForeachRunner[T](sliceNum: Int, data: Seq[T], f: T => Unit) | ||
| extends Function0[Unit] { | ||
| def apply() = { | ||
| printf("Running slice %d of parallel foreach\n", sliceNum) | ||
| data.foreach(f) | ||
| } | ||
| } | ||
| @serializable | ||
| private class MapRunner[T, U](sliceNum: Int, data: Seq[T], f: T => U) | ||
| extends Function0[Array[U]] { | ||
| def apply(): Array[U] = { | ||
| printf("Running slice %d of parallel map\n", sliceNum) | ||
| return data.map(f).toArray | ||
| } | ||
| } | ||
| private class FilteredParallelArray[T]( | ||
| sc: SparkContext, array: ParallelArray[T], predicate: T => Boolean) | ||
| extends ParallelArray[T](sc) { | ||
| val cleanPred = sc.clean(predicate) | ||
| def foreach(f: T => Unit) { | ||
| val cleanF = sc.clean(f) | ||
| array.foreach(t => if (cleanPred(t)) cleanF(t)) | ||
| } | ||
| def map[U](f: T => U): Array[U] = { | ||
| val cleanF = sc.clean(f) | ||
| throw new UnsupportedOperationException( | ||
| "Map is not yet supported on FilteredParallelArray") | ||
| } | ||
| } |
| @@ -0,0 +1,9 @@ | ||
| package spark | ||
| // Scheduler trait, implemented by both NexusScheduler and LocalScheduler. | ||
| private trait Scheduler { | ||
| def start() | ||
| def waitForRegister() | ||
| def runTasks[T](tasks: Array[Task[T]]): Array[T] | ||
| def stop() | ||
| } |
| @@ -0,0 +1,75 @@ | ||
| // This is a copy of Scala 2.7.7's Range class, (c) 2006-2009, LAMP/EPFL. | ||
| // The only change here is to make it Serializable, because Ranges aren't. | ||
| // This won't be needed in Scala 2.8, where Scala's Range becomes Serializable. | ||
| package spark | ||
| @serializable | ||
| private class SerializableRange(val start: Int, val end: Int, val step: Int) | ||
| extends RandomAccessSeq.Projection[Int] { | ||
| if (step == 0) throw new Predef.IllegalArgumentException | ||
| /** Create a new range with the start and end values of this range and | ||
| * a new <code>step</code>. | ||
| */ | ||
| def by(step: Int): Range = new Range(start, end, step) | ||
| override def foreach(f: Int => Unit) { | ||
| if (step > 0) { | ||
| var i = this.start | ||
| val until = if (inInterval(end)) end + 1 else end | ||
| while (i < until) { | ||
| f(i) | ||
| i += step | ||
| } | ||
| } else { | ||
| var i = this.start | ||
| val until = if (inInterval(end)) end - 1 else end | ||
| while (i > until) { | ||
| f(i) | ||
| i += step | ||
| } | ||
| } | ||
| } | ||
| lazy val length: Int = { | ||
| if (start < end && this.step < 0) 0 | ||
| else if (start > end && this.step > 0) 0 | ||
| else { | ||
| val base = if (start < end) end - start | ||
| else start - end | ||
| assert(base >= 0) | ||
| val step = if (this.step < 0) -this.step else this.step | ||
| assert(step >= 0) | ||
| base / step + last(base, step) | ||
| } | ||
| } | ||
| protected def last(base: Int, step: Int): Int = | ||
| if (base % step != 0) 1 else 0 | ||
| def apply(idx: Int): Int = { | ||
| if (idx < 0 || idx >= length) throw new Predef.IndexOutOfBoundsException | ||
| start + (step * idx) | ||
| } | ||
| /** a <code>Seq.contains</code>, not a <code>Iterator.contains</code>! */ | ||
| def contains(x: Int): Boolean = { | ||
| inInterval(x) && (((x - start) % step) == 0) | ||
| } | ||
| /** Is the argument inside the interval defined by `start' and `end'? | ||
| * Returns true if `x' is inside [start, end). | ||
| */ | ||
| protected def inInterval(x: Int): Boolean = | ||
| if (step > 0) | ||
| (x >= start && x < end) | ||
| else | ||
| (x <= start && x > end) | ||
| //def inclusive = new Range.Inclusive(start,end,step) | ||
| override def toString = "SerializableRange(%d, %d, %d)".format(start, end, step) | ||
| } |
| @@ -0,0 +1,89 @@ | ||
| package spark | ||
| import java.io._ | ||
| import java.util.UUID | ||
| import scala.collection.mutable.ArrayBuffer | ||
| class SparkContext(master: String, frameworkName: String) { | ||
| Cache.initialize() | ||
| def parallelize[T](seq: Seq[T], numSlices: Int): ParallelArray[T] = | ||
| new SimpleParallelArray[T](this, seq, numSlices) | ||
| def parallelize[T](seq: Seq[T]): ParallelArray[T] = parallelize(seq, 2) | ||
| def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = | ||
| new Accumulator(initialValue, param) | ||
| // TODO: Keep around a weak hash map of values to Cached versions? | ||
| def broadcast[T](value: T) = new Cached(value, local) | ||
| def textFile(path: String) = new HdfsTextFile(this, path) | ||
| val LOCAL_REGEX = """local\[([0-9]+)\]""".r | ||
| private var scheduler: Scheduler = master match { | ||
| case "local" => new LocalScheduler(1) | ||
| case LOCAL_REGEX(threads) => new LocalScheduler(threads.toInt) | ||
| case _ => { System.loadLibrary("nexus"); | ||
| new NexusScheduler(master, frameworkName, createExecArg()) } | ||
| } | ||
| private val local = scheduler.isInstanceOf[LocalScheduler] | ||
| scheduler.start() | ||
| private def createExecArg(): Array[Byte] = { | ||
| // Our executor arg is an array containing all the spark.* system properties | ||
| val props = new ArrayBuffer[(String, String)] | ||
| val iter = System.getProperties.entrySet.iterator | ||
| while (iter.hasNext) { | ||
| val entry = iter.next | ||
| val (key, value) = (entry.getKey.toString, entry.getValue.toString) | ||
| if (key.startsWith("spark.")) | ||
| props += (key, value) | ||
| } | ||
| return Utils.serialize(props.toArray) | ||
| } | ||
| def runTasks[T](tasks: Array[() => T]): Array[T] = { | ||
| runTaskObjects(tasks.map(f => new FunctionTask(f))) | ||
| } | ||
| private[spark] def runTaskObjects[T](tasks: Seq[Task[T]]): Array[T] = { | ||
| println("Running " + tasks.length + " tasks in parallel") | ||
| val start = System.nanoTime | ||
| val result = scheduler.runTasks(tasks.toArray) | ||
| println("Tasks finished in " + (System.nanoTime - start) / 1e9 + " s") | ||
| return result | ||
| } | ||
| def stop() { | ||
| scheduler.stop() | ||
| scheduler = null | ||
| } | ||
| def waitForRegister() { | ||
| scheduler.waitForRegister() | ||
| } | ||
| // Clean a closure to make it ready to serialized and send to tasks | ||
| // (removes unreferenced variables in $outer's, updates REPL variables) | ||
| private[spark] def clean[F <: AnyRef](f: F): F = { | ||
| ClosureCleaner.clean(f) | ||
| return f | ||
| } | ||
| } | ||
| object SparkContext { | ||
| implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { | ||
| def add(t1: Double, t2: Double): Double = t1 + t2 | ||
| def zero(initialValue: Double) = 0.0 | ||
| } | ||
| implicit object IntAccumulatorParam extends AccumulatorParam[Int] { | ||
| def add(t1: Int, t2: Int): Int = t1 + t2 | ||
| def zero(initialValue: Int) = 0 | ||
| } | ||
| // TODO: Add AccumulatorParams for other types, e.g. lists and strings | ||
| } |
| @@ -0,0 +1,7 @@ | ||
| package spark | ||
| class SparkException(message: String) extends Exception(message) { | ||
| def this(message: String, errorCode: Int) { | ||
| this("%s (error code: %d)".format(message, errorCode)) | ||
| } | ||
| } |
| @@ -0,0 +1,16 @@ | ||
| package spark | ||
| import nexus._ | ||
| @serializable | ||
| trait Task[T] { | ||
| def run: T | ||
| def prefers(slot: SlaveOffer): Boolean = true | ||
| def markStarted(slot: SlaveOffer) {} | ||
| } | ||
| @serializable | ||
| class FunctionTask[T](body: () => T) | ||
| extends Task[T] { | ||
| def run: T = body() | ||
| } |
| @@ -0,0 +1,9 @@ | ||
| package spark | ||
| import scala.collection.mutable.Map | ||
| // Task result. Also contains updates to accumulator variables. | ||
| // TODO: Use of distributed cache to return result is a hack to get around | ||
| // what seems to be a bug with messages over 60KB in libprocess; fix it | ||
| @serializable | ||
| private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any]) |
| @@ -0,0 +1,28 @@ | ||
| package spark | ||
| import java.io._ | ||
| private object Utils { | ||
| def serialize[T](o: T): Array[Byte] = { | ||
| val bos = new ByteArrayOutputStream | ||
| val oos = new ObjectOutputStream(bos) | ||
| oos.writeObject(o) | ||
| oos.close | ||
| return bos.toByteArray | ||
| } | ||
| def deserialize[T](bytes: Array[Byte]): T = { | ||
| val bis = new ByteArrayInputStream(bytes) | ||
| val ois = new ObjectInputStream(bis) | ||
| return ois.readObject.asInstanceOf[T] | ||
| } | ||
| def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { | ||
| val bis = new ByteArrayInputStream(bytes) | ||
| val ois = new ObjectInputStream(bis) { | ||
| override def resolveClass(desc: ObjectStreamClass) = | ||
| Class.forName(desc.getName, false, loader) | ||
| } | ||
| return ois.readObject.asInstanceOf[T] | ||
| } | ||
| } |
| @@ -0,0 +1,86 @@ | ||
| package spark.repl | ||
| import java.io.{ByteArrayOutputStream, InputStream} | ||
| import java.net.{URI, URL, URLClassLoader} | ||
| import java.util.concurrent.{Executors, ExecutorService} | ||
| import org.apache.hadoop.conf.Configuration | ||
| import org.apache.hadoop.fs.{FileSystem, Path} | ||
| import org.objectweb.asm._ | ||
| import org.objectweb.asm.commons.EmptyVisitor | ||
| import org.objectweb.asm.Opcodes._ | ||
| // A ClassLoader that reads classes from a Hadoop FileSystem URL, used to load | ||
| // classes defined by the interpreter when the REPL is in use | ||
| class ExecutorClassLoader(classDir: String, parent: ClassLoader) | ||
| extends ClassLoader(parent) { | ||
| val fileSystem = FileSystem.get(new URI(classDir), new Configuration()) | ||
| val directory = new URI(classDir).getPath | ||
| override def findClass(name: String): Class[_] = { | ||
| try { | ||
| //println("repl.ExecutorClassLoader resolving " + name) | ||
| val path = new Path(directory, name.replace('.', '/') + ".class") | ||
| val bytes = readAndTransformClass(name, fileSystem.open(path)) | ||
| return defineClass(name, bytes, 0, bytes.length) | ||
| } catch { | ||
| case e: Exception => throw new ClassNotFoundException(name, e) | ||
| } | ||
| } | ||
| def readAndTransformClass(name: String, in: InputStream): Array[Byte] = { | ||
| if (name.startsWith("line") && name.endsWith("$iw$")) { | ||
| // Class seems to be an interpreter "wrapper" object storing a val or var. | ||
| // Replace its constructor with a dummy one that does not run the | ||
| // initialization code placed there by the REPL. The val or var will | ||
| // be initialized later through reflection when it is used in a task. | ||
| val cr = new ClassReader(in) | ||
| val cw = new ClassWriter( | ||
| ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS) | ||
| val cleaner = new ConstructorCleaner(name, cw) | ||
| cr.accept(cleaner, 0) | ||
| return cw.toByteArray | ||
| } else { | ||
| // Pass the class through unmodified | ||
| val bos = new ByteArrayOutputStream | ||
| val bytes = new Array[Byte](4096) | ||
| var done = false | ||
| while (!done) { | ||
| val num = in.read(bytes) | ||
| if (num >= 0) | ||
| bos.write(bytes, 0, num) | ||
| else | ||
| done = true | ||
| } | ||
| return bos.toByteArray | ||
| } | ||
| } | ||
| } | ||
| class ConstructorCleaner(className: String, cv: ClassVisitor) | ||
| extends ClassAdapter(cv) { | ||
| override def visitMethod(access: Int, name: String, desc: String, | ||
| sig: String, exceptions: Array[String]): MethodVisitor = { | ||
| val mv = cv.visitMethod(access, name, desc, sig, exceptions) | ||
| if (name == "<init>" && (access & ACC_STATIC) == 0) { | ||
| // This is the constructor, time to clean it; just output some new | ||
| // instructions to mv that create the object and set the static MODULE$ | ||
| // field in the class to point to it, but do nothing otherwise. | ||
| //println("Cleaning constructor of " + className) | ||
| mv.visitCode() | ||
| mv.visitVarInsn(ALOAD, 0) // load this | ||
| mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V") | ||
| mv.visitVarInsn(ALOAD, 0) // load this | ||
| //val classType = className.replace('.', '/') | ||
| //mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") | ||
| mv.visitInsn(RETURN) | ||
| mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed | ||
| mv.visitEnd() | ||
| return null | ||
| } else { | ||
| return mv | ||
| } | ||
| } | ||
| } |
| @@ -0,0 +1,16 @@ | ||
| package spark.repl | ||
| import scala.collection.mutable.Set | ||
| object Main { | ||
| private var _interp: SparkInterpreterLoop = null | ||
| def interp = _interp | ||
| private[repl] def interp_=(i: SparkInterpreterLoop) { _interp = i } | ||
| def main(args: Array[String]) { | ||
| _interp = new SparkInterpreterLoop | ||
| _interp.main(args) | ||
| } | ||
| } |
| @@ -0,0 +1,366 @@ | ||
| /* NSC -- new Scala compiler | ||
| * Copyright 2005-2009 LAMP/EPFL | ||
| * @author Alexander Spoon | ||
| */ | ||
| // $Id: InterpreterLoop.scala 16881 2009-01-09 16:28:11Z cunei $ | ||
| package spark.repl | ||
| import scala.tools.nsc | ||
| import scala.tools.nsc._ | ||
| import java.io.{BufferedReader, File, FileReader, PrintWriter} | ||
| import java.io.IOException | ||
| import java.lang.{ClassLoader, System} | ||
| import scala.tools.nsc.{InterpreterResults => IR} | ||
| import scala.tools.nsc.interpreter._ | ||
| import spark.SparkContext | ||
| /** The | ||
| * <a href="http://scala-lang.org/" target="_top">Scala</a> | ||
| * interactive shell. It provides a read-eval-print loop around | ||
| * the Interpreter class. | ||
| * After instantiation, clients should call the <code>main()</code> method. | ||
| * | ||
| * <p>If no in0 is specified, then input will come from the console, and | ||
| * the class will attempt to provide input editing feature such as | ||
| * input history. | ||
| * | ||
| * @author Moez A. Abdel-Gawad | ||
| * @author Lex Spoon | ||
| * @version 1.2 | ||
| */ | ||
| class SparkInterpreterLoop(in0: Option[BufferedReader], val out: PrintWriter, | ||
| master: Option[String]) | ||
| { | ||
| def this(in0: BufferedReader, out: PrintWriter, master: String) = | ||
| this(Some(in0), out, Some(master)) | ||
| def this(in0: BufferedReader, out: PrintWriter) = | ||
| this(Some(in0), out, None) | ||
| def this() = this(None, new PrintWriter(Console.out), None) | ||
| /** The input stream from which interpreter commands come */ | ||
| var in: InteractiveReader = _ //set by main() | ||
| /** The context class loader at the time this object was created */ | ||
| protected val originalClassLoader = | ||
| Thread.currentThread.getContextClassLoader | ||
| var settings: Settings = _ // set by main() | ||
| var interpreter: SparkInterpreter = null // set by createInterpreter() | ||
| def isettings = interpreter.isettings | ||
| /** A reverse list of commands to replay if the user | ||
| * requests a :replay */ | ||
| var replayCommandsRev: List[String] = Nil | ||
| /** A list of commands to replay if the user requests a :replay */ | ||
| def replayCommands = replayCommandsRev.reverse | ||
| /** Record a command for replay should the user requset a :replay */ | ||
| def addReplay(cmd: String) = | ||
| replayCommandsRev = cmd :: replayCommandsRev | ||
| /** Close the interpreter, if there is one, and set | ||
| * interpreter to <code>null</code>. */ | ||
| def closeInterpreter() { | ||
| if (interpreter ne null) { | ||
| interpreter.close | ||
| interpreter = null | ||
| Thread.currentThread.setContextClassLoader(originalClassLoader) | ||
| } | ||
| } | ||
| /** Create a new interpreter. Close the old one, if there | ||
| * is one. */ | ||
| def createInterpreter() { | ||
| //closeInterpreter() | ||
| interpreter = new SparkInterpreter(settings, out) { | ||
| override protected def parentClassLoader = | ||
| classOf[SparkInterpreterLoop].getClassLoader | ||
| } | ||
| interpreter.setContextClassLoader() | ||
| } | ||
| /** Bind the settings so that evaluated code can modiy them */ | ||
| def bindSettings() { | ||
| interpreter.beQuietDuring { | ||
| interpreter.compileString(InterpreterSettings.sourceCodeForClass) | ||
| interpreter.bind( | ||
| "settings", | ||
| "scala.tools.nsc.InterpreterSettings", | ||
| isettings) | ||
| } | ||
| } | ||
| /** print a friendly help message */ | ||
| def printHelp { | ||
| //printWelcome | ||
| out.println("This is Scala " + Properties.versionString + " (" + | ||
| System.getProperty("java.vm.name") + ", Java " + System.getProperty("java.version") + ")." ) | ||
| out.println("Type in expressions to have them evaluated.") | ||
| out.println("Type :load followed by a filename to load a Scala file.") | ||
| //out.println("Type :replay to reset execution and replay all previous commands.") | ||
| out.println("Type :quit to exit the interpreter.") | ||
| } | ||
| /** Print a welcome message */ | ||
| def printWelcome() { | ||
| out.println("""Welcome to | ||
| ____ __ | ||
| / __/__ ___ _____/ /__ | ||
| _\ \/ _ \/ _ `/ __/ '_/ | ||
| /___/ .__/\_,_/_/ /_/\_\ version 0.0 | ||
| /_/ | ||
| """) | ||
| out.println("Using Scala " + Properties.versionString + " (" + | ||
| System.getProperty("java.vm.name") + ", Java " + | ||
| System.getProperty("java.version") + ")." ) | ||
| out.flush() | ||
| } | ||
| def createSparkContext(): SparkContext = { | ||
| val master = this.master match { | ||
| case Some(m) => m | ||
| case None => { | ||
| val prop = System.getenv("MASTER") | ||
| if (prop != null) prop else "local" | ||
| } | ||
| } | ||
| new SparkContext(master, "Spark REPL") | ||
| } | ||
| /** Prompt to print when awaiting input */ | ||
| val prompt = Properties.shellPromptString | ||
| /** The main read-eval-print loop for the interpreter. It calls | ||
| * <code>command()</code> for each line of input, and stops when | ||
| * <code>command()</code> returns <code>false</code>. | ||
| */ | ||
| def repl() { | ||
| out.println("Intializing...") | ||
| out.flush() | ||
| interpreter.beQuietDuring { | ||
| command(""" | ||
| spark.repl.Main.interp.out.println("Registering with Nexus..."); | ||
| @transient val sc = spark.repl.Main.interp.createSparkContext(); | ||
| sc.waitForRegister(); | ||
| spark.repl.Main.interp.out.println("Spark context available as sc.") | ||
| """) | ||
| command("import spark.SparkContext._"); | ||
| } | ||
| out.println("Type in expressions to have them evaluated.") | ||
| out.println("Type :help for more information.") | ||
| out.flush() | ||
| var first = true | ||
| while (true) { | ||
| out.flush() | ||
| val line = | ||
| if (first) { | ||
| /* For some reason, the first interpreted command always takes | ||
| * a second or two. So, wait until the welcome message | ||
| * has been printed before calling bindSettings. That way, | ||
| * the user can read the welcome message while this | ||
| * command executes. | ||
| */ | ||
| val futLine = scala.concurrent.ops.future(in.readLine(prompt)) | ||
| bindSettings() | ||
| first = false | ||
| futLine() | ||
| } else { | ||
| in.readLine(prompt) | ||
| } | ||
| if (line eq null) | ||
| return () // assumes null means EOF | ||
| val (keepGoing, finalLineMaybe) = command(line) | ||
| if (!keepGoing) | ||
| return | ||
| finalLineMaybe match { | ||
| case Some(finalLine) => addReplay(finalLine) | ||
| case None => () | ||
| } | ||
| } | ||
| } | ||
| /** interpret all lines from a specified file */ | ||
| def interpretAllFrom(filename: String) { | ||
| val fileIn = try { | ||
| new FileReader(filename) | ||
| } catch { | ||
| case _:IOException => | ||
| out.println("Error opening file: " + filename) | ||
| return | ||
| } | ||
| val oldIn = in | ||
| val oldReplay = replayCommandsRev | ||
| try { | ||
| val inFile = new BufferedReader(fileIn) | ||
| in = new SimpleReader(inFile, out, false) | ||
| out.println("Loading " + filename + "...") | ||
| out.flush | ||
| repl | ||
| } finally { | ||
| in = oldIn | ||
| replayCommandsRev = oldReplay | ||
| fileIn.close | ||
| } | ||
| } | ||
| /** create a new interpreter and replay all commands so far */ | ||
| def replay() { | ||
| closeInterpreter() | ||
| createInterpreter() | ||
| for (cmd <- replayCommands) { | ||
| out.println("Replaying: " + cmd) | ||
| out.flush() // because maybe cmd will have its own output | ||
| command(cmd) | ||
| out.println | ||
| } | ||
| } | ||
| /** Run one command submitted by the user. Three values are returned: | ||
| * (1) whether to keep running, (2) the line to record for replay, | ||
| * if any. */ | ||
| def command(line: String): (Boolean, Option[String]) = { | ||
| def withFile(command: String)(action: String => Unit) { | ||
| val spaceIdx = command.indexOf(' ') | ||
| if (spaceIdx <= 0) { | ||
| out.println("That command requires a filename to be specified.") | ||
| return () | ||
| } | ||
| val filename = command.substring(spaceIdx).trim | ||
| if (! new File(filename).exists) { | ||
| out.println("That file does not exist") | ||
| return () | ||
| } | ||
| action(filename) | ||
| } | ||
| val helpRegexp = ":h(e(l(p)?)?)?" | ||
| val quitRegexp = ":q(u(i(t)?)?)?" | ||
| val loadRegexp = ":l(o(a(d)?)?)?.*" | ||
| //val replayRegexp = ":r(e(p(l(a(y)?)?)?)?)?.*" | ||
| var shouldReplay: Option[String] = None | ||
| if (line.matches(helpRegexp)) | ||
| printHelp | ||
| else if (line.matches(quitRegexp)) | ||
| return (false, None) | ||
| else if (line.matches(loadRegexp)) { | ||
| withFile(line)(f => { | ||
| interpretAllFrom(f) | ||
| shouldReplay = Some(line) | ||
| }) | ||
| } | ||
| //else if (line matches replayRegexp) | ||
| // replay | ||
| else if (line startsWith ":") | ||
| out.println("Unknown command. Type :help for help.") | ||
| else | ||
| shouldReplay = interpretStartingWith(line) | ||
| (true, shouldReplay) | ||
| } | ||
| /** Interpret expressions starting with the first line. | ||
| * Read lines until a complete compilation unit is available | ||
| * or until a syntax error has been seen. If a full unit is | ||
| * read, go ahead and interpret it. Return the full string | ||
| * to be recorded for replay, if any. | ||
| */ | ||
| def interpretStartingWith(code: String): Option[String] = { | ||
| interpreter.interpret(code) match { | ||
| case IR.Success => Some(code) | ||
| case IR.Error => None | ||
| case IR.Incomplete => | ||
| if (in.interactive && code.endsWith("\n\n")) { | ||
| out.println("You typed two blank lines. Starting a new command.") | ||
| None | ||
| } else { | ||
| val nextLine = in.readLine(" | ") | ||
| if (nextLine == null) | ||
| None // end of file | ||
| else | ||
| interpretStartingWith(code + "\n" + nextLine) | ||
| } | ||
| } | ||
| } | ||
| def loadFiles(settings: Settings) { | ||
| settings match { | ||
| case settings: GenericRunnerSettings => | ||
| for (filename <- settings.loadfiles.value) { | ||
| val cmd = ":load " + filename | ||
| command(cmd) | ||
| replayCommandsRev = cmd :: replayCommandsRev | ||
| out.println() | ||
| } | ||
| case _ => | ||
| } | ||
| } | ||
| def main(settings: Settings) { | ||
| this.settings = settings | ||
| in = | ||
| in0 match { | ||
| case Some(in0) => | ||
| new SimpleReader(in0, out, true) | ||
| case None => | ||
| val emacsShell = System.getProperty("env.emacs", "") != "" | ||
| //println("emacsShell="+emacsShell) //debug | ||
| if (settings.Xnojline.value || emacsShell) | ||
| new SimpleReader() | ||
| else | ||
| InteractiveReader.createDefault() | ||
| } | ||
| createInterpreter() | ||
| loadFiles(settings) | ||
| try { | ||
| if (interpreter.reporter.hasErrors) { | ||
| return // it is broken on startup; go ahead and exit | ||
| } | ||
| printWelcome() | ||
| repl() | ||
| } finally { | ||
| closeInterpreter() | ||
| } | ||
| } | ||
| /** process command-line arguments and do as they request */ | ||
| def main(args: Array[String]) { | ||
| def error1(msg: String) { out.println("scala: " + msg) } | ||
| val command = new InterpreterCommand(List.fromArray(args), error1) | ||
| if (!command.ok || command.settings.help.value || command.settings.Xhelp.value) { | ||
| // either the command line is wrong, or the user | ||
| // explicitly requested a help listing | ||
| if (command.settings.help.value) out.println(command.usageMsg) | ||
| if (command.settings.Xhelp.value) out.println(command.xusageMsg) | ||
| out.flush | ||
| } | ||
| else | ||
| main(command.settings) | ||
| } | ||
| } |
| @@ -0,0 +1,21 @@ | ||
| package ubiquifs | ||
| import java.io.{DataInputStream, DataOutputStream} | ||
| object RequestType { | ||
| val READ = 0 | ||
| val WRITE = 1 | ||
| } | ||
| class Header(val requestType: Int, val path: String) { | ||
| def write(out: DataOutputStream) { | ||
| out.write(requestType) | ||
| out.writeUTF(path) | ||
| } | ||
| } | ||
| object Header { | ||
| def read(in: DataInputStream): Header = { | ||
| new Header(in.read(), in.readUTF()) | ||
| } | ||
| } |
| @@ -0,0 +1,49 @@ | ||
| package ubiquifs | ||
| import scala.actors.Actor | ||
| import scala.actors.Actor._ | ||
| import scala.actors.remote.RemoteActor | ||
| import scala.actors.remote.RemoteActor._ | ||
| import scala.actors.remote.Node | ||
| import scala.collection.mutable.{ArrayBuffer, Map, Set} | ||
| class Master(port: Int) extends Actor { | ||
| case class SlaveInfo(host: String, port: Int) | ||
| val files = Set[String]() | ||
| val slaves = new ArrayBuffer[SlaveInfo]() | ||
| def act() { | ||
| alive(port) | ||
| register('UbiquiFS, self) | ||
| println("Created UbiquiFS Master on port " + port) | ||
| loop { | ||
| react { | ||
| case RegisterSlave(host, port) => | ||
| slaves += SlaveInfo(host, port) | ||
| sender ! RegisterSucceeded() | ||
| case Create(path) => | ||
| if (files.contains(path)) { | ||
| sender ! CreateFailed("File already exists") | ||
| } else if (slaves.isEmpty) { | ||
| sender ! CreateFailed("No slaves registered") | ||
| } else { | ||
| files += path | ||
| sender ! CreateSucceeded(slaves(0).host, slaves(0).port) | ||
| } | ||
| case m: Any => | ||
| println("Unknown message: " + m) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| object MasterMain { | ||
| def main(args: Array[String]) { | ||
| val port = args(0).toInt | ||
| new Master(port).start() | ||
| } | ||
| } |
| @@ -0,0 +1,14 @@ | ||
| package ubiquifs | ||
| sealed case class Message() | ||
| case class RegisterSlave(host: String, port: Int) extends Message | ||
| case class RegisterSucceeded() extends Message | ||
| case class Create(path: String) extends Message | ||
| case class CreateSucceeded(host: String, port: Int) extends Message | ||
| case class CreateFailed(message: String) extends Message | ||
| case class Read(path: String) extends Message | ||
| case class ReadSucceeded(host: String, port: Int) extends Message | ||
| case class ReadFailed(message: String) extends Message |
| @@ -0,0 +1,141 @@ | ||
| package ubiquifs | ||
| import java.io.{DataInputStream, DataOutputStream, IOException} | ||
| import java.net.{InetAddress, Socket, ServerSocket} | ||
| import java.util.concurrent.locks.ReentrantLock | ||
| import scala.actors.Actor | ||
| import scala.actors.Actor._ | ||
| import scala.actors.remote.RemoteActor | ||
| import scala.actors.remote.RemoteActor._ | ||
| import scala.actors.remote.Node | ||
| import scala.collection.mutable.{ArrayBuffer, Map, Set} | ||
| class Slave(myPort: Int, master: String) extends Thread("UbiquiFS slave") { | ||
| val CHUNK_SIZE = 1024 * 1024 | ||
| val buffers = Map[String, Buffer]() | ||
| override def run() { | ||
| // Create server socket | ||
| val socket = new ServerSocket(myPort) | ||
| // Register with master | ||
| val (masterHost, masterPort) = Utils.parseHostPort(master) | ||
| val masterActor = select(Node(masterHost, masterPort), 'UbiquiFS) | ||
| val myHost = InetAddress.getLocalHost.getHostName | ||
| val reply = masterActor !? RegisterSlave(myHost, myPort) | ||
| println("Registered with master, reply = " + reply) | ||
| while (true) { | ||
| val conn = socket.accept() | ||
| new ConnectionHandler(conn).start() | ||
| } | ||
| } | ||
| class ConnectionHandler(conn: Socket) extends Thread("ConnectionHandler") { | ||
| try { | ||
| val in = new DataInputStream(conn.getInputStream) | ||
| val out = new DataOutputStream(conn.getOutputStream) | ||
| val header = Header.read(in) | ||
| header.requestType match { | ||
| case RequestType.READ => | ||
| performRead(header.path, out) | ||
| case RequestType.WRITE => | ||
| performWrite(header.path, in) | ||
| case other => | ||
| throw new IOException("Invalid header type " + other) | ||
| } | ||
| println("hi") | ||
| } catch { | ||
| case e: Exception => e.printStackTrace() | ||
| } finally { | ||
| conn.close() | ||
| } | ||
| } | ||
| def performWrite(path: String, in: DataInputStream) { | ||
| var buffer = new Buffer() | ||
| synchronized { | ||
| if (buffers.contains(path)) | ||
| throw new IllegalArgumentException("Path " + path + " already exists") | ||
| buffers(path) = buffer | ||
| } | ||
| var chunk = new Array[Byte](CHUNK_SIZE) | ||
| var pos = 0 | ||
| while (true) { | ||
| var numRead = in.read(chunk, pos, chunk.size - pos) | ||
| if (numRead == -1) { | ||
| buffer.addChunk(chunk.subArray(0, pos), true) | ||
| return | ||
| } else { | ||
| pos += numRead | ||
| if (pos == chunk.size) { | ||
| buffer.addChunk(chunk, false) | ||
| chunk = new Array[Byte](CHUNK_SIZE) | ||
| pos = 0 | ||
| } | ||
| } | ||
| } | ||
| // TODO: launch a thread to write the data to disk, and when this finishes, | ||
| // remove the hard reference to buffer | ||
| } | ||
| def performRead(path: String, out: DataOutputStream) { | ||
| var buffer: Buffer = null | ||
| synchronized { | ||
| if (!buffers.contains(path)) | ||
| throw new IllegalArgumentException("Path " + path + " doesn't exist") | ||
| buffer = buffers(path) | ||
| } | ||
| for (chunk <- buffer.iterator) { | ||
| out.write(chunk, 0, chunk.size) | ||
| } | ||
| } | ||
| class Buffer { | ||
| val chunks = new ArrayBuffer[Array[Byte]] | ||
| var finished = false | ||
| val mutex = new ReentrantLock | ||
| val chunksAvailable = mutex.newCondition() | ||
| def addChunk(chunk: Array[Byte], finish: Boolean) { | ||
| mutex.lock() | ||
| chunks += chunk | ||
| finished = finish | ||
| chunksAvailable.signalAll() | ||
| mutex.unlock() | ||
| } | ||
| def iterator = new Iterator[Array[Byte]] { | ||
| var index = 0 | ||
| def hasNext: Boolean = { | ||
| mutex.lock() | ||
| while (index >= chunks.size && !finished) | ||
| chunksAvailable.await() | ||
| val ret = (index < chunks.size) | ||
| mutex.unlock() | ||
| return ret | ||
| } | ||
| def next: Array[Byte] = { | ||
| mutex.lock() | ||
| if (!hasNext) | ||
| throw new NoSuchElementException("End of file") | ||
| val ret = chunks(index) // hasNext ensures we advance past index | ||
| index += 1 | ||
| mutex.unlock() | ||
| return ret | ||
| } | ||
| } | ||
| } | ||
| } | ||
| object SlaveMain { | ||
| def main(args: Array[String]) { | ||
| val port = args(0).toInt | ||
| val master = args(1) | ||
| new Slave(port, master).start() | ||
| } | ||
| } |
| @@ -0,0 +1,11 @@ | ||
| package ubiquifs | ||
| import java.io.{InputStream, OutputStream} | ||
| class UbiquiFS(master: String) { | ||
| private val (masterHost, masterPort) = Utils.parseHostPort(master) | ||
| def create(path: String): OutputStream = null | ||
| def open(path: String): InputStream = null | ||
| } |
| @@ -0,0 +1,12 @@ | ||
| package ubiquifs | ||
| private[ubiquifs] object Utils { | ||
| private val HOST_PORT_RE = "([a-zA-Z0-9.-]+):([0-9]+)".r | ||
| def parseHostPort(string: String): (String, Int) = { | ||
| string match { | ||
| case HOST_PORT_RE(host, port) => (host, port.toInt) | ||
| case _ => throw new IllegalArgumentException(string) | ||
| } | ||
| } | ||
| } |