diff --git a/utils.py b/utils.py index c843d95d6..439520184 100644 --- a/utils.py +++ b/utils.py @@ -6,7 +6,7 @@ def get_mask_from_lengths(lengths): max_len = torch.max(lengths).item() ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) - mask = (ids < lengths.unsqueeze(1)).byte() + mask = (ids < lengths.unsqueeze(1)).bool() return mask