From 5131cf2bb3d11f0b19ceda79cee490d150130e61 Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Tue, 13 Jun 2017 14:40:01 +0800 Subject: [PATCH 1/6] get momentum work on pure Julia model --- src/Flux.jl | 1 + src/core.jl | 8 ++-- src/optimizers.jl | 118 ++++++++++++++++++++++++++++++++++++++++++++++ src/training.jl | 6 +-- 4 files changed, 126 insertions(+), 7 deletions(-) create mode 100644 src/optimizers.jl diff --git a/src/Flux.jl b/src/Flux.jl index 9a508002cf..ea67d79c1b 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -41,6 +41,7 @@ include("layers/shims.jl") include("backend/backend.jl") include("data.jl") +include("optimizers.jl") include("training.jl") end # module diff --git a/src/core.jl b/src/core.jl index 3100f957ab..2f635802b1 100644 --- a/src/core.jl +++ b/src/core.jl @@ -13,12 +13,12 @@ may be arrays or tuples of arrays (for multiple inputs/outputs). back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(model))") """ - update!(model, η) => m + update!(model, o) => model -Update the parameters of the model `m` using the accumulated gradients from -`back!`, using the learning rate `η`. +Update the parameters of the model `model` using the accumulated gradients from +`back!`, using the optimizer `o`. """ -update!(m, η) = m +update!(m, o) = m """ graph(model) => ::IVertex{Any} | nothing diff --git a/src/optimizers.jl b/src/optimizers.jl new file mode 100644 index 0000000000..c3d7e646a8 --- /dev/null +++ b/src/optimizers.jl @@ -0,0 +1,118 @@ +export SGD + +struct Optimizer + cache::Dict{Param, Vector{Function}} + steps + Optimizer(steps) = new(Dict{Param, Function}(), steps) +end + +function update!(p::Param, o::Optimizer) + steps = Base.@get!(o.cache, p, map(x->x(p), o.steps)) + foreach(f->f(p), steps) + p.x .-= p.Δx +end + +function Momentum(η) + function (p) + momentum = zeros(p.x) + + function (p) + momentum .= η .* momentum .+ p.Δx + p.Δx .= momentum + end + end +end + +function NesterovMomentum() + error("TODO") +end + +function WeightDecayConst() + error("TODO") +end + +function WeightDecayRatio() + error("TODO") +end + +function GradDecayFix(lr) + function (p::Param) + function (p::Param) + p.Δx .= lr .* p.Δx + end + end +end + +function GradDecayExp() + error("TODO") +end + +function GradDecayInv() + error("TODO") +end + +function WeightClipConst() + error("TODO") +end + +function WeightClipNorm() + error("TODO") +end + +function GradClipConst() + error("TODO") +end + +function GradClipNorm() + error("TODO") +end + +macro restrict_range(var::Symbol, range::String) + left, right = split(range, ", ") + lo = left[1] == '[' ? :>= : :> + lt = left[2:end] + ro = right[end] == ']' ? :<= : :< + rt = right[1:end-1] + + error_msg = "$var ∈ $range must be hold" + var = esc(var) + + quote + $( lt != "-∞" && :( $lo($var, $(parse(Float64, lt))) || throw(ArgumentError($error_msg)) ) ) + $( rt != "∞" && :( $ro($var, $(parse(Float64, rt))) || throw(ArgumentError($error_msg)) ) ) + end +end + +""" +Stochastic gradient descent optimizer. +Includes support for momentum, +learning rate decay, and Nesterov momentum. + +# Arguments + lr: float >= 0. Learning rate. + momentum: float >= 0. Parameter updates momentum. + decay: float >= 0. Learning rate decay over each update. + nesterov: boolean. Whether to apply Nesterov momentum. +""" +function SGD(; lr::Real=.1, + momentum::Real=0, + decay::Real=0, + nesterov::Bool=false) + + @restrict_range lr "[0, ∞)" + @restrict_range momentum "[0, 1]" + @restrict_range decay "[0, ∞)" + + steps = [] + + if momentum != 0 + nesterov ? push!(steps, NesterovMomentum(momentum)) : + push!(steps, Momentum(momentum)) + end + + decay != 0 && push!(steps, GradDecayInv(decay)) + + lr != 1 && push!(steps, GradDecayFix(lr)) + + Optimizer(steps) +end \ No newline at end of file diff --git a/src/training.jl b/src/training.jl index 62a34c46e8..e17e01388f 100644 --- a/src/training.jl +++ b/src/training.jl @@ -25,8 +25,8 @@ macro cb(ex, t, f) end) end -function train!(m, train; cb = [], - epoch = 1, η = 0.1, loss = mse) +function train!(m, train; cb = [], opt = SGD(), + epoch = 1, loss = mse) @progress for e in 1:epoch info("Epoch $e") @cb for (x, y) in train @@ -35,7 +35,7 @@ function train!(m, train; cb = [], any(isnan, ŷ) && error("NaN") Δ = back!(loss, 1, ŷ, y) back!(m, Δ, x) - update!(m, η) + update!(m, opt) end 5 foreach(f -> f(), cb) end return m From 7a54878b94540961cfadaa68d3a0644cfa78fcc0 Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Tue, 27 Jun 2017 14:54:21 +0800 Subject: [PATCH 2/6] more optimizers --- src/optimizers.jl | 410 +++++++++++++++++++++++++++++++++------------- test/optimizer.jl | 82 +++++----- 2 files changed, 336 insertions(+), 156 deletions(-) diff --git a/src/optimizers.jl b/src/optimizers.jl index c3d7e646a8..203fae5cf8 100644 --- a/src/optimizers.jl +++ b/src/optimizers.jl @@ -1,118 +1,292 @@ -export SGD - -struct Optimizer - cache::Dict{Param, Vector{Function}} - steps - Optimizer(steps) = new(Dict{Param, Function}(), steps) -end - -function update!(p::Param, o::Optimizer) - steps = Base.@get!(o.cache, p, map(x->x(p), o.steps)) - foreach(f->f(p), steps) - p.x .-= p.Δx -end - -function Momentum(η) - function (p) - momentum = zeros(p.x) - - function (p) - momentum .= η .* momentum .+ p.Δx - p.Δx .= momentum - end - end -end - -function NesterovMomentum() - error("TODO") -end - -function WeightDecayConst() - error("TODO") -end - -function WeightDecayRatio() - error("TODO") -end - -function GradDecayFix(lr) - function (p::Param) - function (p::Param) - p.Δx .= lr .* p.Δx - end - end -end - -function GradDecayExp() - error("TODO") -end - -function GradDecayInv() - error("TODO") -end - -function WeightClipConst() - error("TODO") -end - -function WeightClipNorm() - error("TODO") -end - -function GradClipConst() - error("TODO") -end - -function GradClipNorm() - error("TODO") -end - -macro restrict_range(var::Symbol, range::String) - left, right = split(range, ", ") - lo = left[1] == '[' ? :>= : :> - lt = left[2:end] - ro = right[end] == ']' ? :<= : :< - rt = right[1:end-1] - - error_msg = "$var ∈ $range must be hold" - var = esc(var) - - quote - $( lt != "-∞" && :( $lo($var, $(parse(Float64, lt))) || throw(ArgumentError($error_msg)) ) ) - $( rt != "∞" && :( $ro($var, $(parse(Float64, rt))) || throw(ArgumentError($error_msg)) ) ) - end -end - -""" -Stochastic gradient descent optimizer. -Includes support for momentum, -learning rate decay, and Nesterov momentum. - -# Arguments - lr: float >= 0. Learning rate. - momentum: float >= 0. Parameter updates momentum. - decay: float >= 0. Learning rate decay over each update. - nesterov: boolean. Whether to apply Nesterov momentum. -""" -function SGD(; lr::Real=.1, - momentum::Real=0, - decay::Real=0, - nesterov::Bool=false) - - @restrict_range lr "[0, ∞)" - @restrict_range momentum "[0, 1]" - @restrict_range decay "[0, ∞)" - - steps = [] - - if momentum != 0 - nesterov ? push!(steps, NesterovMomentum(momentum)) : - push!(steps, Momentum(momentum)) - end - - decay != 0 && push!(steps, GradDecayInv(decay)) - - lr != 1 && push!(steps, GradDecayFix(lr)) - - Optimizer(steps) -end \ No newline at end of file +export SGD, AdaGrad, RMSProp, AdaDelta, Adam + +struct Optimizer + cache::Dict{Param, Vector{Function}} + steps + Optimizer(steps) = new(Dict{Param, Function}(), steps) +end + +function update!(p::Param, o::Optimizer) + steps = Base.@get!(o.cache, p, map(x->x(p), o.steps)) + foreach(f->f(p), steps) + @. p.x -= p.Δx +end + +function Momentum(η) + function (p) + momentum = zeros(p.x) + + function (p) + @. momentum = η * momentum + p.Δx + @. p.Δx = momentum + end + end +end + +function NesterovMomentum(η) + function (p) + momentum = zeros(p.x) + + function (p) + @. momentum = η * momentum + p.Δx + @. p.Δx = η * momentum + p.Δx + end + end +end + +function WeightDecayConst(γ) + function (p) + function (p) + # avoid bouncing around 0 + x = p.x - p.Δx + @. p.Δx += (abs(x) <= γ) * x + (abs(x) > γ) * γ * sign(x) + end + end +end + +function WeightDecayRatio(γ) + function (p) + function (p) + @. p.Δx += γ * p.x + end + end +end + +function GradDecayFix(lr) + function (p) + function (p) + @. p.Δx *= lr + end + end +end + +function GradDecayExp(γ) + function (p) + n_iter = 0 + + function (p) + p.Δx .*= γ ^ n_iter + n_iter += 1 + end + end +end + +function GradDecayInv(γ) + function (p) + n_iter = 0 + + function (p) + p.Δx .*= 1 / (1 + γ * n_iter) + n_iter += 1 + end + end +end + +function WeightClipConst() + error("TODO") +end + +function WeightClipNorm() + error("TODO") +end + +function GradClipConst(threshold) + function (p) + function (p) + p.Δx .= max.(min.(p.Δx, threshold), -threshold) + end + end +end + +function GradClipNorm() + error("TODO") +end + +function Accumulate(window) + function (p) + index = 0 + acc = zeros(p.x) + + function (p) + acc .+= p.Δx + + if index >= window + p.Δx .= acc + acc .= 0 + index = 0 + else + p.Δx .= 0 + index += 1 + end + end + end +end + +function _AdaGrad(ϵ) + function (p) + acc = zeros(p.x) .+ ϵ + + function (p) + @. acc += p.Δx ^ 2 + @. p.Δx /= √acc + end + end +end + +function _RMSProp(ρ, ϵ) + function (p) + acc = zeros(p.x) .+ ϵ + + function (p) + @. acc = ρ * acc + (1 - ρ) * p.Δx ^ 2 + @. p.Δx /= √acc + end + end +end + +function _AdaDelta(ρ, ϵ) + function (p) + acc = zeros(p.x) .+ ϵ + Δacc = zeros(p.x) .+ ϵ + + function (p) + @. acc = ρ * acc + (1 - ρ) * p.Δx ^ 2 + @. p.Δx *= √Δacc / √acc + @. Δacc = ρ * Δacc + (1 - ρ) * p.Δx ^ 2 + end + end +end + +function _Adam(β1, β2, ϵ) + function (p) + mt = zeros(p.x) + vt = zeros(p.x) .+ ϵ + β1p = β1 + β2p = β2 + + function (p) + @. mt = β1 * mt + (1 - β1) * p.Δx + @. vt = β2 * vt + (1 - β2) * p.Δx ^ 2 + + @. p.Δx = √(1 - β2p) / √(1 - β1p) * mt / √vt + + β1p *= β1 + β2p *= β2 + end + end +end + +macro restrict_range(var::Symbol, range::String) + left, right = split(range, ", ") + lo = left[1] == '[' ? :>= : :> + lt = left[2:end] + ro = right[end] == ']' ? :<= : :< + rt = right[1:end-1] + + error_msg = "$var ∈ $range must be hold" + var = esc(var) + + quote + $( lt != "-∞" && :( $lo($var, $(parse(Float64, lt))) || throw(ArgumentError($error_msg)) ) ) + $( rt != "∞" && :( $ro($var, $(parse(Float64, rt))) || throw(ArgumentError($error_msg)) ) ) + end +end + +function SGD(; lr::Real=.1, + momentum::Real=0, + decay::Real=0, + nesterov::Bool=false) + + @restrict_range lr "[0, ∞)" + @restrict_range momentum "[0, 1]" + @restrict_range decay "[0, ∞)" + + steps = [] + + if momentum != 0 + nesterov ? push!(steps, NesterovMomentum(momentum)) : + push!(steps, Momentum(momentum)) + end + + decay != 0 && push!(steps, GradDecayInv(decay)) + + lr != 1 && push!(steps, GradDecayFix(lr)) + + Optimizer(steps) +end + +function AdaGrad(; lr::Real=.001, + epsilon::Real=1e-6, + decay::Real=0.) + + @restrict_range lr "[0, ∞)" + @restrict_range epsilon "(0, ∞)" + @restrict_range decay "[0, ∞)" + + steps = Any[_AdaGrad(epsilon)] + + decay != 0 && push!(steps, GradDecayInv(decay)) + + lr != 1 && push!(steps, GradDecayFix(lr)) + + Optimizer(steps) +end + +function RMSProp(; lr::Real=.001, + rho::Real=.9, + epsilon::Real=1e-6, + decay::Real=0.) + + @restrict_range lr "[0, ∞)" + @restrict_range rho "[0, 1]" + @restrict_range epsilon "(0, ∞)" + @restrict_range decay "[0, ∞)" + + steps = Any[_RMSProp(rho, epsilon)] + + decay != 0 && push!(steps, GradDecayInv(decay)) + + lr != 1 && push!(steps, GradDecayFix(lr)) + + Optimizer(steps) +end + +function AdaDelta(; lr::Real=1., + rho::Real=.9, + epsilon::Real=1e-6, + decay::Real=0.) + + @restrict_range lr "[0, ∞)" + @restrict_range rho "[0, 1]" + @restrict_range epsilon "(0, ∞)" + @restrict_range decay "[0, ∞)" + + steps = Any[_AdaDelta(rho, epsilon)] + + decay != 0 && push!(steps, GradDecayInv(decay)) + + lr != 1 && push!(steps, GradDecayFix(lr)) + + Optimizer(steps) +end + +function Adam(; lr::Real=.1, + beta1::Real=.9, + beta2::Real=.999, + epsilon::Real=1e-6, + decay::Real=0.) + + @restrict_range lr "[0, ∞)" + @restrict_range beta1 "[0, 1]" + @restrict_range beta2 "[0, 1]" + @restrict_range epsilon "(0, ∞)" + @restrict_range decay "[0, ∞)" + + steps = Any[_Adam(beta1, beta2, epsilon)] + + decay != 0 && push!(steps, GradDecayInv(decay)) + + lr != 1 && push!(steps, GradDecayFix(lr)) + + Optimizer(steps) +end diff --git a/test/optimizer.jl b/test/optimizer.jl index 57f1d0113f..3e4c3b25b0 100644 --- a/test/optimizer.jl +++ b/test/optimizer.jl @@ -1,38 +1,44 @@ -@testset "training julia models" begin - - @testset "linear regression" begin - srand(0) - - model = Affine(10, 1) - - truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]' - - data = map(1:256) do i - x = rand(Float32, 10) - x, truth * x + 3rand(Float32) - end - - Flux.train!(model, data, epoch=5) - - @test cor(reshape.((model.W.x, truth), 10)...) > .99 - end - - @testset "logistic regression" begin - srand(0) - - model = Chain(Affine(10, 1), σ) - - truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]' - - data = map(1:256) do i - x = rand(Float32, 10) - x, truth * x + 2rand(Float32) > 5f0 - end - - Flux.train!(model, data, epoch=10) - - @test cor(reshape.((model.layers[1].W.x, truth), 10)...) > .99 - end - -end - +@testset "training julia models" begin + + @testset "linear regression" begin + srand(0) + + truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]' + + data = map(1:256) do i + x = rand(Float32, 10) + x, truth * x + 3rand(Float32) + end + + # It's hard to tell if an optimizer works exactly right, but we + # can at least ensure that they all converge at the right point + for opt in (SGD(), SGD(momentum=.9, decay=.01), AdaGrad(lr=1.), RMSProp(lr=.1), AdaDelta(lr=1e3), Adam()) + model = Affine(10, 1) + + Flux.train!(model, data, epoch=10, opt=opt) + + @test cor(reshape.((model.W.x, truth), 10)...) > .99 + end + end + + @testset "logistic regression" begin + srand(0) + + truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]' + + data = map(1:256) do i + x = rand(Float32, 10) + x, truth * x + 2rand(Float32) > 5f0 + end + + for opt in (SGD(), SGD(momentum=.9, decay=.01), AdaGrad(lr=1.), RMSProp(lr=.1), AdaDelta(lr=1e3), Adam()) + model = Chain(Affine(10, 1), σ) + + Flux.train!(model, data, epoch=10) + + @test cor(reshape.((model.layers[1].W.x, truth), 10)...) > .99 + end + end + +end + From 77c812022660e383be4f42cf994204c71fe1f3b3 Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Sat, 8 Jul 2017 14:47:38 +0800 Subject: [PATCH 3/6] add params API --- src/layers/affine.jl | 4 ++++ src/layers/control.jl | 2 ++ src/optimizers.jl | 15 +++++++++------ src/params.jl | 2 ++ src/training.jl | 3 ++- 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/layers/affine.jl b/src/layers/affine.jl index ca79c0044b..c4f34dc56c 100644 --- a/src/layers/affine.jl +++ b/src/layers/affine.jl @@ -22,3 +22,7 @@ function update!(m::Affine, η) update!(m.b, η) m end + +function params(m::Affine) + Param[m.W, m.b] +end diff --git a/src/layers/control.jl b/src/layers/control.jl index d0c5e61b7e..b46f57734a 100644 --- a/src/layers/control.jl +++ b/src/layers/control.jl @@ -9,6 +9,8 @@ end (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers) +params(s::Chain) = mapreduce(params, append!, s.layers) + function back!(s::Chain, Δ, x) crumbs = foldl([x], s.layers[1:end-1]) do crumbs, layer push!(crumbs, layer(crumbs[end])) diff --git a/src/optimizers.jl b/src/optimizers.jl index 203fae5cf8..bb194ee254 100644 --- a/src/optimizers.jl +++ b/src/optimizers.jl @@ -1,15 +1,18 @@ export SGD, AdaGrad, RMSProp, AdaDelta, Adam struct Optimizer - cache::Dict{Param, Vector{Function}} steps - Optimizer(steps) = new(Dict{Param, Function}(), steps) end -function update!(p::Param, o::Optimizer) - steps = Base.@get!(o.cache, p, map(x->x(p), o.steps)) - foreach(f->f(p), steps) - @. p.x -= p.Δx +function (o::Optimizer)(ps::Vector{Param}) + states = map(ps) do p + p, map(x->x(p), o.steps) + end + + () -> for (p, steps) in states + foreach(f->f(p), steps) + @. p.x -= p.Δx + end end function Momentum(η) diff --git a/src/params.jl b/src/params.jl index 7501ad4669..aad2f2788c 100644 --- a/src/params.jl +++ b/src/params.jl @@ -39,3 +39,5 @@ end Base.copy!(xs, p::Param) = copy!(xs, p.x) Base.copy!(p::Param, xs) = copy!(p.x, xs) + +params(m) = Param[] diff --git a/src/training.jl b/src/training.jl index e17e01388f..3445fd66c6 100644 --- a/src/training.jl +++ b/src/training.jl @@ -29,13 +29,14 @@ function train!(m, train; cb = [], opt = SGD(), epoch = 1, loss = mse) @progress for e in 1:epoch info("Epoch $e") + opt! = opt(params(m)) @cb for (x, y) in train x, y = mapt(tobatch, (x, y)) ŷ = m(x) any(isnan, ŷ) && error("NaN") Δ = back!(loss, 1, ŷ, y) back!(m, Δ, x) - update!(m, opt) + opt!() end 5 foreach(f -> f(), cb) end return m From d63df4f8f5c3c8e73b2434db66ef641a1219c4a7 Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Sat, 8 Jul 2017 16:31:58 +0800 Subject: [PATCH 4/6] params API support for mxnet --- src/backend/mxnet/model.jl | 11 +++++++++++ src/training.jl | 3 ++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 3ea9ea12b9..005279ec8b 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -34,6 +34,13 @@ end loadparams!(exec::Exec) = copyargs!(exec.args, exec.graph.params) storeparams!(exec::Exec) = copyargs!(exec.graph.params, exec.args) +storegrads!(exec::Exec) = begin + params, grads = exec.graph.params, exec.grads + for id in intersect(keys(params), keys(grads)) + copy!(params[id].Δx, grads[id]) + grads[id].data[:] = 0 + end +end mxgroup(x) = x mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...) @@ -58,6 +65,7 @@ function executor(graph::Graph, input...; ctx = mx.cpu()) end function (exec::Exec)(input...) + loadparams!(exec) foreach(kv -> copy!(exec.args[kv[1]], kv[2]), dictt(exec.graph.input, input)) mx.forward(exec.exec, is_train = true) mxungroup(exec.graph.output, copy(exec.outs)) @@ -66,6 +74,7 @@ end function Flux.back!(exec::Exec, Δ) mapt(k -> exec.grads[k][:] = 0, exec.graph.input) mx.backward(exec.exec, map(x -> MXArray(x, exec.ctx).data, collectt(Δ))) + storegrads!(exec) mapt(k -> copy(exec.grads[k]), exec.graph.input) end @@ -128,6 +137,8 @@ end Flux.update!(m::Model, η) = (update!(m.last, η); m) +Flux.params(m::Model) = collect(Flux.Param, values(m.graph.params)) + # Recurrent Models using Flux: Stateful, SeqModel diff --git a/src/training.jl b/src/training.jl index 3445fd66c6..06dc2a6b13 100644 --- a/src/training.jl +++ b/src/training.jl @@ -29,10 +29,11 @@ function train!(m, train; cb = [], opt = SGD(), epoch = 1, loss = mse) @progress for e in 1:epoch info("Epoch $e") - opt! = opt(params(m)) + opt! = nothing # `params(m)` is not valid before calling m(x) in mxnet backend @cb for (x, y) in train x, y = mapt(tobatch, (x, y)) ŷ = m(x) + opt! = opt! == nothing ? opt(params(m)) : opt! any(isnan, ŷ) && error("NaN") Δ = back!(loss, 1, ŷ, y) back!(m, Δ, x) From e166c5e7c4608e5a87f9821a5e1bb67f05bc584f Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Thu, 20 Jul 2017 20:35:20 +0800 Subject: [PATCH 5/6] training for the tensorflow backend --- src/backend/tensorflow/model.jl | 43 +++++++++++++++------------- src/backend/tensorflow/tensorflow.jl | 2 +- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index 836861a010..3c99cc2c4c 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -4,18 +4,18 @@ struct Exec session ::Session input ::Any output ::Any - grads ::Any - params ::Dict{Flux.Param,Tensor} + params ::Dict{Param,Param{Tensor}} stacks ::Dict{Any,Any} end function makesession(model, inputs; session = Session(Graph())) inputs = mapt(_ -> placeholder(Float32), inputs) params, stacks, output = tograph(model, inputs...) - # grads = gradients(output, [collectt(inputs)..., values(params)...]) - grads = placeholder(Float32) + params = Dict(x=>Param{Tensor}(y, gradients(output, y)) for (x, y) in params) + inputs = mapt(x->Param{Tensor}(x, gradients(output, x)), inputs) + output = mapt(x->Param{Tensor}(x, placeholder(Float32)), output) run(session, global_variables_initializer()) - Exec(session, inputs, output, grads, params, stacks) + Exec(session, inputs, output, params, stacks) end retuple(xs) = xs @@ -23,29 +23,31 @@ retuple(xs::AbstractArray{<:AbstractArray}) = (retuple.(xs)...,) dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys))) -function params(m::Exec, args...) - shapecheckt(m.input, args) - idict = dictt(m.input, args) - pdict = Dict(t => p.x for (p, t) in m.params) - merge(idict, pdict) +function Flux.params(m::Exec) + collect(keys(m.params)) end function (m::Exec)(args...) - retuple(run(m.session, m.output, params(m, args...))) + dict = merge( + Dict(y.x=>x.x for (x, y) in m.params), + Dict(x.x=>y for (x, y) in zip(m.input, args)) + ) + retuple(run(m.session, mapt(x->x.x, m.output), dict)) end -pullt!(_, xs) = shift!(xs) -pullt!(x::Tuple, xs) = map(x -> pullt!(x, xs), x) +function Flux.back!(m::Exec, Δ, args...) + dict = merge( + Dict(y.x=>x.x for (x, y) in m.params), + Dict(x.x=>y for (x, y) in zip(m.input, args)), + Dict(x.Δx=>y for (x, y) in zip(collectt(m.output), collectt(Δ))) + ) -# TODO: gradients don't work yet -# `gradients` lacks support for `grad_y`s and multiple `y`s + Δin, Δps = run(m.session, (mapt(x->x.Δx, m.input), map(x->x.Δx, values(m.params))), dict) -function Flux.back!(m::Exec, Δ, args...) - Δps = run(m.session, m.grads, params(m, args...)) - Δin = pullt!(m.input, Δps) for (p, Δ) in zip(keys(m.params), Δps) p.Δx .+= Δ end + Δin end @@ -70,8 +72,9 @@ function (m::Model)(args...) @tferr m.exec.stacks m.exec(args...) end -Flux.back!(m::Model, Δ, args...) = back!(m.exec, Δ, args...) -Flux.update!(m::Model, η) = (update!(m.exec, η); m) +Flux.back!(m::Model, Δ, args...) = Flux.back!(m.exec, Δ, args...) +Flux.update!(m::Model, η) = (Flux.update!(m.exec, η); m) +Flux.params(m::Model) = Flux.params(m.exec) # Recurrent Models diff --git a/src/backend/tensorflow/tensorflow.jl b/src/backend/tensorflow/tensorflow.jl index 74c940124c..7536bdbc52 100644 --- a/src/backend/tensorflow/tensorflow.jl +++ b/src/backend/tensorflow/tensorflow.jl @@ -1,7 +1,7 @@ module TF using ..Flux, DataFlow, TensorFlow, Juno -import Flux: accuracy, convertel +import Flux: accuracy, convertel, Param export tf From 45f0f22fd5e95e505c3e8a6e22a8821a52aa7fc7 Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Thu, 20 Jul 2017 20:50:22 +0800 Subject: [PATCH 6/6] fix --- src/backend/tensorflow/model.jl | 6 ++++-- src/training.jl | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index 3c99cc2c4c..4ca52bc10a 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -11,9 +11,11 @@ end function makesession(model, inputs; session = Session(Graph())) inputs = mapt(_ -> placeholder(Float32), inputs) params, stacks, output = tograph(model, inputs...) - params = Dict(x=>Param{Tensor}(y, gradients(output, y)) for (x, y) in params) - inputs = mapt(x->Param{Tensor}(x, gradients(output, x)), inputs) output = mapt(x->Param{Tensor}(x, placeholder(Float32)), output) + params = Dict(x=>Param{Tensor}(y, gradients(mapt(x->x.x, output), + y, mapt(x->x.Δx, output))) for (x, y) in params) + inputs = mapt(x->Param{Tensor}(x, gradients(mapt(x->x.x, output), + x, mapt(x->x.Δx, output))), inputs) run(session, global_variables_initializer()) Exec(session, inputs, output, params, stacks) end diff --git a/src/training.jl b/src/training.jl index 06dc2a6b13..7f752c5e60 100644 --- a/src/training.jl +++ b/src/training.jl @@ -27,9 +27,9 @@ end function train!(m, train; cb = [], opt = SGD(), epoch = 1, loss = mse) + opt! = nothing # `params(m)` is not valid before calling m(x) in both backends @progress for e in 1:epoch info("Epoch $e") - opt! = nothing # `params(m)` is not valid before calling m(x) in mxnet backend @cb for (x, y) in train x, y = mapt(tobatch, (x, y)) ŷ = m(x)