Skip to content

Commit

Permalink
Add an E2E way to train the encoder with a L2 loss
Browse files Browse the repository at this point in the history
This encoder uses an image reconstruction loss over a real input image and its reconstruction. I add this encoder to the "other" folder, as the obtained results have been very poor.
  • Loading branch information
Guim3 committed Dec 19, 2016
1 parent 96ed726 commit 5b1d556
Showing 1 changed file with 351 additions and 0 deletions.
351 changes: 351 additions & 0 deletions other/alternativeEncoders/trainEncoderE2E.lua
@@ -0,0 +1,351 @@
-- This file reads the dataset generated by generateEncoderDataset.lua and
-- trains an encoder net that learns to map an image X to a noise vector Z (encoder Z, type Z)
-- or an encoded that maps an image X to an attribute vector Y (encoder Y, type Y).

require 'image'
require 'nn'
require 'optim'
require 'cunn'
require 'cudnn'
torch.setdefaulttensortype('torch.FloatTensor')

local function getParameters()

opt = {}

-- Type of encoder must be passed as argument to decide what kind of
-- encoder will be trained (encoder Z [type Z] or encoder Y [type Y])
opt.type = os.getenv('type')
opt.type = 'Y'

assert(opt.type, "Parameter 'type' not specified. It is necessary to set the encoder type: 'Z' or 'Y'.\nExample: type=Z th trainEncoder.lua")
assert(string.upper(opt.type)=='Z' or string.upper(opt.type)=='Y',"Parameter 'type' must be 'Z' (encoder Z) or 'Y' (encoder Y).")

-- Load parameters from config file
if string.upper(opt.type)=='Z' then
assert(loadfile("cfg/mainConfig.lua"))(1)
else
assert(loadfile("cfg/mainConfig.lua"))(2)
end

-- one-line argument parser. Parses environment variables to override the defaults
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
print(opt)

if opt.display then display = require 'display' end

return opt
end

local function readDataset(path, imSize)
-- For CelebA: there's expected to find in path a file named images.dmp and imLabels.dmp
-- which contains the images X and attribute vectors Y.
-- images.dmp is obtained running data/preprocess_celebA.lua
-- imLabels.dmp is obtained running trainGAN.lua via data/donkey_celebA.lua
-- For MNIST: It will use the mnist luarocks package
local X, Y
if string.lower(path) == 'mnist' or string.lower(path) == 'mnist/' then
local mnist = require 'mnist'
local trainSet = mnist.traindataset()
X = torch.Tensor(trainSet.data:size(1), 1, imSize, imSize)
local resize = trainSet.data:size(2) ~= imSize

local labels = trainSet.label:int()
Y = torch.IntTensor(trainSet.data:size(1), 10):fill(-1)

for i = 1,trainSet.data:size(1) do
-- Read MNIST images
local im = trainSet.data[{{i}}]
if resize then
im = image.scale(im, imSize, imSize):float()
end
im:div(255):mul(2):add(-1) -- change [0, 255] to [-1, 1]
X[{{i}}]:copy(im)

--Read MNIST labels
local class = trainSet.label[i] -- Convert 0-9 to one-hot vector
Y[{{i},{class+1}}] = 1
end

else
print('Loading images X from '..path..'images.dmp')
local data = torch.load(path..'images.dmp')
print(('Done. Loaded %.2f GB (%d images).'):format((4*data:size(1)*data:size(2)*data:size(3)*data:size(4))/2^30, data:size(1)))

if data:size(3) ~= imSize then
-- Resize images
X = torch.Tensor(data:size(1), imSize, imSize)
for i = 1,data:size(1) do
local im = image.scale(data[{{i}}], imSize, imSize)
im:mul(2):add(-1) -- change [0, 1] -to[-1, 1]
X[{{i}}]:copy(im)
end
else
X = data
X:mul(2):add(-1) -- change [0, 1] to [-1, 1]
end

print('Loading attributes Y from '..path..'imLabels.dmp')
Y = torch.load(path..'imLabels.dmp')
print(('Done. Loaded %d attributes'):format(Y:size(1)))
end

return X, Y
end

local function splitTrainTest(x, y, split)
local xTrain, yTrain, xTest, yTest

local nSamples = x:size(1)
local splitInd = torch.floor(split*nSamples)

xTrain = x[{{1,splitInd}}]
yTrain = y[{{1,splitInd}}]

xTest = x[{{splitInd+1,nSamples}}]
yTest = y[{{splitInd+1,nSamples}}]

return xTrain, yTrain, xTest, yTest
end

local function getEncoder(inputSize, nFiltersBase, nConvLayers, FCsz)
-- Encoder architecture based on Autoencoding beyond pixels using a learned similarity metric (VAE/GAN hybrid)

local encoder = nn.Sequential()
-- Assuming nFiltersBase = 64, nConvLayers = 3
-- 1st Conv layer: 5×5 64 conv. ↓, BNorm, ReLU
-- Data: 32x32 -> 16x16
encoder:add(nn.SpatialConvolution(inputSize[1], nFiltersBase, 5, 5, 2, 2, 2, 2))
encoder:add(nn.SpatialBatchNormalization(nFiltersBase))
encoder:add(nn.ReLU(true))

-- 2nd Conv layer: 5×5 128 conv. ↓, BNorm, ReLU
-- Data: 16x16 -> 8x8
-- 3rd Conv layer: 5×5 256 conv. ↓, BNorm, ReLU
-- Data: 8x8 -> 4x4
local nFilters = nFiltersBase
for j=2,nConvLayers do
encoder:add(nn.SpatialConvolution(nFilters, nFilters*2, 5, 5, 2, 2, 2, 2))
encoder:add(nn.SpatialBatchNormalization(nFilters*2))
encoder:add(nn.ReLU(true))
nFilters = nFilters * 2
end

-- 4th FC layer: 2048 fully-connected
-- Data: 4x4 -> 16
encoder:add(nn.View(-1):setNumInputDims(3)) -- reshape data to 2d tensor (samples x the rest)
-- Assuming squared images and conv layers configuration (kernel, stride and padding) is not changed:
--nFilterFC = (imageSize/2^nConvLayers)²*nFiltersLastConvNet
local inputFilterFC = (inputSize[2]/2^nConvLayers)^2*nFilters

if FCsz == nil then FCsz = inputFilterFC end

encoder:add(nn.Linear(inputFilterFC, FCsz))
encoder:add(nn.BatchNormalization(FCsz))
encoder:add(nn.ReLU(true))

encoder:add(nn.Linear(FCsz, 100)) -- 100 is the size of the Z vector

local criterion = nn.MSECriterion()

return encoder, criterion
end

local function assignBatches(batchX, batchY, x, y, batch, batchSize, shuffle)

data_tm:reset(); data_tm:resume()

batchX:copy(x:index(1, shuffle[{{batch,batch+batchSize-1}}]:long()))
batchY:copy(y:index(1, shuffle[{{batch,batch+batchSize-1}}]:long()))

data_tm:stop()

return batchX, batchY
end

local function displayConfig(disp, title)
-- initialize error display configuration
local errorData, errorDispConfig
if disp then
errorData = {}
errorDispConfig =
{
title = 'Encoder error - ' .. title,
win = 1,
labels = {'Epoch', 'Train error', 'Test error'},
ylabel = "Error",
legend='always'
}
end
return errorData, errorDispConfig
end

function main()

local opt = getParameters()
print(opt)

-- Set timers
local epoch_tm = torch.Timer()
local tm = torch.Timer()
data_tm = torch.Timer()

-- Read images X and labels Y from dataset
local X, Y = readDataset(opt.datasetPath, opt.loadSize)
-- X: #samples x im3 x im2 x im1
-- Y: #samples x ny

-- Split train and test
local xTrain, yTrain, xTest, yTest
xTrain, yTrain, xTest, yTest = splitTrainTest(X, Y, opt.split)

-- Set network architecture
local encoder, criterion = getEncoder(xTrain[1]:size(), opt.nf, opt.nConvLayers, opt.FCsz)

-- Initialize batches
local batchX = torch.Tensor(opt.batchSize, xTrain:size(2), xTrain:size(3), xTrain:size(4))
local batchY = torch.Tensor(opt.batchSize, yTrain:size(2))

-- Load generator
local generator = torch.load('checkpoints/c_celebA_64_filt_Yconv1_noTest_wrongYFixed_24_net_G.t7')

-- Copy variables to GPU
cutorch.setDevice(opt.gpu)
batchX = batchX:cuda(); batchY = batchY:cuda()

if pcall(require, 'cudnn') then
cudnn.benchmark = true
cudnn.convert(encoder, cudnn)
end
generator:cuda()
encoder:cuda()
criterion:cuda()

generator:evaluate()

-- Remove optimization to enable backward step, in case the network is optimized
--local optnet = require 'optnet'
--optnet.removeOptimization(generator)

local params, gradParams = encoder:getParameters() -- This has to be performed always after the cuda call

-- Define optim (general optimizer)
local errorTrain
local errorTest
local function optimFunction(params) -- This function needs to be declared here to avoid using global variables.
-- reset gradients (gradients are always accumulated, to accommodat batch methods)
gradParams:zero()

-- Encode images
local outputs = encoder:forward(batchX)
outputs:resize(outputs:size(1), outputs:size(2), 1, 1) -- Adapt dimensionality to generator input (batchSz x nz x 1 x 1)

-- Reconstruct encoded images
local reconstr = generator:forward{outputs, batchY}

-- Compute error (MSE pixel-wise image reconstruction)
errorTrain = criterion:forward(batchX, reconstr)

-- Compute error gradients w.r.t. generator output
local df_do = criterion:backward(batchX, reconstr)

-- Compute error gradients w.r.t. generator input / encoder output
local df_dg = generator:updateGradInput({outputs, batchY}, df_do)
-- df_dg[1] contains the gradients of the latent space Z (encoder output, generator Z input)
-- which are the ones you need to compute the backward step.
-- df_dg[2] contains the gradients of Y. You don't need them as they are not the encoder output.

df_dg[1]:resize(df_dg[1]:size(1), df_dg[1]:size(2)) -- Adapt dimensionality to encoder output (batchSz x nz)
outputs:resize(outputs:size(1), outputs:size(2))

-- Backpropagate generator gradients to encoder
encoder:backward(batchX, df_dg[1])

return errorTrain, gradParams
end

local optimState = {
learningRate = opt.lr,
beta1 = opt.beta1,
}

local nTrainSamples = xTrain:size(1)
local nTestSamples = xTest:size(1)

-- Initialize display configuration (if enabled)
local errorData, errorDispConfig = displayConfig(opt.display, opt.name)
paths.mkdir(opt.outputPath)

local dispBatchX, dispBatchY
if opt.display == 2 then
dispBatchX = torch.Tensor(opt.batchSize, xTrain:size(2), xTrain:size(3), xTrain:size(4)):cuda()
dispBatchY = torch.Tensor(opt.batchSize, yTrain:size(2)):cuda()
dispBatchX:copy(xTest[{{1,opt.batchSize}}])
dispBatchY:copy(yTest[{{1,opt.batchSize}}])
display.image(image.toDisplayTensor(dispBatchX,0,torch.round(math.sqrt(opt.batchSize))), {win=3, title='Real test images'})
end

-- Train network
local batchIterations = 0 -- for display purposes only
for epoch = 1, opt.nEpochs do
epoch_tm:reset()
local shuffle = torch.randperm(nTrainSamples)
for batch = 1, nTrainSamples-opt.batchSize+1, opt.batchSize do
tm:reset()

batchX, batchY = assignBatches(batchX, batchY, xTrain, yTrain, batch, opt.batchSize, shuffle)

-- Update network
optim.adam(optimFunction, params, optimState)

-- Display train and test error
if opt.display and batchIterations % 20 == 0 then
-- Test error
batchX, batchY = assignBatches(batchX, batchY, xTest, yTest, torch.random(1,nTestSamples-opt.batchSize+1), opt.batchSize, torch.randperm(nTestSamples))

local outputs = encoder:forward(batchX)

outputs:resize(outputs:size(1), outputs:size(2), 1, 1)
local reconstr = generator:forward{outputs, batchY}

errorTest = criterion:forward(batchX, reconstr)

table.insert(errorData,
{
batchIterations/math.ceil(nTrainSamples / opt.batchSize), -- x-axis
errorTrain, -- y-axis for label1
errorTest -- y-axis for label2
})
display.plot(errorData, errorDispConfig)

if opt.display == 2 then
local outputs = encoder:forward(dispBatchX)
outputs:resize(outputs:size(1), outputs:size(2), 1, 1)
local reconstr = generator:forward{outputs, dispBatchY}
display.image(image.toDisplayTensor(reconstr,0,torch.round(math.sqrt(opt.batchSize))), {win=4, title='Reconstructions'})
end
end

-- Verbose
if ((batch-1) / opt.batchSize) % 1 == 0 then
print(('Epoch: [%d][%4d / %4d] Error (train): %.4f Error (test): %.4f '
.. ' Time: %.3f s Data time: %.3f s'):format(
epoch, ((batch-1) / opt.batchSize),
math.ceil(nTrainSamples / opt.batchSize),
errorTrain and errorTrain or -1,
errorTest and errorTest or -1,
tm:time().real, data_tm:time().real))
end
batchIterations = batchIterations + 1
end
print(('End of epoch %d / %d \t Time Taken: %.3f s'):format(
epoch, opt.nEpochs, epoch_tm:time().real))

-- Store network
torch.save(opt.outputPath .. opt.name .. '_' .. epoch .. 'epochs.t7', encoder:clearState())
torch.save('checkpoints/' .. opt.name .. '_error.t7', errorData)
end

end

main()

0 comments on commit 5b1d556

Please sign in to comment.