This repository has been archived by the owner on Dec 18, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
/
focus.jl
134 lines (115 loc) · 3.94 KB
/
focus.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
using YaoBase, TupleTools
export focus!, relax!, partial_tr, exchange_sysenv
"""
contiguous_shape_orders(shape, orders)
Merge the shape and orders if the orders are contiguous. Returns the
new merged shape and order.
# Example
```jldoctest; setup=:(using YaoArrayRegister)
julia> YaoArrayRegister.contiguous_shape_orders((2, 3, 4), (1, 2, 3))
([24], [1])
```
"""
function contiguous_shape_orders(shape, orders)
new_shape, new_orders = Int[], Int[]
prv = -1
for cur in orders
if cur == prv + 1
new_shape[end] *= shape[cur]
else
push!(new_orders, cur)
push!(new_shape, shape[cur])
end
prv = cur
end
# NOTE: some of the orders are merged above
# we use sortperm to retrieve correct
# orders
inv_orders = sortperm(new_orders)
return new_shape[inv_orders], invperm(inv_orders)
end
# NOTE: don't use Vector for move_ahead, it's way slower!
# Before:
# julia> @benchmark move_ahead($c, $o)
# BenchmarkTools.Trial:
# memory estimate: 14.87 KiB
# allocs estimate: 392
# --------------
# minimum time: 29.062 μs (0.00% GC)
# median time: 31.316 μs (0.00% GC)
# mean time: 39.724 μs (14.37% GC)
# maximum time: 39.013 ms (99.86% GC)
# --------------
# samples: 10000
# evals/sample: 1
# After:
# julia> @benchmark move_ahead($c, $o)
# BenchmarkTools.Trial:
# memory estimate: 4.04 KiB
# allocs estimate: 24
# --------------
# minimum time: 2.848 μs (0.00% GC)
# median time: 3.045 μs (0.00% GC)
# mean time: 4.013 μs (18.20% GC)
# maximum time: 4.494 ms (99.86% GC)
# --------------
# samples: 10000
# evals/sample: 9
"""
move_ahead(collection, orders)
Move `orders` to the beginning of `collection`.
"""
move_ahead(collection, orders) = (orders..., setdiff(collection, orders)...)
move_ahead(ndim::Int, orders) = (orders..., setdiff(1:ndim, orders)...)
function group_permutedims(A::AbstractArray, orders)
@assert length(orders) == ndims(A) "number of orders does not match number of dimensions"
return unsafe_group_permutedims(A, orders)
end
# forward directly if the length is the same with ndims
function group_permutedims(A::AbstractArray{T,N}, orders::NTuple{N,Int}) where {T,N}
return unsafe_group_permutedims(A, orders)
end
function unsafe_group_permutedims(A::AbstractArray, orders)
s, o = contiguous_shape_orders(size(A), orders)
return permutedims(reshape(A, s...), o)
end
"""
is_order_same(locs) -> Bool
Check if the order specified by `locs` is the same as current order.
"""
is_order_same(locs) = all(a == b for (a, b) in zip(locs, 1:length(locs)))
# NOTE: locations is not the same with orders
# locations: some location of the wire
# orders: includes all the location of the wire in some order
function YaoBase.focus!(r::ArrayReg{B}, locs) where {B}
if is_order_same(locs)
arr = r.state
else
new_orders = move_ahead(nactive(r) + 1, locs)
arr = group_permutedims(hypercubic(r), new_orders)
end
r.state = reshape(arr, 1 << length(locs), :)
return r
end
function YaoBase.relax!(r::ArrayReg{B}, locs; to_nactive::Int = nqubits(r)) where {B}
r.state = reshape(state(r), 1 << to_nactive, :)
if !is_order_same(locs)
new_orders = TupleTools.invperm(move_ahead(to_nactive + 1, locs))
r.state = reshape(group_permutedims(hypercubic(r), new_orders), 1 << to_nactive, :)
end
return r
end
function YaoBase.partial_tr(r::ArrayReg{B}, locs) where {B}
orders = setdiff(1:nqubits(r), locs)
focus!(r, orders)
state = sum(rank3(r); dims = 2)
relax!(r, orders)
return normalize!(ArrayReg(state))
end
"""
exchange_sysenv(reg::ArrayReg) -> ArrayReg
Exchange system (focused qubits) and environment (remaining qubits).
"""
function exchange_sysenv(reg::ArrayReg{B}) where {B}
ArrayReg{B}(reshape(permutedims(rank3(reg), (2, 1, 3)), :, size(reg.state, 1) * B))
end