Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZXW rewrite rules + Differentiation #93

Merged
merged 40 commits into from
Aug 6, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
88b8ebc
include previous left change
exAClior Jul 3, 2023
88dfeab
implement first two rewrite rule
exAClior Jul 7, 2023
0090310
add revise, finish implment match
exAClior Jul 8, 2023
3edfbb1
finish test + implement basic rewrite
exAClior Jul 8, 2023
51d5696
begin diff operation implementation
exAClior Jul 9, 2023
90fa5bf
start implement diff rule for general case
exAClior Jul 10, 2023
3aeb785
finish diff implementation
exAClior Jul 11, 2023
806ab36
replace adt rules with multiple dispatch
exAClior Jul 14, 2023
7c47450
enable variable for both signs
exAClior Jul 15, 2023
923b27f
implement theorem 16 and start cor 17
exAClior Jul 15, 2023
1f3925d
finish implement circuit diff and expval diff
exAClior Jul 16, 2023
470a273
finish integrating
exAClior Jul 16, 2023
84d6309
add test to check spider numbers make sense
exAClior Jul 16, 2023
79bc45d
revise PR
exAClior Jul 17, 2023
6579c44
update variable substitution
exAClior Jul 24, 2023
4b982cb
add find negative symbol feature
exAClior Jul 24, 2023
ee905e6
fix einsum contraction order issue
exAClior Jul 25, 2023
33a35d5
fix concatenation of circuits
exAClior Jul 25, 2023
caabc29
finish test for exp val circuit
exAClior Jul 25, 2023
4216c94
construct variance ZXWDiagram test
exAClior Jul 25, 2023
1d4d79d
take care of importing vtx
exAClior Jul 27, 2023
044e2f5
implement stacking of zxwdiagram
exAClior Jul 27, 2023
b18ae24
correct triangle order
exAClior Jul 27, 2023
6805651
fix gradient calculation factor problem
exAClior Jul 28, 2023
dedf06b
test two variable integartion pass
exAClior Jul 28, 2023
b9a5b04
correct integration upto factor of 2
exAClior Jul 28, 2023
0c1e441
fix integration factors again
exAClior Jul 29, 2023
5629848
include test for variance
exAClior Jul 29, 2023
ee834f8
fix past test
exAClior Jul 29, 2023
4fe16ad
remove un-necessary dependencies
exAClior Aug 2, 2023
339e46b
correct nothing comparison
exAClior Aug 2, 2023
ab5a043
port differentiation
exAClior Aug 2, 2023
6eec454
fix PiUnit number testing quickly
exAClior Aug 2, 2023
dc3a3aa
make access output by qubit safer
exAClior Aug 2, 2023
6644b35
move helper functions to utils.jl
exAClior Aug 2, 2023
0c1c758
converted integration to rewrite API
exAClior Aug 2, 2023
504edb2
change intergration to fit rewrite API
exAClior Aug 3, 2023
e51272b
implement inplace stacking
exAClior Aug 4, 2023
da30a7a
remove un-necessary file
exAClior Aug 5, 2023
6870623
rename differentiatio rule
exAClior Aug 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Multigraphs = "7ebac608-6c66-46e6-9856-b5f43e107bac"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
exAClior marked this conversation as resolved.
Show resolved Hide resolved
YaoHIR = "6769671a-fce8-4286-b3f7-6099e1b1298a"
YaoLocations = "66df03fb-d475-48f7-b449-3d9064bf085b"

Expand Down
7 changes: 5 additions & 2 deletions src/ZXCalculus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ export SpiderType, EdgeType, ZXWSpiderType
export AbstractZXDiagram, ZXDiagram, ZXGraph
export ZXWDiagram
export Rule, Match
export ZXWRule

export spider_type, phase, spiders, rem_spider!, rem_spiders!, scalar
export parameter
export push_gate!, pushfirst_gate!, tcount
export push_gate!, pushfirst_gate!, tcount, insert_wtrig!
export convert_to_chain, convert_to_zxd
export rewrite!, simplify!, clifford_simplification, full_reduction,
circuit_extraction, phase_teleportation
export random_circuit
export substitute_variables!, expval_circ!, diff_diagram!, stack_zxwd, concat!, integrate!

include("parameter.jl")
include("phase.jl")
Expand All @@ -40,6 +41,8 @@ include("to_eincode.jl")
include("utils.jl")

include("rules.jl")
include("zxw_rules.jl")
include("diff.jl")
include("simplify.jl")
include("circuit_extraction.jl")
include("phase_teleportation.jl")
Expand Down
202 changes: 202 additions & 0 deletions src/diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
using ZXCalculus: contains, dagger, concat!, stack

"""
Take derivative of ZXWDiagram with respect to a parameter

Assuming Spiders have Parameter of type PiUnit which is parameterized purely by θ
"""
function diff_diagram!(zxwd::ZXWDiagram{T,P}, θ::Symbol) where {T,P}
exAClior marked this conversation as resolved.
Show resolved Hide resolved
vs_pos = symbol_vertices(zxwd, θ)
vs_neg = symbol_vertices(zxwd, θ; neg = true)

length(vs_pos) + length(vs_neg) == 0 && return zxwd

add_global_phase!(zxwd, P(π / 2))
ChenZhao44 marked this conversation as resolved.
Show resolved Hide resolved
w_trig_vs = T[]

for v in vs_pos

h_v = @match spider_type(zxwd,v) begin
X(_) => add_spider!(zxwd, H, [v])
Z(_) => v
end

x_v = add_spider!(zxwd, X(Parameter(Val(:PiUnit), 1.0)), [h_v])
w_v = add_spider!(zxwd, D, [x_v])

frac_v = @match parameter(zxwd, v) begin
PiUnit(_,_) => w_v
Factor(_,_) => error("Only supports PiUnit differentiation")
_ => error("not a valid parameter")

Check warning on line 30 in src/diff.jl

View check run for this annotation

Codecov / codecov/patch

src/diff.jl#L29-L30

Added lines #L29 - L30 were not covered by tests
end
push!(w_trig_vs, frac_v)
end

for v in vs_neg

h_v = @match spider_type(zxwd,v) begin
X(_) => add_spider!(zxwd, H, [v])
Z(_) => v
end

x_v = add_spider!(zxwd, X(Parameter(Val(:PiUnit), 1.0)), [h_v])
w_v = add_spider!(zxwd, D, [x_v])

frac_v = @match spider_type(zxwd, v).p begin
PiUnit(_, _) => add_spider!(zxwd, Z(Parameter(Val(:PiUnit), 1.0)), [w_v])
Factor(_, _) => error("Only supports PiUnit differentiation")
_ => error("not a valid parameter")

Check warning on line 48 in src/diff.jl

View check run for this annotation

Codecov / codecov/patch

src/diff.jl#L47-L48

Added lines #L47 - L48 were not covered by tests
end
push!(w_trig_vs, frac_v)
end

head = insert_wtrig!(zxwd, w_trig_vs)

add_spider!(zxwd, X(Parameter(Val(:PiUnit), 1.0)), [head])
# our definition of x_tensor exceeds one power of sqrt(2)
add_power!(zxwd, -1)

return zxwd
end

"""
Construct ZXW Diagram for representing the expectation value circuit
"""
function expval_circ!(zxwd::ZXWDiagram{T,P}, H::String) where {T,P}
# convert U to U H U^\dagger
zxwd_dag = dagger(zxwd)
for (i, h) in enumerate(H)
if h == 'Z'
push_gate!(zxwd, Val(:Z), i, 1.0)
elseif h == 'X'
push_gate!(zxwd, Val(:X), i, 1.0)

Check warning on line 72 in src/diff.jl

View check run for this annotation

Codecov / codecov/patch

src/diff.jl#L72

Added line #L72 was not covered by tests
elseif h == 'Y'
push_gate!(zxwd, Val(:Z), i, 1.0)
push_gate!(zxwd, Val(:X), i, 1.0)
add_global_phase!(zxwd, P(π / 2))

Check warning on line 76 in src/diff.jl

View check run for this annotation

Codecov / codecov/patch

src/diff.jl#L74-L76

Added lines #L74 - L76 were not covered by tests
elseif h == 'I'
continue
else
error("Invalid Hamiltonian, enter only Z, X, Y")

Check warning on line 80 in src/diff.jl

View check run for this annotation

Codecov / codecov/patch

src/diff.jl#L80

Added line #L80 was not covered by tests
end
end
concat!(zxwd, zxwd_dag)
return zxwd
end

"""

Finds vertices of Spider that contains the parameter θ or -θ
"""
function symbol_vertices(zxwd::ZXWDiagram{T,P}, θ::Symbol; neg::Bool = false) where {T,P}
if neg
target = Expr(:call, :-, θ)
else
target = θ
end
matched = T[]
for v in vertices(zxwd.mg)
res = @match spider_type(zxwd, v) begin
Z(p1) && if contains(p1, target)
end => v
X(p1) && if contains(p1, target)
end => v
_ => nothing
end
res !== nothing && push!(matched, v)
end
return matched
end

"""
Replace symbols in ZXW Diagram with specific values
"""
function substitute_variables!(
zxwd::ZXWDiagram{T,P},
sbd::Dict{Symbol,<:Number},
) where {T,P}
for (θ, val) in sbd
for negative in [false, true]
matched_pos = symbol_vertices(zxwd, θ; neg = negative)
val = negative ? -val : val
for idx in matched_pos
p = spider_type(zxwd, idx).p
@match p begin
PiUnit(pu, _) => set_phase!(zxwd, idx, Parameter(Val(:PiUnit), val))
Factor(pf, _) => set_phase!(zxwd, idx, Parameter(Val(:Factor), val))

Check warning on line 126 in src/diff.jl

View check run for this annotation

Codecov / codecov/patch

src/diff.jl#L126

Added line #L126 was not covered by tests
end
end
end
end
return zxwd
end

"""
Integrate over the Spiders at locs with respect to the parameter θ.

User need to check that the parameters are indeed in the form of k * θ where k is Int
"""
function integrate!(zxwd::ZXWDiagram{T,P}, locs::Vector{T}) where {T,P}
length(locs) == 2 && return integrate2!(zxwd, locs[1], locs[2])
length(locs) == 4 && return integrate4!(zxwd, locs[1], locs[2], locs[3], locs[4])
end

function integrate2!(zxwd::ZXWDiagram{T,P}, loc1::T, loc2::T) where {T,P}
exAClior marked this conversation as resolved.
Show resolved Hide resolved
loc1 = int_prep!(zxwd, loc1)
loc2 = int_prep!(zxwd, loc2)
add_edge!(zxwd.mg, loc1, loc2)
return zxwd
end

"""
Integrate two pairs of +/- parameter. Theorem 23 of https://arxiv.org/abs/2201.13250
"""
function integrate4!(zxwd::ZXWDiagram{T,P}, loca::T, locb::T, locc::T, locd::T) where {T,P}
loca = int_prep!(zxwd, loca)
locb = int_prep!(zxwd, locb)
locc = int_prep!(zxwd, locc)
locd = int_prep!(zxwd, locd)

# a, b = + , - \theta
# c, d = + , - \theta
loca = add_spider!(zxwd, Z(Parameter(Val(:PiUnit), 0)), [loca])
locb = add_spider!(zxwd, Z(Parameter(Val(:PiUnit), 0)), [locb])
locc = add_spider!(zxwd, X(Parameter(Val(:PiUnit), 0)), [locc])
locd = add_spider!(zxwd, X(Parameter(Val(:PiUnit), 0)), [locd])

add_edge!(zxwd, loca, locc)
add_edge!(zxwd, locb, locd)

locm = add_spider!(zxwd, X(Parameter(Val(:PiUnit), 0)), [loca, locb])
locm = add_spider!(zxwd, D, [locm])
locm = add_spider!(zxwd, X(Parameter(Val(:PiUnit), 1.0)), [locm])
add_spider!(zxwd, Z(Parameter(Val(:PiUnit), 0)), [locm, locc, locd])

# pink spider is different from red spider, we had three of them
# each with three legs, 3 * (3-2)/2 powers of 2 need to be added
# see 2307.01803
add_power!(zxwd,3)
return zxwd
end

"""
Prepare spider at loc for integration.

Perform the simplified step of zeroing out phase of spider
and readying it for integration
1. If target spider is X spider, turn it to Z by adding H to all its legs
2. Pull out the Phase of the spider
3. zero out the phase
4. change the current spider back to its original type if necessary,
will generate one extra H spider.
"""
function int_prep!(zxwd::ZXWDiagram{T,P}, loc::T) where {T,P}
set_phase!(zxwd, loc, Parameter(Val(:PiUnit), 0.0))

new_loc = @match spider_type(zxwd, loc) begin
X(_) => add_spider!(zxwd, H, [loc])
Z(_) => loc

Check warning on line 198 in src/diff.jl

View check run for this annotation

Codecov / codecov/patch

src/diff.jl#L198

Added line #L198 was not covered by tests
_ => error("Not a valid Spider to integrate over")
end
return new_loc
end
65 changes: 59 additions & 6 deletions src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
Factor(f, _) => Parameter(Val(:Factor), f)
end

function Base.show(io::IO, p::Parameter)
function Base.show(io::IO, ::MIME"text/plain", p::Parameter)

Check warning on line 45 in src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter.jl#L45

Added line #L45 was not covered by tests
@match p begin
PiUnit(pu, _) && if pu isa Number
end => print(io, "$(pu)⋅π")
Expand Down Expand Up @@ -70,9 +70,27 @@
Base.:(==)(p1::Parameter, p2::Number) = eqeq(p1, p2)
Base.:(==)(p1::Number, p2::Parameter) = eqeq(p2, p1)

# following the same convention in Phase.jl implementation
function Base.contains(p::Parameter, θ::Symbol)
@match p begin
PiUnit(pu, pt) && if !(pt <: Number)
end => Base.contains(repr(pu), ":" * string(θ))
_ => false
end
end

function Base.contains(p::Parameter, θ::Expr)
@match p begin
PiUnit(pu, pt) && if !(pt <: Number)
end => Base.contains(repr(pu), ":(" * string(θ) * ")")
_ => false
end
end

# following the same convention in phase.jl implementation
# comparison have inconsistent, we are comparing phases to numbers
# if cause trouble, will change
#

Base.isless(p1::Parameter, p2::Number) = @match p1 begin
PiUnit(_...) => p1.pu isa Number && p1.pu < p2
_ => p1.f < p2
Expand Down Expand Up @@ -109,8 +127,9 @@
@match (p1, p2) begin
(PiUnit(pu1, _), PiUnit(pu2, _)) && if pu1 isa Number && pu2 isa Number
end => Parameter(Val(:PiUnit), pu1 + pu2)
(PiUnit(pu1, _), PiUnit(pu2, _)) && if !(pu1 isa Number) || !(pu2 isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :+, pu1, pu2))
(PiUnit(pu1, pu1_t), PiUnit(pu2, pu2_t)) &&
if !(pu1 isa Number) || !(pu2 isa Number)
end => PiUnit(Expr(:call, :+, pu1, pu2), Base.promote_op(+, pu1_t, pu2_t))
(Factor(f1, _), Factor(f2, _)) => Parameter(Val(:Factor), f1 + f2)
(PiUnit(pu1, _), Factor(f2, _)) => Parameter(Val(:Factor), exp(im * pu1 * π) * f2)
(Factor(f1, _), PiUnit(pu2, _)) => Parameter(Val(:Factor), exp(im * pu2 * π) * f1)
Expand All @@ -134,8 +153,8 @@
end => Parameter(Val(:PiUnit), pu1 - pu2)
(PiUnit(pu1, pu_t1), PiUnit(pu2, pu_t2)) &&
if !(pu1 isa Number) || !(pu2 isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :-, pu1, pu2))
(Factor(f1, _), Factor(f2, _)) => Factor(f1 - f2)
end => PiUnit(Expr(:call, :-, pu1, pu2), Base.promote_op(-, pu_t1, pu_t2))
(Factor(f1, _), Factor(f2, _)) => Parameter(Val(:Factor), f1 - f2)
(PiUnit(_...), Factor(_...)) => Parameter(Val(:Factor), exp(im * p1.pu * π) - p2.f)
(Factor(_...), PiUnit(_...)) => Parameter(Val(:Factor), p1.f - exp(im * p2.pu * π))
(_, PiUnit(_...)) => Parameter(Val(:PiUnit), p1 - p2.pu)
Expand All @@ -151,6 +170,29 @@
Base.:(-)(p1::Number, p2::Parameter) = subt_param(p1, p2)
Base.:(-)(p1::Parameter, p2::Number) = add_param(p1, -p2)

# needed for tensor contraction
function mul_param(p1, p2)
@match (p1, p2) begin
(PiUnit(p1u, _), PiUnit(p2u, _)) && if p1u isa Number && p2u isa Number
end => exp(im * (p1u + p2u) * π)
(PiUnit(p1u, _), n2::Number) && if p1u isa Number
end => exp(im * p1u * π) * n2
(Factor(f1, _), PiUnit(p2u, _)) && if p2u isa Number
end => f1 * exp(im * p2u * π)
(PiUnit(p1u, _), Factor(f2, _)) && if p1u isa Number
end => f2 * exp(im * p1u * π)
(Factor(f1, _), Factor(f2, _)) => f1 * f2
(Factor(f1, _), n2::Number) => f1 * n2
_ => error(

Check warning on line 186 in src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter.jl#L174-L186

Added lines #L174 - L186 were not covered by tests
"Invalid input '$(p1)' of type $(typeof(p1)) and '$(p2)' of type $(typeof(p2)) for ADT: *",
)
end
end

Base.:(*)(p1::Parameter, p2::Parameter) = mul_param(p1, p2)
Base.:(*)(p1::Parameter, p2::Number) = mul_param(p1, p2)
Base.:(*)(p1::Number, p2::Parameter) = mul_param(p2, p1)

Check warning on line 194 in src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter.jl#L192-L194

Added lines #L192 - L194 were not covered by tests

function Base.rem(p::Parameter, d::Number)
@match p begin
PiUnit(pu, pu_t) && if pu isa Number
Expand All @@ -163,3 +205,14 @@
)
end
end

function Base.inv(p::Parameter)
@match p begin
PiUnit(pu, pu_t) && if pu isa Number
end => Parameter(Val(:PiUnit), -pu)
PiUnit(pu, pu_t) && if !(pu isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :-, pu))
Factor(f, _) => Parameter(Val(:Factor), inv(f))

Check warning on line 215 in src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter.jl#L215

Added line #L215 was not covered by tests
_ => error("Invalid input '$(p)' of type $(typeof(p)) for ADT: inv")
end
end
18 changes: 13 additions & 5 deletions src/to_eincode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,25 @@ function to_eincode(zxwd::ZXWDiagram{T,P}) where {T,P}
Input(q) => nothing
Output(q) => nothing
end
if res !== nothing

if res != nothing
exAClior marked this conversation as resolved.
Show resolved Hide resolved
push!(ixs, to_eincode_indices(zxwd, v))
push!(tensors, res)
else
push!(iy, to_eincode_indices(zxwd, v)[])
end
end

for v in get_outputs(zxwd)
push!(iy, to_eincode_indices(zxwd, v)[])
end

for v in get_inputs(zxwd)
push!(iy, to_eincode_indices(zxwd, v)[])
end

scalar_tensor = zeros(ComplexF64, ())

scalar_tensor[] = ZXCalculus.unwrap_scalar(scalar(zxwd))
push!(ixs, [])
push!(ixs, Tuple{T,T,T}[])
push!(tensors, scalar_tensor)
return EinCode(ixs, iy), tensors
end
Expand All @@ -60,6 +67,7 @@ function to_eincode_indices(zxwd::ZXWDiagram{T,P}, v) where {T,P}
end
return ids
end

edge_index(v1, v2, mul) = (min(v1, v2), max(v1, v2), mul)

function z_tensor(n::Int, α::Parameter)
Expand All @@ -81,7 +89,7 @@ function x_tensor(n::Int, α::Parameter)
shape = (fill(2, n)...,)
factor = @match α begin
PiUnit(pu, _) => exp(im * pu * π)
Factor(f,_) => f
Factor(f, _) => f
_ => error("Invalid parameter type for X-spider")
end
return reshape(reduce(kron, fill(pos, n)) + factor * reduce(kron, fill(neg, n)), shape)
Expand Down
Loading
Loading