Skip to content

Commit

Permalink
Decrease save file sizes; fix nparams messages
Browse files Browse the repository at this point in the history
  • Loading branch information
aleju committed Nov 21, 2015
1 parent aad7af2 commit 79e6bc7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 6 deletions.
3 changes: 2 additions & 1 deletion pretrain_g.lua
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ function main()
print("G autoencoder:")
print(G_AUTOENCODER)
print(string.format('Number of free parameters in G (total): %d', NN_UTILS.getNumberOfParameters(G_AUTOENCODER)))
if OPT.gpu ~= false then
if OPT.gpu == false then
print(string.format('... encoder: %d', NN_UTILS.getNumberOfParameters(G_AUTOENCODER:get(1))))
print(string.format('... decoder: %d', NN_UTILS.getNumberOfParameters(G_AUTOENCODER:get(2))))
else
Expand Down Expand Up @@ -204,6 +204,7 @@ function epoch()
print(string.format("<trainer> saving network to %s", filename))

-- Clone the autoencoder and deactivate cuda mode
NN_UTILS.prepareNetworkForSave(G_AUTOENCODER)
local G2 = G_AUTOENCODER:clone()
G2:float()
G2 = NN_UTILS.deactivateCuda(G2)
Expand Down
6 changes: 4 additions & 2 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ function main()
MODEL_D = tmp.D
MODEL_G = tmp.G
OPTSTATE = tmp.optstate
EPOCH = tmp.epoch
EPOCH = tmp.epoch + 1
if NORMALIZE then
NORMALIZE_MEAN = tmp.normalize_mean
NORMALIZE_STD = tmp.normalize_std
Expand Down Expand Up @@ -255,7 +255,9 @@ function saveAs(filename)
os.execute(string.format("mv %s %s.old", filename, filename))
end
print(string.format("<trainer> saving network to %s", filename))
torch.save(filename, {D = MODEL_D, G = MODEL_G, opt = OPT, plot_data = PLOT_DATA, epoch = EPOCH+1, normalize_mean=NORMALIZE_MEAN, normalize_std=NORMALIZE_STD})
NN_UTILS.prepareNetworkForSave(MODEL_G)
NN_UTILS.prepareNetworkForSave(MODEL_D)
torch.save(filename, {D = MODEL_D, G = MODEL_G, opt = OPT, plot_data = PLOT_DATA, epoch = EPOCH, normalize_mean=NORMALIZE_MEAN, normalize_std=NORMALIZE_STD})
end

main()
4 changes: 1 addition & 3 deletions train_v.lua
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,7 @@ function epoch()
os.execute(string.format("mkdir -p %s", sys.dirname(filename)))
print(string.format("<trainer> saving network to %s", filename))

-- apparently something in the OPTSTATE is a CudaTensor, so saving it and then loading
-- in CPU mode would cause an error
--torch.save(filename, {V=NN_UTILS.deactivateCuda(V), opt=OPT, EPOCH=EPOCH+1}) --, optstate=OPTSTATE
NN_UTILS.prepareNetworkForSave(V)
torch.save(filename, {V=V, opt=OPT, EPOCH=EPOCH+1})
end

Expand Down
36 changes: 36 additions & 0 deletions utils/nn_utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,42 @@ function nn_utils.normalize(data, mean_, std_)
return 0.5, 0.5
end

-- from https://github.com/torch/DEPRECEATED-torch7-distro/issues/47
function nn_utils.zeroDataSize(data)
if type(data) == 'table' then
for i = 1, #data do
data[i] = nn_utils.zeroDataSize(data[i])
end
elseif type(data) == 'userdata' then
data = torch.Tensor():typeAs(data)
end
return data
end

-- from https://github.com/torch/DEPRECEATED-torch7-distro/issues/47
-- Resize the output, gradInput, etc temporary tensors to zero (so that the on disk size is smaller)
function nn_utils.prepareNetworkForSave(node)
if node.output ~= nil then
node.output = nn_utils.zeroDataSize(node.output)
end
if node.gradInput ~= nil then
node.gradInput = nn_utils.zeroDataSize(node.gradInput)
end
if node.finput ~= nil then
node.finput = nn_utils.zeroDataSize(node.finput)
end
-- Recurse on nodes with 'modules'
if (node.modules ~= nil) then
if (type(node.modules) == 'table') then
for i = 1, #node.modules do
local child = node.modules[i]
nn_utils.prepareNetworkForSave(child)
end
end
end
collectgarbage()
end

function nn_utils.getNumberOfParameters(net)
local nparams = 0
local dModules = net:listModules()
Expand Down

0 comments on commit 79e6bc7

Please sign in to comment.