Permalink
Fetching contributors…
Cannot retrieve contributors at this time
158 lines (133 sloc) 6.26 KB
package keystoneml.pipelines.images.voc
import java.io.File
import breeze.linalg._
import breeze.stats._
import keystoneml.evaluation.MeanAveragePrecisionEvaluator
import keystoneml.loaders.{VOCDataPath, VOCLabelPath, VOCLoader}
import keystoneml.nodes.images.external.{FisherVector, SIFTExtractor}
import keystoneml.nodes.images._
import keystoneml.nodes.learning._
import keystoneml.nodes.stats.{ColumnSampler, NormalizeRows, SignedHellingerMapper}
import keystoneml.nodes.util.{Cacher, ClassLabelIndicatorsFromIntArrayLabels, FloatToDouble, MatrixVectorizer}
import org.apache.spark.{SparkConf, SparkContext}
import keystoneml.pipelines.Logging
import scopt.OptionParser
import keystoneml.utils.Image
import keystoneml.workflow.Pipeline
object VOCSIFTFisher extends Serializable with Logging {
val appName = "VOCSIFTFisher"
def run(sc: SparkContext, conf: SIFTFisherConfig): Pipeline[Image, DenseVector[Double]] = {
// Load the data and extract training labels.
val parsedRDD = VOCLoader(
sc,
VOCDataPath(conf.trainLocation, "VOCdevkit/VOC2007/JPEGImages/", Some(1)),
VOCLabelPath(conf.labelPath)).repartition(conf.numParts).cache()
val labelGrabber = MultiLabelExtractor andThen
ClassLabelIndicatorsFromIntArrayLabels(VOCLoader.NUM_CLASSES) andThen
new Cacher
val trainingLabels = labelGrabber(parsedRDD)
val trainingData = MultiLabeledImageExtractor(parsedRDD)
val numTrainingImages = trainingData.count().toInt
val numPCASamplesPerImage = conf.numPcaSamples / numTrainingImages
val numGMMSamplesPerImage = conf.numGmmSamples / numTrainingImages
// Part 1: Scale and convert images to grayscale & Extract Sifts.
val siftExtractor = PixelScaler andThen
GrayScaler andThen
new Cacher andThen
new SIFTExtractor(scaleStep = conf.scaleStep)
// Part 1a: If necessary, perform PCA on samples of the SIFT features, or load a PCA matrix from disk.
// Part 2: Compute dimensionality-reduced PCA features.
val pcaFeaturizer = (conf.pcaFile match {
case Some(fname) =>
siftExtractor andThen new BatchPCATransformer(convert(csvread(new File(fname)), Float).t)
case None =>
val sampler = ColumnSampler(numPCASamplesPerImage).toPipeline
val pca = ColumnPCAEstimator(conf.descDim) withData (sampler(siftExtractor(trainingData)))
siftExtractor andThen pca
}) andThen new Cacher
// Part 2a: If necessary, compute a GMM based on the dimensionality-reduced features, or load from disk.
// Part 3: Compute Fisher Vectors and signed-square-root normalization.
val fisherFeaturizer = (conf.gmmMeanFile match {
case Some(f) =>
val gmm = new GaussianMixtureModel(
csvread(new File(conf.gmmMeanFile.get)),
csvread(new File(conf.gmmVarFile.get)),
csvread(new File(conf.gmmWtsFile.get)).toDenseVector)
pcaFeaturizer andThen FisherVector(gmm)
case None =>
val sampler = ColumnSampler(numGMMSamplesPerImage).toPipeline
val fisherVector = GMMFisherVectorEstimator(conf.vocabSize) withData (sampler(pcaFeaturizer(trainingData)))
pcaFeaturizer andThen fisherVector
}) andThen
FloatToDouble andThen
MatrixVectorizer andThen
NormalizeRows andThen
SignedHellingerMapper andThen
NormalizeRows andThen
new Cacher
// Part 4: Fit a linear model to the data.
val predictor = fisherFeaturizer andThen
(new BlockLeastSquaresEstimator(4096, 1, conf.lambda, Some(2 * conf.descDim * conf.vocabSize)),
trainingData,
trainingLabels)
// Now featurize and apply the model to test data.
val testParsedRDD = VOCLoader(
sc,
VOCDataPath(conf.testLocation, "VOCdevkit/VOC2007/JPEGImages/", Some(1)),
VOCLabelPath(conf.labelPath)).repartition(conf.numParts)
val testData = MultiLabeledImageExtractor(testParsedRDD)
logInfo("Test Cached RDD has: " + testData.count)
val testActuals = MultiLabelExtractor(testParsedRDD)
val predictions = predictor(testData)
val map = new MeanAveragePrecisionEvaluator(VOCLoader.NUM_CLASSES).evaluate(predictions, testActuals)
logInfo(s"TEST APs are: ${map.toArray.mkString(",")}")
logInfo(s"TEST MAP is: ${mean(map)}")
predictor
}
case class SIFTFisherConfig(
trainLocation: String = "",
testLocation: String = "",
labelPath: String = "",
numParts: Int = 496,
lambda: Double = 0.5,
descDim: Int = 80,
vocabSize: Int = 256,
scaleStep: Int = 0,
pcaFile: Option[String] = None,
gmmMeanFile: Option[String]= None,
gmmVarFile: Option[String] = None,
gmmWtsFile: Option[String] = None,
numPcaSamples: Int = 1e6.toInt,
numGmmSamples: Int = 1e6.toInt)
def parse(args: Array[String]): SIFTFisherConfig = new OptionParser[SIFTFisherConfig](appName) {
head(appName, "0.1")
help("help") text("prints this usage text")
opt[String]("trainLocation") required() action { (x,c) => c.copy(trainLocation=x) }
opt[String]("testLocation") required() action { (x,c) => c.copy(testLocation=x) }
opt[String]("labelPath") required() action { (x,c) => c.copy(labelPath=x) }
opt[Int]("numParts") action { (x,c) => c.copy(numParts=x) }
opt[Double]("lambda") action { (x,c) => c.copy(lambda=x) }
opt[Int]("descDim") action { (x,c) => c.copy(descDim=x) }
opt[Int]("vocabSize") action { (x,c) => c.copy(vocabSize=x) }
opt[Int]("scaleStep") action { (x,c) => c.copy(scaleStep=x) }
opt[String]("pcaFile") action { (x,c) => c.copy(pcaFile=Some(x)) }
opt[String]("gmmMeanFile") action { (x,c) => c.copy(gmmMeanFile=Some(x)) }
opt[String]("gmmVarFile") action { (x,c) => c.copy(gmmVarFile=Some(x)) }
opt[String]("gmmWtsFile") action { (x,c) => c.copy(gmmWtsFile=Some(x)) }
opt[Int]("numPcaSamples") action { (x,c) => c.copy(numPcaSamples=x) }
opt[Int]("numGmmSamples") action { (x,c) => c.copy(numGmmSamples=x) }
}.parse(args, SIFTFisherConfig()).get
/**
* The actual driver receives its configuration parameters from spark-submit usually.
*
* @param args
*/
def main(args: Array[String]) = {
val appConfig = parse(args)
val conf = new SparkConf().setAppName(appName)
conf.setIfMissing("spark.master", "local[2]")
val sc = new SparkContext(conf)
run(sc, appConfig)
sc.stop()
}
}