Skip to content

Commit

Permalink
option for fixed/random validation + fix to accuracy during training
Browse files Browse the repository at this point in the history
  • Loading branch information
anewell committed Jan 4, 2017
1 parent c25da4b commit f61d3c6
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/opts.lua
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ local function parse(arg)
cmd:option('-trainBatch', 6, 'Mini-batch size')
cmd:option('-validIters', 1000, 'Number of validation iterations per epoch')
cmd:option('-validBatch', 1, 'Mini-batch size for validation')
cmd:option('-nValidImgs', 1000, 'Number of images to use for validation')
cmd:option('-nValidImgs', 1000, 'Number of images to use for validation. Only relevant if randomValid is set to true')
cmd:option('-randomValid', false, 'Whether or not to use a fixed validation set of 2958 images (same as Tompson et al. 2015)')
cmd:text()
cmd:text(' ---------- Data options ---------------------------------------')
cmd:text()
Expand Down
4 changes: 4 additions & 0 deletions src/ref.lua
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,9 @@ if not ref.alreadyChecked then

printDims("Input is a ", ref.inputDim)
printDims("Output is a ", ref.outputDim)

print("# of training images:", opt.idxRef.train:size(1))
print("# of validation images:", opt.idxRef.valid:size(1))

ref.alreadyChecked = true
end
33 changes: 29 additions & 4 deletions src/util/dataset/mpii.lua
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,35 @@ function Dataset:__init()
opt.idxRef.test = allIdxs[annot.istrain:eq(0)]
opt.idxRef.train = allIdxs[annot.istrain:eq(1)]

-- Set up training/validation split
local perm = torch.randperm(opt.idxRef.train:size(1)):long()
opt.idxRef.valid = opt.idxRef.train:index(1, perm:sub(1,opt.nValidImgs))
opt.idxRef.train = opt.idxRef.train:index(1, perm:sub(opt.nValidImgs+1,-1))
if not opt.randomValid then
-- Use same validation set as used in our paper (and same as Tompson et al)
tmpAnnot = annot.index:cat(annot.person, 2):long()
tmpAnnot:add(-1)

local validAnnot = hdf5.open(paths.concat(projectDir, 'data/mpii/annot/valid.h5'),'r')
local tmpValid = validAnnot:read('index'):all():cat(validAnnot:read('person'):all(),2):long()
opt.idxRef.valid = torch.zeros(tmpValid:size(1))
opt.nValidImgs = opt.idxRef.valid:size(1)
opt.idxRef.train = torch.zeros(opt.idxRef.train:size(1) - opt.nValidImgs)

-- Loop through to get proper index values
local validCount = 1
local trainCount = 1
for i = 1,annot.index:size(1) do
if validCount <= tmpValid:size(1) and tmpAnnot[i]:equal(tmpValid[validCount]) then
opt.idxRef.valid[validCount] = i
validCount = validCount + 1
elseif annot.istrain[i] == 1 then
opt.idxRef.train[trainCount] = i
trainCount = trainCount + 1
end
end
else
-- Set up random training/validation split
local perm = torch.randperm(opt.idxRef.train:size(1)):long()
opt.idxRef.valid = opt.idxRef.train:index(1, perm:sub(1,opt.nValidImgs))
opt.idxRef.train = opt.idxRef.train:index(1, perm:sub(opt.nValidImgs+1,-1))
end

torch.save(opt.save .. '/options.t7', opt)
end
Expand Down
2 changes: 2 additions & 0 deletions src/util/eval.lua
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ function getPreds(hm)
local preds = torch.repeatTensor(idx, 1, 1, 2):float()
preds[{{}, {}, 1}]:apply(function(x) return (x - 1) % hm:size(4) + 1 end)
preds[{{}, {}, 2}]:add(-1):div(hm:size(3)):floor():add(1)
local predMask = max:gt(0):repeatTensor(1, 1, 2):float()
preds:add(-1):cmul(predMask):add(1)
return preds
end

Expand Down

0 comments on commit f61d3c6

Please sign in to comment.