In [2]:
function GD!(nIter, covs, sqrt_covs, X, objective, times, η)
    start = time()
    d = size(covs)[1]
    n = size(covs)[3]
    # Cache variables for memory efficiency. T refers to transport map
    T = zeros(d,d)
    evals = zeros(d)
    evecs = zeros(d,d)
    

    for i in 1:nIter
        T .= zeros(d,d)

        for j in 1:n
            sq = @view sqrt_covs[:,:,j]

            e = eigen(Symmetric(sq*X*sq))
            evals .= e.values
            evecs .= e.vectors
            objective[i] += tr(covs[:,:,j].-2*evecs*diagm(evals.^(0.5))*evecs')

            T .= T .+ sq*evecs*diagm(evals.^(-0.5))*evecs'*sq
        end
        
        objective[i] = objective[i]/n + tr(X)
        
        T .= T./n

        X .= Symmetric(((1 - η).*I(d) .+ η.*T)*X*((1-η).*I(d) + η.*T))
        times[i] = time()-start
    end
end

function EGD!(nIter, covs, sqrt_covs, X, objective, times, η, α, β)
    start = time()
    d = size(covs)[1]
    n = size(covs)[3]
    # Cache variables for memory efficiency. T refers to transport map
    T = zeros(d,d)
    evals = zeros(d)
    evecs = zeros(d,d)
    
    
    for i in 1:nIter
        
        T .= zeros(d,d)

        for j in 1:n
            sq = @view sqrt_covs[:,:,j]
            
            e = eigen(Symmetric(sq*X*sq))
            evals .= e.values
            evecs .= e.vectors
            objective[i] += tr(covs[:,:,j].-2*evecs*diagm(evals.^(0.5))*evecs')

            T .= T .+ sq*evecs*diagm(evals.^(-0.5))*evecs'*sq
        end
    
        objective[i] = objective[i]/n + tr(X)
        times[i] = time()-start
        
        T .= T./n
        X .= Symmetric(X .- η.*(I(d) .- T))
        
        clip!(X, α, β)
    end
end

function SDP!(covs; verbose=false, maxIter = 5000)
    d = size(covs)[1]
    n = size(covs)[3]
    
    Σ = Variable(d,d)
    Ss = [Variable(d,d) for _ in 1:n]
    constr = [([covs[:,:,i] Ss[i]; Ss[i]' Σ] ⪰ 0) for i in 1:n]
    problem = minimize(tr(Σ) - 2*mean(tr.(Ss)))
    problem.constraints += constr
    problem.constraints += (Σ ⪰ 0)
    optimizer = SCS.Optimizer(verbose = verbose)
    MOI.set(optimizer, MOI.RawParameter("max_iters"), maxIter)
    solve!(problem, optimizer)
    return Σ.value
end


# Function that clips eigenvalues of X to specified range, in place
function clip!(X, α, β)
    e = eigen(X)
    X .= e.vectors*diagm(clamp.(e.values, α, β))*e.vectors'
end

# Calculates squared wasserstein distance between two centered Gaussians
function bures(sq,x)
    e = eigen(sq*x*sq)
    return tr(x+sq*sq- 2 .* e.vectors*diagm(e.values.^(.5))*e.vectors')
end

# bary = I(d)
# for i in 1:n
#     A = 0.05*randn(d,d)
#     A = (A + A')/2
#     X = bary + A + lyapc(bary, A)*bary*lyapc(bary, A)
#     covs[:,:,i] .= X
#     e = eigen(X)
#     sqrt_covs[:,:,i] = e.vectors*diagm(e.values.^(0.5))*e.vectors'
# end

function barycenter_functional(sqrt_covs, X)
    return mean(bures(sqrt_covs[:,:,i], X) for i in 1:size(sqrt_covs)[3])
end

barycenter_functional (generic function with 1 method)