forked from soumith/dcgan.torch
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an E2E way to train the encoder with a L2 loss
This encoder uses an image reconstruction loss over a real input image and its reconstruction. I add this encoder to the "other" folder, as the obtained results have been very poor.
- Loading branch information
Showing
1 changed file
with
351 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,351 @@ | ||
-- This file reads the dataset generated by generateEncoderDataset.lua and | ||
-- trains an encoder net that learns to map an image X to a noise vector Z (encoder Z, type Z) | ||
-- or an encoded that maps an image X to an attribute vector Y (encoder Y, type Y). | ||
|
||
require 'image' | ||
require 'nn' | ||
require 'optim' | ||
require 'cunn' | ||
require 'cudnn' | ||
torch.setdefaulttensortype('torch.FloatTensor') | ||
|
||
local function getParameters() | ||
|
||
opt = {} | ||
|
||
-- Type of encoder must be passed as argument to decide what kind of | ||
-- encoder will be trained (encoder Z [type Z] or encoder Y [type Y]) | ||
opt.type = os.getenv('type') | ||
opt.type = 'Y' | ||
|
||
assert(opt.type, "Parameter 'type' not specified. It is necessary to set the encoder type: 'Z' or 'Y'.\nExample: type=Z th trainEncoder.lua") | ||
assert(string.upper(opt.type)=='Z' or string.upper(opt.type)=='Y',"Parameter 'type' must be 'Z' (encoder Z) or 'Y' (encoder Y).") | ||
|
||
-- Load parameters from config file | ||
if string.upper(opt.type)=='Z' then | ||
assert(loadfile("cfg/mainConfig.lua"))(1) | ||
else | ||
assert(loadfile("cfg/mainConfig.lua"))(2) | ||
end | ||
|
||
-- one-line argument parser. Parses environment variables to override the defaults | ||
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end | ||
print(opt) | ||
|
||
if opt.display then display = require 'display' end | ||
|
||
return opt | ||
end | ||
|
||
local function readDataset(path, imSize) | ||
-- For CelebA: there's expected to find in path a file named images.dmp and imLabels.dmp | ||
-- which contains the images X and attribute vectors Y. | ||
-- images.dmp is obtained running data/preprocess_celebA.lua | ||
-- imLabels.dmp is obtained running trainGAN.lua via data/donkey_celebA.lua | ||
-- For MNIST: It will use the mnist luarocks package | ||
local X, Y | ||
if string.lower(path) == 'mnist' or string.lower(path) == 'mnist/' then | ||
local mnist = require 'mnist' | ||
local trainSet = mnist.traindataset() | ||
X = torch.Tensor(trainSet.data:size(1), 1, imSize, imSize) | ||
local resize = trainSet.data:size(2) ~= imSize | ||
|
||
local labels = trainSet.label:int() | ||
Y = torch.IntTensor(trainSet.data:size(1), 10):fill(-1) | ||
|
||
for i = 1,trainSet.data:size(1) do | ||
-- Read MNIST images | ||
local im = trainSet.data[{{i}}] | ||
if resize then | ||
im = image.scale(im, imSize, imSize):float() | ||
end | ||
im:div(255):mul(2):add(-1) -- change [0, 255] to [-1, 1] | ||
X[{{i}}]:copy(im) | ||
|
||
--Read MNIST labels | ||
local class = trainSet.label[i] -- Convert 0-9 to one-hot vector | ||
Y[{{i},{class+1}}] = 1 | ||
end | ||
|
||
else | ||
print('Loading images X from '..path..'images.dmp') | ||
local data = torch.load(path..'images.dmp') | ||
print(('Done. Loaded %.2f GB (%d images).'):format((4*data:size(1)*data:size(2)*data:size(3)*data:size(4))/2^30, data:size(1))) | ||
|
||
if data:size(3) ~= imSize then | ||
-- Resize images | ||
X = torch.Tensor(data:size(1), imSize, imSize) | ||
for i = 1,data:size(1) do | ||
local im = image.scale(data[{{i}}], imSize, imSize) | ||
im:mul(2):add(-1) -- change [0, 1] -to[-1, 1] | ||
X[{{i}}]:copy(im) | ||
end | ||
else | ||
X = data | ||
X:mul(2):add(-1) -- change [0, 1] to [-1, 1] | ||
end | ||
|
||
print('Loading attributes Y from '..path..'imLabels.dmp') | ||
Y = torch.load(path..'imLabels.dmp') | ||
print(('Done. Loaded %d attributes'):format(Y:size(1))) | ||
end | ||
|
||
return X, Y | ||
end | ||
|
||
local function splitTrainTest(x, y, split) | ||
local xTrain, yTrain, xTest, yTest | ||
|
||
local nSamples = x:size(1) | ||
local splitInd = torch.floor(split*nSamples) | ||
|
||
xTrain = x[{{1,splitInd}}] | ||
yTrain = y[{{1,splitInd}}] | ||
|
||
xTest = x[{{splitInd+1,nSamples}}] | ||
yTest = y[{{splitInd+1,nSamples}}] | ||
|
||
return xTrain, yTrain, xTest, yTest | ||
end | ||
|
||
local function getEncoder(inputSize, nFiltersBase, nConvLayers, FCsz) | ||
-- Encoder architecture based on Autoencoding beyond pixels using a learned similarity metric (VAE/GAN hybrid) | ||
|
||
local encoder = nn.Sequential() | ||
-- Assuming nFiltersBase = 64, nConvLayers = 3 | ||
-- 1st Conv layer: 5×5 64 conv. ↓, BNorm, ReLU | ||
-- Data: 32x32 -> 16x16 | ||
encoder:add(nn.SpatialConvolution(inputSize[1], nFiltersBase, 5, 5, 2, 2, 2, 2)) | ||
encoder:add(nn.SpatialBatchNormalization(nFiltersBase)) | ||
encoder:add(nn.ReLU(true)) | ||
|
||
-- 2nd Conv layer: 5×5 128 conv. ↓, BNorm, ReLU | ||
-- Data: 16x16 -> 8x8 | ||
-- 3rd Conv layer: 5×5 256 conv. ↓, BNorm, ReLU | ||
-- Data: 8x8 -> 4x4 | ||
local nFilters = nFiltersBase | ||
for j=2,nConvLayers do | ||
encoder:add(nn.SpatialConvolution(nFilters, nFilters*2, 5, 5, 2, 2, 2, 2)) | ||
encoder:add(nn.SpatialBatchNormalization(nFilters*2)) | ||
encoder:add(nn.ReLU(true)) | ||
nFilters = nFilters * 2 | ||
end | ||
|
||
-- 4th FC layer: 2048 fully-connected | ||
-- Data: 4x4 -> 16 | ||
encoder:add(nn.View(-1):setNumInputDims(3)) -- reshape data to 2d tensor (samples x the rest) | ||
-- Assuming squared images and conv layers configuration (kernel, stride and padding) is not changed: | ||
--nFilterFC = (imageSize/2^nConvLayers)²*nFiltersLastConvNet | ||
local inputFilterFC = (inputSize[2]/2^nConvLayers)^2*nFilters | ||
|
||
if FCsz == nil then FCsz = inputFilterFC end | ||
|
||
encoder:add(nn.Linear(inputFilterFC, FCsz)) | ||
encoder:add(nn.BatchNormalization(FCsz)) | ||
encoder:add(nn.ReLU(true)) | ||
|
||
encoder:add(nn.Linear(FCsz, 100)) -- 100 is the size of the Z vector | ||
|
||
local criterion = nn.MSECriterion() | ||
|
||
return encoder, criterion | ||
end | ||
|
||
local function assignBatches(batchX, batchY, x, y, batch, batchSize, shuffle) | ||
|
||
data_tm:reset(); data_tm:resume() | ||
|
||
batchX:copy(x:index(1, shuffle[{{batch,batch+batchSize-1}}]:long())) | ||
batchY:copy(y:index(1, shuffle[{{batch,batch+batchSize-1}}]:long())) | ||
|
||
data_tm:stop() | ||
|
||
return batchX, batchY | ||
end | ||
|
||
local function displayConfig(disp, title) | ||
-- initialize error display configuration | ||
local errorData, errorDispConfig | ||
if disp then | ||
errorData = {} | ||
errorDispConfig = | ||
{ | ||
title = 'Encoder error - ' .. title, | ||
win = 1, | ||
labels = {'Epoch', 'Train error', 'Test error'}, | ||
ylabel = "Error", | ||
legend='always' | ||
} | ||
end | ||
return errorData, errorDispConfig | ||
end | ||
|
||
function main() | ||
|
||
local opt = getParameters() | ||
print(opt) | ||
|
||
-- Set timers | ||
local epoch_tm = torch.Timer() | ||
local tm = torch.Timer() | ||
data_tm = torch.Timer() | ||
|
||
-- Read images X and labels Y from dataset | ||
local X, Y = readDataset(opt.datasetPath, opt.loadSize) | ||
-- X: #samples x im3 x im2 x im1 | ||
-- Y: #samples x ny | ||
|
||
-- Split train and test | ||
local xTrain, yTrain, xTest, yTest | ||
xTrain, yTrain, xTest, yTest = splitTrainTest(X, Y, opt.split) | ||
|
||
-- Set network architecture | ||
local encoder, criterion = getEncoder(xTrain[1]:size(), opt.nf, opt.nConvLayers, opt.FCsz) | ||
|
||
-- Initialize batches | ||
local batchX = torch.Tensor(opt.batchSize, xTrain:size(2), xTrain:size(3), xTrain:size(4)) | ||
local batchY = torch.Tensor(opt.batchSize, yTrain:size(2)) | ||
|
||
-- Load generator | ||
local generator = torch.load('checkpoints/c_celebA_64_filt_Yconv1_noTest_wrongYFixed_24_net_G.t7') | ||
|
||
-- Copy variables to GPU | ||
cutorch.setDevice(opt.gpu) | ||
batchX = batchX:cuda(); batchY = batchY:cuda() | ||
|
||
if pcall(require, 'cudnn') then | ||
cudnn.benchmark = true | ||
cudnn.convert(encoder, cudnn) | ||
end | ||
generator:cuda() | ||
encoder:cuda() | ||
criterion:cuda() | ||
|
||
generator:evaluate() | ||
|
||
-- Remove optimization to enable backward step, in case the network is optimized | ||
--local optnet = require 'optnet' | ||
--optnet.removeOptimization(generator) | ||
|
||
local params, gradParams = encoder:getParameters() -- This has to be performed always after the cuda call | ||
|
||
-- Define optim (general optimizer) | ||
local errorTrain | ||
local errorTest | ||
local function optimFunction(params) -- This function needs to be declared here to avoid using global variables. | ||
-- reset gradients (gradients are always accumulated, to accommodat batch methods) | ||
gradParams:zero() | ||
|
||
-- Encode images | ||
local outputs = encoder:forward(batchX) | ||
outputs:resize(outputs:size(1), outputs:size(2), 1, 1) -- Adapt dimensionality to generator input (batchSz x nz x 1 x 1) | ||
|
||
-- Reconstruct encoded images | ||
local reconstr = generator:forward{outputs, batchY} | ||
|
||
-- Compute error (MSE pixel-wise image reconstruction) | ||
errorTrain = criterion:forward(batchX, reconstr) | ||
|
||
-- Compute error gradients w.r.t. generator output | ||
local df_do = criterion:backward(batchX, reconstr) | ||
|
||
-- Compute error gradients w.r.t. generator input / encoder output | ||
local df_dg = generator:updateGradInput({outputs, batchY}, df_do) | ||
-- df_dg[1] contains the gradients of the latent space Z (encoder output, generator Z input) | ||
-- which are the ones you need to compute the backward step. | ||
-- df_dg[2] contains the gradients of Y. You don't need them as they are not the encoder output. | ||
|
||
df_dg[1]:resize(df_dg[1]:size(1), df_dg[1]:size(2)) -- Adapt dimensionality to encoder output (batchSz x nz) | ||
outputs:resize(outputs:size(1), outputs:size(2)) | ||
|
||
-- Backpropagate generator gradients to encoder | ||
encoder:backward(batchX, df_dg[1]) | ||
|
||
return errorTrain, gradParams | ||
end | ||
|
||
local optimState = { | ||
learningRate = opt.lr, | ||
beta1 = opt.beta1, | ||
} | ||
|
||
local nTrainSamples = xTrain:size(1) | ||
local nTestSamples = xTest:size(1) | ||
|
||
-- Initialize display configuration (if enabled) | ||
local errorData, errorDispConfig = displayConfig(opt.display, opt.name) | ||
paths.mkdir(opt.outputPath) | ||
|
||
local dispBatchX, dispBatchY | ||
if opt.display == 2 then | ||
dispBatchX = torch.Tensor(opt.batchSize, xTrain:size(2), xTrain:size(3), xTrain:size(4)):cuda() | ||
dispBatchY = torch.Tensor(opt.batchSize, yTrain:size(2)):cuda() | ||
dispBatchX:copy(xTest[{{1,opt.batchSize}}]) | ||
dispBatchY:copy(yTest[{{1,opt.batchSize}}]) | ||
display.image(image.toDisplayTensor(dispBatchX,0,torch.round(math.sqrt(opt.batchSize))), {win=3, title='Real test images'}) | ||
end | ||
|
||
-- Train network | ||
local batchIterations = 0 -- for display purposes only | ||
for epoch = 1, opt.nEpochs do | ||
epoch_tm:reset() | ||
local shuffle = torch.randperm(nTrainSamples) | ||
for batch = 1, nTrainSamples-opt.batchSize+1, opt.batchSize do | ||
tm:reset() | ||
|
||
batchX, batchY = assignBatches(batchX, batchY, xTrain, yTrain, batch, opt.batchSize, shuffle) | ||
|
||
-- Update network | ||
optim.adam(optimFunction, params, optimState) | ||
|
||
-- Display train and test error | ||
if opt.display and batchIterations % 20 == 0 then | ||
-- Test error | ||
batchX, batchY = assignBatches(batchX, batchY, xTest, yTest, torch.random(1,nTestSamples-opt.batchSize+1), opt.batchSize, torch.randperm(nTestSamples)) | ||
|
||
local outputs = encoder:forward(batchX) | ||
|
||
outputs:resize(outputs:size(1), outputs:size(2), 1, 1) | ||
local reconstr = generator:forward{outputs, batchY} | ||
|
||
errorTest = criterion:forward(batchX, reconstr) | ||
|
||
table.insert(errorData, | ||
{ | ||
batchIterations/math.ceil(nTrainSamples / opt.batchSize), -- x-axis | ||
errorTrain, -- y-axis for label1 | ||
errorTest -- y-axis for label2 | ||
}) | ||
display.plot(errorData, errorDispConfig) | ||
|
||
if opt.display == 2 then | ||
local outputs = encoder:forward(dispBatchX) | ||
outputs:resize(outputs:size(1), outputs:size(2), 1, 1) | ||
local reconstr = generator:forward{outputs, dispBatchY} | ||
display.image(image.toDisplayTensor(reconstr,0,torch.round(math.sqrt(opt.batchSize))), {win=4, title='Reconstructions'}) | ||
end | ||
end | ||
|
||
-- Verbose | ||
if ((batch-1) / opt.batchSize) % 1 == 0 then | ||
print(('Epoch: [%d][%4d / %4d] Error (train): %.4f Error (test): %.4f ' | ||
.. ' Time: %.3f s Data time: %.3f s'):format( | ||
epoch, ((batch-1) / opt.batchSize), | ||
math.ceil(nTrainSamples / opt.batchSize), | ||
errorTrain and errorTrain or -1, | ||
errorTest and errorTest or -1, | ||
tm:time().real, data_tm:time().real)) | ||
end | ||
batchIterations = batchIterations + 1 | ||
end | ||
print(('End of epoch %d / %d \t Time Taken: %.3f s'):format( | ||
epoch, opt.nEpochs, epoch_tm:time().real)) | ||
|
||
-- Store network | ||
torch.save(opt.outputPath .. opt.name .. '_' .. epoch .. 'epochs.t7', encoder:clearState()) | ||
torch.save('checkpoints/' .. opt.name .. '_error.t7', errorData) | ||
end | ||
|
||
end | ||
|
||
main() |