Skip to content

Commit

Permalink
LSTM:get/set[Grad]HiddenState
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed Dec 28, 2016
1 parent 94e4987 commit b756aeb
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 34 deletions.
14 changes: 10 additions & 4 deletions AbstractRecurrent.lua
Expand Up @@ -13,6 +13,7 @@ function AbstractRecurrent:__init(rho)
self.outputs = {}
self.gradInputs = {}
self._gradOutputs = {}
self.gradOutputs = {}

self.step = 1

Expand Down Expand Up @@ -100,7 +101,7 @@ end
function nn.AbstractRecurrent:clearState()
self:forget()
-- keep the first two sharedClones
nn.utils.clear(self, '_input', '_gradOutput', '_gradOutputs', 'gradPrevOutput', 'cell', 'cells', 'gradCells', 'outputs', 'gradInputs')
nn.utils.clear(self, '_input', '_gradOutput', '_gradOutputs', 'gradPrevOutput', 'cell', 'cells', 'gradCells', 'outputs', 'gradInputs', 'gradOutputs')
for i, clone in ipairs(self.sharedClones) do
clone:clearState()
end
Expand All @@ -120,13 +121,17 @@ function AbstractRecurrent:forget()
self.gradInputs = {}
self.sharedClones = _.compact(self.sharedClones)
self._gradOutputs = _.compact(self._gradOutputs)
self.gradOutputs = {}
if self.cells then
self.cells = {}
self.gradCells = {}
end
end

-- forget the past inputs; restart from first step
self.step = 1


if not self.rmInSharedClones then
if not self.rmInSharedClones then
-- Asserts that issue 129 is solved. In forget as it is often called.
-- Asserts that self.recurrentModule is part of the sharedClones.
-- Since its used for evaluation, it should be used for training.
Expand Down Expand Up @@ -225,11 +230,12 @@ function AbstractRecurrent:maxBPTTstep(rho)
self.rho = rho
end

-- get hidden state: h[t]
-- get stored hidden state: h[t] where h[t] = f(x[t], h[t-1])
function AbstractRecurrent:getHiddenState(step, input)
error"Not Implemented"
end

-- set stored hidden state
function AbstractRecurrent:setHiddenState(step, hiddenState)
error"Not Implemented"
end
Expand Down
68 changes: 49 additions & 19 deletions LSTM.lua
Expand Up @@ -139,8 +139,8 @@ end
function LSTM:getHiddenState(step, input)
local prevOutput, prevCell
if step == 0 then
prevOutput = self.userPrevOutput or self.zeroTensor
prevCell = self.userPrevCell or self.zeroTensor
prevOutput = self.userPrevOutput or self.outputs[step] or self.zeroTensor
prevCell = self.userPrevCell or self.cells[step] or self.zeroTensor
if input then
if input:dim() == 2 then
self.zeroTensor:resize(input:size(1), self.outputSize):zero()
Expand All @@ -156,6 +156,15 @@ function LSTM:getHiddenState(step, input)
return {prevOutput, prevCell}
end

function LSTM:setHiddenState(step, hiddenState)
assert(torch.type(hiddenState) == 'table')
assert(#hiddenState == 2)

-- previous output of this module
self.outputs[step] = hiddenState[1]
self.cells[step] = hiddenState[2]
end

------------------------- forward backward -----------------------------
function LSTM:updateOutput(input)
local prevOutput, prevCell = unpack(self:getHiddenState(self.step-1, input))
Expand Down Expand Up @@ -185,6 +194,26 @@ function LSTM:updateOutput(input)
return self.output
end

function LSTM:getGradHiddenState(step)
local gradOutput, gradCell
if step == self.step-1 then
gradOutput = self.userNextGradOutput or self.gradOutputs[step] or self.zeroTensor
gradCell = self.userNextGradCell or self.gradCells[step] or self.zeroTensor
else
gradOutput = self.gradOutputs[step]
gradCell = self.gradCells[step]
end
return {gradOutput, gradCell}
end

function LSTM:setGradHiddenState(step, gradHiddenState)
assert(torch.type(gradHiddenState) == 'table')
assert(#gradHiddenState == 2)

self.gradOutputs[step] = gradHiddenState[1]
self.gradCells[step] = gradHiddenState[2]
end

function LSTM:_updateGradInput(input, gradOutput)
assert(self.step > 1, "expecting at least one updateOutput")
local step = self.updateGradInputStep - 1
Expand All @@ -194,24 +223,23 @@ function LSTM:_updateGradInput(input, gradOutput)
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
if self.gradPrevOutput then
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)
local gradHiddenState = self:getGradHiddenState(step)
local _gradOutput, gradCell = gradHiddenState[1], gradHiddenState[2]
assert(_gradOutput and gradCell)

if _gradOutput then
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], _gradOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]
end

local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local cell = (step == 1) and (self.userPrevCell or self.zeroTensor) or self.cells[step-1]
local inputTable = {input, output, cell}
local gradCell = (step == self.step-1) and (self.userNextGradCell or self.zeroTensor) or self.gradCells[step]
local inputTable = self:getHiddenState(step-1)
table.insert(inputTable, 1, input)

local gradInputTable = recurrentModule:updateGradInput(inputTable, {gradOutput, gradCell})

local gradInput
gradInput, self.gradPrevOutput, gradCell = unpack(gradInputTable)
self.gradCells[step-1] = gradCell
if self.userPrevOutput then self.userGradPrevOutput = self.gradPrevOutput end
if self.userPrevCell then self.userGradPrevCell = gradCell end
local gradInput = table.remove(gradInputTable, 1)
self:setGradHiddenState(step-1, gradInputTable)

return gradInput
end
Expand All @@ -224,17 +252,19 @@ function LSTM:_accGradParameters(input, gradOutput, scale)
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local cell = (step == 1) and (self.userPrevCell or self.zeroTensor) or self.cells[step-1]
local inputTable = {input, output, cell}
local gradOutput = (step == self.step-1) and gradOutput or self._gradOutputs[step]
local gradCell = (step == self.step-1) and (self.userNextGradCell or self.zeroTensor) or self.gradCells[step]
local gradOutputTable = {gradOutput, gradCell}
local inputTable = self:getHiddenState(step-1)
table.insert(inputTable, 1, input)
local gradOutputTable = self:getGradHiddenState(step)
gradOutputTable[1] = self._gradOutputs[step] or gradOutputTable[1]
recurrentModule:accGradParameters(inputTable, gradOutputTable, scale)
end

function LSTM:clearState()
self.zeroTensor:set()
if self.userPrevOutput then self.userPrevOutput:set() end
if self.userPrevCell then self.userPrevCell:set() end
if self.userGradPrevOutput then self.userGradPrevOutput:set() end
if self.userGradPrevCell then self.userGradPrevCell:set() end
return parent.clearState(self)
end

Expand Down
21 changes: 20 additions & 1 deletion Module.lua
@@ -1,4 +1,4 @@
local Module = nn.Module
local Module = nn.Module

-- You can use this to manually forget past memories in AbstractRecurrent instances
function Module:forget()
Expand Down Expand Up @@ -47,3 +47,22 @@ function Module:maxBPTTstep(rho)
end
end
end

function Module:getHiddenState(step)
if self.modules then
local hiddenState = {}
for i, module in ipairs(self.modules) do
hiddenState[i] = module:getHiddenState(step)
end
return hiddenState
end
end

function Module:setHiddenState(step, hiddenState)
if self.modules then
tc.checktab(hiddenState, 2)
for i, module in ipairs(self.modules) do
module:setHiddenState(step, hiddenState[i])
end
end
end
22 changes: 20 additions & 2 deletions Recurrence.lua
Expand Up @@ -98,11 +98,29 @@ function Recurrence:getHiddenState(step, input)
-- previous output of this module
prevOutput = self.outputs[step]
end
return prevOutput
assert(prevOutput, "Missing hiddenState at step "..step)
-- call getHiddenState on recurrentModule as they may contain AbstractRecurrent instances...
return {prevOutput, nn.Module.getHiddenState(self, step)}
end

function Recurrence:setHiddenState(step, hiddenState)
assert(torch.type(hiddenState) == 'table')
assert(#hiddenState >= 1)
if step == 0 then
self.userPrevOutput = hiddenState[1]
else
-- previous output of this module
self.outputs[step] = hiddenState[1]
end
if hiddenState[2] then
-- call setHiddenState on recurrentModule as they may contain AbstractRecurrent instances...
nn.Module.setHiddenState(self, step, hiddenState[2])
end
end

function Recurrence:updateOutput(input)
local prevOutput = self:getHiddenState(self.step-1, input)
-- output(t-1)
local prevOutput = self:getHiddenState(self.step-1, input)[1]

-- output(t) = recurrentModule{input(t), output(t-1)}
local output
Expand Down
7 changes: 3 additions & 4 deletions examples/encoder-decoder-coupling.lua
Expand Up @@ -6,7 +6,7 @@ Example of "coupled" separate encoder and decoder networks, e.g. for sequence-to

require 'rnn'

version = 1.3 -- Added multiple layers and merged with seqLSTM example
version = 1.4 -- Uses [get,set]GradHiddenState for LSTM

local opt = {}
opt.learningRate = 0.1
Expand Down Expand Up @@ -37,8 +37,7 @@ function backwardConnect(enc, dec)
enc.lstmLayers[i].userNextGradCell = dec.lstmLayers[i].userGradPrevCell
enc.lstmLayers[i].gradPrevOutput = dec.lstmLayers[i].userGradPrevOutput
else
enc.lstmLayers[i].userNextGradCell = nn.rnn.recursiveCopy(enc.lstmLayers[i].userNextGradCell, dec.lstmLayers[i].userGradPrevCell)
enc.lstmLayers[i].gradPrevOutput = nn.rnn.recursiveCopy(enc.lstmLayers[i].gradPrevOutput, dec.lstmLayers[i].userGradPrevOutput)
enc:setGradHiddenState(opt.seqLen, dec:getGradHiddenState(0))
end
end
end
Expand Down Expand Up @@ -101,7 +100,7 @@ for i=1,opt.niter do
local decOut = dec:forward(decInSeq)
--print(decOut)
local err = criterion:forward(decOut, decOutSeq)

print(string.format("Iteration %d ; NLL err = %f ", i, err))

-- Backward pass
Expand Down
66 changes: 62 additions & 4 deletions test/test.lua
Expand Up @@ -4878,14 +4878,13 @@ function rnntest.encoderdecoder()

--[[ Forward coupling: Copy encoder cell and output to decoder LSTM ]]--
local function forwardConnect(encLSTM, decLSTM)
decLSTM.userPrevOutput = nn.rnn.recursiveCopy(decLSTM.userPrevOutput, encLSTM.outputs[opt.inputSeqLen])
decLSTM.userPrevCell = nn.rnn.recursiveCopy(decLSTM.userPrevCell, encLSTM.cells[opt.inputSeqLen])
decLSTM.userPrevOutput = nn.rnn.recursiveCopy(decLSTM.userPrevOutput, encLSTM.outputs[opt.inputSeqLen])
decLSTM.userPrevCell = nn.rnn.recursiveCopy(decLSTM.userPrevCell, encLSTM.cells[opt.inputSeqLen])
end

--[[ Backward coupling: Copy decoder gradients to encoder LSTM ]]--
local function backwardConnect(encLSTM, decLSTM)
encLSTM.userNextGradCell = nn.rnn.recursiveCopy(encLSTM.userNextGradCell, decLSTM.userGradPrevCell)
encLSTM.gradPrevOutput = nn.rnn.recursiveCopy(encLSTM.gradPrevOutput, decLSTM.userGradPrevOutput)
encLSTM:setGradHiddenState(opt.inputSeqLen, decLSTM:getGradHiddenState(0))
end

-- Encoder
Expand Down Expand Up @@ -6680,6 +6679,65 @@ function rnntest.inplaceBackward()
end
end

function rnntest.LSTM_hiddenState()
local seqlen, batchsize = 7, 3
local inputsize, outputsize = 4, 5
local lstm = nn.LSTM(inputsize, outputsize)
local input = torch.randn(seqlen*2, batchsize, inputsize)
local gradOutput = torch.randn(seqlen*2, batchsize, outputsize)
local lstm2 = lstm:clone()

-- test forward
for step=1,seqlen do -- initialize lstm2 hidden state
lstm2:forward(input[step])
end

for step=1,seqlen do
local hiddenState = lstm2:getHiddenState(seqlen+step-1)
mytester:assert(#hiddenState == 2)
lstm:setHiddenState(step-1, hiddenState)
local output = lstm:forward(input[seqlen+step])
local output2 = lstm2:forward(input[seqlen+step])
mytester:assertTensorEq(output, output2, 0.0000001)
end

-- test backward
lstm:zeroGradParameters()
lstm2:zeroGradParameters()
lstm:forget()

for step=1,seqlen do
lstm:forward(input[step])
local hs = lstm:getHiddenState(step)
local hs2 = lstm2:getHiddenState(step)
mytester:assertTensorEq(hs[1], hs2[1], 0.0000001)
mytester:assertTensorEq(hs[2], hs2[2], 0.0000001)
end

for step=seqlen*2,seqlen+1,-1 do
lstm2:backward(input[step], gradOutput[step])
end

lstm2:zeroGradParameters()

for step=seqlen,1,-1 do
local gradHiddenState = lstm2:getGradHiddenState(step)
mytester:assert(#gradHiddenState == 2)
lstm:setGradHiddenState(step, gradHiddenState)
local gradInput = lstm:backward(input[step], gradOutput[step])
local gradInput2 = lstm2:backward(input[step], gradOutput[step])
mytester:assertTensorEq(gradInput, gradInput2, 0.0000001)
end

local params, gradParams = lstm:parameters()
local params2, gradParams2 = lstm2:parameters()

for i=1,#params do
mytester:assertTensorEq(gradParams[i], gradParams2[i], 0.00000001)
end

end

function rnn.test(tests, benchmark_)
mytester = torch.Tester()
benchmark = benchmark_
Expand Down

0 comments on commit b756aeb

Please sign in to comment.