In [1]:
# INITIALIZING LIBRARIES:
import math
import torch
from torch import nn 
from d2l import torch as d2l

In [3]:
# DEFINING MASKED SOFTMAX FUNCTION:
def masked_softmax(X, valid_lens):
    """Perform softmax operation"""
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X
    
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)


# INSPECTION:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

tensor([[[0.4727, 0.5273, 0.0000, 0.0000],
         [0.4991, 0.5009, 0.0000, 0.0000]],

        [[0.2370, 0.3487, 0.4144, 0.0000],
         [0.3451, 0.2260, 0.4289, 0.0000]]])

In [5]:
# INSPECTION:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.2537, 0.4582, 0.2881, 0.0000]],

        [[0.5590, 0.4410, 0.0000, 0.0000],
         [0.1750, 0.3459, 0.1472, 0.3318]]])