Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizers #26

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Flux.jl
Expand Up @@ -41,6 +41,7 @@ include("layers/shims.jl")
include("backend/backend.jl")

include("data.jl")
include("optimizers.jl")
include("training.jl")

end # module
11 changes: 11 additions & 0 deletions src/backend/mxnet/model.jl
Expand Up @@ -34,6 +34,13 @@ end

loadparams!(exec::Exec) = copyargs!(exec.args, exec.graph.params)
storeparams!(exec::Exec) = copyargs!(exec.graph.params, exec.args)
storegrads!(exec::Exec) = begin
params, grads = exec.graph.params, exec.grads
for id in intersect(keys(params), keys(grads))
copy!(params[id].Δx, grads[id])
grads[id].data[:] = 0
end
end

mxgroup(x) = x
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
Expand All @@ -58,6 +65,7 @@ function executor(graph::Graph, input...; ctx = mx.cpu())
end

function (exec::Exec)(input...)
loadparams!(exec)
foreach(kv -> copy!(exec.args[kv[1]], kv[2]), dictt(exec.graph.input, input))
mx.forward(exec.exec, is_train = true)
mxungroup(exec.graph.output, copy(exec.outs))
Expand All @@ -66,6 +74,7 @@ end
function Flux.back!(exec::Exec, Δ)
mapt(k -> exec.grads[k][:] = 0, exec.graph.input)
mx.backward(exec.exec, map(x -> MXArray(x, exec.ctx).data, collectt(Δ)))
storegrads!(exec)
mapt(k -> copy(exec.grads[k]), exec.graph.input)
end

Expand Down Expand Up @@ -128,6 +137,8 @@ end

Flux.update!(m::Model, η) = (update!(m.last, η); m)

Flux.params(m::Model) = collect(Flux.Param, values(m.graph.params))

# Recurrent Models

using Flux: Stateful, SeqModel
Expand Down
45 changes: 25 additions & 20 deletions src/backend/tensorflow/model.jl
Expand Up @@ -4,48 +4,52 @@ struct Exec
session ::Session
input ::Any
output ::Any
grads ::Any
params ::Dict{Flux.Param,Tensor}
params ::Dict{Param,Param{Tensor}}
stacks ::Dict{Any,Any}
end

function makesession(model, inputs; session = Session(Graph()))
inputs = mapt(_ -> placeholder(Float32), inputs)
params, stacks, output = tograph(model, inputs...)
# grads = gradients(output, [collectt(inputs)..., values(params)...])
grads = placeholder(Float32)
output = mapt(x->Param{Tensor}(x, placeholder(Float32)), output)
params = Dict(x=>Param{Tensor}(y, gradients(mapt(x->x.x, output),
y, mapt(x->x.Δx, output))) for (x, y) in params)
inputs = mapt(x->Param{Tensor}(x, gradients(mapt(x->x.x, output),
x, mapt(x->x.Δx, output))), inputs)
run(session, global_variables_initializer())
Exec(session, inputs, output, grads, params, stacks)
Exec(session, inputs, output, params, stacks)
end

retuple(xs) = xs
retuple(xs::AbstractArray{<:AbstractArray}) = (retuple.(xs)...,)

dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))

function params(m::Exec, args...)
shapecheckt(m.input, args)
idict = dictt(m.input, args)
pdict = Dict(t => p.x for (p, t) in m.params)
merge(idict, pdict)
function Flux.params(m::Exec)
collect(keys(m.params))
end

function (m::Exec)(args...)
retuple(run(m.session, m.output, params(m, args...)))
dict = merge(
Dict(y.x=>x.x for (x, y) in m.params),
Dict(x.x=>y for (x, y) in zip(m.input, args))
)
retuple(run(m.session, mapt(x->x.x, m.output), dict))
end

pullt!(_, xs) = shift!(xs)
pullt!(x::Tuple, xs) = map(x -> pullt!(x, xs), x)
function Flux.back!(m::Exec, Δ, args...)
dict = merge(
Dict(y.x=>x.x for (x, y) in m.params),
Dict(x.x=>y for (x, y) in zip(m.input, args)),
Dict(x.Δx=>y for (x, y) in zip(collectt(m.output), collectt(Δ)))
)

# TODO: gradients don't work yet
# `gradients` lacks support for `grad_y`s and multiple `y`s
Δin, Δps = run(m.session, (mapt(x->x.Δx, m.input), map(x->x.Δx, values(m.params))), dict)

function Flux.back!(m::Exec, Δ, args...)
Δps = run(m.session, m.grads, params(m, args...))
Δin = pullt!(m.input, Δps)
for (p, Δ) in zip(keys(m.params), Δps)
p.Δx .+= Δ
end

Δin
end

Expand All @@ -70,8 +74,9 @@ function (m::Model)(args...)
@tferr m.exec.stacks m.exec(args...)
end

Flux.back!(m::Model, Δ, args...) = back!(m.exec, Δ, args...)
Flux.update!(m::Model, η) = (update!(m.exec, η); m)
Flux.back!(m::Model, Δ, args...) = Flux.back!(m.exec, Δ, args...)
Flux.update!(m::Model, η) = (Flux.update!(m.exec, η); m)
Flux.params(m::Model) = Flux.params(m.exec)

# Recurrent Models

Expand Down
2 changes: 1 addition & 1 deletion src/backend/tensorflow/tensorflow.jl
@@ -1,7 +1,7 @@
module TF

using ..Flux, DataFlow, TensorFlow, Juno
import Flux: accuracy, convertel
import Flux: accuracy, convertel, Param

export tf

Expand Down
8 changes: 4 additions & 4 deletions src/core.jl
Expand Up @@ -13,12 +13,12 @@ may be arrays or tuples of arrays (for multiple inputs/outputs).
back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(model))")

"""
update!(model, η) => m
update!(model, o) => model

Update the parameters of the model `m` using the accumulated gradients from
`back!`, using the learning rate `η`.
Update the parameters of the model `model` using the accumulated gradients from
`back!`, using the optimizer `o`.
"""
update!(m, η) = m
update!(m, o) = m

"""
graph(model) => ::IVertex{Any} | nothing
Expand Down
4 changes: 4 additions & 0 deletions src/layers/affine.jl
Expand Up @@ -22,3 +22,7 @@ function update!(m::Affine, η)
update!(m.b, η)
m
end

function params(m::Affine)
Param[m.W, m.b]
end
2 changes: 2 additions & 0 deletions src/layers/control.jl
Expand Up @@ -9,6 +9,8 @@ end
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)

params(s::Chain) = mapreduce(params, append!, s.layers)

function back!(s::Chain, Δ, x)
crumbs = foldl([x], s.layers[1:end-1]) do crumbs, layer
push!(crumbs, layer(crumbs[end]))
Expand Down