In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
class Initialized_Conv1d(nn.Module):
    def __init__(self, in_channels, out_channels,
                 kernel_size=1, stride=1, padding=0, groups=1,
                 relu=False, bias=False):
        super().__init__()
        self.out = nn.Conv1d(
            in_channels, out_channels,
            kernel_size, stride=stride,
            padding=padding, groups=groups, bias=bias)
        nn.init.constant_(self.out.weight, 1.)
        if relu is True:
            self.relu = True
#             nn.init.kaiming_normal_(self.out.weight, nonlinearity='relu')
        else:
            self.relu = False
#             nn.init.xavier_uniform_(self.out.weight)

    def forward(self, x):
        if self.relu is True:
            return F.relu(self.out(x))
        else:
            return self.out(x)

In [3]:
def mask_logits(target, mask):
    mask = mask.type(torch.float32)
    return target * mask + (1 - mask) * (-1e30)  # !!!!!!!!!!!!!!!  do we need * mask after target?

In [4]:
class SelfAttention(nn.Module):
    def __init__(self, d_model, num_head, dropout):
        super().__init__()
        self.d_model = d_model
        self.num_head = num_head
        self.dropout = dropout
        self.mem_conv = Initialized_Conv1d(in_channels=d_model, out_channels=d_model*2, kernel_size=1, relu=False, bias=False)
        self.query_conv = Initialized_Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=1, relu=False, bias=False)

        bias = torch.empty(1)
        nn.init.constant_(bias, 0)
        self.bias = nn.Parameter(bias)

    def forward(self, queries, mask):
        """
        queries: B x D x L
        """
        memory = queries

        memory = self.mem_conv(memory)  # B x 2D x L
        query = self.query_conv(queries)  # B x D x L
        memory = memory.transpose(1, 2)  # B x L x 2D
        query = query.transpose(1, 2)  # B x L x D
        Q = self.split_last_dim(query, self.num_head)  # B x L x H//D x H
        K, V = [self.split_last_dim(tensor, self.num_head) for tensor in torch.split(memory, self.d_model, dim=2)]
        # split memory into 2 (B x L x D). Key & value then split into
        key_depth_per_head = self.d_model // self.num_head

        Q *= key_depth_per_head**-0.5
        x = self.dot_product_attention(Q, K, V, mask = mask)
        return self.combine_last_two_dim(x.permute(0,2,1,3)).transpose(1, 2)

    def dot_product_attention(self, q, k ,v, bias = False, mask = None):
        """dot-product attention.
        Args:
        q: a Tensor with shape [batch, heads, length_q, depth_k]
        k: a Tensor with shape [batch, heads, length_kv, depth_k]
        v: a Tensor with shape [batch, heads, length_kv, depth_v]
        bias: bias Tensor (see attention_bias())
        is_training: a bool of training
        scope: an optional string
        Returns:
        A Tensor.
        """
        logits = torch.matmul(q,k.permute(0,1,3,2))
        print(logits)
        if bias:
            logits += self.bias
        if mask is not None:
            shapes = [x  if x != None else -1 for x in list(logits.size())]
            mask = mask.view(shapes[0], 1, 1, shapes[-1])
            logits = mask_logits(logits, mask)

        weights = F.softmax(logits, dim=-1)
        # dropping out the attention links for each of the heads
        weights = F.dropout(weights, p=self.dropout, training=self.training)
        return torch.matmul(weights, v)

    def split_last_dim(self, x, n):
        """Reshape x so that the last dimension becomes two dimensions.
        The first of these two dimensions is n.
        Args:
        x: a Tensor with shape [..., m]
        n: an integer.
        Returns:
        a Tensor with shape [..., n, m/n]
        """
        old_shape = list(x.size())  # B x L x D
        last = old_shape[-1]  # D
        new_shape = old_shape[:-1] + [n] + [last // n if last else None] # (B x L) x H x D//H
        ret = x.view(new_shape) # B x L x H x D//n
        return ret.permute(0, 2, 1, 3)  # B x H x L x D//n

    def combine_last_two_dim(self, x):
        """Reshape x so that the last two dimension become one.
        Args:
        x: a Tensor with shape [..., a, b]
        Returns:
        a Tensor with shape [..., ab]
        """
        old_shape = list(x.size())
        a, b = old_shape[-2:]
        new_shape = old_shape[:-2] + [a * b if a and b else None]
        ret = x.contiguous().view(new_shape)
        return ret

In [5]:
def make_mask(masks, decode=False):
    """
    :param masks: 0 for pad, 1 for non-pad (batch x seq_len)
    :param decode: decoders are Auto-Regressive (can't see future words)
    :return: mask: (batch x seq_len x seq_len / batch x 1 x seq_len)
    """
    masks = masks.unsqueeze(-2)  # Pad words should not be zeroed across their whole rows
    if decode:
        masks = masks & torch.from_numpy(np.tril(np.ones(masks.shape[-1]))).byte()
    return masks.long()

In [6]:
def self_attention(query, key, value, mask=None, dp=None):
    """
    :param query: Query tensor (batch x heads x seq_len x d_k)
    :param key: Key tensor (batch x heads x seq_len x d_k)
    :param value: Value tensor (batch x heads x seq_len x d_k)
    :param mask: Optional mask, same for all heads (batch x heads x seq_len x seq_len)
    :param dp: Dropout layer
    :return: output, scores (batch x heads x seq_len x d_k), (batch x heads x seq_len x seq_len)
    """
    logits = torch.matmul(query/(key.shape[-1]**(-.5)), key.transpose(-1, -2))  # THIS IS WRONG!
    print(logits)
    if mask is not None:
        logits = logits.masked_fill(mask == 0, -1e9)  # NOT 1e-9. Softmax(1e-9) is still 1.

    scores = F.softmax(logits, dim=-1)

    if dp is not None:
        scores = dp(scores)
    return torch.matmul(scores, value), scores

In [7]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, heads, hidden_size, drop_prob=0.):
        """
        :param heads: Number of attention heads to use
        :param hidden_size: Dimension of input/output vectors
        :param drop_prob: Dropout rate
        """
        super(MultiHeadSelfAttention, self).__init__()

        assert hidden_size % heads == 0, "hidden_size not a multiple of heads"

        self.d_k = hidden_size // heads
        self.heads = heads
        self.Linears = nn.ModuleList([nn.Linear(hidden_size, hidden_size, bias=False) for _ in range(3)])
        for Lin in self.Linears:
            nn.init.constant_(Lin.weight, 1.)

        self.attn = None
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, q, k, v, mask=None):
        """
        :param q: Query tensor (batch_size x seq_len x hidden_size)
        :param k: Key tensor (batch_size x seq_len x hidden_size)
        :param v: Value tensor (batch_size x seq_len x hidden_size)
        :param mask: Optional mask (batch_size x seq_len x seq_len)
        :return: o: output tensor (batch_size x seq_len x hidden_size)
        """
        batch_size = q.shape[0]

        if mask is not None:
            mask = mask.unsqueeze(1)  # (batch_size x 1 x seq_len x seq_len)
        
        # Get the Q, K, V in multiple-heads form after linear layers
        q, k, v = [l(x).view(batch_size, -1, self.heads, self.d_k).transpose(1, 2)
                   for l, x in zip(self.Linears, (q, k, v))]

        o, self.attn = self_attention(q, k, v, mask, self.dropout)  # (batch_size, heads, seq_len, d_k)
        
        o = o.transpose(1, 2).contiguous().view(batch_size, -1, self.heads*self.d_k)

        #return self.Linears[-1](o)  # Some dont use this.
        return o

In [8]:
SA = SelfAttention(6, 2, 0.)  # Depth, Heads
MHSA = MultiHeadSelfAttention(2, 6, 0.)  # Heads, Depth

In [9]:
Q = torch.ones((3,4,6))  # B x L x D
M = torch.cat((torch.ones((3,2)), torch.zeros(3,2)), dim=1)  # B x L

In [10]:
SA(Q.transpose(1, 2),M).transpose(1, 2)  # B x L x D

tensor([[[[62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538]],

         [[62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538]]],


        [[[62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538]],

         [[62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538]]],


        [[[62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538],
          [62.3538, 62.3538, 62.3538, 62.3538]],

         [[62.3538, 62.353

tensor([[[6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.]],

        [[6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.]],

        [[6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.]]], grad_fn=<TransposeBackward0>)

In [12]:
MHSA(Q, Q, Q, make_mask(M))  # B x L x D

tensor([[[[187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615]],

         [[187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615]]],


        [[[187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615]],

         [[187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615]]],


        [[[187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615],
          [187.0615, 187.0615, 187.0615, 187.0615],


tensor([[[6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.]],

        [[6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.]],

        [[6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6., 6.]]], grad_fn=<ViewBackward>)

## As seen above, when the scale is wrong, the softmax outputs are still the same. But this is due to non-random inputs. The scale made the dot-product much larger, making the softmax sharp and peaky.

## Why divide by sqrt(d_k)? Q.KT = sum(qk) for i = 1, ..., d_k. Assume Q & K are random variables with 0 mean, 1 variance. E(X+Y) = E(X) + E(Y) = 0. Var(X+-Y) = Var(X) + Var(Y) -2Cov(X, Y), where Cov(X, Y) = 0  if independent. Therefore, Var(Q.KT) = sum(Var(qk) for i = 1, ..., d_k) = d_k. std(Q.KT) = sqrt(d_k)

## The "Convolution" applied by QANet Pytorch is just the same as linear layers. Their kernel size is 1, which means it's a pointwise (1x1) conv, akin to linear layers.

In [13]:
Q = torch.rand((3,4,6))  # B x L x D

In [14]:
SA(Q.transpose(1, 2),M).transpose(1, 2)  # B x L x D

tensor([[[[27.5361, 14.9369, 24.5258, 23.2872],
          [14.9369,  8.1025, 13.3040, 12.6321],
          [24.5258, 13.3040, 21.8446, 20.7414],
          [23.2872, 12.6321, 20.7414, 19.6940]],

         [[27.5361, 14.9369, 24.5258, 23.2872],
          [14.9369,  8.1025, 13.3040, 12.6321],
          [24.5258, 13.3040, 21.8446, 20.7414],
          [23.2872, 12.6321, 20.7414, 19.6940]]],


        [[[16.9138,  8.8713, 17.3691, 16.5678],
          [ 8.8713,  4.6529,  9.1100,  8.6898],
          [17.3691,  9.1100, 17.8366, 17.0138],
          [16.5678,  8.6898, 17.0138, 16.2289]],

         [[16.9138,  8.8713, 17.3691, 16.5678],
          [ 8.8713,  4.6529,  9.1100,  8.6898],
          [17.3691,  9.1100, 17.8366, 17.0138],
          [16.5678,  8.6898, 17.0138, 16.2289]]],


        [[[ 7.6209, 12.1175, 16.4530, 12.8870],
          [12.1175, 19.2673, 26.1609, 20.4909],
          [16.4530, 26.1609, 35.5210, 27.8223],
          [12.8870, 20.4909, 27.8223, 21.7922]],

         [[ 7.6209, 12.117

tensor([[[3.9872, 3.9872, 3.9872, 3.9872, 3.9872, 3.9872],
         [3.9853, 3.9853, 3.9853, 3.9853, 3.9853, 3.9853],
         [3.9872, 3.9872, 3.9872, 3.9872, 3.9872, 3.9872],
         [3.9872, 3.9872, 3.9872, 3.9872, 3.9872, 3.9872]],

        [[3.1245, 3.1245, 3.1245, 3.1245, 3.1245, 3.1245],
         [3.1034, 3.1034, 3.1034, 3.1034, 3.1034, 3.1034],
         [3.1245, 3.1245, 3.1245, 3.1245, 3.1245, 3.1245],
         [3.1244, 3.1244, 3.1244, 3.1244, 3.1244, 3.1244]],

        [[3.3216, 3.3216, 3.3216, 3.3216, 3.3216, 3.3216],
         [3.3343, 3.3343, 3.3343, 3.3343, 3.3343, 3.3343],
         [3.3352, 3.3352, 3.3352, 3.3352, 3.3352, 3.3352],
         [3.3346, 3.3346, 3.3346, 3.3346, 3.3346, 3.3346]]],
       grad_fn=<TransposeBackward0>)

In [15]:
MHSA(Q, Q, Q, make_mask(M))  # B x L x D

tensor([[[[ 82.6084,  44.8108,  73.5774,  69.8617],
          [ 44.8108,  24.3076,  39.9120,  37.8964],
          [ 73.5774,  39.9120,  65.5338,  62.2243],
          [ 69.8617,  37.8964,  62.2243,  59.0819]],

         [[ 82.6084,  44.8108,  73.5774,  69.8617],
          [ 44.8108,  24.3076,  39.9120,  37.8964],
          [ 73.5774,  39.9120,  65.5338,  62.2243],
          [ 69.8617,  37.8964,  62.2243,  59.0819]]],


        [[[ 50.7414,  26.6138,  52.1073,  49.7034],
          [ 26.6138,  13.9588,  27.3301,  26.0693],
          [ 52.1073,  27.3301,  53.5099,  51.0413],
          [ 49.7034,  26.0693,  51.0413,  48.6866]],

         [[ 50.7414,  26.6138,  52.1073,  49.7034],
          [ 26.6138,  13.9588,  27.3301,  26.0693],
          [ 52.1073,  27.3301,  53.5099,  51.0413],
          [ 49.7034,  26.0693,  51.0413,  48.6866]]],


        [[[ 22.8627,  36.3525,  49.3591,  38.6611],
          [ 36.3525,  57.8019,  78.4827,  61.4726],
          [ 49.3591,  78.4827, 106.5630,  83.4668],


tensor([[[3.9872, 3.9872, 3.9872, 3.9872, 3.9872, 3.9872],
         [3.9872, 3.9872, 3.9872, 3.9872, 3.9872, 3.9872],
         [3.9872, 3.9872, 3.9872, 3.9872, 3.9872, 3.9872],
         [3.9872, 3.9872, 3.9872, 3.9872, 3.9872, 3.9872]],

        [[3.1249, 3.1249, 3.1249, 3.1249, 3.1249, 3.1249],
         [3.1249, 3.1249, 3.1249, 3.1249, 3.1249, 3.1249],
         [3.1249, 3.1249, 3.1249, 3.1249, 3.1249, 3.1249],
         [3.1249, 3.1249, 3.1249, 3.1249, 3.1249, 3.1249]],

        [[3.3353, 3.3353, 3.3353, 3.3353, 3.3353, 3.3353],
         [3.3353, 3.3353, 3.3353, 3.3353, 3.3353, 3.3353],
         [3.3353, 3.3353, 3.3353, 3.3353, 3.3353, 3.3353],
         [3.3353, 3.3353, 3.3353, 3.3353, 3.3353, 3.3353]]],
       grad_fn=<ViewBackward>)

## As seen above, when random inputs are fed, the wrongly-scaled attention has similar softmax outputs since the inputs have all been scaled very large.