diff --git a/lasagne/layers/recurrent.py b/lasagne/layers/recurrent.py index ee1e49c5..15cf7b93 100644 --- a/lasagne/layers/recurrent.py +++ b/lasagne/layers/recurrent.py @@ -429,7 +429,7 @@ def step_masked(input_n, mask_n, hid_previous, *args): # Skip over any input with mask 0 by copying the previous # hidden state; proceed normally for any input with mask 1. hid = step(input_n, hid_previous, *args) - hid_out = hid*mask_n + hid_previous*(1 - mask_n) + hid_out = T.switch(mask_n, hid, hid_previous) return [hid_out] if mask is not None: @@ -1038,9 +1038,8 @@ def step_masked(input_n, mask_n, cell_previous, hid_previous, *args): # Skip over any input with mask 0 by copying the previous # hidden state; proceed normally for any input with mask 1. - not_mask = 1 - mask_n - cell = cell*mask_n + cell_previous*not_mask - hid = hid*mask_n + hid_previous*not_mask + cell = T.switch(mask_n, cell, cell_previous) + hid = T.switch(mask_n, hid, hid_previous) return [cell, hid] @@ -1417,8 +1416,7 @@ def step_masked(input_n, mask_n, hid_previous, *args): # Skip over any input with mask 0 by copying the previous # hidden state; proceed normally for any input with mask 1. - not_mask = 1 - mask_n - hid = hid*mask_n + hid_previous*not_mask + hid = T.switch(mask_n, hid, hid_previous) return hid