In [1]:
mutable struct DIIS
    iter::Int
    
    max_vecs::Int
    
    trial_vecs::Vector{Vector{Float64}}
    direction_vecs::Vector{Vector{Float64}}
    error_vecs::Vector{Vector{Float64}}
end

function DIIS(shape)
    num_vecs = 3
    vec_length = prod(shape)
    trial_vecs = [zeros(vec_length) for i in 1:num_vecs]
    direction_vecs = [zeros(vec_length) for i in 1:num_vecs]
    error_vecs = [zeros(vec_length) for i in 1:num_vecs]
    
    return DIIS(0, num_vecs, trial_vecs, direction_vecs, error_vecs)
end;

In [315]:
function compute_new_vector(mixer::DIIS, trial, direction, error)
    (; iter, max_vecs, trial_vecs, direction_vecs, error_vecs) = mixer
    
    index = iter % max_vecs + 1
    iter += 1
    mixer.iter = iter
    
    error_vecs[index] .= vec(error)
    trial_vecs[index] .= vec(trial)
    direction_vecs[index] .= vec(direction)
    
    # Setting up the equations
    B_dim = min(iter, max_vecs) + 1
    B = zeros(B_dim, B_dim)
    for i in 1:B_dim-1
        for j in 1:i
            B[i, j] = la.dot(error_vecs[i], error_vecs[j])
            if i != j
                B[j, i] = B[i, j]
            end
        end
        B[i, B_dim] = -1
        B[B_dim, i] = -1
    end
    
    pre_condition = zeros(B_dim)
    if any(i <= 0 for i in la.diag(B)[1:end-1])
        pre_condition[1:end-1] .= 1
    else
        pre_condition[1:end-1] .= 1 ./ sqrt.(la.diag(B)[1:end-1])
    end
    pre_condition[end] = 1
    
    for i in 1:B_dim
        for j in 1:B_dim
            B[i, j] *= pre_condition[i] * pre_condition[j]
        end
    end
    # Solving the equations
    weights = -la.pinv(B)[end, :] # final row
    weights .*= pre_condition
    
    # Using solution to get new vector
    new_trial = zero(trial_vecs[1])
    for i in 1:B_dim-1
        new_trial .+= weights[i] .* (trial_vecs[i] .+ direction_vecs[i])
    end
    
    return reshape(new_trial, size(trial))
end

compute_new_vector (generic function with 1 method)

In [316]:
import LinearAlgebra as la

In [317]:
n = 2
mixer = DIIS((n, n, n, n));

In [323]:
trial = (reshape([i + 1 for i in 0:n^4-1], (n, n, n, n)))
direction = (reshape([i + 2 for i in 0:n^4-1], (n, n, n, n)))
error = (reshape([i + 3 for i in 0:n^4-1], (n, n, n, n)));

trial = permutedims(trial, (2, 1, 3, 4))
direction = permutedims(direction, (2, 1, 3, 4))
error = permutedims(error, (2, 1, 3, 4));

In [320]:
trial = (reshape([2i + 5 for i in 0:n^4-1], (n, n, n, n)))
direction = (reshape([i - 2 for i in 0:n^4-1], (n, n, n, n)))
error = (reshape([i for i in 0:n^4-1], (n, n, n, n)));

trial = permutedims(trial, (2, 1, 3, 4))
direction = permutedims(direction, (2, 1, 3, 4))
error = permutedims(error, (2, 1, 3, 4));

In [326]:
compute_new_vector(mixer, trial, direction, error)

4

2×2×2×2 Array{Float64, 4}:
[:, :, 1, 1] =
 3.0  5.0
 7.0  9.0

[:, :, 2, 1] =
 11.0  13.0
 15.0  17.0

[:, :, 1, 2] =
 19.0  21.0
 23.0  25.0

[:, :, 2, 2] =
 27.0  29.0
 31.0  33.0

In [332]:
function testys(a, b)
    print(a+b)
end

LoadError: syntax: optional positional arguments must occur at end around In[332]:1

In [331]:
testy(4, 8)

12