Skip to content

Commit

Permalink
add tests for vectorized deriv and deriv2
Browse files Browse the repository at this point in the history
  • Loading branch information
Evizero committed Dec 2, 2016
1 parent c6f30af commit 7dd53a2
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions test/tst_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,62 @@ function test_vector_value(l::DistanceLoss, t, y)
end
end

function test_vector_deriv(l::MarginLoss, t, y)
@testset "$(l): " begin
ref = [ LossFunctions.deriv(l,t[i],y[i]) for i in 1:length(y) ]
@test LossFunctions.deriv.(l, t, y) == ref
@test t .* LossFunctions.deriv.(l, t .* y) == ref
@test l'.(t, y) == ref
@test t .* l'.(t .* y) == ref
end
end

function test_vector_deriv(l::DistanceLoss, t, y)
@testset "$(l): " begin
ref = [ LossFunctions.deriv(l,t[i],y[i]) for i in 1:length(y) ]
@test LossFunctions.deriv.(l, t, y) == ref
@test LossFunctions.deriv.(l, y - t) == ref
@test l'.(t, y) == ref
@test l'.(y - t) == ref
end
end

function test_vector_deriv2(l::MarginLoss, t, y)
@testset "$(l): " begin
ref = [ LossFunctions.deriv2(l,t[i],y[i]) for i in 1:length(y) ]
@test LossFunctions.deriv2.(l, t, y) == ref
@test LossFunctions.deriv2.(l, t .* y) == ref
@test l''.(t, y) == ref
@test l''.(t .* y) == ref
end
end

function test_vector_deriv2(l::DistanceLoss, t, y)
@testset "$(l): " begin
ref = [ LossFunctions.deriv2(l,t[i],y[i]) for i in 1:length(y) ]
@test LossFunctions.deriv2.(l, t, y) == ref
@test LossFunctions.deriv2.(l, y - t) == ref
@test l''.(t, y) == ref
@test l''.(y - t) == ref
end
end

@testset "Vectorized API" begin
targets = rand([-1,1], 10)
outputs = (rand(10)-.5) * 20

for loss in margin_losses
test_vector_value(loss, targets, outputs)
test_vector_deriv(loss, targets, outputs)
test_vector_deriv2(loss, targets, outputs)
end

targets = (rand(10)-.5) * 20
outputs = (rand(10)-.5) * 20
for loss in distance_losses
test_vector_value(loss, targets, outputs)
test_vector_deriv(loss, targets, outputs)
test_vector_deriv2(loss, targets, outputs)
end
end

Expand Down

0 comments on commit 7dd53a2

Please sign in to comment.