In [None]:
import torch
from torch import nn

In [None]:
def masked_softmax(X, valid_lens):  #@save
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0): # X: (4, 4), valid_len: [2, 2, 3, 3]
        maxlen = X.size(1) # sequence max length
        mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None] # mask: (4, 4)
        X[~mask] = value
        return X 

    if valid_lens is None:
        return nn.functional.softmax(X, dim=0)
    else:
        shape = X.shape # (2, 2, 4)
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1]) # [2, 2, 3, 3]
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) # X: (4, 4)
        return nn.functional.softmax(X.reshape(shape), dim=-1) # X: (2, 2, 4)

In [None]:
scores = masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
print(scores, scores.shape)

In [None]:
(3, 4, 512) 