Skip to content

Commit

Permalink
Merge pull request #554 from grammarly/master
Browse files Browse the repository at this point in the history
Replace masked code in all recurrent layers by T.switch command
  • Loading branch information
f0k committed Dec 31, 2015
2 parents 5520c06 + f040499 commit 3dc3ccd
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions lasagne/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def step_masked(input_n, mask_n, hid_previous, *args):
# Skip over any input with mask 0 by copying the previous
# hidden state; proceed normally for any input with mask 1.
hid = step(input_n, hid_previous, *args)
hid_out = hid*mask_n + hid_previous*(1 - mask_n)
hid_out = T.switch(mask_n, hid, hid_previous)
return [hid_out]

if mask is not None:
Expand Down Expand Up @@ -1038,9 +1038,8 @@ def step_masked(input_n, mask_n, cell_previous, hid_previous, *args):

# Skip over any input with mask 0 by copying the previous
# hidden state; proceed normally for any input with mask 1.
not_mask = 1 - mask_n
cell = cell*mask_n + cell_previous*not_mask
hid = hid*mask_n + hid_previous*not_mask
cell = T.switch(mask_n, cell, cell_previous)
hid = T.switch(mask_n, hid, hid_previous)

return [cell, hid]

Expand Down Expand Up @@ -1417,8 +1416,7 @@ def step_masked(input_n, mask_n, hid_previous, *args):

# Skip over any input with mask 0 by copying the previous
# hidden state; proceed normally for any input with mask 1.
not_mask = 1 - mask_n
hid = hid*mask_n + hid_previous*not_mask
hid = T.switch(mask_n, hid, hid_previous)

return hid

Expand Down

0 comments on commit 3dc3ccd

Please sign in to comment.