Skip to content

Commit

Permalink
Add SpatialDropout, BatchNormalization, input centering and scaling
Browse files Browse the repository at this point in the history
Inputs are now normalized to have std deviation 1 and mean 0.
  • Loading branch information
andreaskoepf committed Oct 19, 2015
1 parent 0ee4feb commit f256ba6
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 9 deletions.
15 changes: 15 additions & 0 deletions BatchIterator.lua
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,21 @@ function BatchIterator:processImage(img, rois)
end
end

if cfg.normalization.centering then
for i = 1,3 do
img[i] = img[i]:add(-img[i]:mean())
end
end

if cfg.normalization.scaling then
for i = 1,3 do
local s = img[i]:std()
if s > 1e-8 then
img[i] = img[i]:div(s)
end
end
end

img[1] = self.normalization:forward(img[{{1}}]) -- normalize luminance channel img

return img, rois
Expand Down
2 changes: 1 addition & 1 deletion config/duplo.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
target_smaller_side = 450,
scales = { 32, 64, 128, 256 },
max_pixel_size = 1000,
normalization = { method = 'contrastive', width = 7 },
normalization = { method = 'contrastive', width = 7, centering = true, scaling = true },
augmentation = { vflip = 0.5, hflip = 0.5, random_scaling = 0.0, aspect_jitter = 0.0 },
color_space = 'yuv',
roi_pooling = { kw = 6, kh = 6 },
Expand Down
2 changes: 1 addition & 1 deletion config/imagenet.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
target_smaller_side = 480,
scales = { 48, 96, 192, 384 },
max_pixel_size = 1000,
normalization = { method = 'contrastive', width = 7 },
normalization = { method = 'contrastive', width = 7, centering = true, scaling = true },
augmentation = { vflip = 0, hflip = 0.25, random_scaling = 0, aspect_jitter = 0 },
color_space = 'yuv',
roi_pooling = { kw = 6, kh = 6 },
Expand Down
4 changes: 2 additions & 2 deletions main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,6 @@ function evaluation_demo(cfg, model_path, training_data_filename, network_filena

end

--graph_training(cfg, opt.model, opt.name, opt.train, opt.restore)
evaluation_demo(cfg, opt.model, opt.train, opt.restore)
graph_training(cfg, opt.model, opt.name, opt.train, opt.restore)
--evaluation_demo(cfg, opt.model, opt.train, opt.restore)

5 changes: 4 additions & 1 deletion models/model_utilities.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ function create_proposal_net(layers, anchor_nets)
container:add(nn.SpatialConvolution(nInputPlane, nOutputPlane, kW,kH, 1,1, padW,padH))
container:add(nn.PReLU())
if dropout and dropout > 0 then
container:add(nn.Dropout(dropout))
container:add(nn.SpatialDropout(dropout))
end
return container
end
Expand Down Expand Up @@ -80,6 +80,9 @@ function create_classification_net(inputs, class_count, class_layers)
local prev_input_count = inputs
for i,l in ipairs(class_layers) do
net:add(nn.Linear(prev_input_count, l.n))
if l.batch_norm then
net:add(nn.BatchNormalization(l.n))
end
net:add(nn.PReLU())
if l.dropout and l.dropout > 0 then
net:add(nn.Dropout(l.dropout))
Expand Down
4 changes: 2 additions & 2 deletions models/vgg_large.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ function vgg_large(cfg)
}

local class_layers = {
{ n=1024, dropout=0.5 },
{ n=1024, dropout=0.5 },
{ n=1024, dropout=0.5, batch_norm=true },
{ n=512, dropout=0.5 },
}

return create_model(cfg, layers, anchor_nets, class_layers)
Expand Down
4 changes: 2 additions & 2 deletions models/vgg_small.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ function vgg_small(cfg)
}

local class_layers = {
{ n=1024, dropout=0.5 },
{ n=1024, dropout=0.5 },
{ n=1024, dropout=0.5, batch_norm=true },
{ n=512, dropout=0.5 },
}

return create_model(cfg, layers, anchor_nets, class_layers)
Expand Down

0 comments on commit f256ba6

Please sign in to comment.