Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZXW rewrite rules + Differentiation #93

Merged
merged 40 commits into from
Aug 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
88b8ebc
include previous left change
exAClior Jul 3, 2023
88dfeab
implement first two rewrite rule
exAClior Jul 7, 2023
0090310
add revise, finish implment match
exAClior Jul 8, 2023
3edfbb1
finish test + implement basic rewrite
exAClior Jul 8, 2023
51d5696
begin diff operation implementation
exAClior Jul 9, 2023
90fa5bf
start implement diff rule for general case
exAClior Jul 10, 2023
3aeb785
finish diff implementation
exAClior Jul 11, 2023
806ab36
replace adt rules with multiple dispatch
exAClior Jul 14, 2023
7c47450
enable variable for both signs
exAClior Jul 15, 2023
923b27f
implement theorem 16 and start cor 17
exAClior Jul 15, 2023
1f3925d
finish implement circuit diff and expval diff
exAClior Jul 16, 2023
470a273
finish integrating
exAClior Jul 16, 2023
84d6309
add test to check spider numbers make sense
exAClior Jul 16, 2023
79bc45d
revise PR
exAClior Jul 17, 2023
6579c44
update variable substitution
exAClior Jul 24, 2023
4b982cb
add find negative symbol feature
exAClior Jul 24, 2023
ee905e6
fix einsum contraction order issue
exAClior Jul 25, 2023
33a35d5
fix concatenation of circuits
exAClior Jul 25, 2023
caabc29
finish test for exp val circuit
exAClior Jul 25, 2023
4216c94
construct variance ZXWDiagram test
exAClior Jul 25, 2023
1d4d79d
take care of importing vtx
exAClior Jul 27, 2023
044e2f5
implement stacking of zxwdiagram
exAClior Jul 27, 2023
b18ae24
correct triangle order
exAClior Jul 27, 2023
6805651
fix gradient calculation factor problem
exAClior Jul 28, 2023
dedf06b
test two variable integartion pass
exAClior Jul 28, 2023
b9a5b04
correct integration upto factor of 2
exAClior Jul 28, 2023
0c1e441
fix integration factors again
exAClior Jul 29, 2023
5629848
include test for variance
exAClior Jul 29, 2023
ee834f8
fix past test
exAClior Jul 29, 2023
4fe16ad
remove un-necessary dependencies
exAClior Aug 2, 2023
339e46b
correct nothing comparison
exAClior Aug 2, 2023
ab5a043
port differentiation
exAClior Aug 2, 2023
6eec454
fix PiUnit number testing quickly
exAClior Aug 2, 2023
dc3a3aa
make access output by qubit safer
exAClior Aug 2, 2023
6644b35
move helper functions to utils.jl
exAClior Aug 2, 2023
0c1c758
converted integration to rewrite API
exAClior Aug 2, 2023
504edb2
change intergration to fit rewrite API
exAClior Aug 3, 2023
e51272b
implement inplace stacking
exAClior Aug 4, 2023
da30a7a
remove un-necessary file
exAClior Aug 5, 2023
6870623
rename differentiatio rule
exAClior Aug 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/ZXCalculus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ export SpiderType, EdgeType, ZXWSpiderType
export AbstractZXDiagram, ZXDiagram, ZXGraph
export ZXWDiagram
export Rule, Match
export CalcRule

export spider_type, phase, spiders, rem_spider!, rem_spiders!, scalar
export parameter
export push_gate!, pushfirst_gate!, tcount
export push_gate!, pushfirst_gate!, tcount, insert_wtrig!
export convert_to_chain, convert_to_zxd
export rewrite!, simplify!, clifford_simplification, full_reduction,
circuit_extraction, phase_teleportation
export random_circuit
export substitute_variables!, expval_circ!, stack_zxwd!, concat!

include("parameter.jl")
include("phase.jl")
Expand All @@ -40,6 +41,7 @@ include("to_eincode.jl")
include("utils.jl")

include("rules.jl")
include("zxw_rules.jl")
include("simplify.jl")
include("circuit_extraction.jl")
include("phase_teleportation.jl")
Expand Down
65 changes: 59 additions & 6 deletions src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
Factor(f, _) => Parameter(Val(:Factor), f)
end

function Base.show(io::IO, p::Parameter)
function Base.show(io::IO, ::MIME"text/plain", p::Parameter)

Check warning on line 45 in src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter.jl#L45

Added line #L45 was not covered by tests
@match p begin
PiUnit(pu, _) && if pu isa Number
end => print(io, "$(pu)⋅π")
Expand Down Expand Up @@ -70,9 +70,27 @@
Base.:(==)(p1::Parameter, p2::Number) = eqeq(p1, p2)
Base.:(==)(p1::Number, p2::Parameter) = eqeq(p2, p1)

# following the same convention in Phase.jl implementation
function Base.contains(p::Parameter, θ::Symbol)
@match p begin
PiUnit(pu, pt) && if !(pu isa Number)
end => Base.contains(repr(pu), ":" * string(θ))
_ => false
end
end

function Base.contains(p::Parameter, θ::Expr)
@match p begin
PiUnit(pu, pt) && if !(pu isa Number)
end => Base.contains(repr(pu), ":(" * string(θ) * ")")
_ => false
end
end

# following the same convention in phase.jl implementation
# comparison have inconsistent, we are comparing phases to numbers
# if cause trouble, will change
#

Base.isless(p1::Parameter, p2::Number) = @match p1 begin
PiUnit(_...) => p1.pu isa Number && p1.pu < p2
_ => p1.f < p2
Expand Down Expand Up @@ -109,8 +127,9 @@
@match (p1, p2) begin
(PiUnit(pu1, _), PiUnit(pu2, _)) && if pu1 isa Number && pu2 isa Number
end => Parameter(Val(:PiUnit), pu1 + pu2)
(PiUnit(pu1, _), PiUnit(pu2, _)) && if !(pu1 isa Number) || !(pu2 isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :+, pu1, pu2))
(PiUnit(pu1, pu1_t), PiUnit(pu2, pu2_t)) &&
if !(pu1 isa Number) || !(pu2 isa Number)
end => PiUnit(Expr(:call, :+, pu1, pu2), Base.promote_op(+, pu1_t, pu2_t))
(Factor(f1, _), Factor(f2, _)) => Parameter(Val(:Factor), f1 + f2)
(PiUnit(pu1, _), Factor(f2, _)) => Parameter(Val(:Factor), exp(im * pu1 * π) * f2)
(Factor(f1, _), PiUnit(pu2, _)) => Parameter(Val(:Factor), exp(im * pu2 * π) * f1)
Expand All @@ -134,8 +153,8 @@
end => Parameter(Val(:PiUnit), pu1 - pu2)
(PiUnit(pu1, pu_t1), PiUnit(pu2, pu_t2)) &&
if !(pu1 isa Number) || !(pu2 isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :-, pu1, pu2))
(Factor(f1, _), Factor(f2, _)) => Factor(f1 - f2)
end => PiUnit(Expr(:call, :-, pu1, pu2), Base.promote_op(-, pu_t1, pu_t2))
(Factor(f1, _), Factor(f2, _)) => Parameter(Val(:Factor), f1 - f2)
(PiUnit(_...), Factor(_...)) => Parameter(Val(:Factor), exp(im * p1.pu * π) - p2.f)
(Factor(_...), PiUnit(_...)) => Parameter(Val(:Factor), p1.f - exp(im * p2.pu * π))
(_, PiUnit(_...)) => Parameter(Val(:PiUnit), p1 - p2.pu)
Expand All @@ -151,6 +170,29 @@
Base.:(-)(p1::Number, p2::Parameter) = subt_param(p1, p2)
Base.:(-)(p1::Parameter, p2::Number) = add_param(p1, -p2)

# needed for tensor contraction
function mul_param(p1, p2)
@match (p1, p2) begin
(PiUnit(p1u, _), PiUnit(p2u, _)) && if p1u isa Number && p2u isa Number
end => exp(im * (p1u + p2u) * π)
(PiUnit(p1u, _), n2::Number) && if p1u isa Number
end => exp(im * p1u * π) * n2
(Factor(f1, _), PiUnit(p2u, _)) && if p2u isa Number
end => f1 * exp(im * p2u * π)
(PiUnit(p1u, _), Factor(f2, _)) && if p1u isa Number
end => f2 * exp(im * p1u * π)
(Factor(f1, _), Factor(f2, _)) => f1 * f2
(Factor(f1, _), n2::Number) => f1 * n2
_ => error(

Check warning on line 186 in src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter.jl#L174-L186

Added lines #L174 - L186 were not covered by tests
"Invalid input '$(p1)' of type $(typeof(p1)) and '$(p2)' of type $(typeof(p2)) for ADT: *",
)
end
end

Base.:(*)(p1::Parameter, p2::Parameter) = mul_param(p1, p2)
Base.:(*)(p1::Parameter, p2::Number) = mul_param(p1, p2)
Base.:(*)(p1::Number, p2::Parameter) = mul_param(p2, p1)

Check warning on line 194 in src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter.jl#L192-L194

Added lines #L192 - L194 were not covered by tests

function Base.rem(p::Parameter, d::Number)
@match p begin
PiUnit(pu, pu_t) && if pu isa Number
Expand All @@ -163,3 +205,14 @@
)
end
end

function Base.inv(p::Parameter)
@match p begin
PiUnit(pu, pu_t) && if pu isa Number
end => Parameter(Val(:PiUnit), -pu)
PiUnit(pu, pu_t) && if !(pu isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :-, pu))
Factor(f, _) => Parameter(Val(:Factor), inv(f))

Check warning on line 215 in src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter.jl#L215

Added line #L215 was not covered by tests
_ => error("Invalid input '$(p)' of type $(typeof(p)) for ADT: inv")
end
end
18 changes: 13 additions & 5 deletions src/to_eincode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,25 @@ function to_eincode(zxwd::ZXWDiagram{T,P}) where {T,P}
Input(q) => nothing
Output(q) => nothing
end
if res !== nothing

if !isnothing(res)
push!(ixs, to_eincode_indices(zxwd, v))
push!(tensors, res)
else
push!(iy, to_eincode_indices(zxwd, v)[])
end
end

for v in get_outputs(zxwd)
push!(iy, to_eincode_indices(zxwd, v)[])
end

for v in get_inputs(zxwd)
push!(iy, to_eincode_indices(zxwd, v)[])
end

scalar_tensor = zeros(ComplexF64, ())

scalar_tensor[] = ZXCalculus.unwrap_scalar(scalar(zxwd))
push!(ixs, [])
push!(ixs, Tuple{T,T,T}[])
push!(tensors, scalar_tensor)
return EinCode(ixs, iy), tensors
end
Expand All @@ -60,6 +67,7 @@ function to_eincode_indices(zxwd::ZXWDiagram{T,P}, v) where {T,P}
end
return ids
end

edge_index(v1, v2, mul) = (min(v1, v2), max(v1, v2), mul)

function z_tensor(n::Int, α::Parameter)
Expand All @@ -81,7 +89,7 @@ function x_tensor(n::Int, α::Parameter)
shape = (fill(2, n)...,)
factor = @match α begin
PiUnit(pu, _) => exp(im * pu * π)
Factor(f,_) => f
Factor(f, _) => f
_ => error("Invalid parameter type for X-spider")
end
return reshape(reduce(kron, fill(pos, n)) + factor * reduce(kron, fill(neg, n)), shape)
Expand Down
Loading
Loading