From f0404997266a8a9907204f910268924a3acf5e17 Mon Sep 17 00:00:00 2001 From: Magic Date: Thu, 24 Dec 2015 17:25:09 +0200 Subject: [PATCH] Replace masked code in all recurrent layers by T.switch command --- lasagne/layers/recurrent.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/lasagne/layers/recurrent.py b/lasagne/layers/recurrent.py index e3d3f695..1c54ab2d 100644 --- a/lasagne/layers/recurrent.py +++ b/lasagne/layers/recurrent.py @@ -439,7 +439,7 @@ def step_masked(input_n, mask_n, hid_previous, *args): # Skip over any input with mask 0 by copying the previous # hidden state; proceed normally for any input with mask 1. hid = step(input_n, hid_previous, *args) - hid_out = hid*mask_n + hid_previous*(1 - mask_n) + hid_out = T.switch(mask_n, hid, hid_previous) return [hid_out] if mask is not None: @@ -1074,9 +1074,8 @@ def step_masked(input_n, mask_n, cell_previous, hid_previous, *args): # Skip over any input with mask 0 by copying the previous # hidden state; proceed normally for any input with mask 1. - not_mask = 1 - mask_n - cell = cell*mask_n + cell_previous*not_mask - hid = hid*mask_n + hid_previous*not_mask + cell = T.switch(mask_n, cell, cell_previous) + hid = T.switch(mask_n, hid, hid_previous) return [cell, hid] @@ -1471,8 +1470,7 @@ def step_masked(input_n, mask_n, hid_previous, *args): # Skip over any input with mask 0 by copying the previous # hidden state; proceed normally for any input with mask 1. - not_mask = 1 - mask_n - hid = hid*mask_n + hid_previous*not_mask + hid = T.switch(mask_n, hid, hid_previous) return hid