Skip to content

Commit

Permalink
reformatted naive bayes to more idiomatic ruby
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry committed Jan 23, 2014
1 parent 9d4278d commit d45bdc8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 39 deletions.
43 changes: 22 additions & 21 deletions lib/ai4r/classifiers/naive_bayes.rb
Expand Up @@ -57,8 +57,8 @@ module Classifiers

class NaiveBayes < Classifier

parameters_info :m => "Default value is set to 0. It may be set to a value greater than " +
"0 when the size of the dataset is relatively small"
parameters_info :m => 'Default value is set to 0. It may be set to a value greater than ' +
'0 when the size of the dataset is relatively small'

def initialize
@m = 0
Expand All @@ -75,7 +75,7 @@ def initialize
# b.eval(["Red", "SUV", "Domestic"])
# => 'No'
def eval(data)
prob = @class_prob.map {|cp| cp}
prob = @class_prob.dup
prob = calculate_class_probabilities_for_entry(data, prob)
index_to_klass(prob.index(prob.max))
end
Expand All @@ -90,27 +90,28 @@ def eval(data)
# b.get_probability_map(["Red", "SUV", "Domestic"])
# => {"Yes"=>0.4166666666666667, "No"=>0.5833333333333334}
def get_probability_map(data)
prob = @class_prob.map {|cp| cp}
prob = @class_prob.dup
prob = calculate_class_probabilities_for_entry(data, prob)
prob = normalize_class_probability prob
probability_map = {}
prob.each_with_index { |p, i| probability_map[index_to_klass(i)] = p }
return probability_map

probability_map
end

# counts values of the attribute instances and calculates the probability of the classes
# and the conditional probabilities
# Parameter data has to be an instance of CsvDataSet
def build(data)
raise "Error instance must be passed" unless data.is_a?(Ai4r::Data::DataSet)
raise "Data should not be empty" if data.data_items.length == 0
raise 'Error instance must be passed' unless data.is_a?(Ai4r::Data::DataSet)
raise 'Data should not be empty' if data.data_items.length == 0

initialize_domain_data(data)
initialize_klass_index
initialize_pc
calculate_probabilities

return self
self
end

private
Expand All @@ -128,7 +129,7 @@ def initialize_domain_data(data)
# probability of every attribute in condition to a specific class
# this is repeated for every class
def calculate_class_probabilities_for_entry(data, prob)
prob.each_with_index do |prob_entry, prob_index|
0.upto(prob.length - 1) do |prob_index|
data.each_with_index do |att, index|
next if value_index(att, index).nil?
prob[prob_index] *= @pcp[index][value_index(att, index)][prob_index]
Expand All @@ -140,13 +141,13 @@ def calculate_class_probabilities_for_entry(data, prob)
def normalize_class_probability(prob)
prob_sum = sum(prob)
prob_sum > 0 ?
prob.map {|prob_entry| prob_entry / prob_sum } :
prob.map { |prob_entry| prob_entry / prob_sum } :
prob
end

# sums an array up; returns a number of type Float
def sum(array)
array.inject(0.0){|b, i| b+i}
array.inject(0.0) { |b, i| b + i }
end

# returns the name of the class when the index is found
Expand All @@ -160,7 +161,7 @@ def initialize_klass_index
@klass_index[dl] = index
end

@data_labels.each_with_index do |dl, index|
0.upto(@data_labels.length - 1) do |index|
@values[index] = {}
@domains[index].each_with_index do |d, d_index|
@values[index][d] = d_index
Expand All @@ -180,27 +181,27 @@ def value_index(value, dl_index)

# builds an array of the form:
# array[attributes][values][classes]
def build_array(dl, index)
def build_array(index)
domains = Array.new(@domains[index].length)
domains.map do |p1|
pl = Array.new @klasses.length, 0
domains.map do
Array.new @klasses.length, 0
end
end

# initializes the two array for storing the count and conditional probabilities of
# the attributes
def initialize_pc
@data_labels.each_with_index do |dl, index|
@pcc << build_array(dl, index)
@pcp << build_array(dl, index)
0.upto(@data_labels.length - 1) do |index|
@pcc << build_array(index)
@pcp << build_array(index)
end
end

# calculates the occurrences of a class and the instances of a certain value of a
# certain attribute and the assigned class.
# In addition to that, it also calculates the conditional probabilities and values
def calculate_probabilities
@klasses.each {|dl| @class_counts[klass_index(dl)] = 0}
@klasses.each { |dl| @class_counts[klass_index(dl)] = 0 }

calculate_class_probabilities
count_instances
Expand All @@ -220,7 +221,7 @@ def calculate_class_probabilities
# counts the instances of a certain value of a certain attribute and the assigned class
def count_instances
@data_items.each do |item|
@data_labels.each_with_index do |dl, dl_index|
0.upto(@data_labels.length - 1) do |dl_index|
@pcc[dl_index][value_index(item[dl_index], dl_index)][klass_index(item.klass)] += 1
end
end
Expand All @@ -231,7 +232,7 @@ def calculate_conditional_probabilities
@pcc.each_with_index do |attributes, a_index|
attributes.each_with_index do |values, v_index|
values.each_with_index do |klass, k_index|
@pcp[a_index][v_index][k_index] = (klass.to_f + @m * @class_prob[k_index]) / (@class_counts[k_index] + @m).to_f
@pcp[a_index][v_index][k_index] = (klass.to_f + @m * @class_prob[k_index]) / (@class_counts[k_index] + @m)
end
end
end
Expand Down
36 changes: 18 additions & 18 deletions test/classifiers/naive_bayes_test.rb
Expand Up @@ -7,37 +7,37 @@

class NaiveBayesTest < Test::Unit::TestCase

@@data_labels = [ "Color","Type","Origin","Stolen?" ]
@@data_labels = %w(Color Type Origin Stolen?)

@@data_items = [
["Red", "Sports", "Domestic", "Yes"],
["Red", "Sports", "Domestic", "No"],
["Red", "Sports", "Domestic", "Yes"],
["Yellow","Sports", "Domestic", "No"],
["Yellow","Sports", "Imported", "Yes"],
["Yellow","SUV", "Imported", "No"],
["Yellow","SUV", "Imported", "Yes"],
["Yellow","Sports", "Domestic", "No"],
["Red", "SUV", "Imported", "No"],
["Red", "Sports", "Imported", "Yes"]
]
%w(Red Sports Domestic Yes),
%w(Red Sports Domestic No),
%w(Red Sports Domestic Yes),
%w(Yellow Sports Domestic No),
%w(Yellow Sports Imported Yes),
%w(Yellow SUV Imported No),
%w(Yellow SUV Imported Yes),
%w(Yellow Sports Domestic No),
%w(Red SUV Imported No),
%w(Red Sports Imported Yes)
]

def setup
@data_set = DataSet.new
@data_set = DataSet.new(:data_items => @@data_items, :data_labels => @@data_labels)
@b = NaiveBayes.new.set_parameters({:m=>3}).build @data_set
@b = NaiveBayes.new.set_parameters({:m => 3}).build @data_set
end

def test_eval
result = @b.eval(["Red", "SUV", "Domestic"])
assert_equal "No", result
result = @b.eval(%w(Red SUV Domestic))
assert_equal 'No', result
end

def test_get_probability_map
map = @b.get_probability_map(["Red", "SUV", "Domestic"])
map = @b.get_probability_map(%w(Red SUV Domestic))
assert_equal 2, map.keys.length
assert_in_delta 0.42, map["Yes"], 0.1
assert_in_delta 0.58, map["No"], 0.1
assert_in_delta 0.42, map['Yes'], 0.1
assert_in_delta 0.58, map['No'], 0.1
end

end

0 comments on commit d45bdc8

Please sign in to comment.