Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,24 @@ julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficientl
3 6 15 3 9 3 12 3 6 15 3
```
"""
onehotbatch(ls, labels, default...) = _onehotbatch(ls, length(labels) < 32 ? Tuple(labels) : labels, default...)
# NB function barier:
_onehotbatch(ls, labels, default...) = batch([onehot(l, labels, default...) for l in ls])
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)

function _onehotbatch(data, labels)
indices = UInt32[something(_findval(i, labels), 0) for i in data]
if 0 in indices
for x in data
isnothing(_findval(x, labels)) && error("Value $x not found in labels")
end
end
return OneHotArray(indices, length(labels))
end

function _onehotbatch(data, labels, default)
default_index = _findval(default, labels)
isnothing(default_index) && error("Default value $default is not in labels")
indices = UInt32[something(_findval(i, labels), default_index) for i in data]
return OneHotArray(indices, length(labels))
end

"""
onecold(y::AbstractArray, labels = 1:size(y,1))
Expand Down
2 changes: 2 additions & 0 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using Test
@test onehotbatch("abc", 'a':'c') == Bool[1 0 0; 0 1 0; 0 0 1]
@test onehotbatch("zbc", ('a', 'b', 'c'), 'a') == Bool[1 0 0; 0 1 0; 0 0 1]

@test onehotbatch([10, 20], [30, 40, 50], 30) == Bool[1 1; 0 0; 0 0]

@test_throws Exception onehotbatch([:a, :d], [:a, :b, :c])
@test_throws Exception onehotbatch([:a, :d], (:a, :b, :c))
@test_throws Exception onehotbatch([:a, :d], [:a, :b, :c], :e)
Expand Down