<a href="https://colab.research.google.com/github/DavoodSZ1993/Dive_into_Deep_Learning/blob/main/11_3_attention_scoring_functions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install d2l==1.0.0-alpha1.post0 --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.0/93.0 KB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.0/121.0 KB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m83.6/83.6 KB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[?25h

## 11.3 Attention Scoring Functions

In [2]:
import torch
import math
from torch import nn
from d2l import torch as d2l



### 11.3.2 Convenience Functions

#### Masked Softmax Operation

In [5]:
def masked_softmax(X, valid_lens):
  """Perform softmax operation by masking elements on the last axis."""
  # X: 3D tensor, valid_lens: 1D or 2D tensor
  def _sequence_mask(X, valid_len, value=0):
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
                        device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

  if valid_lens is None:
    return nn.functional.softmax(X, dim=1)
  else:
    shape = X.shape
    if valid_lens.dim() ==1:
      valid_lens = torch.repeat_interleave(valid_lens, shape[1])
    else: 
      valid_lens = valid_lens.reshape(-1)

    # On the last axis, replace masked elements with a very large negative value,
    # whose exponentiation outputs 0
    X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
    return nn.functional.softmax(X.reshape(shape), dim=-1)

In [6]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

tensor([[[0.6211, 0.3789, 0.0000, 0.0000],
         [0.6359, 0.3641, 0.0000, 0.0000]],

        [[0.3451, 0.2957, 0.3592, 0.0000],
         [0.2842, 0.2535, 0.4623, 0.0000]]])

In [7]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.3596, 0.3400, 0.3004, 0.0000]],

        [[0.4090, 0.5910, 0.0000, 0.0000],
         [0.2142, 0.3390, 0.2524, 0.1944]]])