This notebook is an example of how to use a basic dot-product attention to aggregate a variable number of discrete variables. Dot product attention is performed over the embedding vectors for the discrete variables.

For each embedding vector, $\mathbf{e}_i$, compute an unnormalized attention energy by taking the dot product with a model parameter vector $\mathbf{s}$:

$$
\mathbf{u}_i = \mathbf{e}_i^\top \mathbf{s}
$$

Attention energies for each embedding vector are computed at once using a matrix multiplication. Stack the $n$ embedding vectors into a matrix:

$$
\mathbf{E} = 
\begin{bmatrix}
-\mathbf{e}_1- \\
-\mathbf{e}_2- \\
\vdots \\
-\mathbf{e}_n-
\end{bmatrix}
$$

Then, compute a vector of attention energies $\mathbf{u}$:

$$
\mathbf{u} = \mathbf{Es}
$$

Normalize these into a probability distribution:

$$
\mathbf{n} = \text{Softmax}(\mathbf{u})
$$

Finally, compute the aggregated representation as a linear combination of the individual embedding vectors, weighted by the attention energies:

$$
\mathbf{x} = \mathbf{E}^\top \mathbf{n}
$$

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pdb

In [2]:
df = pd.DataFrame({'num_levels': np.random.randint(1, 9, size=(100,))})
df.head()

Unnamed: 0,num_levels
0,4
1,3
2,1
3,2
4,8


In [3]:
df['levels'] = df.num_levels.apply(lambda x: np.random.randint(1,11, size=(x,)))

In [4]:
df.head()

Unnamed: 0,num_levels,levels
0,4,"[8, 10, 4, 7]"
1,3,"[5, 7, 4]"
2,1,[8]
3,2,"[2, 4]"
4,8,"[1, 10, 2, 4, 4, 5, 9, 3]"


In [5]:
class MyDataset(Dataset):
    def __init__(self, df, max_sz=10):
        # each X_i is 1d and has shape (var_len)
        self.X = [torch.tensor(x) for x in df.levels]
        self.max_sz = max_sz
        
    def __len__(self): return len(self.X)
        
    def __getitem__(self, i): 
        x = self.X[i]
        return F.pad(x, (0, self.max_sz-len(x)), value=0)

In [6]:
ds = MyDataset(df)
next(iter(ds))

tensor([ 8, 10,  4,  7,  0,  0,  0,  0,  0,  0])

In [7]:
dl = DataLoader(ds, batch_size=5, shuffle=False)

In [8]:
next(iter(dl))

tensor([[ 8, 10,  4,  7,  0,  0,  0,  0,  0,  0],
        [ 5,  7,  4,  0,  0,  0,  0,  0,  0,  0],
        [ 8,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 2,  4,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 10,  2,  4,  4,  5,  9,  3,  0,  0]])

In [9]:
class AttnBOWAggregator(nn.Module):
    def __init__(self, n_emb, emb_sz):
        super().__init__()
        self.emb = nn.Embedding(n_emb, emb_sz)
        self.attn_nrg = nn.Linear(emb_sz, 1)
        
    def forward(self, x):
        # x [bs, max_levels]
        mask = x == 0  # [bs, max_levels]
        x = self.emb(x)  # [bs, max_levels, emb_sz]
        
        # compute normalized attention energies
        nrgs = self.attn_nrg(x).squeeze(2)  # [bs, max_levels]
        nrgs = nrgs.masked_fill(mask, -1e18)
        nrgs = F.softmax(nrgs, dim=1).unsqueeze(2)  # [bs, max_levels, 1]
        
        x = x.transpose(1,2)  # [bs, emb_sz, max_levels]
        x = torch.bmm(x, nrgs)  # [bs, emb_sz, 1]
        x = x.squeeze(2)  # [bs, emb_sz]
        
        return x

In [10]:
m = AttnBOWAggregator(11, 5)

In [11]:
xb = next(iter(dl))

In [12]:
out = m(xb)
print(out.shape)
out

torch.Size([5, 5])


tensor([[ 0.5052,  0.6485,  0.3360,  0.2359,  0.3457],
        [ 0.6475,  0.5330, -0.4086,  0.9691,  0.7141],
        [ 0.7085, -0.2934,  0.4703, -0.5907, -1.6166],
        [ 1.1590, -0.3336, -0.0345, -0.2408,  0.4665],
        [ 0.3007,  0.4601,  0.2272, -0.0333,  0.0281]],
       grad_fn=<SqueezeBackward1>)