# Get the fidelities and runtimes for the surface codes

Long runtime should be anticipated, and vary for different machine used. 

In [1]:
using Distributed

num_cores = length(Sys.cpu_info())
if nprocs()==1
    addprocs(num_cores; exeflags=`--project=$(Base.active_project())`)
end    

@everywhere begin
    using LatticeAlgorithms
    using Counters
    using LinearAlgebra
end
using Plots
using JLD2
using LinearAlgebra
using LaTeXStrings


In [None]:
nprocs()

In [None]:
@everywhere begin

    function get_a(ξ_bar, l, σ)
        return sum([exp(-(ξ_bar + n * l)^2 / (2 * σ^2)) for n in -2:2:2])
    end

    function get_b(ξ_bar, l, σ)
        return sum([exp(-(ξ_bar + n * l)^2 / (2 * σ^2)) for n in -3:2:3])
    end 


    function decode_stabilizer_GKP_code_non_exact(x::Vector, stabilizers::Dict{Int64, Vector{Int64}}, σ::Float64)
        x = x/√π

        # get the weights for each qubit, which are the edges of the graph
        closest_integers = closest_integer.(x)
        second_closest_integers = second_closest_integer.(x)

    #     # error_list = closest_integers # mod.(closest_integers, 2)
    #     edge_weight_list = (second_closest_integers.-x).^2 - (closest_integers.-x).^2

        edge_weight_list = []
        for (x0, n) in zip(x, closest_integers)
            ξ_bar = x0 - n
            a = get_a(ξ_bar * √π, √π, σ)
            b = get_b(ξ_bar * √π, √π, σ)
            cond_prob = b/(a+b)
            push!(edge_weight_list, log(1/cond_prob))
        end
    #     println("new edge_weight_list")


        # get edge_to_mode_dict
        num_vertices = length(stabilizers)+1
        keys_stabilizers = collect(keys(stabilizers))
        values_stabilizers = collect(values(stabilizers))
        edge_to_mode_dict = Dict()
        for qubit in 1 : length(x)
            qubit_in_stab = findall(qubit .∈ values_stabilizers) # determine which stabilizers the qubit is in
            if length(qubit_in_stab) > 2
                error("Cannot decode code where a single fault can lead more than 2 errors.")
            elseif length(qubit_in_stab) == 2
                vertex_1, vertex_2 = keys_stabilizers[qubit_in_stab[1]], keys_stabilizers[qubit_in_stab[2]]
            elseif length(qubit_in_stab) == 1
                vertex_1, vertex_2 = keys_stabilizers[qubit_in_stab[1]], num_vertices
            elseif length(qubit_in_stab) == 0 # ignore if the qubit is not in any stabilizers
                continue
            end

            if Set([vertex_1, vertex_2]) in keys(edge_to_mode_dict)
                if edge_weight_list[qubit] < edge_to_mode_dict[Set([vertex_1, vertex_2])][2]
                    edge_to_mode_dict[Set([vertex_1, vertex_2])] = (qubit, edge_weight_list[qubit])
                end
            else
                merge!(edge_to_mode_dict, Dict(Set([vertex_1, vertex_2]) => (qubit, edge_weight_list[qubit])))
            end
        end

        # println("edge_to_mode_dict = $edge_to_mode_dict")

        # highlight the unhappy stabilizers/vertices
        highlighted_vertices = zeros(Int, num_vertices)
        for (index, stabilizer) in stabilizers
            if mod(sum(closest_integers[stabilizer]), 2) == 1
                highlighted_vertices[index] = 1
            end
        end

        highlighted_vertices[num_vertices] = mod(sum(highlighted_vertices), 2)

        # Get the weights of the graph
        g = Matrix(0.0I, num_vertices, num_vertices)
        for (key, (qubit, weight)) in edge_to_mode_dict
            vertex_1, vertex_2 = collect(key)[1], collect(key)[2]
            g[vertex_1, vertex_2] = weight
        end
        g = g + transpose(g)

        paths = mwpm(g, highlighted_vertices)

        correction_list = zeros(Int, length(x))
        for (i, j) in paths
            correction_list[edge_to_mode_dict[Set([i,j])][1]] += 1
        end

        # Get the closest point 
        cp = [
            if (correction == 1) second_closest_integers[i] else closest_integers[i] end 
                for (i, correction) in enumerate(correction_list)
        ]    
        error_list = mod.(closest_integers, 2)
        final_list = correction_list - error_list

        return cp .* √π, final_list
    end    

    function decode_surface_code_non_exact(
        ξ::Vector, 
        surface_code_x_stabilizers::Dict{Int64, Vector{Int64}},
        surface_code_z_stabilizers::Dict{Int64, Vector{Int64}},
        σ::Float64
    )

        d = √(length(ξ)/2)
        if !isinteger(d)
            error("The length of ξ has to be 2d^2 for an integer d")
        end

        ξ_q = ξ[1:2:end]
        ξ_p = ξ[2:2:end]

        y_q, final_list_q = decode_stabilizer_GKP_code_non_exact(ξ_q, surface_code_z_stabilizers, σ)
        y_p, final_list_p = decode_stabilizer_GKP_code_non_exact(ξ_p, surface_code_x_stabilizers, σ)

        y = zeros(length(ξ))
        y[1:2:end], y[2:2:end] = y_q, y_p

        final_list = zeros(length(ξ))
        final_list[1:2:end], final_list[2:2:end] = final_list_q, final_list_p

        return y, final_list
    end
end

## Get the data for sigma = 0.4-0.8 with 1e6 samples

In [None]:
drange = 3:2:29
σrange = range(0.4, 0.8, 50)
num_samples = Int(1e6)

data_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
data_non_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
p0list_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
p0list_non_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
timelist_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
timelist_non_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
for (ind_σ, σ) in enumerate(σrange)
    for (ind_d, d) in enumerate(drange)
        println(["$(ind_σ)/$(length(σrange))", "$(ind_d)/$(length(drange))"])
        
        surface_code_z_logicals = surface_code_Z_logicals(d)
        surface_code_x_logicals = surface_code_X_logicals(d)
        surface_code_x_stabilizers = surface_code_X_stabilizers(d)
        surface_code_z_stabilizers = surface_code_Z_stabilizers(d)
        ξ_list = [σ * randn(2d^2) for i in 1:num_samples];
        
        counters = pmap(ξ_list) do ξ

            counter_I_exact = 0;
            counter_X_exact = 0;
            counter_Z_exact = 0;
            counter_Y_exact = 0;
            counter_I_non_exact = 0;
            counter_X_non_exact = 0;
            counter_Z_non_exact = 0;
            counter_Y_non_exact = 0;    

            time_exact = @elapsed final_list_exact = decode_surface_code(ξ, surface_code_x_stabilizers, surface_code_z_stabilizers)[2]
            time_non_exact = @elapsed final_list_non_exact = decode_surface_code_non_exact(ξ, surface_code_x_stabilizers, surface_code_z_stabilizers, σ)[2]

            # get nx, nz for exact
            final_list_q, final_list_p = final_list_exact[1:2:end], final_list_exact[2:2:end]
            nx = mod(sum(final_list_q[surface_code_z_logicals[1]]), 2)
            nz = mod(sum(final_list_p[surface_code_x_logicals[1]]), 2)

            if mod(nx, 2) == 0 && mod(nz, 2) == 0
                counter_I_exact += 1
            elseif mod(nx, 2) == 1 && mod(nz, 2) == 0
                counter_X_exact += 1
            elseif mod(nx, 2) == 0 && mod(nz, 2) == 1
                counter_Z_exact += 1
            elseif mod(nx, 2) == 1 && mod(nz, 2) == 1
                counter_Y_exact += 1
            end    

            # get nx, nz for non-exact
            final_list_q, final_list_p = final_list_non_exact[1:2:end], final_list_non_exact[2:2:end]
            nx = mod(sum(final_list_q[surface_code_z_logicals[1]]), 2)
            nz = mod(sum(final_list_p[surface_code_x_logicals[1]]), 2)

            if mod(nx, 2) == 0 && mod(nz, 2) == 0
                counter_I_non_exact += 1
            elseif mod(nx, 2) == 1 && mod(nz, 2) == 0
                counter_X_non_exact += 1
            elseif mod(nx, 2) == 0 && mod(nz, 2) == 1
                counter_Z_non_exact += 1
            elseif mod(nx, 2) == 1 && mod(nz, 2) == 1
                counter_Y_non_exact += 1
            end

            return [
                counter_I_exact,
                counter_X_exact,
                counter_Z_exact,
                counter_Y_exact,
                counter_I_non_exact,
                counter_X_non_exact,
                counter_Z_non_exact,
                counter_Y_non_exact,
                time_exact,
                time_non_exact
            ]

        end ;

        (p_I_exact, p_X_exact, p_Z_exact, p_Y_exact,
            p_I_non_exact, p_X_non_exact, p_Z_non_exact, p_Y_non_exact,time_exact,
                time_non_exact) = sum(counters)/num_samples

        p_list_non_exact = [p_I_non_exact, p_X_non_exact, p_Z_non_exact, p_Y_non_exact];
        p_list_exact = [p_I_exact, p_X_exact, p_Z_exact, p_Y_exact];

        
        push!(data_exact[σ], p_list_exact)
        push!(data_non_exact[σ], p_list_non_exact)
        push!(p0list_exact[σ], p_I_exact)
        push!(p0list_non_exact[σ], p_I_non_exact)
        push!(timelist_exact[σ], time_exact)
        push!(timelist_non_exact[σ], time_non_exact)        

        
    end
end


In [None]:
# # Uncomment this cell if you want to overwrite the data

# fn = "data/surface_codes/fidelity_time_surf_$(dmin)_$(dmax).jld2";
# jldsave(fn; 
#     σrange=σrange, 
#     drange=drange, 
#     num_samples=num_samples,
#     data_exact = data_exact,
#     data_non_exact = data_non_exact,
#     p0list_exact = p0list_exact,
#     p0list_non_exact = p0list_non_exact,
#     timelist_exact = timelist_exact,
#     timelist_non_exact=timelist_non_exact
# )

## Get the data for sigma = 0.5960-0.6070 with 1e7 samples

In [None]:
num_samples = 1e7
σrange = 0.5960 : 0.0010: 0.6070
drange = [31, 33]

data_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
data_non_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
p0list_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
p0list_non_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
timelist_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
timelist_non_exact = Dict(σrange.=>[[] for _ in 1 : length(σrange)])
for (ind_σ, σ) in enumerate(σrange)
    for (ind_d, d) in enumerate(drange)
        println(["$(ind_σ)/$(length(σrange))", "$(ind_d)/$(length(drange))"])
        
        surface_code_z_logicals = surface_code_Z_logicals(d)
        surface_code_x_logicals = surface_code_X_logicals(d)
        surface_code_x_stabilizers = surface_code_X_stabilizers(d)
        surface_code_z_stabilizers = surface_code_Z_stabilizers(d)
        ξ_list = [σ * randn(2d^2) for i in 1:num_samples];
        
        counters = pmap(ξ_list) do ξ

            counter_I_exact = 0;
            counter_X_exact = 0;
            counter_Z_exact = 0;
            counter_Y_exact = 0;
            counter_I_non_exact = 0;
            counter_X_non_exact = 0;
            counter_Z_non_exact = 0;
            counter_Y_non_exact = 0;    

            time_exact = @elapsed final_list_exact = decode_surface_code(ξ, surface_code_x_stabilizers, surface_code_z_stabilizers)[2]
            time_non_exact = @elapsed final_list_non_exact = decode_surface_code_non_exact(ξ, surface_code_x_stabilizers, surface_code_z_stabilizers, σ)[2]

            # get nx, nz for exact
            final_list_q, final_list_p = final_list_exact[1:2:end], final_list_exact[2:2:end]
            nx = mod(sum(final_list_q[surface_code_z_logicals[1]]), 2)
            nz = mod(sum(final_list_p[surface_code_x_logicals[1]]), 2)

            if mod(nx, 2) == 0 && mod(nz, 2) == 0
                counter_I_exact += 1
            elseif mod(nx, 2) == 1 && mod(nz, 2) == 0
                counter_X_exact += 1
            elseif mod(nx, 2) == 0 && mod(nz, 2) == 1
                counter_Z_exact += 1
            elseif mod(nx, 2) == 1 && mod(nz, 2) == 1
                counter_Y_exact += 1
            end    

            # get nx, nz for non-exact
            final_list_q, final_list_p = final_list_non_exact[1:2:end], final_list_non_exact[2:2:end]
            nx = mod(sum(final_list_q[surface_code_z_logicals[1]]), 2)
            nz = mod(sum(final_list_p[surface_code_x_logicals[1]]), 2)

            if mod(nx, 2) == 0 && mod(nz, 2) == 0
                counter_I_non_exact += 1
            elseif mod(nx, 2) == 1 && mod(nz, 2) == 0
                counter_X_non_exact += 1
            elseif mod(nx, 2) == 0 && mod(nz, 2) == 1
                counter_Z_non_exact += 1
            elseif mod(nx, 2) == 1 && mod(nz, 2) == 1
                counter_Y_non_exact += 1
            end

            return [
                counter_I_exact,
                counter_X_exact,
                counter_Z_exact,
                counter_Y_exact,
                counter_I_non_exact,
                counter_X_non_exact,
                counter_Z_non_exact,
                counter_Y_non_exact,
                time_exact,
                time_non_exact
            ]

        end ;

        (p_I_exact, p_X_exact, p_Z_exact, p_Y_exact,
            p_I_non_exact, p_X_non_exact, p_Z_non_exact, p_Y_non_exact,time_exact,
                time_non_exact) = sum(counters)/num_samples

        p_list_non_exact = [p_I_non_exact, p_X_non_exact, p_Z_non_exact, p_Y_non_exact];
        p_list_exact = [p_I_exact, p_X_exact, p_Z_exact, p_Y_exact];

        
        push!(data_exact[σ], p_list_exact)
        push!(data_non_exact[σ], p_list_non_exact)
        push!(p0list_exact[σ], p_I_exact)
        push!(p0list_non_exact[σ], p_I_non_exact)
        push!(timelist_exact[σ], time_exact)
        push!(timelist_non_exact[σ], time_non_exact)        

        
    end
end

In [None]:
# # Uncomment this cell if you want to overwrite the data

# fn = "data/surface_codes/fidelity_time_surf_$(dmin)_$(dmax)_{num_samples}.jld2";
# jldsave(fn; 
#     σrange=σrange, 
#     drange=drange, 
#     num_samples=num_samples,
#     data_exact = data_exact,
#     data_non_exact = data_non_exact,
#     p0list_exact = p0list_exact,
#     p0list_non_exact = p0list_non_exact,
#     timelist_exact = timelist_exact,
#     timelist_non_exact=timelist_non_exact
# )

## Get the runtime of the surface-GKP codes

In [None]:
num_samples = 1e2 # 1e4
drange_closest_point = [3, 5]
time_closest_point = pmap(drange_closest_point) do d
    M = surface_code_M(d)
    Mperp = GKP_logical_operator_generator(M)
    ξs = [0.6 * randn(2d^2) for i in 1:num_samples]
    
    t2 = @elapsed ys2 = [closest_point(ξ, √(2π) * Mperp) for ξ in ξs]
    
    return t2/num_samples
end

In [None]:
drange_decode_surface_code = 3 : 2 : 51 # 11, 21
time_decode_surface_code = pmap(drange_decode_surface_code) do d
    ξs = [0.6 * randn(2d^2) for i in 1:num_samples]
    
    surface_code_z_stabilizers, surface_code_x_stabilizers = surface_code_Z_stabilizers(d), surface_code_X_stabilizers(d)
    t1 = @elapsed ys1 = [decode_surface_code(ξ, surface_code_x_stabilizers, surface_code_z_stabilizers) for ξ in ξs]
    
    return t1/(num_samples)
end


In [None]:
# # Uncomment this cell if you want to overwrite the data

# fn = "data/surface_codes/runtime_surf.jld2";
# jldsave(fn; 
#     drange_decode_surface_code=drange_decode_surface_code, 
#     drange_closest_point = drange_closest_point,
#     time_decode_surface_code=time_decode_surface_code,
#     time_closest_point=time_closest_point,    
#     num_samples=num_samples)