Skip to content

Commit

Permalink
add tmp in evalObjFctn
Browse files Browse the repository at this point in the history
  • Loading branch information
lruthotto committed Feb 20, 2018
1 parent fd27123 commit 2eda86a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
Binary file modified examples/.EResNN_CIFAR10.jl.swp
Binary file not shown.
11 changes: 7 additions & 4 deletions src/optimization/sgd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,15 @@ function solve(this::SGD{T},objFun::dnnObjFctn,xc::Array{T},Y::Array{T},C::Array
hisFA = Array{Array{T,1},1}(nw)
Jval = zero(T)
pVal = Array{T,1}()
tmp = Array{Any}(0,0)

@sync begin
for (i, pid) in enumerate(Yd.pids)
@async JcA[i], hisFA[i] = @fetchfrom pid evalObjFctn_local(objFun, xc, Yd, Cd, n_worker)
end

# Currently validation data is on master so this is local
@async Jval,pVal = @fetchfrom 1 getMisfit(objFun,xc,Yv,Cv,false);
@async Jval,pVal = @fetchfrom 1 getMisfit(objFun,xc,Yv,Cv,tmp, false);
end

Jc = sum(JcA)
Expand Down Expand Up @@ -103,8 +104,9 @@ function evalObjFctn_local(objFun::dnnObjFctn, xc::Array{T,1}, Y::DArray{T,2}, C
nex = size(Y_local,2)
ids = randperm(nex)
idt = ids[1:min(n, nex)]
tmp = Array{Any}(0,0)

Jc, hisF, dJ = evalObjFctn(objFun,xc,Y_local[:,idt], C_local[:,idt], false);
Jc, hisF, dJ = evalObjFctn(objFun,xc,Y_local[:,idt], C_local[:,idt], tmp, false);

return Jc, hisF, dJ
end
Expand All @@ -122,13 +124,14 @@ function train(this::SGD{T}, objFun::dnnObjFctn, xc::Array{T,1}, Y::DArray{T,2},
ids = randperm(nex)
lr = this.learningRate
dJ = zeros(T,size(xc))
tmp = Array{Any}(0,0)

for k=1:ceil(Int64,nex/this.miniBatch)
idk = ids[(k-1)*this.miniBatch+1: min(k*this.miniBatch,nex)]
if this.nesterov && !this.ADAM
Jk,dummy,dJk = evalObjFctn(objFun, xc-this.momentum*dJ, Y_local[:,idk], C_local[:,idk]);
Jk,dummy,dJk = evalObjFctn(objFun, xc-this.momentum*dJ, Y_local[:,idk], C_local[:,idk], tmp);
else
Jk,dummy,dJk = evalObjFctn(objFun, xc, Y_local[:,idk], C_local[:,idk]);
Jk,dummy,dJk = evalObjFctn(objFun, xc, Y_local[:,idk], C_local[:,idk], tmp);
end

if this.ADAM
Expand Down

0 comments on commit 2eda86a

Please sign in to comment.