Skip to content

Commit

Permalink
MNIST unit tested
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Feb 3, 2016
1 parent 87d40d5 commit 0d5be3a
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 11 deletions.
13 changes: 10 additions & 3 deletions DataLoader.lua
Expand Up @@ -7,7 +7,7 @@ end

function DataLoader:sample(batchsize, inputs, targets)
self._indices = self._indices or torch.LongTensor()
self._indices:resize(batchsize):random(1,self:nSample())
self._indices:resize(batchsize):random(1,self:size())
return self:index(self._indices, inputs, targets)
end

Expand All @@ -25,15 +25,22 @@ function DataLoader:split(ratio)
error"Not Implemented"
end

-- number of samples
function DataLoader:size()
error"Not Implemented"
end

function DataLoader:inputSize()
-- size of inputs
function DataLoader:isize(excludedim)
-- by default, batch dimension is excluded
excludedim = excludedim == nil and 1 or excludedim
error"Not Implemented"
end

function DataLoader:targetSize()
-- size of targets
function DataLoader:tsize(excludedim)
-- by default, batch dimension is excluded
excludedim = excludedim == nil and 1 or excludedim
error"Not Implemented"
end

2 changes: 1 addition & 1 deletion MNIST.lua
Expand Up @@ -38,7 +38,7 @@ function dl.loadMNIST(datapath, validratio, scale, srcurl)
targets:add(1)

-- from bhwc to bchw
inputs:resize(inputs:size(1), 28, 28, 1)
inputs:resize(inputs:size(1), 1, 28, 28)

-- wrap into loader
local loader = dl.TensorLoader(inputs, targets)
Expand Down
9 changes: 5 additions & 4 deletions TensorLoader.lua
Expand Up @@ -16,8 +16,9 @@ end

function TensorLoader:shuffle()
local indices = torch.LongTensor():randperm(self:size())
self.inputs = self.inputs:index(1, indices)
self.targets = self.targets:index(1, indices)
self.inputs = torchx.recursiveIndex(nil, self.inputs, 1, indices)
self.targets = torchx.recursiveIndex(nil, self.targets, 1, indices)
return self, indices
end

function TensorLoader:split(ratio)
Expand All @@ -42,13 +43,13 @@ function TensorLoader:size()
return torchx.recursiveBatchSize(self.inputs)
end

function TensorLoader:inputSize(excludedim)
function TensorLoader:isize(excludedim)
-- by default, batch dimension is excluded
excludedim = excludedim == nil and 1 or excludedim
return torchx.recursiveSize(self.inputs, excludedim)
end

function TensorLoader:targetSize()
function TensorLoader:tsize(excludedim)
-- by default, batch dimension is excluded
excludedim = excludedim == nil and 1 or excludedim
return torchx.recursiveSize(self.targets, excludedim)
Expand Down
36 changes: 35 additions & 1 deletion test.lua
Expand Up @@ -6,11 +6,45 @@ local precision_backward = 1e-6
local nloop = 50
local mytester

--e.g. usage: th -e "dl = require "dataload"; dl.test()"
--e.g. usage: th -e "dl = require 'dataload'; dl.test()"

function dltest.loadMNIST()
-- this unit test also tests TensorLoader to some extent.
-- To test download, the data/mnist directory should be deleted
local train, valid, test = dl.loadMNIST()

-- test size and split
mytester:assert(train:size()+valid:size()+test:size() == 70000)
mytester:assert(torch.pointer(train.inputs:storage():data()) == torch.pointer(valid.inputs:storage():data()))

-- test sub (and index incidently)
local inputs, targets = train:sub(1,100)
mytester:assertTableEq(inputs:size():totable(), {100,1,28,28}, 0.000001)
mytester:assertTableEq(targets:size():totable(), {100}, 0.000001)
mytester:assert(targets:min() >= 1)
mytester:assert(targets:max() <= 10)

-- test sample (and index)
local inputs_, targets_ = inputs, targets
inputs, targets = train:sample(100, inputs, targets)
mytester:assert(torch.pointer(inputs:storage():data()) == torch.pointer(inputs_:storage():data()))
mytester:assert(torch.pointer(targets:storage():data()) == torch.pointer(targets_:storage():data()))
mytester:assertTableEq(inputs:size():totable(), {100,1,28,28}, 0.000001)
mytester:assertTableEq(targets:size():totable(), {100}, 0.000001)
mytester:assert(targets:min() >= 1)
mytester:assert(targets:max() <= 10)
mytester:assert(inputs:view(100,-1):sum(2):min() > 0)

-- test shuffle
local isum, tsum = train.inputs:sum(), train.targets:sum()
train:shuffle()
mytester:assert(math.abs(isum - train.inputs:sum()) < 0.0000001)
mytester:assert(math.abs(tsum - train.targets:sum()) < 0.0000001)

-- test inputSize and outputSize
local isize, tsize = train:isize(), train:tsize()
mytester:assertTableEq(isize, {1,28,28}, 0.0000001)
mytester:assert(#tsize == 0)
end


Expand Down
4 changes: 2 additions & 2 deletions utils.lua
Expand Up @@ -19,11 +19,11 @@ function dl.downloadfile(dstdir, srcurl, existfile)
dl.withcwd(
dstdir,
function()
local protocol, scpurl, filename = url:match('(.-)://(.*)/(.-)$')
local protocol, scpurl, filename = srcurl:match('(.-)://(.*)/(.-)$')
if protocol == 'scp' then
os.execute(string.format('%s %s %s', 'scp', scpurl .. '/' .. filename, filename))
else
os.execute('wget ' .. url)
os.execute('wget ' .. srcurl)
end
end
)
Expand Down

0 comments on commit 0d5be3a

Please sign in to comment.