# GHZ state in Rydberg atoms

In [1]:
using Revise
using QuantumOptimalControl
using QuantumOptics
using LinearAlgebra
using Flux, DiffEqFlux
using Optim
using PlotlyJS
using ProgressMeter
using Random
using DifferentialEquations: DP5, Tsit5, BS3, Vern7
using NLopt
ProgressMeter.ijulia_behavior(:clear)

false

In [2]:
bs = SpinBasis(1//2)
sx = sigmax(bs)
ni = 0.5*(identityoperator(bs) + sigmaz(bs));

In [3]:
V = 2π*24.0
δe = -2π*4.5

-28.274333882308138

In [4]:
n_atoms = 4
bsys = tensor([bs for i in 1:n_atoms]...);

In [5]:
H0 = V*sum([embed(bsys, [i, j], [ni, ni])/abs(i-j)^6  for i in 1:n_atoms for j in i+1:n_atoms])
H0 -= δe*sum([embed(bsys, [i], [ni]) for i in [1, n_atoms]]);

In [6]:
H1 = 0.5*sum([embed(bsys, [i], [sx]) for i in 1:n_atoms])
H2 = -sum([embed(bsys, [i], [ni]) for i in 1:n_atoms]);

In [7]:
function GHZ_state(n_atoms)
    state = tensor([spindown(bs)⊗spinup(bs) for i in 1:Int(n_atoms/2)]...) +
            tensor([spinup(bs)⊗spindown(bs) for i in 1:Int(n_atoms/2)]...)
    state/sqrt(2.0)
end 

ground_state(n_atoms) = tensor([spindown(bs) for i in 1:n_atoms]...)
trans = StateTransform(ground_state(n_atoms)=>GHZ_state(n_atoms));

In [8]:
n_neurons = 8
sigmoid(x)= @. 2π*7 / (1 + exp(-x))
Random.seed!(10)
ann = FastChain(FastDense(1, n_neurons, tanh), 
                FastDense(n_neurons, n_neurons, tanh), 
                FastDense(n_neurons, 2))
θ = initial_params(ann)  
n_params = length(θ)

106

In [9]:
t0, t1 = 0.0, 0.5

tsf32 = Float32(t0):Float32(t1/49):Float32(t1)
Ωs = Vector{Float32}(2π*vcat(0:0.5:4, 5*ones(32), 4:-0.5:0))
Δs = Vector{Float32}(2π*(-5:10/49:5))
ts = Vector{Float64}(tsf32)

function loss(p)
    c = 0.0f0
    for (i,t) in enumerate(tsf32)
        x = ann([t], p)
        c += (abs(x[1]) - Ωs[i])^2
        c += (x[2] - Δs[i])^2
    end
    #println(c)
    c
end

res = DiffEqFlux.sciml_train(loss, initial_params(ann), ADAM(0.1f0), maxiters = 5000)
θ = Vector{Float64}(res.u);

In [10]:
coeffs(params, t) = let vals = ann([t], params)
                        [abs(vals[1]), vals[2]]
                    end    

cost = CostFunction((x, y) -> 1-real(-x'*y),
                     p->2e-3*(abs(ann([t0], p)[1])+ 5.0*abs(ann([t1], p)[1])))

CostFunction(var"#19#21"(), var"#20#22"())

In [11]:
H = Hamiltonian(H0, [H1, H2], coeffs);

In [12]:
prob = QOCProblem(H, trans, (t0, t1), cost);

In [13]:
sol = solve(prob, θ, ADAM(0.005); maxiter=400)
sol1 = solve(prob, sol.params, ADAM(0.01); maxiter=1400)

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:12:08[39m
[34m  distance:    0.005951537629625747[39m
[34m  contraints:  0.00040549648642542737[39m


Solution{Float64}([-4.544822327106292, -4.068141895631946, 4.701768797716582, -6.893103439030881, -11.971791252887998, -12.988360104009205, 14.189507083340075, -4.492489429590117, 1.6220867295192256, 1.657703255124861  …  1.5339208144406142, -4.941319704606653, 3.5718494229834894, -2.942115068888514, -4.474609409244589, 17.476607464984383, 20.114562521978957, 5.118211583458067, 4.21527028863117, 1.8396845742651653], [0.23804419965813484, 0.24946634502641107, 0.2900841464038537, 0.24154403690848547, 0.2642953286454417, 0.2656963327784123, 0.24140802959771557, 0.24129662427958298, 0.2594221828507768, 0.2541560227857427  …  0.005882897350956595, 0.005889219036561144, 0.006045385263037328, 0.006096852836517397, 0.005986848483866991, 0.005895991729720884, 0.0059203648809504905, 0.0059157943578218575, 0.005873625064934496, 0.005951537629625747])

In [14]:
plot(sol1.trace)

In [15]:
Ω(t) = abs(ann([t], sol1.params)[1])/2π
Δ(t) = ann([t], sol1.params)[2]/2π
ts = collect(t0:0.001:t1)

f = plot(
    [
        scatter(x=ts, y=Ω.(ts), name="Ω/2π"),
        scatter(x=ts, y=Δ.(ts), name="Δ/2π"),
    ],
    Layout(
        xaxis_title_text="Time (µs)",
        yaxis_title_text="Frequency (MHz)",
        legend=attr(x=0, y=1,),
        font=attr(
            size=16
        )
    )
)
savefig(f, "GHZ_4_atoms_wfs.eps")

"GHZ_4_atoms_wfs.eps"

In [16]:
tout, psit = schroedinger_dynamic(ts, ground_state(n_atoms), H, sol1.params);

In [17]:
f = plot(
    [
        scatter(x=ts, y=real(expect(dm(GHZ_state(n_atoms)), psit))),
    ],
    Layout(
        xaxis_title_text="Time (µs)",
        yaxis_title_text="Overlap (|⟨ψ|GHZ⟩|²)",
        font=attr(
            size=16
        )
    )
)
savefig(f,"GHZ_4_atoms_overlap.eps")

"GHZ_4_atoms_overlap.eps"