Fetching contributors…
Cannot retrieve contributors at this time
205 lines (171 sloc) 11 KB
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
* Binary Naive Bayes classifer (United States, United Kingdom) example for an out of sample document based
* on a model trained on the wikipedia xml dump:
* NOTE: As of version 0.10.0 Mahout uses MapReduce seq2sparse to vectorize large text corpora.
* To run this example first run :
* $MAHOUT_HOME/examples/bin/ --> option 2
* then from the mahout spark-shell:
* :load {MAHOUT_HOME}/examples/spark-document-classifier.mscala
import org.apache.mahout.classifier.naivebayes._
import org.apache.mahout.classifier.stats._
import org.apache.mahout.nlp.tfidf._
val pathToData = "/tmp/mahout-work-wiki/"
// read in our full set as vectorized by seq2sparse in
val fullData = drmDfsRead(pathToData + "wikipediaVecs/tfidf-vectors")
// uncomment if you want to train and test on the split "fullData" set and adjust below as necessary
//val trainData = drmDfsRead(pathToData + "training")
//val testData = drmDfsRead(pathToData + "testing")
// build a standard NaiveBayes model using the full dataset (training +testing)
val (labelIndex, aggregatedObservations) = SparkNaiveBayes.extractLabelsAndAggregateObservations(fullData)
val model = NaiveBayes.train(aggregatedObservations, labelIndex, false)
// self test on the full set
val resAnalyzer = NaiveBayes.test(model, fullData, false)
// display the confusion matrix
// read in the dictionary and document frequency count
val dictionary = sdc.sequenceFile(pathToData + "wikipediaVecs/dictionary.file-0", classOf[Text], classOf[IntWritable])
val documentFrequencyCount = sdc.sequenceFile(pathToData + "wikipediaVecs/df-count", classOf[IntWritable], classOf[LongWritable])
// setup the dictionary and document frequency count as maps
val dictionaryRDD = { case (wKey, wVal) => wKey.asInstanceOf[Text].toString() -> wVal.get() }
val documentFrequencyCountRDD ={ case (wKey, wVal) => wKey.asInstanceOf[IntWritable].get() -> wVal.get() }
val dictionaryMap = => x._1.toString -> x._2.toInt).toMap
val dfCountMap = => x._1.toInt -> x._2.toLong).toMap
// for this simple example, tokenize our document into unigrams using native string methods andvectorize using
// our dictionary and document frequencies. You could also use a lucene analyzer for bigrams, trigrams, etc.
def vectorizeDocument(document: String,
dictionaryMap: Map[String,Int],
dfMap: Map[Int,Long]): Vector = {
val wordCounts = document.replaceAll("[^\\p{L}\\p{Nd}]+", " ").toLowerCase.split(" ").groupBy(identity).mapValues(_.length)
val vec = new RandomAccessSparseVector(dictionaryMap.size)
val totalDFSize = dfMap(-1)
val docSize = wordCounts.size
for (word <- wordCounts) {
val term = word._1
if (dictionaryMap.contains(term)) {
val tfidf: TFIDF = new TFIDF()
val termFreq = word._2
val dictIndex = dictionaryMap(term)
val docFreq = dfCountMap(dictIndex)
val currentTfIdf = tfidf.calculate(termFreq, docFreq.toInt, docSize, totalDFSize.toInt)
vec.setQuick(dictIndex, currentTfIdf)
val labelMap = model.labelIndex
val numLabels = model.numLabels
val reverseLabelMap = => x._2 -> x._1)
// instantiate the correct type of classifier
val classifier = model.isComplementary match {
case true => new ComplementaryNBClassifier(model)
case _ => new StandardNBClassifier(model)
// the label with the higest score wins the classification for a given document
def argmax(v: Vector): (Int, Double) = {
var bestIdx: Int = Integer.MIN_VALUE
var bestScore: Double = Integer.MIN_VALUE.asInstanceOf[Int].toDouble
for(i <- 0 until v.size) {
if(v(i) > bestScore){
bestScore = v(i)
bestIdx = i
(bestIdx, bestScore)
// our final classifier
def classifyDocument(clvec: Vector) : String = {
val cvec = classifier.classifyFull(clvec)
val (bestIdx, bestScore) = argmax(cvec)
// A random United States football article
val UStextToClassify = new String("(Reuters) - Super Bowl security officials acknowledge the NFL championship game represents" +
" a high profile target on a world stage but are unaware of any specific credible threats against" +
" Sunday's showcase. In advance of one of the world's biggest single day sporting events, Homeland" +
" Security Secretary Jeh Johnson was in Glendale on Wednesday to review security preparations and" +
" tour University of Phoenix Stadium where the Seattle Seahawks and New England Patriots will battle." +
" Deadly shootings in Paris and arrest of suspects in Belgium, Greece and Germany heightened fears of" +
" more attacks around the world and social media accounts linked to Middle East militant groups have" +
" carried a number of threats to attack high-profile U.S. events. There is no specific credible" +
" threat, said Johnson, who has appointed a federal coordination team to work with local, state and" +
" federal agencies to ensure safety of fans, players and other workers associated with the Super Bowl." +
" I'm confident we will have a safe and secure and successful event. Sunday's game has been given a" +
" Special Event Assessment Rating (SEAR) 1 rating, the same as in previous years, except for the year" +
" after the Sept. 11, 2001 attacks, when a higher level was declared. But security will be tight and" +
" visible around Super Bowl-related events as well as during the game itself. All fans will pass through" +
" metal detectors and pat downs. Over 4,000 private security personnel will be deployed and the almost" +
" 3,000 member Phoenix police force will be on Super Bowl duty. Nuclear device sniffing teams will be" +
" deployed and a network of Bio-Watch detectors will be set up to provide a warning in the event of " +
" a biological attack. The Department of Homeland Security (DHS) said in a press release it had held " +
" special cyber-security and anti-sniper training sessions. A U.S. official said the Transportation " +
" Security Administration, which is responsible for screening airline passengers, will add screeners " +
" and checkpoint lanes at airports. Federal air marshals, behavior detection officers and dog teams " +
" will help to secure transportation systems in the area. We will be ramping it (security) up on Sunday," +
" there is no doubt about that, said Federal Coordinator Matthew Allen, the DHS point of contact for " +
" planning and support. I have every confidence the public safety agencies that represented in the " +
" planning process are going to have their best and brightest out there this weekend and we will have" +
" a very safe Super Bowl.")
// A random United Kingdom football article
val UKtextToClassify = new String("(Reuters) - Manchester United have signed a sponsorship deal with online financial trading company" +
" Swissquote, expanding the commercial partnerships that have helped to make the English club one of" +
" the richest teams in world soccer. United did not give a value for the deal, the club's first in the" +
" sector, but said on Monday it was a multi-year agreement. The Premier League club, 20 times English" +
" champions, claim to have 659 million followers around the globe, making the United name attractive to" +
" major brands like Chevrolet cars and sportswear group Adidas. Swissquote said the global deal would" +
" allow it to use United's popularity in Asia to help it meet its targets for expansion in China. Among" +
" benefits from the deal, Swissquote's clients will have a chance to meet United players and get behind" +
" the scenes at the Old Trafford stadium. Swissquote is a Geneva-based online trading company that allows" +
" retail investors to buy and sell foreign exchange, equities, bonds and other asset classes. Like other" +
" retail FX brokers, Swissquote was left nursing losses on the Swiss franc after Switzerland's central bank" +
" stunned markets this month by abandoning its cap on the currency. The fallout from the abrupt move put rival" +
" and West Ham United shirt sponsor Alpari UK into administration. Swissquote itself was forced to book a 25 "+
" million Swiss francs ($28 million) provision for its clients who were left out of pocket following the" +
" franc's surge. United's ability to grow revenues off the pitch has made them the second richest club in" +
" the world behind Spain's Real Madrid, despite a downturn in their playing fortunes. United Managing" +
" Director Richard Arnold said there was still lots of scope for United to develop sponsorships in" +
" other areas of business. The last quoted statistics that we had showed that of the top 25 sponsorship" +
" categories, we were only active in 15 of those, Arnold told Reuters. I think there is a huge potential" +
" still for the club, and the other thing we have seen is there is very significant growth even within" +
" categories. United have endured a tricky transition following the retirement of manager Alex Ferguson" +
" in 2013, finishing seventh in the Premier League last season and missing out on a place in the lucrative" +
" Champions League. ($1 = 0.8910 Swiss francs) (Writing by Neil Maidment, additional reporting by Jemima Kelly;" +
" editing by Keith Weir)")
val usVec = vectorizeDocument(UStextToClassify, dictionaryMap, dfCountMap)
val ukVec = vectorizeDocument(UKtextToClassify, dictionaryMap, dfCountMap)
println("Classifying the news article about superbowl security (united states)")
println("Classifying the news article about Manchester United (united kingdom)")
// to classify new text, tie everything together in a new method
def classifyText(txt: String): String = {
val v = vectorizeDocument(txt, dictionaryMap, dfCountMap)
// now we can simply call our classifyText method on any string
classifyText("Hello world from Queens")
classifyText("Hello world from London")