-
-
Notifications
You must be signed in to change notification settings - Fork 617
Description
Hello, I'm working on a binary classification problem and I'm having trouble with using binarycrossentropy as my loss function. Here is an example of what I am trying to do. Forgive me if I made a mistake early on and that is why it's not working.
using Flux
nb_in_nodes = 20
nb_hidden_nodes = 8
nb_out_nodes = 1
feature = rand(nb_in_nodes, 100)
response = rand([0,1], 100)
dataset = [(feature, response)]
model = Chain(Dense(nb_in_nodes, nb_hidden_nodes, relu),
Dense(nb_hidden_nodes, nb_out_nodes),
softmax)
loss(x, y) = Flux.binarycrossentropy(model(x), y)
opt = ADAM()
Flux.train!(loss, params(model), dataset, opt)
I get the following error:
MethodError: no method matching eps(::TrackedArray{…,Array{Float32,2}})
...
Stacktrace:
[1] binarycrossentropy(::TrackedArray{…,Array{Float32,2}}, ::Array{Int64,1}) at /home/henry/.julia/packages/Flux/qXNjB/src/layers/stateless.jl:26
So it looks like julia's eps function doesn't like working with the TrackedArray's that Flux.Train! eventually sends along. So I then tried the same thing but replaced the definition of the loss function on line 15 with the following, just defining epsilon directly so eps() was never called.
loss(x, y) = Flux.binarycrossentropy(model(x), y; ϵ=1e-15)
Then if run the above, I get the following:
MethodError: no method matching +(::TrackedArray{…,Array{Float32,2}}, ::Float64)
...
Stacktrace:
[1] binarycrossentropy(::TrackedArray{…,Array{Float32,2}}, ::Array{Int64,1}) at /home/henry/.julia/packages/Flux/qXNjB/src/layers/stateless.jl:26
So it looks like julia doesn't want to add an Array to a float (epsilon), like so.
julia> [1, 2, 3] + 1.0
ERROR: MethodError: no method matching +(::Array{Int64,1}, ::Float64)
Those are the problems I'm having.
Is binarycrossentropy() supposed to have broadcasting in it for the arithmetic and log taking, or perhaps is it supposed to only be getting a single value for ŷ and y at a time? Here is the code for binarycrossentropy as found here on github and on my own machine
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
mse(), crossentropy(), and logitcrossentropy() appear to be using broadcasting, but there could be some subtlety I'm missing in the math and how binarycrossentropy() is supposed to be used.