-
Notifications
You must be signed in to change notification settings - Fork 116
/
NewsgroupsPipeline.scala
executable file
·78 lines (62 loc) · 2.66 KB
/
NewsgroupsPipeline.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package keystoneml.pipelines.text
import breeze.linalg.SparseVector
import keystoneml.evaluation.MulticlassClassifierEvaluator
import keystoneml.loaders.NewsgroupsDataLoader
import keystoneml.nodes.learning.NaiveBayesEstimator
import keystoneml.nodes.nlp._
import keystoneml.nodes.stats.TermFrequency
import keystoneml.nodes.util.{CommonSparseFeatures, MaxClassifier}
import org.apache.spark.{SparkConf, SparkContext}
import keystoneml.pipelines.Logging
import scopt.OptionParser
import keystoneml.workflow.Pipeline
object NewsgroupsPipeline extends Logging {
val appName = "NewsgroupsPipeline"
def run(sc: SparkContext, conf: NewsgroupsConfig): Pipeline[String, Int] = {
val trainData = NewsgroupsDataLoader(sc, conf.trainLocation)
val numClasses = NewsgroupsDataLoader.classes.length
// Build the classifier estimator
logInfo("Training classifier")
val predictor = Trim andThen
LowerCase() andThen
Tokenizer() andThen
NGramsFeaturizer(1 to conf.nGrams) andThen
TermFrequency(x => 1) andThen
(CommonSparseFeatures[Seq[String]](conf.commonFeatures), trainData.data) andThen
(NaiveBayesEstimator[SparseVector[Double]](numClasses), trainData.data, trainData.labels) andThen
MaxClassifier
// Evaluate the classifier
logInfo("Evaluating classifier")
val testData = NewsgroupsDataLoader(sc, conf.testLocation)
val testLabels = testData.labels
val testResults = predictor(testData.data)
val eval = new MulticlassClassifierEvaluator(numClasses).evaluate(testResults, testLabels)
logInfo("\n" + eval.summary(NewsgroupsDataLoader.classes))
predictor
}
case class NewsgroupsConfig(
trainLocation: String = "",
testLocation: String = "",
nGrams: Int = 2,
commonFeatures: Int = 100000)
def parse(args: Array[String]): NewsgroupsConfig = new OptionParser[NewsgroupsConfig](appName) {
head(appName, "0.1")
opt[String]("trainLocation") required() action { (x,c) => c.copy(trainLocation=x) }
opt[String]("testLocation") required() action { (x,c) => c.copy(testLocation=x) }
opt[Int]("nGrams") action { (x,c) => c.copy(nGrams=x) }
opt[Int]("commonFeatures") action { (x,c) => c.copy(commonFeatures=x) }
}.parse(args, NewsgroupsConfig()).get
/**
* The actual driver receives its configuration parameters from spark-submit usually.
*
* @param args
*/
def main(args: Array[String]) = {
val conf = new SparkConf().setAppName(appName)
conf.setIfMissing("spark.master", "local[2]") // This is a fallback if things aren't set via spark submit.
val sc = new SparkContext(conf)
val appConfig = parse(args)
run(sc, appConfig)
sc.stop()
}
}