Skip to content

Commit

Permalink
binary conventions for sparse arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ahwillia committed Oct 18, 2016
1 parent a250132 commit 0896b4c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/supervised/supervised.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,34 @@ preallocated buffer, which has to be the same size as the parameters.
end
end

##
# Function for sparse arrays
@generated function value!{T,N,Q,Ti,M}(
buffer::AbstractArray,
loss::MarginLoss,
target::AbstractSparseArray{Q,Ti,M},
output::AbstractArray{T,N}
)
M > N && throw(ArgumentError("target has more dimensions than output; broadcasting not supported in this direction."))
quote
@_dimcheck size(buffer) == size(output)
@nexprs $M (n)->@_dimcheck(size(target,n) == size(output,n))
zeroQ = zero(Q)
negQ = Q(-1)
@simd for I in CartesianRange(size(output))
@nexprs $N n->(i_n = I[n])
tgt = @nref($M,target,i)
if tgt == zeroQ
# convention is that zeros in a sparse array are interpreted as negative one
@inbounds @nref($N,buffer,i) = value(loss, negQ, @nref($N,output,i))
else
@inbounds @nref($N,buffer,i) = value(loss, tgt, @nref($N,output,i))
end
end
buffer
end
end

"""
deriv!(buffer::AbstractArray, loss::SupervisedLoss, target::AbstractArray, output::AbstractArray)
Expand Down
39 changes: 39 additions & 0 deletions test/tst_loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,42 @@ end
end
end

@testset "Test sparse array conventions for margin-based losses" begin
@testset "sparse vector target, vector output" begin
N = 50

# sparse vector of {0,1}
sparse_target = sprand(N,0.5)
nz = sparse_target .> 0.0
sparse_target[nz] = 1.0
@test typeof(sparse_target) <: AbstractSparseArray

# dense vector of {-1,1}
target = [ t == 0.0 ? -1.0 : 1.0 for t in sparse_target ]

output = randn(N)

for loss in margin_losses
@test isapprox(value(loss,sparse_target,output), value(loss,target,output))
end
end

@testset "sparse vector target, matrix output" begin
N = 50

# sparse vector of {0,1}
sparse_target = sprand(N,0.5)
nz = sparse_target .> 0.0
sparse_target[nz] = 1.0
@test typeof(sparse_target) <: AbstractSparseArray

# dense vector of {-1,1}
target = [ t == 0.0 ? -1.0 : 1.0 for t in sparse_target ]

output = randn(N,N)

for loss in margin_losses
@test isapprox(value(loss,sparse_target,output), value(loss,target,output))
end
end
end

0 comments on commit 0896b4c

Please sign in to comment.