Skip to content

Commit

Permalink
Added test for prediction
Browse files Browse the repository at this point in the history
- Test predictOnValues for accuracy on a test stream
  • Loading branch information
freeman-lab committed Aug 19, 2014
1 parent 217b5e9 commit 32c43c2
Showing 1 changed file with 65 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
}

// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
test("streaming linear regression parameter accuracy") {
test("parameter accuracy") {

val testDir = Files.createTempDir()
val numBatches = 10
Expand All @@ -76,7 +76,6 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {

ssc.stop(stopSparkContext=false)

System.clearProperty("spark.driver.port")
Utils.deleteRecursively(testDir)

// check accuracy of final parameter estimates
Expand All @@ -91,7 +90,7 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
}

// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
test("streaming linear regression parameter convergence") {
test("parameter convergence") {

val testDir = Files.createTempDir()
val batchDuration = Milliseconds(2000)
Expand Down Expand Up @@ -121,7 +120,6 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {

ssc.stop(stopSparkContext=false)

System.clearProperty("spark.driver.port")
Utils.deleteRecursively(testDir)

val deltas = history.drop(1).zip(history.dropRight(1))
Expand All @@ -132,4 +130,67 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {

}

// Test predictions on a stream
test("predictions") {

val trainDir = Files.createTempDir()
val testDir = Files.createTempDir()
val batchDuration = Milliseconds(1000)
val numBatches = 10
val nPoints = 100

val ssc = new StreamingContext(sc, batchDuration)
val data = ssc.textFileStream(trainDir.toString).map(LabeledPoint.parse)
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0, 0.0))
.setStepSize(0.1)
.setNumIterations(50)

model.trainOn(data)

ssc.start()

// write training data to a file stream
for (i <- 0 until numBatches) {
val samples = LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
val file = new File(trainDir, i.toString)
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
Thread.sleep(batchDuration.milliseconds)
}

ssc.stop(stopSparkContext=false)

Utils.deleteRecursively(trainDir)

print(model.latestModel().weights.toArray.mkString(" "))
print(model.latestModel().intercept)

val ssc2 = new StreamingContext(sc, batchDuration)
val data2 = ssc2.textFileStream(testDir.toString).map(LabeledPoint.parse)

val history = new ArrayBuffer[Double](numBatches)
val predictions = model.predictOnValues(data2.map(x => (x.label, x.features)))
val errors = predictions.map(x => math.abs(x._1 - x._2))
errors.foreachRDD(rdd => history.append(rdd.reduce(_+_) / nPoints.toDouble))

ssc2.start()

// write test data to a file stream

// make a function
for (i <- 0 until numBatches) {
val samples = LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
val file = new File(testDir, i.toString)
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
Thread.sleep(batchDuration.milliseconds)
}

println(history)

ssc2.stop(stopSparkContext=false)

}

}

0 comments on commit 32c43c2

Please sign in to comment.