In [1]:
using LinearAlgebra, Distributions, Random, Optim, LineSearches, Statistics
using CSV, DataFrames, DelimitedFiles

In [2]:
# beta is (1, 2, 3, ... p)ᵀ, given in paper.
data = CSV.read("../data/lin_reg.csv", DataFrame);
data = Matrix(data);

In [3]:
const X = data[:, 1:end - 1];
const y = data[:, end];
const 𝒩 = Normal();

In [4]:
n, p = size(X)

(10000, 1000)

In [5]:
function batch_shuffle(X, y, mb_size)
    n = size(X)[1]
    perm = randperm(n)[1:mb_size]
    X_mini = @view X[perm, :]
    y_mini = @view y[perm, :]
    return X_mini, y_mini
end

batch_shuffle (generic function with 1 method)

In [6]:
function F(β, X, y)
    res = y .- X * β
    G = -X' * res
    l = sum(abs2, res)
    
    return l, G
end

F (generic function with 1 method)

In [7]:
function backtracking(F, d, β, X, y, r = 0.5, c= 1e-4, nmax=100)
    
    #https://en.wikipedia.org/wiki/Backtracking_line_search
    α = 1
    fᵢ, gᵢ = F(β, X, y)
    βₛ = β + α * d
    fₛ, gₛ = F(βₛ, X, y)
    n = 1
    
    while fₛ > fᵢ + c * α * dot(gᵢ, d) && n < nmax
        n = n + 1
        α = α * r
        βₛ = β + α * d
        fₛ, gₛ = F(βₛ, X, y)
    end
    
    return α, fₛ, gₛ
end

backtracking (generic function with 4 methods)

In [8]:
function approxInvHess(g, S, Y, H₀, global_iter)
    
    #https://en.wikipedia.org/wiki/Limited-memory_BFGS
    n, m = size(S)
    ρ = zeros(m)
    
    upper = global_iter
    lower = global_iter - m + 2
    @inbounds for index in lower:upper
        if index < 2
            continue
        end
        i = mod1(index, m)
        ρ[i] = abs(1 / dot(Y[:, i], S[:, i]))
    end

    q = zeros(n)
    α = zeros(m)
    d = zeros(n)
    β = zero(eltype(ρ))

    copyto!(q, g)
    
    upper = global_iter
    lower = global_iter - m + 1
    @inbounds for index in upper:-1:lower
        if index < 1
            continue
        end
        
        i = mod1(index, m)
        α[i] = ρ[i] * dot(S[:, i], q)
        @. q -= α[i] * Y[:, i]
    end

    d = H₀ * q
    
    @inbounds for index in lower:upper
        if index < 1
            continue
        end
        i = mod1(index, m)
        β = ρ[i] * dot(Y[:, i], d)
        @. d = d + S[:, i] * (α[i] - β)
    end
    
    return d
end

approxInvHess (generic function with 1 method)

In [9]:
function loss_vector(β, X, y)
    # returns the loss for each data point in X instead of summing over whole dataset
    res = (y .- X * β)
    l = res .^ 2
    return l
end

loss_vector (generic function with 1 method)

In [27]:
function stop(β, βₛ, X, y, α = 0.95)
    # calculates if we get a decrease with CI α.
    
    z = quantile(𝒩, (1 + α)/2)
    
    f = loss_vector(β, X, y)
    g = loss_vector(βₛ, X, y)
    mean_f = mean(f)
    mean_g = mean(g)
    
    var_diff = var(f) + var(g) - 2 * cov(f, g)
    std_diff = sqrt(var_diff)
    
    diff = mean_f - mean_g
    
    suff_dec = diff - z * std_diff
    println("z: $(z), diff: $(diff), std_diff: $(std_diff), suff_dec: $(suff_dec)")
    if suff_dec > 0
        return true
    else
        return false
    end
end

stop (generic function with 2 methods)

In [28]:
function lbfgs!(F, βᵢ, X, y, Sₘ, Yₘ, global_iter, maxIt, m, τgrad = 1e-8, verbose = 0)
    
    local_iter = 0
    n = length(βᵢ)
    d = zeros(n)
    fᵢ, gᵢ = F(βᵢ, X, y)
    
    # use the simplest line search to find step size
    α, fₛ, gₛ = backtracking(F, -gᵢ, βᵢ, X, y)
    βₛ = βᵢ - α * gᵢ
    
    # counter
    local_iter = 1

    while true
        
        if local_iter > maxIt
            break; 
        end
        
        gnorm = norm(gᵢ)
        
        if gnorm < τgrad
            
            break; 
        end
        
        s₀ = βₛ - βᵢ
        y₀ = gₛ - gᵢ
        
        H₀ = dot(s₀, y₀)/ dot(y₀, y₀) # hessian diagonal satisfying secant condition
        i = mod1(global_iter, m)
        
        Sₘ[:, i] .= s₀
        Yₘ[:, i] .= y₀
        d = -approxInvHess(gₛ, Sₘ, Yₘ, H₀, global_iter)
            
        # new direction=p, find new step size
        α, fs, gs = backtracking(F, d, βₛ, X, y)
        
        # update for next iteration
        βᵢ .= βₛ
        gᵢ .= gₛ
        βₛ .= βₛ + α .* d
        stop(βᵢ, βₛ, X, y)
        fₛ = fs
        gₛ = gs
        local_iter = local_iter + 1
        global_iter = global_iter + 1

        if verbose == 1 
            println("Iteration: $local_iter -- loss: $fₛ gradnorm: $(norm(gᵢ)) ssize: $α")
        end
    end
    
    local_iter = local_iter - 1
    return βₛ, fₛ, local_iter, global_iter
end

lbfgs! (generic function with 3 methods)

In [29]:
m = 20
Sₘ = zeros(p, m)
Yₘ = zeros(p, m);
global_iter = 1

1

In [30]:
lbfgs!(F, zeros(p), X, y, Sₘ, Yₘ, global_iter, 45, m, 10, 1)

z: 1.6448536269514717, diff: 9.065942199443707e7, std_diff: 1.3400870494215952e8, suff_dec: -1.2976528237274364e8
Iteration: 2 -- loss: 3.996104621139921e10 gradnorm: 1.154875207226343e8 ssize: 1
z: 1.6448536269514717, diff: 3.5616079086442366e6, std_diff: 5.494923872394696e6, suff_dec: -5.476737552686404e6
Iteration: 3 -- loss: 4.344967124956835e9 gradnorm: 2.307758053626485e7 ssize: 1
z: 1.6448536269514717, diff: 309485.6055967433, std_diff: 487008.8202430072, suff_dec: -491572.6187373245
Iteration: 4 -- loss: 1.2501110689894018e9 gradnorm: 5.214471879254262e6 ssize: 1
z: 1.6448536269514717, diff: 120679.54900001573, std_diff: 181682.2150847075, suff_dec: -178161.1014346428
Iteration: 5 -- loss: 4.33155789892445e7 gradnorm: 2.7404527457160526e6 ssize: 1
z: 1.6448536269514717, diff: 112.83740850430513, std_diff: 6498.6915853059445, suff_dec: -10576.559016025187
Iteration: 6 -- loss: 4.218720490420143e7 gradnorm: 738511.9742515993 ssize: 1
z: 1.6448536269514717, diff: 4105.894681685081

([0.9999915679830766, 1.9999896321114272, 2.999986656076358, 3.999987809034542, 4.9999928744199424, 5.999983326537657, 7.000001629521358, 8.000009935585632, 8.999994714539213, 9.999996715921549  …  991.0000003148443, 992.0000069050113, 993.0000038082152, 993.9999908327774, 994.9999957442349, 995.9999963467276, 996.9999984415899, 997.9999934050252, 998.9999954328889, 999.9999863828688], 0.00041919056372776743, 16, 17)