Skip to content

Commit

Permalink
Synchronous and shared memory, scaled batchsize and learning rate
Browse files Browse the repository at this point in the history
  • Loading branch information
klensink committed Feb 21, 2018
1 parent 89c4b89 commit 96b112c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
2 changes: 1 addition & 1 deletion examples/EResNN_CIFAR10.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using MAT, Meganet
BLAS.set_num_threads(1)

n = 10000
n = 512
Y_train,C_train,Y_test,C_test = getCIFAR10(n,Pkg.dir("Meganet")*"/data/CIFAR10/");

# using PyPlot
Expand Down
46 changes: 23 additions & 23 deletions src/optimization/sgd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,35 @@ function solve(this::SGD{T},objFun::dnnObjFctn,xc::Array{T},Y::Array{T},C::Array

# Distribute the data
nw = nworkers()
Yd = distribute(Y, dist = (1, nw))
Cd = distribute(C, dist = (1, nw))
Ys = SharedArray(Y)
Cs = SharedArray(C)
println("Using $(nw) workers...")

while epoch <= this.maxEpochs
tic()
# Train on all workers
@sync for pid in Yd.pids
@async @fetchfrom pid train(this, objFun, xc, Yd, Cd, beta1, beta2)
#@sync for pid in Ys.pids
# @async @fetchfrom pid train(this, objFun, xc, Ys, Cs, beta1, beta2)
#end
for pid in Ys.pids
@fetchfrom pid train(this, objFun, xc, Ys, Cs, beta1, beta2)
end

# we sample 2^12 images from the training set for displaying the objective.
xc = Meganet.XC
nex = size(Y,2)
n_total = min(nex,2^10)
n_total = min(nex,60)
n_worker = div(n_total, nw)

JcA = Array{T,1}(nw)
hisFA = Array{Array{T,1},1}(nw)
Jval = zero(T)
pVal = Array{T,1}()
tmp = Array{Any}(0,0)
tmp = Array{Any}(0,0)

@sync begin
for (i, pid) in enumerate(Yd.pids)
@async JcA[i], hisFA[i] = @fetchfrom pid evalObjFctn_local(objFun, xc, Yd, Cd, n_worker)
for (i, pid) in enumerate(Ys.pids)
@async JcA[i], hisFA[i] = @fetchfrom pid evalObjFctn_local(objFun, xc, Ys, Cs, n_worker)
end

# Currently validation data is on master so this is local
Expand All @@ -96,42 +99,39 @@ end
Evaluate the objective function on `n` random examples from `Y`
"""
function evalObjFctn_local(objFun::dnnObjFctn, xc::Array{T,1}, Y::DArray{T,2}, C::DArray{T,2}, n::Int) where {T<:Number}
function evalObjFctn_local(objFun::dnnObjFctn, xc::Array{T,1}, Y::SharedArray{T,2}, C::SharedArray{T,2}, n::Int) where {T<:Number}

Y_local = localpart(Y)
C_local = localpart(C)

nex = size(Y_local,2)
nex = size(Y,2)
ids = randperm(nex)
idt = ids[1:min(n, nex)]
tmp = Array{Any}(0,0)

Jc, hisF, dJ = evalObjFctn(objFun,xc,Y_local[:,idt], C_local[:,idt], tmp, false);
Jc, hisF, dJ = evalObjFctn(objFun,xc,Y[:,idt], C[:,idt], tmp, false);

return Jc, hisF, dJ
end

"""
Train on the local part of the distributed data in Y
"""
function train(this::SGD{T}, objFun::dnnObjFctn, xc::Array{T,1}, Y::DArray{T,2}, C::DArray{T,2}, beta1::T, beta2::T) where {T<:Number}
function train(this::SGD{T}, objFun::dnnObjFctn, xc::Array{T,1}, Y::SharedArray{T,2}, C::SharedArray{T,2}, beta1::T, beta2::T) where {T<:Number}
# TODO send the worker SGD and objFun onl once

Y_local = localpart(Y)
C_local = localpart(C)

nex = size(Y_local,2)
nex = size(Y,2)
nworkers = length(Y.pids)
ids = randperm(nex)
lr = this.learningRate
lr = this.learningRate*nworkers
dJ = zeros(T,size(xc))
tmp = Array{Any}(0,0)

for k=1:ceil(Int64,nex/this.miniBatch)
batchsize = div(this.miniBatch, nworkers)

for k=1:ceil(Int64,nex/batchsize)
idk = ids[(k-1)*this.miniBatch+1: min(k*this.miniBatch,nex)]
if this.nesterov && !this.ADAM
Jk,dummy,dJk = evalObjFctn(objFun, xc-this.momentum*dJ, Y_local[:,idk], C_local[:,idk], tmp);
Jk,dummy,dJk = evalObjFctn(objFun, xc-this.momentum*dJ, Y[:,idk], C[:,idk], tmp);
else
Jk,dummy,dJk = evalObjFctn(objFun, xc, Y_local[:,idk], C_local[:,idk], tmp);
Jk,dummy,dJk = evalObjFctn(objFun, xc, Y[:,idk], C[:,idk], tmp);
end

if this.ADAM
Expand Down

0 comments on commit 96b112c

Please sign in to comment.