In [1]:
using ChainRules
using DiffOpt
using Flux
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated
using OSQP
using JuMP
using LinearAlgebra

In [2]:
imgs = Flux.Data.MNIST.images()
labels = Flux.Data.MNIST.labels();

In [21]:
## prepare data
# Preprocessing
X = hcat(float.(reshape.(imgs, :))...) #stack all the images
Y = onehotbatch(labels, 0:9); # just a common way to encode categorical variables

test_X = hcat(float.(reshape.(Flux.Data.MNIST.images(:test), :))...)
test_Y = onehotbatch(Flux.Data.MNIST.labels(:test), 0:9);

## Custom Relu

In [20]:
function myRelu(y)
    N = length(y)
    # create model
    model = Model(() -> diff_optimizer(OSQP.Optimizer))
    @variable(model, x[1:N])
    @constraint(model, x .>= 0.0)
    @objective(
        model,
        Min,
        x'x -2x'y + y'y,
    )

    optimize!(model)

    x̂ = value.(x)
    return x̂
end

myRelu (generic function with 2 methods)

In [8]:
function rrule(::typeof(myRelu), dx::Array{Float64})
    N = length(y)
    # create model
    model = Model(() -> diff_optimizer(OSQP.Optimizer))
    @variable(model, x[1:N])
    @constraint(model, x .>= 0.0)
    @objective(
        model,
        Min,
        x'x -2x'y + y'y,
    )

    optimize!(model)

    x̂ = value.(x)
    
    function _pullback(dx)
        MOI.set.(
            model,
            DiffOpt.BackwardIn{MOI.VariablePrimal}(), 
            x,
            dx
        ) 

        # find grad
        DiffOpt.backward(model)
        
        
#         MOI.get.(
#             model,
#             DiffOpt.BackwardOut{DiffOpt.ConstraintCoefficient}(), 
#             x,
#             cons,
#         )
        
        dy = MOI.get.(
            model,
            DiffOpt.BackwardOut{DiffOpt.ConstraintConstant}(), 
            cons, 
        )
        return NO_FIELDS, dy
    end
    return x̂, _pullback
end

rrule (generic function with 1 method)

## Define the NN

In [22]:
m = Chain(
    Dense(784, 64, myRelu),
    Dense(64, 10),
    softmax
)

Chain(Dense(784, 64, myRelu), Dense(64, 10), NNlib.softmax)

In [23]:
loss(x, y) = crossentropy(m(x), y) 
opt = ADAM(); # popular stochastic gradient descent variant

accuracy(x, y) = mean(onecold(m(x)) .== onecold(y)) # cute way to find average of correct guesses

dataset = repeated((X,Y), 2) # repeat the data set
evalcb = () -> @show(loss(X, Y)) # callback to show loss

#31 (generic function with 1 method)

In [24]:
Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 5)); #took me ~5 minutes to train on CPU

MethodError: MethodError: no method matching zero(::Type{Adjoint{GenericAffExpr{Float64,VariableRef},Array{GenericAffExpr{Float64,VariableRef},1}}})
Closest candidates are:
  zero(!Matched::Type{LibGit2.GitHash}) at /buildworker/worker/package_linux32/build/usr/share/julia/stdlib/v1.0/LibGit2/src/oid.jl:220
  zero(!Matched::Type{Pkg.Resolve.VersionWeights.VersionWeight}) at /buildworker/worker/package_linux32/build/usr/share/julia/stdlib/v1.0/Pkg/src/resolve/VersionWeights.jl:19
  zero(!Matched::Type{Pkg.Resolve.MaxSum.FieldValues.FieldValue}) at /buildworker/worker/package_linux32/build/usr/share/julia/stdlib/v1.0/Pkg/src/resolve/FieldValues.jl:44
  ...

In [128]:
@show accuracy(X,Y)
@show accuracy(test_X, test_Y);

accuracy(X, Y) = 0.24961666666666665
accuracy(test_X, test_Y) = 0.2567
