diff --git a/examples/PortZygote/gate_learning.jl b/examples/PortZygote/gate_learning.jl new file mode 100644 index 0000000..ec4e9d5 --- /dev/null +++ b/examples/PortZygote/gate_learning.jl @@ -0,0 +1,31 @@ +using YaoExtensions, Yao +using Test, Random +using Optim: LBFGS, optimize + +# port the `Matrix` function to Yao's AD. +include("zygote_patch.jl") + +function loss(u, ansatz) + m = Matrix(ansatz) + sum(abs.(u .- m)) +end + +""" + learn_u4(u::AbstractMatrix; niter=100) + +Learn a general U4 gate. The optimizer is LBFGS. +""" +function learn_u4(u::AbstractMatrix; niter=100) + ansatz = general_U4() * put(2, 1=>phase(0.0)) # initial values are 0, here, we attach a global phase. + params = parameters(ansatz) + g!(G, x) = (dispatch!(ansatz, x); G .= gradient(ansatz->loss(u, ansatz), ansatz)[1]) + optimize(x->(dispatch!(ansatz, x); loss(u, ansatz)), g!, parameters(ansatz), + LBFGS(), Optim.Options(iterations=niter)) + println("final loss = $(loss(u,ansatz))") + return ansatz +end + +using Random +Random.seed!(2) +u = rand_unitary(4) +c = learn_u4(u) diff --git a/examples/PortZygote/simple_example.jl b/examples/PortZygote/simple_example.jl new file mode 100644 index 0000000..ae10024 --- /dev/null +++ b/examples/PortZygote/simple_example.jl @@ -0,0 +1,17 @@ +include("zygote_patch.jl") + +import YaoExtensions, Random + +c = YaoExtensions.variational_circuit(5) +dispatch!(c, :random) + +function loss(reg::AbstractRegister, circuit::AbstractBlock{N}) where N + #copy(reg) |> circuit + reg = apply!(copy(reg), circuit) + st = state(reg) + sum(real(st.*st)) +end + +reg0 = zero_state(5) +paramsδ = gradient(c->loss(reg0, c), c)[1] +regδ = gradient(reg->loss(reg, c), reg0)[1] diff --git a/examples/PortZygote/zygote_patch.jl b/examples/PortZygote/zygote_patch.jl new file mode 100644 index 0000000..f577146 --- /dev/null +++ b/examples/PortZygote/zygote_patch.jl @@ -0,0 +1,42 @@ +using Zygote +using Zygote: @adjoint +using Yao, Yao.AD + +@adjoint function apply!(reg::ArrayReg, block::AbstractBlock) + out = apply!(reg, block) + out, function (outδ) + (in, inδ), paramsδ = apply_back((out, outδ), block) + return (inδ, paramsδ) + end +end + +@adjoint function Matrix(block::AbstractBlock) + out = Matrix(block) + out, function (outδ) + paramsδ = mat_back(block, outδ) + return (paramsδ,) + end +end + +@adjoint function ArrayReg{B}(raw::AbstractArray) where B + ArrayReg{B}(raw), adjy->(reshape(adjy.state, size(raw)),) +end + +@adjoint function ArrayReg{B}(raw::ArrayReg) where B + ArrayReg{B}(raw), adjy->(adjy,) +end + +@adjoint function ArrayReg(raw::AbstractArray) + ArrayReg(raw), adjy->(reshape(adjy.state, size(raw)),) +end + +@adjoint function copy(reg::ArrayReg) where B + copy(reg), adjy->(adjy,) +end + +@adjoint state(reg::ArrayReg) = state(reg), adjy->(ArrayReg(adjy),) +@adjoint statevec(reg::ArrayReg) = statevec(reg), adjy->(ArrayReg(adjy),) +@adjoint state(reg::AdjointArrayReg) = state(reg), adjy->(ArrayReg(adjy')',) +@adjoint statevec(reg::AdjointArrayReg) = statevec(reg), adjy->(ArrayReg(adjy')',) +@adjoint parent(reg::AdjointArrayReg) = parent(reg), adjy->(adjy',) +@adjoint Base.adjoint(reg::ArrayReg) = Base.adjoint(reg), adjy->(parent(adjy),)