In [1]:
import math
import torch
from torch import nn
from d2l import torch as d2l

In [2]:
# Lo que buscamos acá es que aquellos valores que superen un máximo i se restrinjan a cero, como una región tipo I o II de cálculo, buscamos delimitar los valores útiles.
def masked_softmax(X, valid_lens):
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X
    
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [6]:
masked_softmax(torch.rand(2, 2, 5), torch.tensor([2, 4]))

tensor([[[0.4343, 0.5657, 0.0000, 0.0000, 0.0000],
         [0.4381, 0.5619, 0.0000, 0.0000, 0.0000]],

        [[0.1422, 0.2521, 0.2758, 0.3299, 0.0000],
         [0.2163, 0.2127, 0.2418, 0.3292, 0.0000]]])

In [9]:
masked_softmax(torch.rand(2, 2, 5), torch.tensor([[1, 5], [2, 4]]))

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2339, 0.1292, 0.3221, 0.1390, 0.1758]],

        [[0.5080, 0.4920, 0.0000, 0.0000, 0.0000],
         [0.1828, 0.1550, 0.2761, 0.3861, 0.0000]]])

In [11]:
masked_softmax(torch.rand(2, 2, 5), torch.tensor([[1, 1], [2, 4]]))

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.5090, 0.4910, 0.0000, 0.0000, 0.0000],
         [0.3383, 0.3431, 0.1511, 0.1676, 0.0000]]])