Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add onehot_batch #5

Merged
merged 2 commits into from May 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/bit_str.jl
@@ -1,4 +1,4 @@
export BitStr, @bit_str, bcat, bit_literal, bit, to_location
export BitStr, @bit_str, bcat, bit_literal, bit, to_location, onehot, onehot_batch

"""
BitStr{T}
Expand Down Expand Up @@ -281,13 +281,17 @@ Base.repeat(s::BitStr, n::Integer) = bcat(s for i in 1:n)
Base.show(io::IO, bitstr::BitStr) = print(io, string(bitstr.val, base=2, pad=length(bitstr)), " (", bitstr.val, ")")

"""
onehot([T=Float64], bit_str)
onehot([T=Float64], bit_str[, nbatch])

Returns an onehot vector of type `Vector{T}`, where the `bit_str`-th element is one.
Returns an onehot vector in type `Vector{T}`, or a batch of onehot
vector in type `Matrix{T}`, where the `bit_str`-th element is one.
"""
onehot(::Type{T}, n::BitStr) where T = onehot(T, length(n), n.val)
onehot(n::BitStr) = onehot(Float64, n)

onehot(::Type{T}, n::BitStr, nbatch::Int) where T = onehot(T, length(n), n.val, nbatch)
onehot(n::BitStr, nbatch::Int) = onehot(Float64, n, nbatch)

# conversions
for IntType in [:Int8, :Int16, :Int32, :Int64, :Int128, :BigInt]
@eval Base.convert(::Type{$IntType}, x::BitStr) = $IntType(x.val)
Expand Down
13 changes: 11 additions & 2 deletions src/utils.jl
Expand Up @@ -17,9 +17,10 @@ Return number of different bits.
bdistance(i::Ti, j::Ti) where Ti<:Integer = count_ones(i ⊻ j)

"""
onehot([T=Float64], nbits, x::Integer)
onehot([T=Float64], nbits, x::Integer[, nbatch::Int])

Returns an onehot vector of type `Vector{T}`, where index `x + 1` is one.
Create an onehot vector in type `Vector{T}` or a batch of onehot vector in type `Matrix{T}`,
where index `x + 1` is one.
"""
function onehot(::Type{T}, nbits::Int, x::Integer) where T
v = zeros(T, 1 << nbits)
Expand All @@ -29,6 +30,14 @@ end

onehot(nbits::Int, x::Integer) = onehot(Float64, nbits, x)

function onehot(::Type{T}, nbits::Int, x::Integer, nbatch::Int) where T
v = zeros(T, 1 << nbits, nbatch)
v[x + 1, :] .= 1
return v
end

onehot(nbits::Int, x::Integer, nbatch::Int) = onehot(Float64, nbits, x, nbatch)

"""
unsafe_sub(a::UnitRange, b::NTuple{N}) -> NTuple{N}

Expand Down
1 change: 1 addition & 0 deletions test/utils.jl
Expand Up @@ -3,6 +3,7 @@ using Test, BitBasis
@test bsizeof(ind) == sizeof(Int) * 8
@test onehot(ComplexF64, 2, 2) == [0, 0, 1, 0]
@test bdistance(1,7) == 2
@test onehot(ComplexF64, bit"01", 2) == transpose(ComplexF64[0 1 0 0;0 1 0 0])
@test log2dim1(rand(4, 4)) == log2i(4)

@testset "log2i" begin
Expand Down