To perform this operation, we first resize our `attention_mask` tensor:

In [7]:
attention_mask = tokens['attention_mask']
attention_mask.shape

torch.Size([4, 128])

In [8]:
mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
mask.shape

torch.Size([4, 128, 768])

In [9]:
mask

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 

Each vector above represents a single token attention mask - each token now has a vector of size 768 representing it's *attention_mask* status. Then we multiply the two tensors to apply the attention mask:

In [11]:
masked_embeddings = embeddings * mask
masked_embeddings.shape

torch.Size([4, 128, 768])

In [12]:
masked_embeddings

tensor([[[-0.0692,  0.6230,  0.0354,  ...,  0.8033,  1.6314,  0.3281],
         [ 0.0367,  0.6842,  0.1946,  ...,  0.0848,  1.4747, -0.3008],
         [-0.0121,  0.6543, -0.0727,  ..., -0.0326,  1.7717, -0.6812],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.0000]],

        [[-0.3212,  0.8251,  1.0554,  ..., -0.1855,  0.1517,  0.3937],
         [-0.7146,  1.0297,  1.1217,  ...,  0.0331,  0.2382, -0.1563],
         [-0.2352,  1.1353,  0.8594,  ..., -0.4310, -0.0272, -0.2968],
         ...,
         [-0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.0000],
         [-0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.0000],
         [-0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.0000]],

        [[-0.7576,  0.8399, -0.3792,  ...,  0.1271,  1.2514,  0.1365],
         [-0.6591,  0.7613, -0.4662,  ...,  0

Then we sum the remained of the embeddings along axis `1`:

In [13]:
summed = torch.sum(masked_embeddings, 1)
summed.shape

torch.Size([4, 768])

Then sum the number of values that must be given attention in each position of the tensor:

In [14]:
summed_mask = torch.clamp(mask.sum(1), min=1e-9)
summed_mask.shape

torch.Size([4, 768])

In [15]:
summed_mask

tensor([[15., 15., 15.,  ..., 15., 15., 15.],
        [22., 22., 22.,  ..., 22., 22., 22.],
        [15., 15., 15.,  ..., 15., 15., 15.],
        [14., 14., 14.,  ..., 14., 14., 14.]])

Finally, we calculate the mean as the sum of the embedding activations `summed` divided by the number of values that should be given attention in each position `summed_mask`:

In [16]:
mean_pooled = summed / summed_mask

In [17]:
mean_pooled

tensor([[ 0.0745,  0.8637,  0.1795,  ...,  0.7734,  1.7247, -0.1803],
        [-0.3715,  0.9729,  1.0840,  ..., -0.2552, -0.2759,  0.0358],
        [-0.5030,  0.7950, -0.1240,  ...,  0.1441,  0.9704, -0.1791],
        [-0.2131,  1.0175, -0.8833,  ...,  0.7371,  0.1947, -0.3011]],
       grad_fn=<DivBackward0>)