Skip to content

Commit

Permalink
ZXW rewrite rules + Differentiation (#93)
Browse files Browse the repository at this point in the history
* add variable substitution feature

* add find negative symbol feature

* fix einsum contraction order issue

* implement expectation value zxw diagram

*implement differentiation and integration with rewrite API

* implement inplace concatenation

* implement inplace stacking
  • Loading branch information
exAClior committed Aug 6, 2023
1 parent 5161a51 commit 6bb22d9
Show file tree
Hide file tree
Showing 11 changed files with 980 additions and 41 deletions.
6 changes: 4 additions & 2 deletions src/ZXCalculus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ export SpiderType, EdgeType, ZXWSpiderType
export AbstractZXDiagram, ZXDiagram, ZXGraph
export ZXWDiagram
export Rule, Match
export CalcRule

export spider_type, phase, spiders, rem_spider!, rem_spiders!, scalar
export parameter
export push_gate!, pushfirst_gate!, tcount
export push_gate!, pushfirst_gate!, tcount, insert_wtrig!
export convert_to_chain, convert_to_zxd
export rewrite!, simplify!, clifford_simplification, full_reduction,
circuit_extraction, phase_teleportation
export random_circuit
export substitute_variables!, expval_circ!, stack_zxwd!, concat!

include("parameter.jl")
include("phase.jl")
Expand All @@ -40,6 +41,7 @@ include("to_eincode.jl")
include("utils.jl")

include("rules.jl")
include("zxw_rules.jl")
include("simplify.jl")
include("circuit_extraction.jl")
include("phase_teleportation.jl")
Expand Down
65 changes: 59 additions & 6 deletions src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Base.copy(p::Parameter) = @match p begin
Factor(f, _) => Parameter(Val(:Factor), f)
end

function Base.show(io::IO, p::Parameter)
function Base.show(io::IO, ::MIME"text/plain", p::Parameter)
@match p begin
PiUnit(pu, _) && if pu isa Number
end => print(io, "$(pu)⋅π")
Expand Down Expand Up @@ -70,9 +70,27 @@ Base.:(==)(p1::Parameter, p2::Parameter) = eqeq(p1, p2)
Base.:(==)(p1::Parameter, p2::Number) = eqeq(p1, p2)
Base.:(==)(p1::Number, p2::Parameter) = eqeq(p2, p1)

# following the same convention in Phase.jl implementation
function Base.contains(p::Parameter, θ::Symbol)
@match p begin
PiUnit(pu, pt) && if !(pu isa Number)
end => Base.contains(repr(pu), ":" * string(θ))
_ => false
end
end

function Base.contains(p::Parameter, θ::Expr)
@match p begin
PiUnit(pu, pt) && if !(pu isa Number)
end => Base.contains(repr(pu), ":(" * string(θ) * ")")
_ => false
end
end

# following the same convention in phase.jl implementation
# comparison have inconsistent, we are comparing phases to numbers
# if cause trouble, will change
#

Base.isless(p1::Parameter, p2::Number) = @match p1 begin
PiUnit(_...) => p1.pu isa Number && p1.pu < p2
_ => p1.f < p2
Expand Down Expand Up @@ -109,8 +127,9 @@ function add_param(p1, p2)
@match (p1, p2) begin
(PiUnit(pu1, _), PiUnit(pu2, _)) && if pu1 isa Number && pu2 isa Number
end => Parameter(Val(:PiUnit), pu1 + pu2)
(PiUnit(pu1, _), PiUnit(pu2, _)) && if !(pu1 isa Number) || !(pu2 isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :+, pu1, pu2))
(PiUnit(pu1, pu1_t), PiUnit(pu2, pu2_t)) &&
if !(pu1 isa Number) || !(pu2 isa Number)
end => PiUnit(Expr(:call, :+, pu1, pu2), Base.promote_op(+, pu1_t, pu2_t))
(Factor(f1, _), Factor(f2, _)) => Parameter(Val(:Factor), f1 + f2)
(PiUnit(pu1, _), Factor(f2, _)) => Parameter(Val(:Factor), exp(im * pu1 * π) * f2)
(Factor(f1, _), PiUnit(pu2, _)) => Parameter(Val(:Factor), exp(im * pu2 * π) * f1)
Expand All @@ -134,8 +153,8 @@ function subt_param(p1, p2)
end => Parameter(Val(:PiUnit), pu1 - pu2)
(PiUnit(pu1, pu_t1), PiUnit(pu2, pu_t2)) &&
if !(pu1 isa Number) || !(pu2 isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :-, pu1, pu2))
(Factor(f1, _), Factor(f2, _)) => Factor(f1 - f2)
end => PiUnit(Expr(:call, :-, pu1, pu2), Base.promote_op(-, pu_t1, pu_t2))
(Factor(f1, _), Factor(f2, _)) => Parameter(Val(:Factor), f1 - f2)
(PiUnit(_...), Factor(_...)) => Parameter(Val(:Factor), exp(im * p1.pu * π) - p2.f)
(Factor(_...), PiUnit(_...)) => Parameter(Val(:Factor), p1.f - exp(im * p2.pu * π))
(_, PiUnit(_...)) => Parameter(Val(:PiUnit), p1 - p2.pu)
Expand All @@ -151,6 +170,29 @@ Base.:(-)(p1::Parameter, p2::Parameter) = subt_param(p1, p2)
Base.:(-)(p1::Number, p2::Parameter) = subt_param(p1, p2)
Base.:(-)(p1::Parameter, p2::Number) = add_param(p1, -p2)

# needed for tensor contraction
function mul_param(p1, p2)
@match (p1, p2) begin
(PiUnit(p1u, _), PiUnit(p2u, _)) && if p1u isa Number && p2u isa Number
end => exp(im * (p1u + p2u) * π)
(PiUnit(p1u, _), n2::Number) && if p1u isa Number
end => exp(im * p1u * π) * n2
(Factor(f1, _), PiUnit(p2u, _)) && if p2u isa Number
end => f1 * exp(im * p2u * π)
(PiUnit(p1u, _), Factor(f2, _)) && if p1u isa Number
end => f2 * exp(im * p1u * π)
(Factor(f1, _), Factor(f2, _)) => f1 * f2
(Factor(f1, _), n2::Number) => f1 * n2
_ => error(
"Invalid input '$(p1)' of type $(typeof(p1)) and '$(p2)' of type $(typeof(p2)) for ADT: *",
)
end
end

Base.:(*)(p1::Parameter, p2::Parameter) = mul_param(p1, p2)
Base.:(*)(p1::Parameter, p2::Number) = mul_param(p1, p2)
Base.:(*)(p1::Number, p2::Parameter) = mul_param(p2, p1)

function Base.rem(p::Parameter, d::Number)
@match p begin
PiUnit(pu, pu_t) && if pu isa Number
Expand All @@ -163,3 +205,14 @@ function Base.rem(p::Parameter, d::Number)
)
end
end

function Base.inv(p::Parameter)
@match p begin
PiUnit(pu, pu_t) && if pu isa Number
end => Parameter(Val(:PiUnit), -pu)
PiUnit(pu, pu_t) && if !(pu isa Number)
end => Parameter(Val(:PiUnit), Expr(:call, :-, pu))
Factor(f, _) => Parameter(Val(:Factor), inv(f))
_ => error("Invalid input '$(p)' of type $(typeof(p)) for ADT: inv")
end
end
18 changes: 13 additions & 5 deletions src/to_eincode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,25 @@ function to_eincode(zxwd::ZXWDiagram{T,P}) where {T,P}
Input(q) => nothing
Output(q) => nothing
end
if res !== nothing

if !isnothing(res)
push!(ixs, to_eincode_indices(zxwd, v))
push!(tensors, res)
else
push!(iy, to_eincode_indices(zxwd, v)[])
end
end

for v in get_outputs(zxwd)
push!(iy, to_eincode_indices(zxwd, v)[])
end

for v in get_inputs(zxwd)
push!(iy, to_eincode_indices(zxwd, v)[])
end

scalar_tensor = zeros(ComplexF64, ())

scalar_tensor[] = ZXCalculus.unwrap_scalar(scalar(zxwd))
push!(ixs, [])
push!(ixs, Tuple{T,T,T}[])
push!(tensors, scalar_tensor)
return EinCode(ixs, iy), tensors
end
Expand All @@ -60,6 +67,7 @@ function to_eincode_indices(zxwd::ZXWDiagram{T,P}, v) where {T,P}
end
return ids
end

edge_index(v1, v2, mul) = (min(v1, v2), max(v1, v2), mul)

function z_tensor(n::Int, α::Parameter)
Expand All @@ -81,7 +89,7 @@ function x_tensor(n::Int, α::Parameter)
shape = (fill(2, n)...,)
factor = @match α begin
PiUnit(pu, _) => exp(im * pu * π)
Factor(f,_) => f
Factor(f, _) => f
_ => error("Invalid parameter type for X-spider")
end
return reshape(reduce(kron, fill(pos, n)) + factor * reduce(kron, fill(neg, n)), shape)
Expand Down
Loading

0 comments on commit 6bb22d9

Please sign in to comment.