Skip to content

Commit

Permalink
merged changes with master
Browse files Browse the repository at this point in the history
  • Loading branch information
bararchy committed Dec 26, 2017
2 parents 8bfc7c5 + dc8dc48 commit 3ab8072
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 138 deletions.
41 changes: 36 additions & 5 deletions spec/network_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@ describe SHAInet::Network do

xor = SHAInet::Network.new
xor.add_layer(:input, 2, :memory)
1.times { |x| xor.add_layer(:hidden, 2, :memory) }
1.times { |x| xor.add_layer(:hidden, 3, :memory) }
xor.add_layer(:output, 1, :memory)
xor.fully_connect

# data, cost_function, activation_function, epochs, error_threshold, learning_rate, momentum)
xor.train(training_data, :mse, :sigmoid, 10000, 0.000001)
# data, training_type, cost_function, activation_function, epochs, error_threshold (sum of errors), learning_rate, momentum)
xor.train(training_data, :sgdm, :mse, :sigmoid, 10000, 0.001)

(xor.run([0, 0]).first < 0.1).should eq(true)
(xor.run([1, 0]).first > 0.9).should eq(true)
(xor.run([0, 1]).first > 0.9).should eq(true)
(xor.run([1, 1]).first < 0.1).should eq(true)
end

it "works on iris dataset" do
Expand All @@ -36,7 +39,6 @@ describe SHAInet::Network do
iris = SHAInet::Network.new
iris.add_layer(:input, 4, :memory)
iris.add_layer(:hidden, 5, :memory)
iris.add_layer(:hidden, 5, :memory)
iris.add_layer(:output, 3, :memory)
iris.fully_connect

Expand All @@ -53,7 +55,36 @@ describe SHAInet::Network do
normalized = SHAInet::TrainingData.new(inputs, outputs)
normalized.normalize_min_max
puts normalized
iris.train(normalized.data, :mse, :sigmoid, 10000, 0.000001)
iris.train(normalized.data, :sgdm, :mse, :sigmoid, 20000, 0.1)
iris.run(normalized.normalized_inputs.first)
puts "Expected output is: [0,0,1]"
end

it "works on iris dataset with batch train with Rprop" do
label = {
"setosa" => [0.to_f64, 0.to_f64, 1.to_f64],
"versicolor" => [0.to_f64, 1.to_f64, 0.to_f64],
"virginica" => [1.to_f64, 0.to_f64, 0.to_f64],
}
iris = SHAInet::Network.new
iris.add_layer(:input, 4, :memory)
iris.add_layer(:hidden, 5, :memory)
iris.add_layer(:output, 3, :memory)
iris.fully_connect

outputs = Array(Array(Float64)).new
inputs = Array(Array(Float64)).new
CSV.each_row(File.read(__DIR__ + "/test_data/iris.csv")) do |row|
row_arr = Array(Float64).new
row[0..-2].each do |num|
row_arr << num.to_f64
end
inputs << row_arr
outputs << label[row[-1]]
end
normalized = SHAInet::TrainingData.new(inputs, outputs)
normalized.normalize_min_max
iris.train_batch(normalized.data, :rprop, :mse, :sigmoid, 20000, 0.1)
result = iris.run(normalized.normalized_inputs.first)
((result.first < 0.1) && (result[1] < 0.1) && (result.last > 0.9)).should eq(true)
end
Expand Down
13 changes: 13 additions & 0 deletions src/shainet/functions.cr
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,17 @@ module SHAInet

return input_size, vocabulary_v, payloads_v
end

# # Other # #

# Used in Rprop
def self.sign(input : GenNum)
if input > 0
return +1
elsif input < 0
return -1
else
return 0
end
end
end
Loading

0 comments on commit 3ab8072

Please sign in to comment.