Skip to content

Commit

Permalink
Replace masked code in all recurrent layers by T.switch command
Browse files Browse the repository at this point in the history
  • Loading branch information
avostryakov committed Dec 24, 2015
1 parent c5ea6a4 commit f040499
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions lasagne/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def step_masked(input_n, mask_n, hid_previous, *args):
# Skip over any input with mask 0 by copying the previous
# hidden state; proceed normally for any input with mask 1.
hid = step(input_n, hid_previous, *args)
hid_out = hid*mask_n + hid_previous*(1 - mask_n)
hid_out = T.switch(mask_n, hid, hid_previous)
return [hid_out]

if mask is not None:
Expand Down Expand Up @@ -1074,9 +1074,8 @@ def step_masked(input_n, mask_n, cell_previous, hid_previous, *args):

# Skip over any input with mask 0 by copying the previous
# hidden state; proceed normally for any input with mask 1.
not_mask = 1 - mask_n
cell = cell*mask_n + cell_previous*not_mask
hid = hid*mask_n + hid_previous*not_mask
cell = T.switch(mask_n, cell, cell_previous)
hid = T.switch(mask_n, hid, hid_previous)

return [cell, hid]

Expand Down Expand Up @@ -1471,8 +1470,7 @@ def step_masked(input_n, mask_n, hid_previous, *args):

# Skip over any input with mask 0 by copying the previous
# hidden state; proceed normally for any input with mask 1.
not_mask = 1 - mask_n
hid = hid*mask_n + hid_previous*not_mask
hid = T.switch(mask_n, hid, hid_previous)

return hid

Expand Down

0 comments on commit f040499

Please sign in to comment.