Skip to content

Commit

Permalink
LSTM:clearState and trailing spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed Dec 13, 2016
1 parent 6aee4ee commit aee194a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 31 deletions.
11 changes: 11 additions & 0 deletions AbstractSequencer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,14 @@ function AbstractSequencer:remember(remember)
return self
end

function AbstractSequencer:hasMemory()
local _ = require 'moses'
if (self.train ~= false) and _.contains({'both','train'}, self._remember) then -- train (defaults to nil...)
return true
elseif (self.train == false) and _.contains({'both','eval'}, self._remember) then -- evaluate
return true
else
return false
end
end

57 changes: 35 additions & 22 deletions LSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ function LSTM:__init(inputSize, outputSize, rho, cell2gate)
self.recurrentModule = self:buildModel()
-- make it work with nn.Container
self.modules[1] = self.recurrentModule
self.sharedClones[1] = self.recurrentModule
self.sharedClones[1] = self.recurrentModule

-- for output(0), cell(0) and gradCell(T)
self.zeroTensor = torch.Tensor()
self.zeroTensor = torch.Tensor()

self.cells = {}
self.gradCells = {}
end
Expand All @@ -39,7 +39,7 @@ function LSTM:buildGate()
local input2gate = nn.Linear(self.inputSize, self.outputSize)
local output2gate = nn.LinearNoBias(self.outputSize, self.outputSize)
local para = nn.ParallelTable()
para:add(input2gate):add(output2gate)
para:add(input2gate):add(output2gate)
if self.cell2gate then
para:add(nn.CMul(self.outputSize)) -- diagonal cell to gate weight matrix
end
Expand Down Expand Up @@ -76,7 +76,7 @@ end

function LSTM:buildCell()
-- build
self.inputGate = self:buildInputGate()
self.inputGate = self:buildInputGate()
self.forgetGate = self:buildForgetGate()
self.hiddenLayer = self:buildHidden()
-- forget = forgetGate{input, output(t-1), cell(t-1)} * cell(t-1)
Expand All @@ -99,16 +99,16 @@ function LSTM:buildCell()
cell:add(nn.CAddTable())
self.cellLayer = cell
return cell
end
end

function LSTM:buildOutputGate()
self.outputGate = self:buildGate()
return self.outputGate
end

-- cell(t) = cellLayer{input, output(t-1), cell(t-1)}
-- output(t) = outputGate{input, output(t-1), cell(t)}*tanh(cell(t))
-- output of Model is table : {output(t), cell(t)}
-- output of Model is table : {output(t), cell(t)}
function LSTM:buildModel()
-- build components
self.cellLayer = self:buildCell()
Expand All @@ -118,7 +118,7 @@ function LSTM:buildModel()
concat:add(nn.NarrowTable(1,2)):add(self.cellLayer)
local model = nn.Sequential()
model:add(concat)
-- output of concat is {{input, output}, cell(t)},
-- output of concat is {{input, output}, cell(t)},
-- so flatten to {input, output, cell(t)}
model:add(nn.FlattenTable())
local cellAct = nn.Sequential()
Expand Down Expand Up @@ -152,7 +152,7 @@ function LSTM:updateOutput(input)
prevOutput = self.outputs[self.step-1]
prevCell = self.cells[self.step-1]
end

-- output(t), cell(t) = lstm{input(t), output(t-1), cell(t-1)}
local output, cell
if self.train ~= false then
Expand All @@ -163,13 +163,13 @@ function LSTM:updateOutput(input)
else
output, cell = unpack(self.recurrentModule:updateOutput{input, prevOutput, prevCell})
end

self.outputs[self.step] = output
self.cells[self.step] = cell

self.output = output
self.cell = cell

self.step = self.step + 1
self.gradPrevOutput = nil
self.updateGradInputStep = nil
Expand All @@ -182,40 +182,40 @@ function LSTM:_updateGradInput(input, gradOutput)
assert(self.step > 1, "expecting at least one updateOutput")
local step = self.updateGradInputStep - 1
assert(step >= 1)

-- set the output/gradOutput states of current Module
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
if self.gradPrevOutput then
self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], self.gradPrevOutput)
nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput)
gradOutput = self._gradOutputs[step]
end

local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local cell = (step == 1) and (self.userPrevCell or self.zeroTensor) or self.cells[step-1]
local inputTable = {input, output, cell}
local gradCell = (step == self.step-1) and (self.userNextGradCell or self.zeroTensor) or self.gradCells[step]

local gradInputTable = recurrentModule:updateGradInput(inputTable, {gradOutput, gradCell})

local gradInput
gradInput, self.gradPrevOutput, gradCell = unpack(gradInputTable)
self.gradCells[step-1] = gradCell
if self.userPrevOutput then self.userGradPrevOutput = self.gradPrevOutput end
if self.userPrevCell then self.userGradPrevCell = gradCell end

return gradInput
end

function LSTM:_accGradParameters(input, gradOutput, scale)
local step = self.accGradParametersStep - 1
assert(step >= 1)

-- set the output/gradOutput states of current Module
local recurrentModule = self:getStepModule(step)

-- backward propagate through this step
local output = (step == 1) and (self.userPrevOutput or self.zeroTensor) or self.outputs[step-1]
local cell = (step == 1) and (self.userPrevCell or self.zeroTensor) or self.cells[step-1]
Expand All @@ -226,3 +226,16 @@ function LSTM:_accGradParameters(input, gradOutput, scale)
recurrentModule:accGradParameters(inputTable, gradOutputTable, scale)
end

function LSTM:clearState()
self.zeroTensor:set()
return parent.clearState(self)
end

function LSTM:type(type, ...)
if type then
self:forget()
self:clearState()
self.zeroTensor = self.zeroTensor:type(type)
end
return parent.type(self, type, ...)
end
18 changes: 9 additions & 9 deletions MaskZero.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function MaskZero:__init(module, nInputDim, silent)
print("Warning : you are most likely using MaskZero the wrong way. "
.."You should probably use AbstractRecurrent:maskZero() so that "
.."it wraps the internal AbstractRecurrent.recurrentModule instead of "
.."wrapping the AbstractRecurrent module itself.")
.."wrapping the AbstractRecurrent module itself.")
end
assert(torch.type(nInputDim) == 'number', 'Expecting nInputDim number at arg 1')
self.nInputDim = nInputDim
Expand All @@ -36,7 +36,7 @@ function MaskZero:recursiveMask(output, input, mask)
else
assert(torch.isTensor(input))
output = torch.isTensor(output) and output or input.new()

-- make sure mask has the same dimension as the input tensor
local inputSize = input:size():fill(1)
if self.batchmode then
Expand All @@ -51,7 +51,7 @@ function MaskZero:recursiveMask(output, input, mask)
return output
end

function MaskZero:updateOutput(input)
function MaskZero:updateOutput(input)
-- recurrent module input is always the first one
local rmi = self:recursiveGetFirst(input):contiguous()
if rmi:dim() == self.nInputDim then
Expand All @@ -63,9 +63,9 @@ function MaskZero:updateOutput(input)
else
error("nInputDim error: "..rmi:dim()..", "..self.nInputDim)
end

-- build mask
local vectorDim = rmi:dim()
local vectorDim = rmi:dim()
self._zeroMask = self._zeroMask or rmi.new()
self._zeroMask:norm(rmi, 2, vectorDim)
self.zeroMask = self.zeroMask or (
Expand All @@ -74,9 +74,9 @@ function MaskZero:updateOutput(input)
or torch.ByteTensor()
)
self._zeroMask.eq(self.zeroMask, self._zeroMask, 0)

-- forward through decorated module
local output = self.module:updateOutput(input)
local output = self.modules[1]:updateOutput(input)

self.output = self:recursiveMask(self.output, output, self.zeroMask)
return self.output
Expand All @@ -85,8 +85,8 @@ end
function MaskZero:updateGradInput(input, gradOutput)
-- zero gradOutputs before backpropagating through decorated module
self.gradOutput = self:recursiveMask(self.gradOutput, gradOutput, self.zeroMask)
self.gradInput = self.module:updateGradInput(input, self.gradOutput)

self.gradInput = self.modules[1]:updateGradInput(input, self.gradOutput)
return self.gradInput
end

Expand Down

0 comments on commit aee194a

Please sign in to comment.