Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/PortZygote/gate_learning.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using YaoExtensions, Yao
using Test, Random
using Optim: LBFGS, optimize

# port the `Matrix` function to Yao's AD.
include("zygote_patch.jl")

function loss(u, ansatz)
m = Matrix(ansatz)
sum(abs.(u .- m))
end

"""
learn_u4(u::AbstractMatrix; niter=100)

Learn a general U4 gate. The optimizer is LBFGS.
"""
function learn_u4(u::AbstractMatrix; niter=100)
ansatz = general_U4() * put(2, 1=>phase(0.0)) # initial values are 0, here, we attach a global phase.
params = parameters(ansatz)
g!(G, x) = (dispatch!(ansatz, x); G .= gradient(ansatz->loss(u, ansatz), ansatz)[1])
optimize(x->(dispatch!(ansatz, x); loss(u, ansatz)), g!, parameters(ansatz),
LBFGS(), Optim.Options(iterations=niter))
println("final loss = $(loss(u,ansatz))")
return ansatz
end

using Random
Random.seed!(2)
u = rand_unitary(4)
c = learn_u4(u)
17 changes: 17 additions & 0 deletions examples/PortZygote/simple_example.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
include("zygote_patch.jl")

import YaoExtensions, Random

c = YaoExtensions.variational_circuit(5)
dispatch!(c, :random)

function loss(reg::AbstractRegister, circuit::AbstractBlock{N}) where N
#copy(reg) |> circuit
reg = apply!(copy(reg), circuit)
st = state(reg)
sum(real(st.*st))
end

reg0 = zero_state(5)
paramsδ = gradient(c->loss(reg0, c), c)[1]
regδ = gradient(reg->loss(reg, c), reg0)[1]
42 changes: 42 additions & 0 deletions examples/PortZygote/zygote_patch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using Zygote
using Zygote: @adjoint
using Yao, Yao.AD

@adjoint function apply!(reg::ArrayReg, block::AbstractBlock)
out = apply!(reg, block)
out, function (outδ)
(in, inδ), paramsδ = apply_back((out, outδ), block)
return (inδ, paramsδ)
end
end

@adjoint function Matrix(block::AbstractBlock)
out = Matrix(block)
out, function (outδ)
paramsδ = mat_back(block, outδ)
return (paramsδ,)
end
end

@adjoint function ArrayReg{B}(raw::AbstractArray) where B
ArrayReg{B}(raw), adjy->(reshape(adjy.state, size(raw)),)
end

@adjoint function ArrayReg{B}(raw::ArrayReg) where B
ArrayReg{B}(raw), adjy->(adjy,)
end

@adjoint function ArrayReg(raw::AbstractArray)
ArrayReg(raw), adjy->(reshape(adjy.state, size(raw)),)
end

@adjoint function copy(reg::ArrayReg) where B
copy(reg), adjy->(adjy,)
end

@adjoint state(reg::ArrayReg) = state(reg), adjy->(ArrayReg(adjy),)
@adjoint statevec(reg::ArrayReg) = statevec(reg), adjy->(ArrayReg(adjy),)
@adjoint state(reg::AdjointArrayReg) = state(reg), adjy->(ArrayReg(adjy')',)
@adjoint statevec(reg::AdjointArrayReg) = statevec(reg), adjy->(ArrayReg(adjy')',)
@adjoint parent(reg::AdjointArrayReg) = parent(reg), adjy->(adjy',)
@adjoint Base.adjoint(reg::ArrayReg) = Base.adjoint(reg), adjy->(parent(adjy),)