In [None]:
using LinearAlgebra
using Convex, SCS
using Plots
using Statistics

In [None]:
function randomkraus(D, M)
    Ks = [randn(D, D) for _ in 1:M]
    a = sum(K'K for K in Ks)
    b = sqrt(inv(a))
    return [K * b for K in Ks]
end

function randomdensity(D)
    v = randn(D)
    ψ = v / norm(v)
    return ψ * ψ'
end

applychannel(ρ, Ks) = sum(K * ρ * K' for K in Ks)

In [None]:
# Objectives

# Extend
Convex.square(x::Float64) = x^2

# opdiff: score of difference between two lists of operators
opsdiff(Ksleft, Ksright) = sum(square(norm(vec(Kl - Kr), 2)) for (Kl, Kr) in zip(Ksleft, Ksright))

# singleobj: score of difference between channel output and target state
singleobj(ρ₀, ρ₁, Ks) = square(norm(vec(sum(K * ρ₀ * K' for K in Ks) - ρ₁), 2))

# obj: score of difference between channel outputs and target states
obj(ρ0s, ρ1s, Ks) = sum(singleobj(ρ0, ρ1, Ks) for (ρ0, ρ1) in zip(ρ0s, ρ1s))

# diffobj: score of difference between output of channel with different left/right operators and target state
singlediffobj(Ksleft, ρ₀, Ksright, ρ₁) = square(norm(vec(sum(Kl * ρ₀ * Kr' for (Kl, Kr) in zip(Ksleft, Ksright)) - ρ₁), 2))

# obj: score of difference between channel outputs and target states
diffobj(Ksleft, ρ0s, Ksright, ρ1s) = 1/length(ρ0s) * sum(singlediffobj(Ksleft, ρ0, Ksright, ρ1) for (ρ0, ρ1) in zip(ρ0s, ρ1s))

# krausnorm: norm of left/right Kraus operators
krausnorm(Ks) = sum(K'K for K in Ks)
krausnorm(Ksleft, Ksright) = sum(Kr'Kl for (Kl, Kr) in zip(Ksleft, Ksright))

# objective with different Kraus operators on left and right. Scale regularization λ by the number of Kraus operators
regobj(Ksleft, ρ0s, Ksright, ρ1s, λ) = diffobj(Ksleft, ρ0s, Ksright, ρ1s) + λ/length(Ksleft) * opsdiff(Ksleft, Ksright)

function seqoptmult(ρ0s, ρ1s, Ksinit; λ, niters)
    D = size(ρ0s[1])[1]
    M = length(Ksinit)
    objvals = [] # list of objective values over iterations
    regobjvals = [] # list of objective (with regularization) values
    opsdiffs = [] # list of difference scores between left and right Kraus operators

    Kscurr = copy(Ksinit)
    push!(objvals, obj(ρ0s, ρ1s, Kscurr))
    Ksvar  = [Variable(D, D) for i in 1:M] # optimization variables
    for iter in 1:niters
        # optimize Kraus operators on right side of the density
        rightobj = regobj(Kscurr, ρ0s, Ksvar, ρ1s, λ)
        rightconstr = [krausnorm(Kscurr, Ksvar) ≤ I(D)]
        rightproblem = minimize(rightobj, rightconstr)
        solve!(rightproblem, SCS.Optimizer, silent_solver = true)

        push!(regobjvals, evaluate(rightobj))
        push!(opsdiffs, evaluate(opsdiff(Kscurr, Ksvar)))
        Kscurr = [K.value for K in Ksvar]
        push!(objvals, obj(ρ0s, ρ1s, Kscurr))

        # optimize Kraus operators on left side of the density
        leftobj = regobj(Ksvar, ρ0s, Kscurr, ρ1s, λ)
        leftconstr = [krausnorm(Ksvar, Kscurr) ≤ I(D)]
        leftproblem = minimize(leftobj, leftconstr)
        solve!(leftproblem, SCS.Optimizer, silent_solver = true)

        push!(regobjvals, evaluate(leftobj))
        push!(opsdiffs, evaluate(opsdiff(Ksvar, Kscurr)))
        Kscurr = [K.value for K in Ksvar]
        push!(objvals, obj(ρ0s, ρ1s, Kscurr))
    end
    return Kscurr, objvals, regobjvals, opsdiffs
end

In [None]:
function getobjsamples(params, nsamples; niters=100)
    D, M, P, λ = params
    objsamples = []
    for i in 1:nsamples
        # Generate data
        ρ0s = [randomdensity(D) for _ in 1:P]           # generate random initial states
        Kstrue = randomkraus(D, M)                      # generate random Kraus operators
        ρ1s = [applychannel(ρ, Kstrue) for ρ in ρ0s]    # pass initial states through channel

        Ksinit = randomkraus(D, M)
        optKs, objvals, regobjvals, opdiffs = seqoptmult(ρ0s, ρ1s, Ksinit; λ, niters)
        push!(objsamples, objvals)
    end
    return objsamples
end

sampleparams(allparams, nsamples; niters) = [getobjsamples(params, nsamples; niters) for params in allparams]

In [None]:
Ds = [2^1]      # number of dimensions
Ms = [2]        # number of Kraus operators
Ps = [10]       # number of datapoints
nsamples = 5    # number of experiments to perform

# Regularization
nλs = 10
λmin = 0.001
λmax = 10
λs = 10 .^ LinRange(log10(λmin), log10(λmax), nλs)

allparams = Base.product(Ds, Ms, Ps, λs)
allobjsN1λs = sampleparams(allparams, nsamples; niters=100);

In [None]:
labels = ["λ = " * string(round(λ, digits=3)) for (_, _, _, λ) in allparams]

p = plot(layout=(length(labels) ÷ 2, 2), size=(1000, nλs * 150), left_margin=8Plots.mm, dpi=400)
for i in eachindex(labels)
    λobjs = allobjsN1λs[i]
    plot!(p, λobjs, subplot=i, yscale=:log10, title=labels[i], xlabel="Iteration", ylabel="Value", label=nothing)
end
p

In [None]:
png("N1-M2-P10-lambda-sweep")

In [None]:
λkeep =  [1, 4, 7, 10]
subsetparams = collect(allparams)[:, :, :, λkeep]
labels = ["λ = " * string(round(λ, digits=3)) for (_, _, _, λ) in subsetparams]

p = plot(layout=(2, 2), size=(1400, 800), margin=10Plots.mm, dpi=400)
for (i, λidx) in enumerate(λkeep)
    λobjs = allobjsN1λs[λidx]
    plot!(p, λobjs, subplot=i, yscale=:log10, title=labels[i], xlabel="Iteration", ylabel="Objective value", label=nothing)
end
p

In [None]:
png("N1-M2-lambda-subset")

In [None]:
λkeep =  [1, 7, 10]
subsetparams = collect(allparams)[:, :, :, λkeep]
labels = ["λ = " * string(round(λ, digits=3)) for (_, _, _, λ) in subsetparams]

p = plot(layout=(1, 3), size=(1600, 400), margin=8Plots.mm, dpi=400)
for (i, λidx) in enumerate(λkeep)
    λobjs = allobjsN1λs[λidx]
    plot!(p, λobjs, subplot=i, yscale=:log10, title=labels[i], xlabel="Iteration", ylabel="Objective value", label=nothing)
end
p

In [None]:
png("N1-M2-lambda-subset3")

In [None]:
Ds = [2^1]      # number of dimensions
Ms = [2]        # number of Kraus operators
Ps = [10]       # number of datapoints
nsamples = 10    # number of experiments to perform

# Regularization
nλs = 10
λmin = 0.05
λmax = 1
λs = 10 .^ LinRange(log10(λmin), log10(λmax), nλs)

allparams = Base.product(Ds, Ms, Ps, λs)
allobjsN1λsnarrow = sampleparams(allparams, nsamples; niters=100);

In [None]:
labels = ["λ = " * string(round(λ, digits=3)) for (_, _, _, λ) in allparams]

p = plot(layout=(length(labels) ÷ 2, 2), size=(1000, nλs * 150), left_margin=8Plots.mm, dpi=400)
for i in eachindex(labels)
    λobjs = allobjsN1λsnarrow[i]
    plot!(p, λobjs, subplot=i, yscale=:log10, title=labels[i], xlabel="Iteration", ylabel="Objective value", label=nothing)
end
p

In [None]:
png("N1-M2-P10-lambda05to1")

In [None]:
Ds = [2^1]           # number of dimensions
Ms = [4]    # number of Kraus operators
Ps = [10]            # number of datapoints
nsamples = 10         # number of experiments to perform
λs = [0.3]         

allparams = Base.product(Ds, Ms, Ps, λs)
allobjsN1Ms = sampleparams(allparams, nsamples; niters=100);

In [None]:
labels = ["M = " * string(M) for (_, M, _, _) in allparams]

p = plot(layout=(length(labels) ÷ 2, 2), size=(1000, nλs * 80), left_margin=8Plots.mm, dpi=400)
for i in eachindex(labels)
    λobjs = allobjsN1Ms[i]
    plot!(p, λobjs, subplot=i, yscale=:log10, title=labels[i], xlabel="Iteration", ylabel="Objective value", label=nothing)
end
p

In [None]:
png("N1-M1to4-P10-lambda03")

In [None]:
Ds = [2^2]           # number of dimensions
Ms = [2, 4, 8]    # number of Kraus operators
Ps = [10]            # number of datapoints
nsamples = 5         # number of experiments to perform
λs = [0.05]         

allparams = Base.product(Ds, Ms, Ps, λs)
allobjsN2Ms = sampleparams(allparams, nsamples; niters=100);

In [None]:
labels = ["M = " * string(M) for (_, M, _, _) in allparams]

p = plot(layout=(1, length(labels)), size=(1500, 400), margin=8Plots.mm, dpi=400)
for i in eachindex(labels)
    λobjs = allobjsN2Ms[i]
    plot!(p, λobjs, subplot=i, yscale=:log10, title=labels[i], xlabel="Iteration", ylabel="Objective value", label=nothing)
end
p

In [None]:
png("N2-M248-P10-lambda03")

In [None]:
Ds = [2^3]           # number of dimensions
Ms = [8]    # number of Kraus operators
Ps = [10]            # number of datapoints
nsamples = 5         # number of experiments to perform
λs = [0.05]         

allparams = Base.product(Ds, Ms, Ps, λs)
allobjsN3Ms = sampleparams(allparams, nsamples; niters=100);

In [None]:
labels = ["M = " * string(M) for (_, M, _, _) in allparams]

p = plot(layout=(1, length(labels)), size=(1500, 400), margin=8Plots.mm, dpi=400)
for i in eachindex(labels)
    λobjs = allobjsN3Ms[i]
    plot!(p, λobjs, subplot=i, yscale=:log10, title=labels[i], xlabel="Iteration", ylabel="Objective value", label=nothing)
end
p

In [None]:
png("N3-M1to4-P10-lambda03")

In [None]:
Ds = [2^2]
Ms = [8]
Ps = 2 .^ LinRange(log2(8), log2(256), 6)
nsamples = 5
λs = [0.3]         

allparams = Base.product(Ds, Ms, Ps, λs)
# allobjsnN2Ps = sampleparams(allparams, nsamples; niters=50);

In [None]:
labels = ["P = " * string(Int(floor(P))) for (_, _, P, _) in allparams]

p = plot(layout=(2, 3), size=(1500, 800), margin=8Plots.mm, top_margin=0Plots.mm, dpi=400)
for i in eachindex(labels)
    objs = allobjsnN2Ps[i]
    plot!(p, objs, subplot=i, yscale=:log10, title=labels[i], xlabel="Iteration", ylabel="Objective value", label=nothing)
end
p

In [None]:
png("N2-M8-P16to256-lambda03")

In [None]:
subparams = collect(allparams)[:, :, [2, 4, 6], :]
labels = ["P = " * string(Int(floor(P))) for (_, _, P, _) in subparams]

p = plot(layout=(1, 3), size=(1600, 400), margin=8Plots.mm, top_margin=0Plots.mm, dpi=400)
for i in eachindex(labels)
    objs = allobjsnN2Ps[i]
    plot!(p, objs, subplot=i, yscale=:log10, title=labels[i], xlabel="Iteration", ylabel="Objective value", label=nothing)
end
p

In [None]:
png("N2-M8-P16to256-lambda03-subset")

In [None]:
N1M4 = allobjsN1Ms[end]
N2M8 = allobjsN2Ms[end]
N3M8 = allobjsN3Ms[end]

p = plot(layout=(1, 3), size=(1500, 400), left_margin=8Plots.mm, bottom_margin=8Plots.mm, dpi=400)
plot!(p, N1M4, subplot=1, yscale=:log10, title="D = 2, M = 4", xlabel="Iteration", ylabel="Objective value", label=nothing)
plot!(p, N2M8, subplot=2, yscale=:log10, title="D = 4, M = 8", xlabel="Iteration", ylabel="Objective value", label=nothing)
plot!(p, N3M8, subplot=3, yscale=:log10, title="D = 8, M = 8", xlabel="Iteration", ylabel="Objective value", label=nothing)

In [None]:
png("N123")