In [1]:
using LinearAlgebra, Distributions, Random, Optim, LineSearches, Statistics
using CSV, DataFrames, DelimitedFiles

In [2]:
# beta is (1, 2, 3, ... p)ᵀ, given in paper.
data = CSV.read("../data/lin_reg.csv", DataFrame);
data = Matrix(data);

In [3]:
const X = data[:, 1:end - 1];
const y = data[:, end];
const 𝒩 = Normal();

In [4]:
n, p = size(X)

(10000, 1000)

In [5]:
function batch_shuffle(X, y, mb_size)
    n = size(X)[1]
    perm = randperm(n)[1:mb_size]
    X_mini = @view X[perm, :]
    y_mini = @view y[perm, :]
    return X_mini, y_mini
end

batch_shuffle (generic function with 1 method)

In [6]:
function F(β, X, y)
    n = length(y)
    res = y .- X * β
    G = -X' * res
    l = sum(abs2, res) / n
    G = G ./ n
    
    return l, G
end

F (generic function with 1 method)

In [7]:
function backtracking(F, d, β, X, y, r = 0.5, c= 1e-4, nmax=100)
    
    #https://en.wikipedia.org/wiki/Backtracking_line_search
    α = 1
    fᵢ, gᵢ = F(β, X, y)
    βₛ = β + α * d
    fₛ, gₛ = F(βₛ, X, y)
    n = 1
    
    while fₛ > fᵢ + c * α * dot(gᵢ, d) && n < nmax
        n = n + 1
        α = α * r
        βₛ = β + α * d
        fₛ, gₛ = F(βₛ, X, y)
    end
    
    return α, fₛ, gₛ
end

backtracking (generic function with 4 methods)

In [8]:
function approxInvHess(g, S, Y, H₀, global_iter)
    
    #https://en.wikipedia.org/wiki/Limited-memory_BFGS
    n, m = size(S)
    ρ = zeros(m)
    
    upper = global_iter
    lower = global_iter - m + 2
    @inbounds for index in lower:upper
        if index < 2
            continue
        end
        i = mod1(index, m)
        ρ[i] = abs(1 / dot(Y[:, i], S[:, i]))
    end

    q = zeros(n)
    α = zeros(m)
    d = zeros(n)
    β = zero(eltype(ρ))

    copyto!(q, g)
    
    upper = global_iter
    lower = global_iter - m + 1
    @inbounds for index in upper:-1:lower
        if index < 1
            continue
        end
        
        i = mod1(index, m)
        α[i] = ρ[i] * dot(S[:, i], q)
        @. q -= α[i] * Y[:, i]
    end

    d = H₀ * q
    
    @inbounds for index in lower:upper
        if index < 1
            continue
        end
        i = mod1(index, m)
        β = ρ[i] * dot(Y[:, i], d)
        @. d = d + S[:, i] * (α[i] - β)
    end
    
    return d
end

approxInvHess (generic function with 1 method)

In [9]:
function loss_vector(β, X, y)
    # returns the loss for each data point in X instead of summing over whole dataset
    res = (y .- X * β)
    l = vec(res .^ 2)
    return l
end

loss_vector (generic function with 1 method)

In [10]:
function stop(β, βₛ, X, y, α = 0.95)
    # calculates if we get a decrease with CI α.
    
    z = quantile(𝒩, (1 + α)/2)
    n = length(y)
    
    f = loss_vector(β, X, y)
    g = loss_vector(βₛ, X, y)
    mean_f = mean(f)
    mean_g = mean(g)
#     println("$(size(f)) $(typeof(β))")
    
    var_diff = (var(f) + var(g) - 2 * cov(f, g)) / n
    std_diff = sqrt(var_diff)
    
    diff = mean_f - mean_g
    
    suff_dec = diff - z * std_diff
    println("z: $(z), diff: $(diff), std_diff: $(std_diff), suff_dec: $(suff_dec)")
    if suff_dec < 0
        return true
    else
        return false
    end
end

stop (generic function with 2 methods)

In [11]:
function lbfgs!(F, βᵢ, X, y, Sₘ, Yₘ, global_iter, maxIt, m, τgrad = 1e-8, verbose = 0)
    
    local_iter = 0
    n = length(βᵢ)
    d = zeros(n)
    fᵢ, gᵢ = F(βᵢ, X, y)
    
    # use the simplest line search to find step size
    α, fₛ, gₛ = backtracking(F, -gᵢ, βᵢ, X, y)
    βₛ = βᵢ - α * gᵢ
    
    # counter
    local_iter = 1

    while true
        
        gnorm = norm(gᵢ)
    
        if local_iter > maxIt
            break; 
        end
        
        if gnorm < τgrad 
            break; 
        end
        
        if stop(βᵢ, βₛ, X, y)
            break;
        end
        
        s₀ = βₛ - βᵢ
        y₀ = gₛ - gᵢ
        
        H₀ = dot(s₀, y₀)/ dot(y₀, y₀) # hessian diagonal satisfying secant condition
        i = mod1(global_iter, m)
        
        Sₘ[:, i] .= s₀
        Yₘ[:, i] .= y₀
        d = -approxInvHess(gₛ, Sₘ, Yₘ, H₀, global_iter)
            
        # new direction=p, find new step size
        α, fs, gs = backtracking(F, d, βₛ, X, y)
        
        # update for next iteration
        βᵢ .= βₛ
        gᵢ .= gₛ
        βₛ .= βₛ + α .* d
        fₛ = fs
        gₛ = gs
        local_iter = local_iter + 1
        global_iter = global_iter + 1

        if verbose == 1
            println("Iteration: $local_iter -- loss: $fₛ gradnorm: $(norm(gᵢ)) ssize: $α")
        end
        


    end
    
    local_iter = local_iter - 1
    return βₛ, fₛ, local_iter, global_iter
end

lbfgs! (generic function with 3 methods)

In [57]:
m = 10
Sₘ = zeros(p, m)
Yₘ = zeros(p, m);
global_iter = 1

mb_size = 4000
X_mini, y_mini = batch_shuffle(X, y, mb_size) 

([0.7345304379615886 1.194419532332163 … -0.8337925975766404 0.7767363650416517; 0.5546372154799928 -1.291015929306516 … -1.4356116278902147 -1.1253333956724902; … ; -1.7691191329774862 -0.41010036497563407 … -0.5384284310795172 1.1860155373631533; -0.5981196367266651 1.482756460765924 … 0.07477237092761908 0.1589464326616038], [18013.960607947964; 19716.67198791676; … ; 5492.8770247251705; 27411.114012553313;;])

In [58]:
lbfgs!(F, zeros(p), X_mini, y_mini, Sₘ, Yₘ, global_iter, 45, m, 1e-6, 1)

z: 1.9599639845400576, diff: 2.3768409310449046e8, std_diff: 7.56635950598813e6, suff_dec: 2.2285430097867143e8
Iteration: 2 -- loss: 1.121393532773848e7 gradnorm: 13211.59233439968 ssize: 1
z: 1.9599639845400576, diff: 9.44565911585074e7, std_diff: 2.413496297382609e6, suff_dec: 8.972622533881672e7
Iteration: 3 -- loss: 4.375254895218422e6 gradnorm: 3423.120998393185 ssize: 1
z: 1.9599639845400576, diff: 6.838680432520061e6, std_diff: 185503.45337556416, suff_dec: 6.475100344896149e6
Iteration: 4 -- loss: 1.7312462680906192e6 gradnorm: 1423.5029101827113 ssize: 1
z: 1.9599639845400576, diff: 2.6440086271278025e6, std_diff: 68895.22625172295, suff_dec: 2.5089764649676867e6
Iteration: 5 -- loss: 235097.00723697478 gradnorm: 945.7814545671425 ssize: 1
z: 1.9599639845400576, diff: 1.4961492608536468e6, std_diff: 35596.04757328947, suff_dec: 1.4263822896180248e6
Iteration: 6 -- loss: 114650.92859321412 gradnorm: 432.2374163584864 ssize: 1
z: 1.9599639845400576, diff: 120446.07864376054, st

([1.0000000213124367; 2.0000000066059154; … ; 998.999999999791; 1000.0000000088077;;], 2.298860149635539e-13, 35, 36)

In [59]:
q = 1.2
num_inner_iter = 5
mbs = ones(Int, num_inner_iter) * 2000

for i in 2:num_inner_iter
    mbs[i] = ceil(Int, mbs[i - 1] * q)
end
mbs

5-element Vector{Int64}:
 2000
 2400
 2880
 3456
 4148

In [60]:
function retrospective_approximation(F, βᵢ, X, y, m, mbs, τgrad = 1e-6)
    # m for LBFGS init.
    S = zeros(p, m)
    Y = zeros(p, m)
    
    global_iter = 1
    local_iter = 0
    grad_calls = 0
    
    for mb_size in  mbs
        X_mini, y_mini = batch_shuffle(X, y, mb_size)
        println("IN BATCH SIZE $mb_size------------------------------------")
        β, _, local_iter, global_iter = lbfgs!(F, βᵢ, X_mini, y_mini, S, Y, global_iter, 45, m, τgrad, 1)
        βᵢ = β
        grad_calls+=(local_iter * mb_size)
    end
    return βᵢ, grad_calls
end

retrospective_approximation (generic function with 2 methods)

In [61]:
retrospective_approximation(F, zeros(p), X, y, m, mbs)

IN BATCH SIZE 2000------------------------------------
z: 1.9599639845400576, diff: 1.268944998069948e8, std_diff: 1.0470958440973582e7, suff_dec: 1.0637179837907086e8
Iteration: 2 -- loss: 2.3588594322308846e7 gradnorm: 20307.508825795165 ssize: 1
z: 1.9599639845400576, diff: 1.75767660344177e8, std_diff: 6.2115603228376005e6, suff_dec: 1.635932258236173e8
Iteration: 3 -- loss: 9.296331543667447e6 gradnorm: 5843.389117761632 ssize: 1
z: 1.9599639845400576, diff: 1.4292262778641399e7, std_diff: 611804.4548650419, suff_dec: 1.3093148081524754e7
Iteration: 4 -- loss: 6.207570109198522e6 gradnorm: 1881.1490709259358 ssize: 1
z: 1.9599639845400576, diff: 3.088761434468925e6, std_diff: 135718.79504214926, suff_dec: 2.8227574841611385e6
Iteration: 5 -- loss: 2.272010827482612e6 gradnorm: 1455.5330890920409 ssize: 1
z: 1.9599639845400576, diff: 3.935559281715911e6, std_diff: 156336.68334471242, suff_dec: 3.629145012897831e6
Iteration: 6 -- loss: 1.3276531870317408e6 gradnorm: 1128.48677929366

Iteration: 7 -- loss: 1.197806690921082e-5 gradnorm: 0.006721302173795593 ssize: 1
z: 1.9599639845400576, diff: 1.902735764764059e-5, std_diff: 8.317276613006658e-7, suff_dec: 1.7397201386545555e-5
Iteration: 8 -- loss: 3.7811732007997966e-6 gradnorm: 0.004168965912185425 ssize: 1
z: 1.9599639845400576, diff: 8.196893708411023e-6, std_diff: 2.667556219841006e-7, suff_dec: 7.674062296648605e-6
Iteration: 9 -- loss: 2.1714956853458853e-6 gradnorm: 0.001198108768231253 ssize: 1
z: 1.9599639845400576, diff: 1.6096775154539113e-6, std_diff: 5.1755653059501876e-8, suff_dec: 1.5082382994609373e-6
Iteration: 10 -- loss: 4.5789107892698314e-7 gradnorm: 0.0007637842339362159 ssize: 1
z: 1.9599639845400576, diff: 1.7136046064189022e-6, std_diff: 5.321744916546499e-8, suff_dec: 1.6093003227054996e-6
Iteration: 11 -- loss: 2.3880467821511844e-7 gradnorm: 0.0004885429922152771 ssize: 0.5
z: 1.9599639845400576, diff: 2.190864007118646e-7, std_diff: 1.0594682522025669e-8, suff_dec: 1.9832120454105825e

([0.9999999966496863; 1.9999999871661986; … ; 998.9999999940852; 1000.0000000014029;;], 141920)