In [2]:
using LinearAlgebra, Distributions, Random, Optim, LineSearches, Statistics
using CSV, DataFrames, DelimitedFiles

In [3]:
# beta is (1, 2, 3, ... p)ᵀ, given in paper.
data = CSV.read("../data/lin_reg.csv", DataFrame);
data = Matrix(data);

In [4]:
const X = data[:, 1:end - 1];
const y = data[:, end];
const 𝒩 = Normal();

In [5]:
n, p = size(X)

(10000, 1000)

In [6]:
function batch_shuffle(X, y, mb_size)
    n = size(X)[1]
    perm = randperm(n)[1:mb_size]
    X_mini = @view X[perm, :]
    y_mini = @view y[perm, :]
    return X_mini, y_mini
end

batch_shuffle (generic function with 1 method)

In [7]:
function F(β, X, y)
    n = length(y)
    res = y .- X * β
    G = -X' * res
    l = sum(abs2, res) / n
    G = G ./ n
    
    return l, G
end

F (generic function with 1 method)

In [8]:
function backtracking(F, d, β, X, y, r = 0.5, c= 1e-4, nmax=100)
    
    #https://en.wikipedia.org/wiki/Backtracking_line_search
    α = 1
    fᵢ, gᵢ = F(β, X, y)
    βₛ = β + α * d
    fₛ, gₛ = F(βₛ, X, y)
    n = 1
    
    while fₛ > fᵢ + c * α * dot(gᵢ, d) && n < nmax
        n = n + 1
        α = α * r
        βₛ = β + α * d
        fₛ, gₛ = F(βₛ, X, y)
    end
    
    return α, fₛ, gₛ
end

backtracking (generic function with 4 methods)

In [9]:
function approxInvHess(g, S, Y, H₀, global_iter)
    
    #https://en.wikipedia.org/wiki/Limited-memory_BFGS
    n, m = size(S)
    ρ = zeros(m)
    
    upper = global_iter
    lower = global_iter - m + 2
    @inbounds for index in lower:upper
        if index < 2
            continue
        end
        i = mod1(index, m)
        ρ[i] = abs(1 / dot(Y[:, i], S[:, i]))
    end

    q = zeros(n)
    α = zeros(m)
    d = zeros(n)
    β = zero(eltype(ρ))

    copyto!(q, g)
    
    upper = global_iter
    lower = global_iter - m + 1
    @inbounds for index in upper:-1:lower
        if index < 1
            continue
        end
        
        i = mod1(index, m)
        α[i] = ρ[i] * dot(S[:, i], q)
        @. q -= α[i] * Y[:, i]
    end

    d = H₀ * q
    
    @inbounds for index in lower:upper
        if index < 1
            continue
        end
        i = mod1(index, m)
        β = ρ[i] * dot(Y[:, i], d)
        @. d = d + S[:, i] * (α[i] - β)
    end
    
    return d
end

approxInvHess (generic function with 1 method)

In [10]:
function loss_vector(β, X, y)
    # returns the loss for each data point in X instead of summing over whole dataset
    res = (y .- X * β)
    l = res .^ 2
    return l
end

loss_vector (generic function with 1 method)

In [11]:
function stop(β, βₛ, X, y, α = 0.95)
    # calculates if we get a decrease with CI α.
    
    z = quantile(𝒩, (1 + α)/2)
    n = length(y)
    
    f = loss_vector(β, X, y)
    g = loss_vector(βₛ, X, y)
    mean_f = mean(f)
    mean_g = mean(g)
    
    var_diff = (var(f) + var(g) - 2 * cov(f, g)) / n
    std_diff = sqrt(var_diff)
    
    diff = mean_f - mean_g
    
    suff_dec = diff - z * std_diff
    println("z: $(z), diff: $(diff), std_diff: $(std_diff), suff_dec: $(suff_dec)")
    if suff_dec > 0
        return true
    else
        return false
    end
end

stop (generic function with 2 methods)

In [12]:
function lbfgs!(F, βᵢ, X, y, Sₘ, Yₘ, global_iter, maxIt, m, τgrad = 1e-8, verbose = 0)
    
    local_iter = 0
    n = length(βᵢ)
    d = zeros(n)
    fᵢ, gᵢ = F(βᵢ, X, y)
    
    # use the simplest line search to find step size
    α, fₛ, gₛ = backtracking(F, -gᵢ, βᵢ, X, y)
    βₛ = βᵢ - α * gᵢ
    
    # counter
    local_iter = 1

    while true
        
        if local_iter > maxIt
            break; 
        end
        
        gnorm = norm(gᵢ)
        
        if gnorm < τgrad
            
            break; 
        end
        
        s₀ = βₛ - βᵢ
        y₀ = gₛ - gᵢ
        
        H₀ = dot(s₀, y₀)/ dot(y₀, y₀) # hessian diagonal satisfying secant condition
        i = mod1(global_iter, m)
        
        Sₘ[:, i] .= s₀
        Yₘ[:, i] .= y₀
        d = -approxInvHess(gₛ, Sₘ, Yₘ, H₀, global_iter)
            
        # new direction=p, find new step size
        α, fs, gs = backtracking(F, d, βₛ, X, y)
        
        # update for next iteration
        βᵢ .= βₛ
        gᵢ .= gₛ
        βₛ .= βₛ + α .* d
        stop(βᵢ, βₛ, X, y)
        fₛ = fs
        gₛ = gs
        local_iter = local_iter + 1
        global_iter = global_iter + 1

        if verbose == 1
            println("Iteration: $local_iter -- loss: $fₛ gradnorm: $(norm(gᵢ)) ssize: $α")
        end
    end
    
    local_iter = local_iter - 1
    return βₛ, fₛ, local_iter, global_iter
end

lbfgs! (generic function with 3 methods)

In [16]:
m = 20
Sₘ = zeros(p, m)
Yₘ = zeros(p, m);
global_iter = 1

1

In [17]:
lbfgs!(F, zeros(p), X, y, Sₘ, Yₘ, global_iter, 45, m, 1e-6, 1)

z: 1.9599639845400576, diff: 3.438133903492506e7, std_diff: 537171.9546652769, suff_dec: 3.332850135027613e7
Iteration: 2 -- loss: 3.3488791470233947e6 gradnorm: 6981.805826161006 ssize: 1
z: 1.9599639845400576, diff: 2.5162279333867105e6, std_diff: 38962.53331605576, suff_dec: 2.439862771340799e6
Iteration: 3 -- loss: 832651.2136366842 gradnorm: 1779.2021469313897 ssize: 1
z: 1.9599639845400576, diff: 778825.7546616547, std_diff: 11633.90627062234, suff_dec: 756023.7173717201
Iteration: 4 -- loss: 53825.45897502951 gradnorm: 728.59967045719 ssize: 1
z: 1.9599639845400576, diff: 49625.42762736646, std_diff: 773.9400202538618, suff_dec: 48108.533061474685
Iteration: 5 -- loss: 4200.031347663051 gradnorm: 198.42819546805657 ssize: 1
z: 1.9599639845400576, diff: 2701.0388705671785, std_diff: 53.59314002221587, suff_dec: 2595.998246305223
Iteration: 6 -- loss: 1498.9924770958726 gradnorm: 72.83594985932504 ssize: 1
z: 1.9599639845400576, diff: 1414.554267703629, std_diff: 21.11934240593261

([0.9999999958734821, 1.9999999961868158, 2.9999999973956446, 3.999999998454709, 5.000000001790283, 5.99999999974061, 6.999999998780582, 7.999999999433447, 8.999999997086208, 9.999999994988105  …  991.000000003319, 991.9999999974096, 992.9999999976799, 994.0000000024258, 994.9999999983291, 995.9999999993394, 996.9999999964253, 997.9999999944831, 998.9999999997395, 1000.0000000022836], 6.077923134517591e-15, 22, 23)