In [1]:
# INPUT:
#     nIter = Number of iterations
#     covs = d × d × n array containing covariance matrices
#     sqrt_covs = d × d × n array containing square roots of matrices in covs
#     objective = length nIter vector to save barycenter objective values in
#     times = length nIter vector to save timings in
#     η = step size
#     distances = length nIter vector to save (W₂)² distance to (sqrt_best)² over training
#     sqrt_best = square root of a d × d matrix that we calculate distances to throughout training (ideally taken to be the true barycenter)
# OUTPUT:
#     square root of d × d matrix that achieves best barycenter functional throughout training
function GD!(nIter, covs, sqrt_covs, objective, times, η, distances, sqrt_best)
    start = time()
    d = size(covs)[1]
    n = size(covs)[3]
    X = zeros(d, d)
    X .= covs[:,:,1]
    
    # Cache variables for memory efficiency. T refers to transport map
    T = zeros(d,d)
    evals = zeros(d)
    evecs = zeros(d,d)
    
    bestval = Inf
    candidate_best = zeros(d,d)
    

    for i in 1:nIter
        T .= zeros(d,d)

        for j in 1:n
            sq = @view sqrt_covs[:,:,j]

            e = eigen(Symmetric(sq*X*sq))
            evals .= e.values
            evecs .= e.vectors
            objective[i] += tr(covs[:,:,j].-2*evecs*diagm(evals.^(0.5))*evecs')

            T .= T .+ sq*evecs*diagm(evals.^(-0.5))*evecs'*sq
        end
        
        objective[i] = objective[i]/n + tr(X)
        if objective[i] < bestval
            candidate_best .= X
            bestval = objective[i]
        end
        
        T .= T./n
        distances[i] = bures(sqrt_best, X)
        X .= Symmetric(((1 - η).*I(d) .+ η.*T)*X*((1-η).*I(d) + η.*T))
        times[i] = time()-start
    end
    return candidate_best^.5
end

# INPUT:
#     nIter = Number of iterations
#     covs = d × d × n array containing covariance matrices
#     sqrt_covs = d × d × n array containing square roots of matrices in covs
#     objective = length nIter vector to save barycenter objective values in
#     times = length nIter vector to save timings in
#     η = step size
#     α = lower eigenvalue to threshold at (should be ∼ average minimum eigenvalue of covs)
#     β = upper eigenvalue to threshold at (should be ∼ average maximum eigenvalue of covs)
#     distances = length nIter vector to save (W₂)² distance to (sqrt_best)² over training
#     sqrt_best = square root of a d × d matrix that we calculate distances to throughout training (ideally taken to be the true barycenter)
function EGD!(nIter, covs, sqrt_covs, objective, times, η, α, β, distances, sqrt_best)
    start = time()
    d = size(covs)[1]
    n = size(covs)[3]
    
    X = zeros(d, d)
    X .= covs[:,:,1]
    
    # Cache variables for memory efficiency. T refers to transport map
    T = zeros(d,d)
    evals = zeros(d)
    evecs = zeros(d,d)
    
    
    for i in 1:nIter
        
        T .= zeros(d,d)

        for j in 1:n
            sq = @view sqrt_covs[:,:,j]
            
            e = eigen(Symmetric(sq*X*sq))
            evals .= e.values
            evecs .= e.vectors
            objective[i] += tr(covs[:,:,j].-2*evecs*diagm(evals.^(0.5))*evecs')

            T .= T .+ sq*evecs*diagm(evals.^(-0.5))*evecs'*sq
        end
    
        objective[i] = objective[i]/n + tr(X)
        times[i] = time()-start
        distances[i] = bures(sqrt_best, X)
        
        T .= T./n
        X .= Symmetric(X .- η.*(I(d) .- T))
        
        clip!(X, α, β)
    end
end

# INPUT
#     covs = d × d × n array containing covariance matrices
#     sqrt_covs = d × d × n array containing square roots of matrices in covs
#     X = starting covariance matrix 
#     objective = length nIter vector to save barycenter objective values in
#     times = length nIter vector to save timings in
#     ηs = array of length n of stepsizes, or single number (in which case that step size is used for all steps)
#     sqrt_bary = square root of true barycenter
function SGD!(covs, sqrt_covs, X, objective, times, ηs; sqrt_bary = nothing)
    start = time()
    d = size(covs)[1]
    n = size(covs)[3]
    if isnothing(sqrt_bary)
        sqrt_bary = I(d)
    end
    
    # Cache variables for memory efficiency. T refers to transport map
    T = zeros(d,d)
    evals = zeros(d)
    evecs = zeros(d,d)

    for i in 1:n
        if length(ηs) == 1
            η = ηs[1] 
        else
            η = ηs[i]
        end
        sq = @view sqrt_covs[:,:,i]
        e = eigen(Symmetric(sq*X*sq))
        evals .= e.values
        evecs .= e.vectors
        T .= sq*evecs*diagm(evals.^(-0.5))*evecs'*sq
        times[i] = time()-start
        objective[i] = bures(sqrt_bary, X)
        X .= Symmetric(((1 - η).*I(d) .+ η.*T)*X*((1-η).*I(d) + η.*T))
    end
end


# INPUT
#     covs = d × d × n array containing covariance matrices
#     sqrt_covs = d × d × n array containing square roots of matrices in covs
#     objective = length nIter vector to save barycenter objective values in
#     times = length nIter vector to save timings in
#     ηs = array of length n of stepsizes, or single number (in which case that step size is used for all steps)
#     α = lower eigenvalue to threshold at (should be ∼ average minimum eigenvalue of covs)
#     β = upper eigenvalue to threshold at (should be ∼ average maximum eigenvalue of covs)
function ESGD!(covs, sqrt_covs, objective, times, ηs, α, β)
    start = time()
    d = size(covs)[1]
    n = size(covs)[3]
    
    X = zeros(d, d)
    X .= covs[:,:,1]
    
    # Cache variables for memory efficiency. T refers to transport map
    T = zeros(d,d)
    evals = zeros(d)
    evecs = zeros(d,d)
    
    
    for i in 1:n
        if length(ηs) == 1
            η = ηs[1] 
        else
            η = ηs[i]
        end
        sq = @view sqrt_covs[:,:,i]
        e = eigen(Symmetric(sq*X*sq))
        evals .= e.values
        evecs .= e.vectors
        T .= sq*evecs*diagm(evals.^(-0.5))*evecs'*sq
        times[i] = time()-start
        objective[i] = bures(I(d), X)
        X .= Symmetric(X .- η.*(I(d) .- T))
        
        clip!(X, α, β)
    end
end

# INPUT
#     covs = d × d × n array containing covariance matrices
#     verbose = boolean indicating whether SDP solver should be verbose
#     maxIter = maximum number of iterations
# OUTPUT
#     the barycenter (d × d matrix)
function SDP(covs; verbose=false, maxIter = 5000)
    d = size(covs)[1]
    n = size(covs)[3]
    
    Σ = Variable(d,d)
    Ss = [Variable(d,d) for _ in 1:n]
    constr = [([covs[:,:,i] Ss[i]; Ss[i]' Σ] ⪰ 0) for i in 1:n]
    problem = minimize(tr(Σ) - 2*mean(tr.(Ss)))
    problem.constraints += constr
    problem.constraints += (Σ ⪰ 0)
    optimizer = SCS.Optimizer(verbose = verbose)
    MOI.set(optimizer, MOI.RawParameter("max_iters"), maxIter)
    solve!(problem, optimizer)
    return Σ.value
end


# Function that clips eigenvalues of X to specified range, in place
function clip!(X, α, β)
    e = eigen(X)
    X .= Symmetric(e.vectors*diagm(clamp.(e.values, α, β))*e.vectors')
end

# Calculates (W₂)² between (sq)^2 and x
function bures(sq,x)
    e = eigen(sq*x*sq)
    return tr(x+sq*sq- 2 .* e.vectors*diagm(e.values.^(.5))*e.vectors')
end

# Calculates barycenter functional of X over the dataset [sqrt_covs[:,:,i]² for i in 1:n]
function barycenter_functional(sqrt_covs, X)
    return mean(bures(sqrt_covs[:,:,i], X) for i in 1:size(sqrt_covs)[3])
end

barycenter_functional (generic function with 1 method)