In [97]:
using Random

In [257]:
global n = 5 #Number of cadits
global d = 3 #Dimension of cadits
;

In [258]:
#Get a distribution of W's
function getW(n, d)
    Wi = rand(Float64, (n,d)) #Generate weights
    Wi ./= sum(Wi, dims=2) #Normalize to distribution
    return Wi
end

#Get the full output distribution
function getFullDistribution(n, d, W)
    #Start with the n=1 marginal for i=1
    res = copy(W[1,:])
    
    for i = 2:n
        new_shape = tuple(d, size(res)...)
        new_res = zeros(new_shape)
        for x in CartesianIndices(res) #From each old event probability...
            p_prev = res[x]
            mod_sum = sum(Tuple(x))
            for j = 1:d # compute each new one
                new_sum = (mod_sum + (j-1)) % d + 1
                p_new = W[i, new_sum] * p_prev
    
                full_new_ind = CartesianIndex((Tuple(x)..., j))
                new_res[full_new_ind] = p_new
            end
        end
        res = new_res
    end
    return res
end

const MarginalId = Dict{Int, Int}

function getMarginalFromFull(n, d, dist, marginalId::MarginalId)
    coloned_index = Any[]
    for i=1:n
        if i ∈ keys(marginalId)
            push!(coloned_index, marginalId[i])
        else
            push!(coloned_index, :)
        end
    end
    return sum(getindex(dist, coloned_index...))
end

function getMarginalMarkov(n, d, W, marginalId::MarginalId)
    #Iterate along, tracking the marginal probability of the running sum 0...d-1.
    #Since we start with a sum of "0 many things", we start with the sum completely supported
    #on the index of modulus 0, which is the last one (d).
    sum_marginal = zeros(Float64, d)
    sum_marginal[d] = 1.0

    for i = 1:n
        new_sum_marginal = zeros(Float64, d)
        #If the index is specified by the marginal,
        if i ∈ keys(marginalId)
            #It's a single update probability
            v = marginalId[i]
            for sum_old = 1:d
                new_sum = (sum_old + (v-1)) % d + 1
                new_sum_marginal[new_sum] = W[i, new_sum] * sum_marginal[sum_old]
            end
        else
            #Otherwise, it becomes a mix over all the ways to propagate
            for v = 1:d, sum_old = 1:d
                new_sum = (sum_old + (v-1)) % d + 1
                new_sum_marginal[new_sum] += W[i, new_sum] * sum_marginal[sum_old]
            end
        end
        sum_marginal = new_sum_marginal
    end
    #Sum over all the possible final sums
    return sum(sum_marginal)
end

function computeExactXEBFromFull(n, d, dist)
    return length(dist) * sum(dist .^ 2) - 1
end

function computeExactXEBMarkov(n, d, W)
    prod = 1
    for i = 1:n
        prod *= sum(W[i, :] .^ 2)
    end
    return d^n * prod - 1
end

computeExactXEBMarkov (generic function with 1 method)

In [259]:
Random.seed!(13679)

Wi = getW(n,d)
dist = getFullDistribution(n, d, Wi)
;

In [260]:
@show getMarginalFromFull(n, d, dist, Dict(1=>1, 3=>2, 4=>3, 5=>1))
@show getMarginalMarkov(n, d, Wi, Dict(1=>1, 3=>2, 4=>3, 5=>1))
@show sum(dist[1, :, 2, 3, 1])
println()
@show getMarginalFromFull(n, d, dist, Dict(1=>1, 2=>2))
@show getMarginalMarkov(n, d, Wi, Dict(1=>1, 2=>2))
@show sum(dist[1, 2, :, :, :])
println()
@show getMarginalFromFull(n, d, dist, Dict{Int,Int}())
@show getMarginalMarkov(n, d, Wi, Dict{Int,Int}())
@show sum(dist)
;

getMarginalFromFull(n, d, dist, Dict(1 => 1, 3 => 2, 4 => 3, 5 => 1)) = 0.012676114234194853
getMarginalMarkov(n, d, Wi, Dict(1 => 1, 3 => 2, 4 => 3, 5 => 1)) = 0.012676114234194853
sum(dist[1, :, 2, 3, 1]) = 0.012676114234194853

getMarginalFromFull(n, d, dist, Dict(1 => 1, 2 => 2)) = 0.4185177366841969
getMarginalMarkov(n, d, Wi, Dict(1 => 1, 2 => 2)) = 0.4185177366841969
sum(dist[1, 2, :, :, :]) = 0.4185177366841969

getMarginalFromFull(n, d, dist, Dict{Int, Int}()) = 1.0
getMarginalMarkov(n, d, Wi, Dict{Int, Int}()) = 1.0
sum(dist) = 1.0


In [261]:
Random.seed!(54321)
n = 15
d = 2

Wi = getW(n,d)
dist = getFullDistribution(n, d, Wi)

@show getMarginalFromFull(n, d, dist, Dict(2=>1, 5=>1, 13=>1))
@show getMarginalMarkov(n, d, Wi, Dict(2=>1, 5=>1, 13=>1))
;

getMarginalFromFull(n, d, dist, Dict(2 => 1, 5 => 1, 13 => 1)) = 0.11855488883250775
getMarginalMarkov(n, d, Wi, Dict(2 => 1, 5 => 1, 13 => 1)) = 0.11855488883250775


In [262]:
@show computeExactXEBFromFull(n, d, dist)
@show computeExactXEBMarkov(n, d, Wi)
;

computeExactXEBFromFull(n, d, dist) = 9.543655574914713
computeExactXEBMarkov(n, d, Wi) = 9.543655574914713
