In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import math
import numpy as np

In [3]:
from modules import Conv1d, ConvTranspose1d, Embedding, Linear, GradMultiply
from modules import get_mask_from_lengths, SinusoidalEncoding, Conv1dGLU


In [7]:
help(torch.bmm)

Help on built-in function bmm in module torch._C:

bmm(...)
    bmm(batch1, batch2, out=None) -> Tensor
    
    Performs a batch matrix-matrix product of matrices stored in :attr:`batch1`
    and :attr:`batch2`.
    
    :attr:`batch1` and :attr:`batch2` must be 3D Tensors each containing
    the same number of matrices.
    
    If :attr:`batch1` is a `b x n x m` Tensor, :attr:`batch2` is a `b x m x p`
    Tensor, :attr:`out` will be a `b x n x p` Tensor.
    
    .. note:: This function does not :ref:`broadcast <broadcasting-semantics>`.
              For broadcasting matrix products, see :func:`torch.matmul`.
    
    Args:
        batch1 (Tensor): First batch of matrices to be multiplied
        batch2 (Tensor): Second batch of matrices to be multiplied
        out (Tensor, optional): Output tensor
    
    Example::
    
        >>> batch1 = torch.randn(10, 3, 4)
        >>> batch2 = torch.randn(10, 4, 5)
        >>> res = torch.bmm(batch1, batch2)
        >>> res.size()
        

Scaled dot-product attention

$\mathrm{Attention}(Q, K, V) = \mathrm{softmax}(\frac{QK^\top}{\sqrt{d_k}})V$

source length: $N$, target length: $M$

$\mathrm{softmax_{over-source}}\left\{\left(\begin{array}{c}
\mathbf{q}_1\\
\hline\\
\vdots\\
\hline\\
\mathbf{q}_M\end{array}\right)
\left(\begin{array}{c|c|c}
\mathbf{k}_1 & \dots & \mathbf{k}_N
\end{array}\right)
\right\}
\left(\begin{array}{c}
\mathbf{v}_1\\
\hline\\
\vdots\\
\hline\\
\mathbf{v}_N
\end{array}\right)$

$ = \left(\begin{array}{ccc}
\mathrm{softmax_{row-wise}}\left(\mathbf{q}_1 \cdot \mathbf{k}_1\right) & \dots & \mathrm{softmax_{row-wise}}\left(\mathbf{q}_1 \cdot \mathbf{k}_N\right)\\
\hline\\
\vdots & \ddots & \vdots\\
\hline\\
\mathrm{softmax_{row-wise}}\left(\mathbf{q}_M \cdot \mathbf{k}_1\right) & \dots & \mathrm{softmax_{row-wise}}\left(\mathbf{q}_M \cdot \mathbf{k}_N\right)
\end{array}\right)
\left(\begin{array}{c}
\mathbf{v}_1\\
\hline\\
\vdots\\
\hline\\
\mathbf{v}_N
\end{array}\right)$

$ = \left(\begin{array}{ccc}
\frac{\exp(\mathbf{q}_1 \cdot \mathbf{k}_1)}{\sum_n \exp(\mathbf{q}_1 \cdot \mathbf{k}_n)} & \dots & \frac{\exp(\mathbf{q}_1 \cdot \mathbf{k}_N)}{\sum_n \exp(\mathbf{q}_1 \cdot \mathbf{k}_n)}\\
\hline\\
\vdots & \ddots & \vdots\\
\hline\\
\frac{\exp(\mathbf{q}_M \cdot \mathbf{k}_1)}{\sum_n \exp(\mathbf{q}_M \cdot \mathbf{k}_n)} & \dots & \frac{\exp(\mathbf{q}_M \cdot \mathbf{k}_N)}{\sum_n \exp(\mathbf{q}_M \cdot \mathbf{k}_n)}
\end{array}\right)
\left(\begin{array}{c}
\mathbf{v}_1\\
\hline\\
\vdots\\
\hline\\
\mathbf{v}_N
\end{array}\right)$

$ = \left(\begin{array}{c}
\mathbf{a}_1\\
\hline\\
\vdots\\
\hline\\
\mathbf{a}_M
\end{array}\right)\mathbf{V}$

$ = \left(\begin{array}{c}
\mathbf{a}_1\mathbf{V}\\
\hline\\
\vdots\\
\hline\\
\mathbf{a}_M\mathbf{V}
\end{array}\right)$

$ = \left(\begin{array}{c}
\sum_{n=1}^N a_{1n}\mathbf{v}_n\\
\hline\\
\vdots\\
\hline\\
\sum_{n=1}^N a_{Mn}\mathbf{v}_n
\end{array}\right)$

$\mathbf{a}_m = \left(
\frac{\exp(\mathbf{q}_m \cdot \mathbf{k}_1)}{\sum_n \exp(\mathbf{q}_m \cdot \mathbf{k}_n)} 
, \dots, 
\frac{\exp(\mathbf{q}_m \cdot \mathbf{k}_N)}{\sum_n \exp(\mathbf{q}_m \cdot \mathbf{k}_n)}
\right)$

In [52]:
class AttentionLayer(nn.Module):
    def __init__(self, conv_channels, embed_dim, dropout=0.1, window_ahead=3, window_backward=1, key_projection=True, value_projection=True):
        super(AttentionLayer, self).__init__()
        self.query_projection = Linear(conv_channels, embed_dim)
        if key_projection:
            self.key_projection = Linear(embed_dim, embed_dim)
            # According to the DeepVoice3 paper, intiailize weights to same values
            # TODO: Does this really work well? not sure..
            # > We initialize the fully-connected layer weights used to compute
            #  hidden attention vectors to the same values for the query projection
            #  and the key projection.
            if conv_channels == embed_dim:
                self.key_projection.weight.data = self.query_projection.weight.data.clone()
        else:
            self.key_projection = None
        if value_projection:
            self.value_projection = Linear(embed_dim, embed_dim)
        else:
            self.value_projection = None
        
        self.out_projection = Linear(embed_dim, conv_channels)
        self.dropout = dropout
        self.wiindow_ahead = window_ahead
        self.window_backward = window_backward
        
    def forward(self, query, encoder_out, mask=None, last_attended=None):
        # query: # (B, tgt_len, conv_channels)
        # keys: (B, embed_dim, source_length)
        # values: (B, source_length, embed_dim)
        keys, values = encoder_out
        residual = query
        if self.value_projection is not None:
            # FC(values)
            # values: (B, source_length, embed_dim)
            values = self.value_projection(values)
        # TODO: yes, this is inefficient
        if self.key_projection is not None:
            # FC(keys)
            # keys: (B, source_length, embed_dim)
            keys = self.key_projection(keys.transpose(1,2)).transpose(1,2)
        
        # attention
        # FC(query)
        # x: (B, tgt_len, embed_dim)
        x = self.query_projection(query)
        # dot(query, keys)
        # keys must be a matrix consists of column vectors
        # (B, tgt_len, src_len)
        x = torch.bmm(x, keys) # batch matrix-matrix product
        
        mask_value = -float("inf")
        if mask is not None:
            # inference mask
            mask = mask.view(query.size(0), 1, -1)
            x.data.masked_fill_(mask, mask_value)
        
        if last_attended is not None:
            backward = last_attended - self.window_backward
            if backward > 0:
                x[:, :, :backward] = mask_value
            ahead = last_attend + self.window_ahead
            if ahed < x.size(-1):
                x[:, :, ahead:] = mask_value
                
        # softmax over last dim
        # (B, tgt_len, src_len)
        sz = x.size()
        x = F.softmax(x.view(sz[0] * sz[1], sz[2]), dim=1)
        x = x.view(sz)
        attn_scores = x
        
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # dot(softmax(dot(query, keys)), values)
        x = torch.bmm(x, values)
        
        # scale attention output
        s = values.size(1)
        x = x * (s * math.sqrt(1.0 / s))
        
        # project back
        x = self.out_projection(x)
        x = (x + residual) * math.sqrt(0.5)
        return x, attn_scores

In [51]:
query_length = 4
source_length = 3
embed_dim = 2
conv_channels = 5

attention_layer = AttentionLayer(conv_channels=conv_channels, embed_dim=embed_dim)

query = Variable(torch.ones(1, query_length, conv_channels))
encoder_out = Variable(torch.from_numpy(np.array(range(source_length * embed_dim), dtype=np.float32).reshape(1, source_length, embed_dim)))

attention_layer(query=query, encoder_out=(encoder_out.transpose(1,2), encoder_out))

(Variable containing:
 (0 ,.,.) = 
  -11.6128   0.0791 -14.0917  -3.4851   8.3419
   -1.5029   0.6000  -1.9552  -0.0414   2.0806
  -11.6128   0.0791 -14.0917  -3.4851   8.3419
  -11.6128   0.0791 -14.0917  -3.4851   8.3419
 [torch.FloatTensor of size 1x4x5], Variable containing:
 (0 ,.,.) = 
   0.0828  0.2373  0.6799
   0.0828  0.2373  0.6799
   0.0828  0.2373  0.6799
   0.0828  0.2373  0.6799
 [torch.FloatTensor of size 1x4x3])