Skip to content
This repository has been archived by the owner on Dec 18, 2021. It is now read-only.

Commit

Permalink
Merge a1e660d into 7cdb707
Browse files Browse the repository at this point in the history
  • Loading branch information
GiggleLiu authored Sep 18, 2019
2 parents 7cdb707 + a1e660d commit 722a4bb
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
4 changes: 2 additions & 2 deletions benchmark/runbench.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using PkgBenchmark

current = BenchmarkConfig(id="multithreading", env = Dict("JULIA_NUM_THREADS"=>4), juliacmd=`julia -O3`)
baseline = BenchmarkConfig(id="master", env = Dict("JULIA_NUM_THREADS"=>1), juliacmd=`julia -O3`)
current = BenchmarkConfig(id="transpose_storage", juliacmd=`julia -O3`)
baseline = BenchmarkConfig(id="master", juliacmd=`julia -O3`)
results = judge("YaoArrayRegister", current, baseline)
export_markdown("report.md", results)
15 changes: 13 additions & 2 deletions src/register.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using YaoBase, BitBasis
import BitBasis: BitStr, BitStr64
using LinearAlgebra: Transpose

export ArrayReg,
AdjointArrayReg,
Expand Down Expand Up @@ -143,6 +144,9 @@ function Base.copyto!(dst::AdjointArrayReg, src::AdjointArrayReg)
return dst
end

Base.convert(::Type{Transpose{T, Matrix{T}}}, arr::AbstractMatrix{T}) where T = transpose(Matrix(transpose(arr)))
Base.convert(t::Type{Transpose{T, Matrix{T}}}, arr::Transpose{T}) where T = invoke(convert, Tuple{Type{Transpose{T, Matrix{T}}}, Transpose}, t, arr)

# register interface
YaoBase.nqubits(r::ArrayReg{B}) where B = log2i(length(r.state) ÷ B)
YaoBase.nactive(r::ArrayReg) = log2dim1(r.state)
Expand Down Expand Up @@ -333,7 +337,14 @@ product_state(total::Int, bit_config::Integer; nbatch::Int=1) = product_state(Co
product_state(::Type{T}, bit_str::BitStr; nbatch::Int=1) where T = ArrayReg{nbatch}(T, bit_str)

function product_state(::Type{T}, total::Int, bit_config::Integer; nbatch::Int=1) where T
return ArrayReg{nbatch}(onehot(T, total, bit_config, nbatch))
if nbatch == 1
raw = onehot(T, total, bit_config, nbatch)
else
raw = zeros(T, nbatch, 1<<total)
raw[:,Int(bit_config)+1] .= 1
raw = transpose(raw)
end
return ArrayReg{nbatch}(raw)
end

"""
Expand Down Expand Up @@ -386,7 +397,7 @@ ArrayReg{2, Complex{Float64}, Array...}
rand_state(n::Int; nbatch::Int=1) = rand_state(ComplexF64, n; nbatch=nbatch)

function rand_state(::Type{T}, n::Int; nbatch::Int=1) where T
raw = randn(T, 1<<n, nbatch)
raw = nbatch == 1 ? randn(T, 1<<n, nbatch) : transpose(randn(T, nbatch, 1<<n))
return normalize!(ArrayReg{nbatch}(raw))
end

Expand Down
6 changes: 6 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,9 @@ end
end
state
end

#### Yao Base patch ####
using YaoBase
function YaoBase.batched_kron!(C::Array{T, 3}, A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T, T1, T2}
YaoBase.batched_kron!(C, convert(Array{T,3}, A), convert(Array{T,3}, B))
end

0 comments on commit 722a4bb

Please sign in to comment.