Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Use masked_fill_ in replace_masked_values #1039

Closed
matt-gardner opened this issue Mar 27, 2018 · 0 comments · Fixed by #1651
Closed

Use masked_fill_ in replace_masked_values #1039

matt-gardner opened this issue Mar 27, 2018 · 0 comments · Fixed by #1651

Comments

@matt-gardner
Copy link
Contributor

matt-gardner commented Mar 27, 2018

I wrote this replace_masked_values function before I realized that masked_fill_ exists. I'm pretty sure we could just use masked_fill_ instead of what's currently here:

def replace_masked_values(tensor: Variable, mask: Variable, replace_with: float) -> Variable:
"""
Replaces all masked values in ``tensor`` with ``replace_with``. ``mask`` must be broadcastable
to the same shape as ``tensor``. We require that ``tensor.dim() == mask.dim()``, as otherwise we
won't know which dimensions of the mask to unsqueeze.
"""
# We'll build a tensor of the same shape as `tensor`, zero out masked values, then add back in
# the `replace_with` value.
if tensor.dim() != mask.dim():
raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim()))
one_minus_mask = 1.0 - mask
values_to_add = replace_with * one_minus_mask
return tensor * mask + values_to_add

I'm not certain this works correctly with gradients and things, but it should...

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants