diff --git a/other/alternativeEncoders/trainEncoderE2E.lua b/other/alternativeEncoders/trainEncoderE2E.lua new file mode 100644 index 00000000..0f153677 --- /dev/null +++ b/other/alternativeEncoders/trainEncoderE2E.lua @@ -0,0 +1,351 @@ +-- This file reads the dataset generated by generateEncoderDataset.lua and +-- trains an encoder net that learns to map an image X to a noise vector Z (encoder Z, type Z) +-- or an encoded that maps an image X to an attribute vector Y (encoder Y, type Y). + +require 'image' +require 'nn' +require 'optim' +require 'cunn' +require 'cudnn' +torch.setdefaulttensortype('torch.FloatTensor') + +local function getParameters() + + opt = {} + + -- Type of encoder must be passed as argument to decide what kind of + -- encoder will be trained (encoder Z [type Z] or encoder Y [type Y]) + opt.type = os.getenv('type') + opt.type = 'Y' + + assert(opt.type, "Parameter 'type' not specified. It is necessary to set the encoder type: 'Z' or 'Y'.\nExample: type=Z th trainEncoder.lua") + assert(string.upper(opt.type)=='Z' or string.upper(opt.type)=='Y',"Parameter 'type' must be 'Z' (encoder Z) or 'Y' (encoder Y).") + + -- Load parameters from config file + if string.upper(opt.type)=='Z' then + assert(loadfile("cfg/mainConfig.lua"))(1) + else + assert(loadfile("cfg/mainConfig.lua"))(2) + end + + -- one-line argument parser. Parses environment variables to override the defaults + for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end + print(opt) + + if opt.display then display = require 'display' end + + return opt +end + +local function readDataset(path, imSize) +-- For CelebA: there's expected to find in path a file named images.dmp and imLabels.dmp +-- which contains the images X and attribute vectors Y. +-- images.dmp is obtained running data/preprocess_celebA.lua +-- imLabels.dmp is obtained running trainGAN.lua via data/donkey_celebA.lua +-- For MNIST: It will use the mnist luarocks package + local X, Y + if string.lower(path) == 'mnist' or string.lower(path) == 'mnist/' then + local mnist = require 'mnist' + local trainSet = mnist.traindataset() + X = torch.Tensor(trainSet.data:size(1), 1, imSize, imSize) + local resize = trainSet.data:size(2) ~= imSize + + local labels = trainSet.label:int() + Y = torch.IntTensor(trainSet.data:size(1), 10):fill(-1) + + for i = 1,trainSet.data:size(1) do + -- Read MNIST images + local im = trainSet.data[{{i}}] + if resize then + im = image.scale(im, imSize, imSize):float() + end + im:div(255):mul(2):add(-1) -- change [0, 255] to [-1, 1] + X[{{i}}]:copy(im) + + --Read MNIST labels + local class = trainSet.label[i] -- Convert 0-9 to one-hot vector + Y[{{i},{class+1}}] = 1 + end + + else + print('Loading images X from '..path..'images.dmp') + local data = torch.load(path..'images.dmp') + print(('Done. Loaded %.2f GB (%d images).'):format((4*data:size(1)*data:size(2)*data:size(3)*data:size(4))/2^30, data:size(1))) + + if data:size(3) ~= imSize then + -- Resize images + X = torch.Tensor(data:size(1), imSize, imSize) + for i = 1,data:size(1) do + local im = image.scale(data[{{i}}], imSize, imSize) + im:mul(2):add(-1) -- change [0, 1] -to[-1, 1] + X[{{i}}]:copy(im) + end + else + X = data + X:mul(2):add(-1) -- change [0, 1] to [-1, 1] + end + + print('Loading attributes Y from '..path..'imLabels.dmp') + Y = torch.load(path..'imLabels.dmp') + print(('Done. Loaded %d attributes'):format(Y:size(1))) + end + + return X, Y +end + +local function splitTrainTest(x, y, split) + local xTrain, yTrain, xTest, yTest + + local nSamples = x:size(1) + local splitInd = torch.floor(split*nSamples) + + xTrain = x[{{1,splitInd}}] + yTrain = y[{{1,splitInd}}] + + xTest = x[{{splitInd+1,nSamples}}] + yTest = y[{{splitInd+1,nSamples}}] + + return xTrain, yTrain, xTest, yTest +end + +local function getEncoder(inputSize, nFiltersBase, nConvLayers, FCsz) + -- Encoder architecture based on Autoencoding beyond pixels using a learned similarity metric (VAE/GAN hybrid) + + local encoder = nn.Sequential() + -- Assuming nFiltersBase = 64, nConvLayers = 3 + -- 1st Conv layer: 5×5 64 conv. ↓, BNorm, ReLU + -- Data: 32x32 -> 16x16 + encoder:add(nn.SpatialConvolution(inputSize[1], nFiltersBase, 5, 5, 2, 2, 2, 2)) + encoder:add(nn.SpatialBatchNormalization(nFiltersBase)) + encoder:add(nn.ReLU(true)) + + -- 2nd Conv layer: 5×5 128 conv. ↓, BNorm, ReLU + -- Data: 16x16 -> 8x8 + -- 3rd Conv layer: 5×5 256 conv. ↓, BNorm, ReLU + -- Data: 8x8 -> 4x4 + local nFilters = nFiltersBase + for j=2,nConvLayers do + encoder:add(nn.SpatialConvolution(nFilters, nFilters*2, 5, 5, 2, 2, 2, 2)) + encoder:add(nn.SpatialBatchNormalization(nFilters*2)) + encoder:add(nn.ReLU(true)) + nFilters = nFilters * 2 + end + + -- 4th FC layer: 2048 fully-connected + -- Data: 4x4 -> 16 + encoder:add(nn.View(-1):setNumInputDims(3)) -- reshape data to 2d tensor (samples x the rest) + -- Assuming squared images and conv layers configuration (kernel, stride and padding) is not changed: + --nFilterFC = (imageSize/2^nConvLayers)²*nFiltersLastConvNet + local inputFilterFC = (inputSize[2]/2^nConvLayers)^2*nFilters + + if FCsz == nil then FCsz = inputFilterFC end + + encoder:add(nn.Linear(inputFilterFC, FCsz)) + encoder:add(nn.BatchNormalization(FCsz)) + encoder:add(nn.ReLU(true)) + + encoder:add(nn.Linear(FCsz, 100)) -- 100 is the size of the Z vector + + local criterion = nn.MSECriterion() + + return encoder, criterion +end + +local function assignBatches(batchX, batchY, x, y, batch, batchSize, shuffle) + + data_tm:reset(); data_tm:resume() + + batchX:copy(x:index(1, shuffle[{{batch,batch+batchSize-1}}]:long())) + batchY:copy(y:index(1, shuffle[{{batch,batch+batchSize-1}}]:long())) + + data_tm:stop() + + return batchX, batchY +end + +local function displayConfig(disp, title) + -- initialize error display configuration + local errorData, errorDispConfig + if disp then + errorData = {} + errorDispConfig = + { + title = 'Encoder error - ' .. title, + win = 1, + labels = {'Epoch', 'Train error', 'Test error'}, + ylabel = "Error", + legend='always' + } + end + return errorData, errorDispConfig +end + +function main() + + local opt = getParameters() + print(opt) + + -- Set timers + local epoch_tm = torch.Timer() + local tm = torch.Timer() + data_tm = torch.Timer() + + -- Read images X and labels Y from dataset + local X, Y = readDataset(opt.datasetPath, opt.loadSize) + -- X: #samples x im3 x im2 x im1 + -- Y: #samples x ny + +-- Split train and test + local xTrain, yTrain, xTest, yTest + xTrain, yTrain, xTest, yTest = splitTrainTest(X, Y, opt.split) + + -- Set network architecture + local encoder, criterion = getEncoder(xTrain[1]:size(), opt.nf, opt.nConvLayers, opt.FCsz) + + -- Initialize batches + local batchX = torch.Tensor(opt.batchSize, xTrain:size(2), xTrain:size(3), xTrain:size(4)) + local batchY = torch.Tensor(opt.batchSize, yTrain:size(2)) + + -- Load generator + local generator = torch.load('checkpoints/c_celebA_64_filt_Yconv1_noTest_wrongYFixed_24_net_G.t7') + + -- Copy variables to GPU + cutorch.setDevice(opt.gpu) + batchX = batchX:cuda(); batchY = batchY:cuda() + + if pcall(require, 'cudnn') then + cudnn.benchmark = true + cudnn.convert(encoder, cudnn) + end + generator:cuda() + encoder:cuda() + criterion:cuda() + + generator:evaluate() + + -- Remove optimization to enable backward step, in case the network is optimized + --local optnet = require 'optnet' + --optnet.removeOptimization(generator) + + local params, gradParams = encoder:getParameters() -- This has to be performed always after the cuda call + + -- Define optim (general optimizer) + local errorTrain + local errorTest + local function optimFunction(params) -- This function needs to be declared here to avoid using global variables. + -- reset gradients (gradients are always accumulated, to accommodat batch methods) + gradParams:zero() + + -- Encode images + local outputs = encoder:forward(batchX) + outputs:resize(outputs:size(1), outputs:size(2), 1, 1) -- Adapt dimensionality to generator input (batchSz x nz x 1 x 1) + + -- Reconstruct encoded images + local reconstr = generator:forward{outputs, batchY} + + -- Compute error (MSE pixel-wise image reconstruction) + errorTrain = criterion:forward(batchX, reconstr) + + -- Compute error gradients w.r.t. generator output + local df_do = criterion:backward(batchX, reconstr) + + -- Compute error gradients w.r.t. generator input / encoder output + local df_dg = generator:updateGradInput({outputs, batchY}, df_do) + -- df_dg[1] contains the gradients of the latent space Z (encoder output, generator Z input) + -- which are the ones you need to compute the backward step. + -- df_dg[2] contains the gradients of Y. You don't need them as they are not the encoder output. + + df_dg[1]:resize(df_dg[1]:size(1), df_dg[1]:size(2)) -- Adapt dimensionality to encoder output (batchSz x nz) + outputs:resize(outputs:size(1), outputs:size(2)) + + -- Backpropagate generator gradients to encoder + encoder:backward(batchX, df_dg[1]) + + return errorTrain, gradParams + end + + local optimState = { + learningRate = opt.lr, + beta1 = opt.beta1, + } + + local nTrainSamples = xTrain:size(1) + local nTestSamples = xTest:size(1) + + -- Initialize display configuration (if enabled) + local errorData, errorDispConfig = displayConfig(opt.display, opt.name) + paths.mkdir(opt.outputPath) + + local dispBatchX, dispBatchY + if opt.display == 2 then + dispBatchX = torch.Tensor(opt.batchSize, xTrain:size(2), xTrain:size(3), xTrain:size(4)):cuda() + dispBatchY = torch.Tensor(opt.batchSize, yTrain:size(2)):cuda() + dispBatchX:copy(xTest[{{1,opt.batchSize}}]) + dispBatchY:copy(yTest[{{1,opt.batchSize}}]) + display.image(image.toDisplayTensor(dispBatchX,0,torch.round(math.sqrt(opt.batchSize))), {win=3, title='Real test images'}) + end + + -- Train network + local batchIterations = 0 -- for display purposes only + for epoch = 1, opt.nEpochs do + epoch_tm:reset() + local shuffle = torch.randperm(nTrainSamples) + for batch = 1, nTrainSamples-opt.batchSize+1, opt.batchSize do + tm:reset() + + batchX, batchY = assignBatches(batchX, batchY, xTrain, yTrain, batch, opt.batchSize, shuffle) + + -- Update network + optim.adam(optimFunction, params, optimState) + + -- Display train and test error + if opt.display and batchIterations % 20 == 0 then + -- Test error + batchX, batchY = assignBatches(batchX, batchY, xTest, yTest, torch.random(1,nTestSamples-opt.batchSize+1), opt.batchSize, torch.randperm(nTestSamples)) + + local outputs = encoder:forward(batchX) + + outputs:resize(outputs:size(1), outputs:size(2), 1, 1) + local reconstr = generator:forward{outputs, batchY} + + errorTest = criterion:forward(batchX, reconstr) + + table.insert(errorData, + { + batchIterations/math.ceil(nTrainSamples / opt.batchSize), -- x-axis + errorTrain, -- y-axis for label1 + errorTest -- y-axis for label2 + }) + display.plot(errorData, errorDispConfig) + + if opt.display == 2 then + local outputs = encoder:forward(dispBatchX) + outputs:resize(outputs:size(1), outputs:size(2), 1, 1) + local reconstr = generator:forward{outputs, dispBatchY} + display.image(image.toDisplayTensor(reconstr,0,torch.round(math.sqrt(opt.batchSize))), {win=4, title='Reconstructions'}) + end + end + + -- Verbose + if ((batch-1) / opt.batchSize) % 1 == 0 then + print(('Epoch: [%d][%4d / %4d] Error (train): %.4f Error (test): %.4f ' + .. ' Time: %.3f s Data time: %.3f s'):format( + epoch, ((batch-1) / opt.batchSize), + math.ceil(nTrainSamples / opt.batchSize), + errorTrain and errorTrain or -1, + errorTest and errorTest or -1, + tm:time().real, data_tm:time().real)) + end + batchIterations = batchIterations + 1 + end + print(('End of epoch %d / %d \t Time Taken: %.3f s'):format( + epoch, opt.nEpochs, epoch_tm:time().real)) + + -- Store network + torch.save(opt.outputPath .. opt.name .. '_' .. epoch .. 'epochs.t7', encoder:clearState()) + torch.save('checkpoints/' .. opt.name .. '_error.t7', errorData) + end + +end + +main() \ No newline at end of file