-
Notifications
You must be signed in to change notification settings - Fork 116
/
MnistRandomFFT.scala
115 lines (95 loc) · 4.24 KB
/
MnistRandomFFT.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package keystoneml.pipelines.images.mnist
import breeze.linalg.DenseVector
import breeze.stats.distributions.{RandBasis, ThreadLocalRandomGenerator}
import keystoneml.evaluation.MulticlassClassifierEvaluator
import keystoneml.loaders.{CsvDataLoader, LabeledData}
import keystoneml.nodes.learning.BlockLeastSquaresEstimator
import keystoneml.nodes.stats.{LinearRectifier, PaddedFFT, RandomSignNode}
import keystoneml.nodes.util._
import org.apache.commons.math3.random.MersenneTwister
import org.apache.spark.{SparkConf, SparkContext}
import keystoneml.pipelines._
import scopt.OptionParser
import keystoneml.utils.Image
import keystoneml.workflow.Pipeline
object MnistRandomFFT extends Serializable with Logging {
val appName = "MnistRandomFFT"
def run(sc: SparkContext, conf: MnistRandomFFTConfig): Pipeline[DenseVector[Double], Int] = {
// This is a property of the MNIST Dataset (digits 0 - 9)
val numClasses = 10
val randomSignSource = new RandBasis(new ThreadLocalRandomGenerator(new MersenneTwister(conf.seed)))
// The number of pixels in an MNIST image (28 x 28 = 784)
// Because the mnistImageSize is 784, we get 512 PaddedFFT features per FFT.
val mnistImageSize = 784
val startTime = System.nanoTime()
val train = LabeledData(
CsvDataLoader(sc, conf.trainLocation, conf.numPartitions)
// The pipeline expects 0-indexed class labels, but the labels in the file are 1-indexed
.map(x => (x(0).toInt - 1, x(1 until x.length)))
.cache())
val labels = ClassLabelIndicatorsFromIntLabels(numClasses).apply(train.labels)
val featurizer = Pipeline.gather {
Seq.fill(conf.numFFTs) {
RandomSignNode(mnistImageSize, randomSignSource) andThen PaddedFFT() andThen LinearRectifier(0.0)
}
} andThen VectorCombiner()
val pipeline = featurizer andThen
(new BlockLeastSquaresEstimator(conf.blockSize, 1, conf.lambda.getOrElse(0)),
train.data, labels) andThen
MaxClassifier
val test = LabeledData(
CsvDataLoader(sc, conf.testLocation, conf.numPartitions)
// The pipeline expects 0-indexed class labels, but the labels in the file are 1-indexed
.map(x => (x(0).toInt - 1, x(1 until x.length)))
.cache())
// Calculate train error
val evaluator = new MulticlassClassifierEvaluator(numClasses)
val trainEval = evaluator.evaluate(pipeline(train.data), train.labels)
logInfo("TRAIN Error is " + (100 * trainEval.totalError) + "%")
// Calculate test error
val testEval = evaluator.evaluate(pipeline(test.data), test.labels)
logInfo("TEST Error is " + (100 * testEval.totalError) + "%")
val endTime = System.nanoTime()
logInfo(s"Pipeline took ${(endTime - startTime)/1e9} s")
pipeline
}
case class MnistRandomFFTConfig(
trainLocation: String = "",
testLocation: String = "",
numFFTs: Int = 200,
blockSize: Int = 2048,
numPartitions: Int = 10,
lambda: Option[Double] = None,
seed: Long = 0)
def parse(args: Array[String]): MnistRandomFFTConfig = new OptionParser[MnistRandomFFTConfig](appName) {
head(appName, "0.1")
help("help") text("prints this usage text")
opt[String]("trainLocation") required() action { (x,c) => c.copy(trainLocation=x) }
opt[String]("testLocation") required() action { (x,c) => c.copy(testLocation=x) }
opt[Int]("numFFTs") action { (x,c) => c.copy(numFFTs=x) }
opt[Int]("blockSize") validate { x =>
// Bitwise trick to test if x is a power of 2
if (x % 512 == 0) {
success
} else {
failure("Option --blockSize must be divisible by 512")
}
} action { (x,c) => c.copy(blockSize=x) }
opt[Int]("numPartitions") action { (x,c) => c.copy(numPartitions=x) }
opt[Double]("lambda") action { (x,c) => c.copy(lambda=Some(x)) }
opt[Long]("seed") action { (x,c) => c.copy(seed=x) }
}.parse(args, MnistRandomFFTConfig()).get
/**
* The actual driver receives its configuration parameters from spark-submit usually.
*
* @param args
*/
def main(args: Array[String]) = {
val appConfig = parse(args)
val conf = new SparkConf().setAppName(appName)
conf.setIfMissing("spark.master", "local[2]")
val sc = new SparkContext(conf)
run(sc, appConfig)
sc.stop()
}
}