Permalink
Browse files

More comments and better arch for CIFAR

  • Loading branch information...
1 parent 65a6dcd commit 056fc229909cab7f14039c380ba120c7918c33dd @clementfarabet clementfarabet committed Apr 19, 2012
Showing with 31 additions and 18 deletions.
  1. +31 −18 train-on-cifar/train-on-cifar.lua
@@ -17,6 +17,7 @@
require 'torch'
require 'nn'
+require 'nnx'
require 'optim'
require 'image'
require 'mattorch'
@@ -37,7 +38,7 @@ cmd:option('-full', false, 'use full dataset (50,000 samples)')
cmd:option('-visualize', false, 'visualize input data and weights during training')
cmd:option('-seed', 1, 'fixed input seed for repeatable experiments')
cmd:option('-optimization', 'SGD', 'optimization method: SGD | ASGD | CG | LBFGS')
-cmd:option('-learningRate', 1e-2, 'learning rate at t=0')
+cmd:option('-learningRate', 1e-3, 'learning rate at t=0')
cmd:option('-batchSize', 1, 'mini-batch size (1 = pure stochastic)')
cmd:option('-weightDecay', 0, 'weight decay (SGD only)')
cmd:option('-momentum', 0, 'momentum (SGD only)')
@@ -62,16 +63,8 @@ end
-- define model to train
-- on the 10-class classification problem
--
-classes = {'airplane',
- 'automobile',
- 'bird',
- 'cat',
- 'deer',
- 'dog',
- 'frog',
- 'horse',
- 'ship',
- 'truck'}
+classes = {'airplane', 'automobile', 'bird', 'cat',
+ 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'}
if opt.network == '' then
-- define model to train
@@ -80,17 +73,37 @@ if opt.network == '' then
if opt.model == 'convnet' then
------------------------------------------------------------
-- convolutional network
+ -- this is a typical convolutional network for vision:
+ -- 1/ the image is transformed into Y-UV space
+ -- 2/ the Y (luminance) channel is locally normalized
+ -- 3/ the first layer allocates for filters to the Y
+ -- channels, and just a few to the U and V channels
+ -- 4/ the first two stages features are locally pooled
+ -- using a max-operator
+ -- 5/ a two-layer neural network is applied on the
+ -- representation
------------------------------------------------------------
- -- reshape
+ -- reshape vector into a 3-channel image (RGB)
model:add(nn.Reshape(3,32,32))
- -- stage 1 : mean suppresion -> filter bank -> squashing -> max pooling
- model:add(nn.SpatialSubtractiveNormalization(3, image.gaussian1D(15)))
- model:add(nn.SpatialConvolutionMap(nn.tables.random(3, 8, 1), 5, 5))
+ -- stage 0 : RGB -> YUV -> normalize(Y)
+ model:add(nn.SpatialColorTransform('rgb2yuv'))
+ do
+ ynormer = nn.Sequential()
+ ynormer:add(nn.Narrow(1,1,1))
+ ynormer:add(nn.SpatialContrastiveNormalization(1, image.gaussian1D(7)))
+ normer = nn.ConcatTable()
+ normer:add(ynormer)
+ normer:add(nn.Narrow(1,2,2))
+ end
+ model:add(normer)
+ model:add(nn.JoinTable(1))
+ -- stage 1 : mean+std normalization -> filter bank -> squashing -> max pooling
+ local table = torch.Tensor{ {1,1},{1,2},{1,3},{1,4},{1,5},{1,6},{1,7},{1,8},{2,9},{2,10},{3,11},{3,12} }
+ model:add(nn.SpatialConvolutionMap(table, 5, 5))
model:add(nn.Tanh())
model:add(nn.SpatialMaxPooling(2, 2, 2, 2))
- -- stage 2 : mean suppresion -> filter bank -> squashing -> max pooling
- model:add(nn.SpatialSubtractiveNormalization(8, image.gaussian1D(15)))
- model:add(nn.SpatialConvolutionMap(nn.tables.random(8, 32, 1), 5, 5))
+ -- stage 2 : filter bank -> squashing -> max pooling
+ model:add(nn.SpatialConvolutionMap(nn.tables.random(12, 32, 4), 5, 5))
model:add(nn.Tanh())
model:add(nn.SpatialMaxPooling(2, 2, 2, 2))
-- stage 3 : standard 2-layer neural network

0 comments on commit 056fc22

Please sign in to comment.