diff --git a/src/instruct.jl b/src/instruct.jl index 4d68f8b..0e13331 100644 --- a/src/instruct.jl +++ b/src/instruct.jl @@ -153,7 +153,8 @@ function _instruct!( locs_raw::SVector, ic::IterControl, ) where {T} - work = ndims(state) == 1 ? similar(state, length(locs_raw)) : + work = + ndims(state) == 1 ? similar(state, length(locs_raw)) : similar(state, length(locs_raw), size(state, 2)) controldo(ic) do i @inbounds unrows!(state, locs_raw .+ i, U, work) diff --git a/src/register.jl b/src/register.jl index 21f67c4..2469b4a 100644 --- a/src/register.jl +++ b/src/register.jl @@ -421,7 +421,8 @@ function rand_state( nbatch::Int = 1, no_transpose_storage::Bool = false, ) where {T} - raw = nbatch == 1 || no_transpose_storage ? randn(T, 1 << n, nbatch) : + raw = + nbatch == 1 || no_transpose_storage ? randn(T, 1 << n, nbatch) : transpose(randn(T, nbatch, 1 << n)) return normalize!(ArrayReg{nbatch}(raw)) end @@ -451,7 +452,8 @@ function uniform_state( nbatch::Int = 1, no_transpose_storage::Bool = false, ) where {T} - raw = nbatch == 1 || no_transpose_storage ? ones(T, 1 << n, nbatch) : + raw = + nbatch == 1 || no_transpose_storage ? ones(T, 1 << n, nbatch) : transpose(ones(T, nbatch, 1 << n)) normalize!(ArrayReg{nbatch}(raw)) end