In [1]:
abstract GradFunction
abstract ProxFunction

# loss functions
type SquareLoss <: GradFunction
end

function grad(::SquareLoss, X, y, theta)
    return 2*(X'*theta - y)
end
function loss_eval(::SquareLoss, X, y, theta)
    return norm(X'*theta - y)^2
end

square_loss = SquareLoss()


# regularizers
type SquareReg <: ProxFunction
end

function prox(::SquareReg, t, z)
    return z/(t+1)
end

square_reg = SquareReg()

type TrivialReg <: ProxFunction
end

function prox(::TrivialReg, t, z)
    return z
end

trivial_reg = TrivialReg()

TrivialReg()

In [60]:
# Empirical risk minimization problem
type Erm
    X::Array{Float64, 2}
    y::Array{Float64, 1}
    loss::GradFunction
    reg::ProxFunction
    opt_val::Float64
    opt_x::Array{Float64, 1}
end

# Constructor

# without regularizer
function Erm(X::Array{Float64, 2}, y::Array{Float64, 1},
        loss::GradFunction)
    return Erm(X, y, loss, trivial_reg, 0, zeros(size(X, 1)))
end

# with regularizer
function Erm(X::Array{Float64, 2}, y::Array{Float64, 1},
        loss::GradFunction, reg::ProxFunction)
    return Erm(X, y, loss, reg, 0, zeros(size(X, 1)))
end

function solve!(erm::Erm, x0=nothing, tol=1e-5, beta=.5, max_iter=20, verbose=true)
    converged = false
    
    if x0 == nothing
        x0 = zeros(size(erm.X, 1))
    end
    
    x = x0
    
    for i in 1:max_iter
        lambda = 1
        curr_grad = grad(erm.loss, erm.X, erm.y, x)
        curr_eval = loss_eval(erm.loss, erm.X, erm.y, x)
        info("current x is $(x)")
        
        # Prox iteration
        while true
            z = prox(erm.reg, lambda, x - lambda*curr_grad)
            delta = z-x
            z_loss = loss_eval(erm.loss, erm.X, erm.y, z)
            if z_loss <= curr_eval + dot(curr_grad,delta) + 1/(2*lambda)*norm(delta)^2
                x = z
                break
            end
            lambda *= beta
        end
    end
    if !converged
        warn("Failed to converge after $(max_iter) iterations")
    end
    
    erm.opt_x = x
    erm.opt_val = loss_eval(erm.loss, erm.X, erm.y, x)
end



solve! (generic function with 6 methods)

In [61]:
X = rand(4, 4)
y = rand(4)
erm = Erm(X, y, square_loss, trivial_reg)

Erm([0.336185 0.140731 0.0133459 0.428493; 0.84794 0.405326 0.771097 0.965971; 0.973543 0.321114 0.88306 0.346935; 0.98983 0.733295 0.279469 0.726687],[0.611797,0.385271,0.442484,0.734189],SquareLoss(),TrivialReg(),0.0,[0.0,0.0,0.0,0.0])

In [62]:
solve!(erm)

[1m[34mINFO: current x is [0.0,0.0,0.0,0.0]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.192635,0.221242,0.367095]
[0m[1m[34mINFO: current x is [0.305898,0.

0.06561182064297796

In [63]:
norm(X'*(pinv(X')*y) - y)^2

1.1093356479670479e-31

In [56]:
pinv(X')*y

4-element Array{Float64,1}:
 -1.20004 
 -0.108752
  1.49606 
  0.567302

In [57]:
erm.opt_x

4-element Array{Float64,1}:
 0.371057
 0.215693
 0.181848
 0.316558

In [59]:
loss_eval(square_loss, erm.X, erm.y, erm.opt_x)

0.3197000351687596