Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZXW rewrite rules + Differentiation #93

Merged
merged 40 commits into from
Aug 6, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
88b8ebc
include previous left change
exAClior Jul 3, 2023
88dfeab
implement first two rewrite rule
exAClior Jul 7, 2023
0090310
add revise, finish implment match
exAClior Jul 8, 2023
3edfbb1
finish test + implement basic rewrite
exAClior Jul 8, 2023
51d5696
begin diff operation implementation
exAClior Jul 9, 2023
90fa5bf
start implement diff rule for general case
exAClior Jul 10, 2023
3aeb785
finish diff implementation
exAClior Jul 11, 2023
806ab36
replace adt rules with multiple dispatch
exAClior Jul 14, 2023
7c47450
enable variable for both signs
exAClior Jul 15, 2023
923b27f
implement theorem 16 and start cor 17
exAClior Jul 15, 2023
1f3925d
finish implement circuit diff and expval diff
exAClior Jul 16, 2023
470a273
finish integrating
exAClior Jul 16, 2023
84d6309
add test to check spider numbers make sense
exAClior Jul 16, 2023
79bc45d
revise PR
exAClior Jul 17, 2023
6579c44
update variable substitution
exAClior Jul 24, 2023
4b982cb
add find negative symbol feature
exAClior Jul 24, 2023
ee905e6
fix einsum contraction order issue
exAClior Jul 25, 2023
33a35d5
fix concatenation of circuits
exAClior Jul 25, 2023
caabc29
finish test for exp val circuit
exAClior Jul 25, 2023
4216c94
construct variance ZXWDiagram test
exAClior Jul 25, 2023
1d4d79d
take care of importing vtx
exAClior Jul 27, 2023
044e2f5
implement stacking of zxwdiagram
exAClior Jul 27, 2023
b18ae24
correct triangle order
exAClior Jul 27, 2023
6805651
fix gradient calculation factor problem
exAClior Jul 28, 2023
dedf06b
test two variable integartion pass
exAClior Jul 28, 2023
b9a5b04
correct integration upto factor of 2
exAClior Jul 28, 2023
0c1e441
fix integration factors again
exAClior Jul 29, 2023
5629848
include test for variance
exAClior Jul 29, 2023
ee834f8
fix past test
exAClior Jul 29, 2023
4fe16ad
remove un-necessary dependencies
exAClior Aug 2, 2023
339e46b
correct nothing comparison
exAClior Aug 2, 2023
ab5a043
port differentiation
exAClior Aug 2, 2023
6eec454
fix PiUnit number testing quickly
exAClior Aug 2, 2023
dc3a3aa
make access output by qubit safer
exAClior Aug 2, 2023
6644b35
move helper functions to utils.jl
exAClior Aug 2, 2023
0c1c758
converted integration to rewrite API
exAClior Aug 2, 2023
504edb2
change intergration to fit rewrite API
exAClior Aug 3, 2023
e51272b
implement inplace stacking
exAClior Aug 4, 2023
da30a7a
remove un-necessary file
exAClior Aug 5, 2023
6870623
rename differentiatio rule
exAClior Aug 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Multigraphs = "7ebac608-6c66-46e6-9856-b5f43e107bac"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
ChenZhao44 marked this conversation as resolved.
Show resolved Hide resolved
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
YaoHIR = "6769671a-fce8-4286-b3f7-6099e1b1298a"
YaoLocations = "66df03fb-d475-48f7-b449-3d9064bf085b"
Expand Down
5 changes: 4 additions & 1 deletion src/ZXCalculus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ export SpiderType, EdgeType, ZXWSpiderType
export AbstractZXDiagram, ZXDiagram, ZXGraph
export ZXWDiagram
export Rule, Match
export ZXWRule

export spider_type, phase, spiders, rem_spider!, rem_spiders!, scalar
export parameter
export push_gate!, pushfirst_gate!, tcount
export push_gate!, pushfirst_gate!, tcount, insert_wtrig!
export convert_to_chain, convert_to_zxd
export rewrite!, simplify!, clifford_simplification, full_reduction,
circuit_extraction, phase_teleportation
Expand All @@ -40,6 +41,8 @@ include("to_eincode.jl")
include("utils.jl")

include("rules.jl")
include("zxw_rules.jl")
include("diff.jl")
include("simplify.jl")
include("circuit_extraction.jl")
include("phase_teleportation.jl")
Expand Down
137 changes: 137 additions & 0 deletions src/diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
using ZXCalculus: contains, dagger, concat!

"""
Take derivative of ZXWDiagram with respect to a parameter

Assuming Spiders have Parameter of type PiUnit which is parameterized purely by θ
"""
function diff_diagram(zxwd::ZXWDiagram{T,P}, θ::Symbol) where {T,P}
add_global_phase!(zxwd, P(π / 2))
ChenZhao44 marked this conversation as resolved.
Show resolved Hide resolved
vs = symbol_vertices(zxwd, θ)
w_trig_vs = T[]
for v in vs
x_v = add_spider!(zxwd, X(Parameter(Val(:PiUnit), 1.0)), [v])
w_v = add_spider!(zxwd, D, [x_v])
frac_v = @match spider_type(zxwd, v).p begin
PiUnit(pu, _) && if !(pu == θ)
end => add_spider!(zxwd, Z(Parameter(Val(:Factor), π)), [w_v])
PiUnit(pu, _) && if pu == θ
end => w_v
Factor(f, _) => error("Only supports PiUnit differentiation")

Check warning on line 20 in src/diff.jl

View check run for this annotation

Codecov / codecov/patch

src/diff.jl#L20

Added line #L20 was not covered by tests
_ => error("not a valid parameter")
end
push!(w_trig_vs, frac_v)
end

head = insert_wtrig!(zxwd, w_trig_vs)

add_spider!(zxwd, X(Parameter(Val(:PiUnit), 1.0)), [head])

return zxwd
end

"""
Take derivative with of Circuit with expectation of Hamiltonian H.
"""
function diff_expval!(zxwd::ZXWDiagram{T,P}, H::String, θ::Symbol) where {T,P}
# convert U to U^\dag H U
zxwd_dag = dagger(zxwd)
for (i, h) in enumerate(H)
if h == "Z"
push_gate!(zxwd, Val(:Z), i)

Check warning on line 41 in src/diff.jl

View check run for this annotation

Codecov / codecov/patch

src/diff.jl#L41

Added line #L41 was not covered by tests
elseif h == "X"
push_gate!(zxwd, Val(:X), i)

Check warning on line 43 in src/diff.jl

View check run for this annotation

Codecov / codecov/patch

src/diff.jl#L43

Added line #L43 was not covered by tests
elseif h == "Y"
push_gate!(zxwd, Val(:Z), i)
push_gate!(zxwd, Val(:X), i)
add_global_phase!(zxwd, P(-π / 2))

Check warning on line 47 in src/diff.jl

View check run for this annotation

Codecov / codecov/patch

src/diff.jl#L45-L47

Added lines #L45 - L47 were not covered by tests
end
end
concat!(zxwd, zxwd_dag)
return diff_diagram(zxwd, θ)
end

"""

Finds vertices of Spider that contains the parameter θ or -θ
"""
function symbol_vertices(zxwd::ZXWDiagram{T,P}, θ::Symbol) where {T,P}

matched = T[]
for v in vertices(zxwd.mg)
res = @match spider_type(zxwd, v) begin
Z(p1) && if contains(p1, θ)
end => v
X(p1) && if contains(p1, θ)
end => v
_ => nothing
end
res !== nothing && push!(matched, v)
end
return matched
end

"""
Integrate over the Spiders at locs with respect to the parameter θ.

User need to check that the parameters are indeed in the form of k * θ where k is Int
"""
function integrate!(zxwd::ZXWDiagram{T,P}, locs::Vector{T}) where {T,P}
length(locs) == 2 && return integrate2!(zxwd, locs[1], locs[2])
length(locs) == 4 && return integrate4!(zxwd, locs[1], locs[2], locs[3], locs[4])
end

function integrate2!(zxwd::ZXWDiagram{T,P}, loc1::T, loc2::T) where {T,P}
exAClior marked this conversation as resolved.
Show resolved Hide resolved
loc1 = int_prep!(zxwd, loc1)
loc2 = int_prep!(zxwd, loc2)
add_edge!(zxwd.mg, loc1, loc2)
return zxwd
end

"""
Integrate two pairs of +/- parameter. Theorem 23 of https://arxiv.org/abs/2201.13250
"""
function integrate4!(zxwd::ZXWDiagram{T,P}, loca::T, locb::T, locc::T, locd::T) where {T,P}
loca = int_prep!(zxwd, loca)
locb = int_prep!(zxwd, locb)
locc = int_prep!(zxwd, locc)
locd = int_prep!(zxwd, locd)

# a, b = + , - \theta
# c, d = + , - \theta
loca = add_spider!(zxwd, Z(Parameter(Val(:PiUnit), 0)), [loca])
locb = add_spider!(zxwd, Z(Parameter(Val(:PiUnit), 0)), [locb])
locc = add_spider!(zxwd, X(Parameter(Val(:PiUnit), 0)), [locc])
locd = add_spider!(zxwd, X(Parameter(Val(:PiUnit), 0)), [locd])

add_edge!(zxwd, loca, locc)
add_edge!(zxwd, locb, locd)

locm = add_spider!(zxwd, X(Parameter(Val(:PiUnit), 0)), [loca, locb])
locm = add_spider!(zxwd, D, [locm])
locm = add_spider!(zxwd, X(Parameter(Val(:PiUnit), π)), [locm])
add_spider!(zxwd, X(Parameter(Val(:PiUnit), 0)), [locm, locc, locd])
return zxwd
end

"""
Prepare spider at loc for integration.

Perform the simplified step of zeroing out phase of spider
and readying it for integration
1. If target spider is X spider, turn it to Z by adding H to all its legs
2. Pull out the Phase of the spider
3. zero out the phase
4. change the current spider back to its original type if necessary,
will generate one extra H spider.
"""
function int_prep!(zxwd::ZXWDiagram{T,P}, loc::T) where {T,P}
set_phase!(zxwd, loc, Parameter(Val(:PiUnit), 0.0))

new_loc = @match spider_type(zxwd, loc) begin
X(_) => add_spider!(zxwd, H, [loc])
Z(_) => loc
_ => error("Not a valid Spider to integrate over")
end
return new_loc
end
29 changes: 25 additions & 4 deletions src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@
# following the same convention in Phase.jl implementation
# comparison have inconsistent, we are comparing phases to numbers
# if cause trouble, will change
#
function contains(p::Parameter, θ::Symbol)
ChenZhao44 marked this conversation as resolved.
Show resolved Hide resolved
@match p begin
PiUnit(pu, pt) && if !(pt <: Number)
end => Base.contains(repr(pu), string(θ))
_ => false
end
end

Base.isless(p1::Parameter, p2::Number) = @match p1 begin
PiUnit(_...) => p1.pu isa Number && p1.pu < p2
_ => p1.f < p2
Expand Down Expand Up @@ -109,8 +118,9 @@
@match (p1, p2) begin
(PiUnit(pu1, _), PiUnit(pu2, _)) && if pu1 isa Number && pu2 isa Number
end => Parameter(Val(:PiUnit), pu1 + pu2)
(PiUnit(pu1, _), PiUnit(pu2, _)) && if !(pu1 isa Number) || !(pu2 isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :+, pu1, pu2))
(PiUnit(pu1, pu1_t), PiUnit(pu2, pu2_t)) &&

Check warning on line 121 in src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter.jl#L121

Added line #L121 was not covered by tests
if !(pu1 isa Number) || !(pu2 isa Number)
end => PiUnit(Expr(:call, :+, pu1, pu2), Base.promote_op(+, pu1_t, pu2_t))
(Factor(f1, _), Factor(f2, _)) => Parameter(Val(:Factor), f1 + f2)
(PiUnit(pu1, _), Factor(f2, _)) => Parameter(Val(:Factor), exp(im * pu1 * π) * f2)
(Factor(f1, _), PiUnit(pu2, _)) => Parameter(Val(:Factor), exp(im * pu2 * π) * f1)
Expand All @@ -134,8 +144,8 @@
end => Parameter(Val(:PiUnit), pu1 - pu2)
(PiUnit(pu1, pu_t1), PiUnit(pu2, pu_t2)) &&
if !(pu1 isa Number) || !(pu2 isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :-, pu1, pu2))
(Factor(f1, _), Factor(f2, _)) => Factor(f1 - f2)
end => PiUnit(Expr(:call, :-, pu1, pu2), Base.promote_op(-, pu_t1, pu_t2))
(Factor(f1, _), Factor(f2, _)) => Parameter(Val(:Factor), f1 - f2)

Check warning on line 148 in src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter.jl#L147-L148

Added lines #L147 - L148 were not covered by tests
(PiUnit(_...), Factor(_...)) => Parameter(Val(:Factor), exp(im * p1.pu * π) - p2.f)
(Factor(_...), PiUnit(_...)) => Parameter(Val(:Factor), p1.f - exp(im * p2.pu * π))
(_, PiUnit(_...)) => Parameter(Val(:PiUnit), p1 - p2.pu)
Expand Down Expand Up @@ -163,3 +173,14 @@
)
end
end

function Base.inv(p::Parameter)
@match p begin
PiUnit(pu, pu_t) && if pu isa Number
end => Parameter(Val(:PiUnit), -pu)
PiUnit(pu, pu_t) && if !(pu isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :-, pu))
Factor(f, _) => Parameter(Val(:Factor), inv(f))

Check warning on line 183 in src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter.jl#L183

Added line #L183 was not covered by tests
_ => error("Invalid input '$(p)' of type $(typeof(p)) for ADT: inv")
end
end
81 changes: 80 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

function _round_phase(p::Parameter)
@match p begin
PiUnit(_...) => rem(rem(p, 2) + 2, 2)
PiUnit(pu, pt) && if pt <: Number
exAClior marked this conversation as resolved.
Show resolved Hide resolved
end => rem(rem(p, 2) + 2, 2)
_ => p
end
end
Expand Down Expand Up @@ -415,3 +416,81 @@
add_power!(zxwd, 1)
return zxwd
end

"""

Insert W triangle on a vector of vertices

"""
function insert_wtrig!(zxwd::ZXWDiagram{T,P}, locs::Vector{T}) where {T,P}
length(locs) < 2 && return nothing

prev_w = add_spider!(zxwd, W, locs[1:2])
head = add_spider!(zxwd, W, [prev_w])

length(locs) == 2 && return head

for loc in locs[3:end]
prev_w = add_spider!(zxwd, W, [head, loc])
head = add_spider!(zxwd, W, [prev_w])
end
return head

end

"""
Convert ZXWDiagram that represents unitary U to U^†
"""
function dagger(zxwd::ZXWDiagram{T,P}) where {T,P}
zxwd_dg = copy(zxwd)
for v in vertices(zxwd_dg.mg)
@match zxwd_dg.st[v] begin
Input(q) => (zxwd_dg.st[v] = Output(q))
Output(q) => (zxwd_dg.st[v] = Input(q))
Z(p) => (zxwd_dg.st[v] = Z(inv(p)))
X(p) => (zxwd_dg.st[v] = X(inv(p)))
W => nothing
H => nothing
D => nothing

Check warning on line 454 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L452-L454

Added lines #L452 - L454 were not covered by tests
end
end
return zxwd_dg
end

"""
Concatenate two ZXWDiagrams, modify d1.

Remove outputs of d1 and inputs of d2. Then add edges between to vertices
that was conntecting to outputs of d1 and inputs of d2.
Assuming you don't concatenate two empty circuit ZXWDiagram
"""
function concat!(d1::ZXWDiagram{T,P}, d2::ZXWDiagram{T,P}) where {T,P}
v2tov1 = Dict{T,T}()
for v2 in vertices(d2.mg)
new_v = @match d2.st[v2] begin
Input(q) => nothing
Output(q) => nothing
(Z(_) || X(_) || W || H || D) => add_vertex!(d1.mg)[1]
_ => error("Unknown spider type $(d2.st[v2])")
end
if new_v !== nothing
v2tov1[v2] = new_v
d1.st[new_v] = d2.st[v2]
end
end
prior_outputs = [neighbors(d1, q_v) for q_v in d1.outputs]
for edge in edges(d2.mg)
src, dst, emul = edge.src, edge.dst, edge.mul
v1srcs, v1dst = @match (spider_type(d2, src), spider_type(d2, dst)) begin
(Input(q), _) => (prior_outputs[q], v2tov1[dst])
(_, Input(q)) => (prior_outputs[q], v2tov1[src])
(Output(q), _) => ([v2tov1[dst]], d1.outputs[q])
(_, Output(q)) => ([v2tov1[src]], d1.outputs[q])
_ => ([v2tov1[src]], v2tov1[dst])
end
for v1src in v1srcs
add_edge!(d1.mg, v1src, v1dst, emul)
end
end
return d1
end
1 change: 1 addition & 0 deletions src/zxw_diagram.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,4 @@ Base.copy(zxwd::ZXWDiagram{T,P}) where {T,P} = ZXWDiagram{T,P}(
copy(zxwd.inputs),
copy(zxwd.outputs),
)

Loading
Loading