Skip to content

Commit

Permalink
Merge pull request #12 from jordanekay/naive-bayes
Browse files Browse the repository at this point in the history
Convert naive Bayes classifier to Swift
  • Loading branch information
ayanonagon committed Feb 20, 2015
2 parents 890b903 + f410dfe commit 6dfc853
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 293 deletions.
7 changes: 4 additions & 3 deletions Parsimmon/Example/ClassifierViewController.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

#import "ClassifierViewController.h"
#import "Parsimmon.h"
#import "Parsimmon-Swift.h"

@interface ClassifierViewController ()
@property (weak, nonatomic) IBOutlet UITextField *messageTextField;
@property (weak, nonatomic) IBOutlet UILabel *resultLabel;
@property (strong, nonatomic) ParsimmonNaiveBayesClassifier *classifier;
@property (strong, nonatomic) NaiveBayesClassifier *classifier;
@end

@implementation ClassifierViewController
Expand Down Expand Up @@ -64,10 +65,10 @@ - (IBAction)spamOrHamAction:(id)sender

#pragma mark - Properties

- (ParsimmonNaiveBayesClassifier *)classifier
- (NaiveBayesClassifier *)classifier
{
if (!_classifier) {
_classifier = [[ParsimmonNaiveBayesClassifier alloc] init];
_classifier = [[NaiveBayesClassifier alloc] init];
}
return _classifier;
}
Expand Down
18 changes: 12 additions & 6 deletions Parsimmon/Parsimmon.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

/* Begin PBXBuildFile section */
45166BDE1A94265800D0E013 /* ParsimmonTokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = B6139F6919442FB700FC6CAA /* ParsimmonTokenizer.swift */; };
459B01491A9534B0000859A1 /* NaiveBayesClassifier.swift in Sources */ = {isa = PBXBuildFile; fileRef = 459B01481A9534B0000859A1 /* NaiveBayesClassifier.swift */; };
459B014B1A954B98000859A1 /* NaiveBayesClassifierTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 459B014A1A954B98000859A1 /* NaiveBayesClassifierTests.swift */; };
459B014D1A955E3D000859A1 /* Functions.swift in Sources */ = {isa = PBXBuildFile; fileRef = 459B014C1A955E3D000859A1 /* Functions.swift */; };
B6139F6A19442FB700FC6CAA /* ParsimmonTokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = B6139F6919442FB700FC6CAA /* ParsimmonTokenizer.swift */; };
B6139F701944D59F00FC6CAA /* ParsimmonTokenizerTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = B6139F6E194433BA00FC6CAA /* ParsimmonTokenizerTests.swift */; };
B63E18C418E618160006BD3E /* InfoPlist.strings in Resources */ = {isa = PBXBuildFile; fileRef = B63E18BC18E618160006BD3E /* InfoPlist.strings */; };
Expand All @@ -32,7 +35,6 @@
B67005C4180A0A1D00CFF860 /* ParsimmonSeed.m in Sources */ = {isa = PBXBuildFile; fileRef = B67005C3180A0A1D00CFF860 /* ParsimmonSeed.m */; };
B6A43FD318837077000F61BA /* ParsimmonDecisionTree.m in Sources */ = {isa = PBXBuildFile; fileRef = B6A43FD218837077000F61BA /* ParsimmonDecisionTree.m */; };
B6A43FD618837CF6000F61BA /* ParsimmonNode.m in Sources */ = {isa = PBXBuildFile; fileRef = B6A43FD518837CF6000F61BA /* ParsimmonNode.m */; };
B6B05E08180A85B500D7F34F /* ParsimmonNaiveBayesClassifier.m in Sources */ = {isa = PBXBuildFile; fileRef = B6B05E07180A85B500D7F34F /* ParsimmonNaiveBayesClassifier.m */; };
B6B05E36180B633F00D7F34F /* ClassifierViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = B6B05E35180B633F00D7F34F /* ClassifierViewController.m */; };
/* End PBXBuildFile section */

Expand All @@ -47,6 +49,9 @@
/* End PBXContainerItemProxy section */

/* Begin PBXFileReference section */
459B01481A9534B0000859A1 /* NaiveBayesClassifier.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = NaiveBayesClassifier.swift; sourceTree = "<group>"; };
459B014A1A954B98000859A1 /* NaiveBayesClassifierTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = NaiveBayesClassifierTests.swift; sourceTree = "<group>"; };
459B014C1A955E3D000859A1 /* Functions.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Functions.swift; sourceTree = "<group>"; };
B6139F6919442FB700FC6CAA /* ParsimmonTokenizer.swift */ = {isa = PBXFileReference; fileEncoding = 4; indentWidth = 2; lastKnownFileType = sourcecode.swift; path = ParsimmonTokenizer.swift; sourceTree = "<group>"; tabWidth = 2; usesTabs = 0; };
B6139F6B1944318F00FC6CAA /* ParsimmonTests-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "ParsimmonTests-Bridging-Header.h"; sourceTree = "<group>"; };
B6139F6E194433BA00FC6CAA /* ParsimmonTokenizerTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; indentWidth = 2; lastKnownFileType = sourcecode.swift; path = ParsimmonTokenizerTests.swift; sourceTree = "<group>"; tabWidth = 2; usesTabs = 0; };
Expand Down Expand Up @@ -85,8 +90,6 @@
B6A43FD218837077000F61BA /* ParsimmonDecisionTree.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = ParsimmonDecisionTree.m; sourceTree = "<group>"; };
B6A43FD418837CF6000F61BA /* ParsimmonNode.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ParsimmonNode.h; sourceTree = "<group>"; };
B6A43FD518837CF6000F61BA /* ParsimmonNode.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = ParsimmonNode.m; sourceTree = "<group>"; };
B6B05E06180A85B500D7F34F /* ParsimmonNaiveBayesClassifier.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ParsimmonNaiveBayesClassifier.h; sourceTree = "<group>"; };
B6B05E07180A85B500D7F34F /* ParsimmonNaiveBayesClassifier.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = ParsimmonNaiveBayesClassifier.m; sourceTree = "<group>"; };
B6B05E34180B633F00D7F34F /* ClassifierViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ClassifierViewController.h; sourceTree = "<group>"; };
B6B05E35180B633F00D7F34F /* ClassifierViewController.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = ClassifierViewController.m; sourceTree = "<group>"; };
/* End PBXFileReference section */
Expand Down Expand Up @@ -118,6 +121,7 @@
isa = PBXGroup;
children = (
B63E18BC18E618160006BD3E /* InfoPlist.strings */,
459B014A1A954B98000859A1 /* NaiveBayesClassifierTests.swift */,
B63E18BE18E618160006BD3E /* ParsimmonTests-Info.plist */,
B63E18C018E618160006BD3E /* ParsimmonDecisionTreeTests.m */,
B63E18C118E618160006BD3E /* ParsimmonLemmatizerTests.m */,
Expand Down Expand Up @@ -196,6 +200,7 @@
B67005BA1809CD5600CFF860 /* ParsimmonTagger.m */,
B67005BF180A05FC00CFF860 /* ParsimmonLemmatizer.h */,
B67005C0180A05FC00CFF860 /* ParsimmonLemmatizer.m */,
459B014C1A955E3D000859A1 /* Functions.swift */,
B6B05E05180A858000D7F34F /* Classifiers */,
B67005C6180A19BD00CFF860 /* Seedlings */,
);
Expand All @@ -217,8 +222,7 @@
B6B05E05180A858000D7F34F /* Classifiers */ = {
isa = PBXGroup;
children = (
B6B05E06180A85B500D7F34F /* ParsimmonNaiveBayesClassifier.h */,
B6B05E07180A85B500D7F34F /* ParsimmonNaiveBayesClassifier.m */,
459B01481A9534B0000859A1 /* NaiveBayesClassifier.swift */,
B6A43FD118837077000F61BA /* ParsimmonDecisionTree.h */,
B6A43FD218837077000F61BA /* ParsimmonDecisionTree.m */,
B6A43FD418837CF6000F61BA /* ParsimmonNode.h */,
Expand Down Expand Up @@ -336,13 +340,14 @@
B67005911807D79500CFF860 /* TaggerViewController.m in Sources */,
B6A43FD318837077000F61BA /* ParsimmonDecisionTree.m in Sources */,
B670058B1807D79500CFF860 /* AppDelegate.m in Sources */,
459B014D1A955E3D000859A1 /* Functions.swift in Sources */,
B67005871807D79500CFF860 /* main.m in Sources */,
B6139F6A19442FB700FC6CAA /* ParsimmonTokenizer.swift in Sources */,
B67005C1180A05FC00CFF860 /* ParsimmonLemmatizer.m in Sources */,
B67005BE1809CE2F00CFF860 /* ParsimmonTaggedToken.m in Sources */,
B6B05E36180B633F00D7F34F /* ClassifierViewController.m in Sources */,
B6B05E08180A85B500D7F34F /* ParsimmonNaiveBayesClassifier.m in Sources */,
B67005BB1809CD5600CFF860 /* ParsimmonTagger.m in Sources */,
459B01491A9534B0000859A1 /* NaiveBayesClassifier.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand All @@ -351,6 +356,7 @@
buildActionMask = 2147483647;
files = (
B6139F701944D59F00FC6CAA /* ParsimmonTokenizerTests.swift in Sources */,
459B014B1A954B98000859A1 /* NaiveBayesClassifierTests.swift in Sources */,
B63E18CC18E6196D0006BD3E /* ParsimmonLemmatizerTests.m in Sources */,
B63E18CD18E619710006BD3E /* ParsimmonTaggerTests.m in Sources */,
45166BDE1A94265800D0E013 /* ParsimmonTokenizer.swift in Sources */,
Expand Down
16 changes: 16 additions & 0 deletions Parsimmon/Parsimmon/Functions.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//
// Functions.swift
// Parsimmon
//
// Created by Jordan Kay on 2/18/15.
//
//

import Foundation

func argmax<T, U: Comparable>(elements: [(T, U)]) -> T? {
if let start = elements.first {
return elements.reduce(start) { $0.1 > $1.1 ? $0 : $1 }.0
}
return nil
}
158 changes: 158 additions & 0 deletions Parsimmon/Parsimmon/NaiveBayesClassifier.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
//
// NaiveBayesClassifier.swift
// Parsimmon
//
// Created by Jordan Kay on 2/18/15.
//
//

/**
## Sample usage
let NaiveBayesClassifier classifier = NaiveBayesClassifier()
Train the classifier with some ham examples.
classifier.trainWithText("nom nom ham", category: "ham")
classifier.trainWithText("make sure to get the ham", category: "ham")
classifier.trainWithText("please put the eggs in the fridge", category: "ham")
Train the classifier with some spam examples.
classifier.trainWithText("spammy spam spam", category: "spam")
classifier.trainWithText("what does the fox say?", category: "spam")
classifier.trainWithText("and fish go blub", category: "spam")
Classify some new text. Is it ham or spam? In practice, you'd want to train with more examples first.
let firstExample = "use the eggs in the fridge."
let secondExample = "what does the fish say?"
println("'\(firstExample)' => \(classifier.classify(firstExample)")
println("'\(secondExample)' => \(classifier.classify(secondExample)")
Output:
'use the eggs in the fridge.' => ham
'what does the fish say?' => spam
*/

import Foundation

private let smoothingParameter = 1.0

public class NaiveBayesClassifier: NSObject {
public typealias Word = String
public typealias Category = String

private let tokenizer: ParsimmonTokenizer

private var categoryOccurrences: [Category: Int] = [:]
private var wordOccurrences: [Word: [Category: Int]] = [:]
private var trainingCount = 0
private var wordCount = 0

public init(tokenizer: ParsimmonTokenizer) {
self.tokenizer = tokenizer
}

public convenience override init() {
self.init(tokenizer: ParsimmonTokenizer())
}

// MARK: - Training

/**
Trains the classifier with text and its category.
@param text The text
@param category The category of the text
*/
public func trainWithText(text: String, category: Category) {
let tokens = tokenizer.tokenize(text)
trainWithTokens(tokens, category: category)
}

/**
Trains the classifier with tokenized text and its category.
This is useful if you wish to use your own tokenization method.
@param tokens The tokenized text
@param category The category of the text
*/
public func trainWithTokens(tokens: [Word], category: Category) {
let words = Set(tokens)
for word in words {
incrementWord(word, category: category)
}
incrementCategory(category)
trainingCount++
}

// MARK: - Classifying

/**
Classifies the given text based on its training data.
@param text The text to classify
@return The category classification
*/
public func classify(text: String) -> Category? {
let tokens = tokenizer.tokenize(text)
return classifyTokens(tokens)
}

/**
Classifies the given tokenized text based on its training data.
@param text The tokenized text to classify
@return The category classification if one was found, or nil if one wasn’t
*/
public func classifyTokens(tokens: [Word]) -> Category? {
// Compute argmax_cat [log(P(C=cat)) + sum_token(log(P(W=token|C=cat)))]
return argmax(map(categoryOccurrences) { (category, count) -> (Category, Double) in
let pCategory = P(category)
let score = tokens.reduce(log(pCategory)) { [wordCount] (total, token) in
total + log((self.P(category, token) + smoothingParameter) / (pCategory + smoothingParameter + Double(wordCount)))
}
return (category, score)
})
}

// MARK: - Probabilites

private func P(category: Category, _ word: Word) -> Double {
if let occurrences = wordOccurrences[word] {
let count = occurrences[category] ?? 0
return Double(count) / Double(trainingCount)
}
return 0.0
}

private func P(category: Category) -> Double {
return Double(totalOccurrencesOfCategory(category)) / Double(trainingCount)
}

// MARK: - Counting

private func incrementWord(word: Word, category: Category) {
if wordOccurrences[word] == nil {
wordCount++
wordOccurrences[word] = [:]
}

let count = wordOccurrences[word]?[category] ?? 0
wordOccurrences[word]?[category] = count + 1
}

private func incrementCategory(category: Category) {
categoryOccurrences[category] = totalOccurrencesOfCategory(category) + 1
}

private func totalOccurrencesOfWord(word: Word) -> Int {
if let occurrences = wordOccurrences[word] {
return Array(occurrences.values).reduce(0, combine: +)
}
return 0
}

private func totalOccurrencesOfCategory(category: Category) -> Int {
return categoryOccurrences[category] ?? 0
}
}
1 change: 0 additions & 1 deletion Parsimmon/Parsimmon/Parsimmon.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,3 @@

#import "ParsimmonTagger.h"
#import "ParsimmonLemmatizer.h"
#import "ParsimmonNaiveBayesClassifier.h"
63 changes: 0 additions & 63 deletions Parsimmon/Parsimmon/ParsimmonNaiveBayesClassifier.h

This file was deleted.

Loading

0 comments on commit 6dfc853

Please sign in to comment.