In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math,copy

In [3]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [99]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
#     print(math.sqrt(d_k))
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)

    if dropout is not None:
        p_attn = dropout(p_attn)
    print(torch.matmul(p_attn, value))
    return torch.matmul(p_attn, value), p_attn

In [100]:
def self_attention(query, key, value, mask=None, dp=None):
    """
    :param query: Query tensor (batch x heads x seq_len x d_k)
    :param key: Key tensor (batch x heads x seq_len x d_k)
    :param value: Value tensor (batch x heads x seq_len x d_k)
    :param mask: Optional mask, same for all heads (batch x heads x seq_len x seq_len)
    :param dp: Dropout layer
    :return: output, scores (batch x heads x seq_len x d_k), (batch x heads x seq_len x seq_len)
    """
#     print(math.sqrt(key.shape[-1]))
    logits = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(key.shape[-1])

    if mask is not None:
        logits = logits.masked_fill(mask==0, -1e9)  # NOT 1e-9. Softmax(1e-9) is still 1.
    scores = F.softmax(logits, dim=-1)

    if dp is not None:
        scores = dp(scores)
    print(torch.matmul(scores, value))
    return torch.matmul(scores, value), scores

In [110]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.0):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        for p in self.parameters():
            torch.nn.init.constant_(p, 1.0)

            
    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]
        # 2) Apply attention on all the projected vectors in batch.

        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [102]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, heads, hidden_size, drop_prob=0.):
        """
        :param heads: Number of attention heads to use
        :param hidden_size: Dimension of input/output vectors
        :param drop_prob: Dropout rate
        """
        super(MultiHeadSelfAttention, self).__init__()

        assert hidden_size % heads == 0, "hidden_size not a multiple of heads"

        self.d_k = hidden_size // heads
        self.heads = heads
        self.Linears = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(4)])
        self.attn = None
        self.dropout = nn.Dropout(p=drop_prob)
        for p in self.parameters():
            torch.nn.init.constant_(p, 1.0)

    def forward(self, q, k, v, mask=None):
        """
        :param q: Query tensor (batch_size x seq_len x hidden_size)
        :param k: Key tensor (batch_size x seq_len x hidden_size)
        :param v: Value tensor (batch_size x seq_len x hidden_size)
        :param mask: Optional mask (batch_size x seq_len x seq_len)
        :return: o: output tensor (batch_size x seq_len x hidden_size)
        """
        batch_size = q.shape[0]

        if mask is not None:
            mask = mask.unsqueeze(1)  # (batch_size x 1 x seq_len x seq_len)

        # Get the Q, K, V in multiple-heads form after linear layers
        q, k, v = [l(x).view(batch_size, -1, self.heads, self.d_k).transpose(1, 2)
                   for l, x in zip(self.Linears, (q, k, v))]

        o, self.attn = self_attention(q, k, v, mask, self.dropout)  # (batch_size, heads, seq_len, d_k)

        o = o.transpose(1, 2).contiguous().view(batch_size, -1, self.heads*self.d_k)

        return self.Linears[-1](o)

In [113]:
MHA=MultiHeadedAttention(4,12)
print(MHA)

MultiHeadedAttention(
  (linears): ModuleList(
    (0): Linear(in_features=12, out_features=12, bias=True)
    (1): Linear(in_features=12, out_features=12, bias=True)
    (2): Linear(in_features=12, out_features=12, bias=True)
    (3): Linear(in_features=12, out_features=12, bias=True)
  )
  (dropout): Dropout(p=0.0)
)


In [114]:
MHSA=MultiHeadSelfAttention(4,12)
print(MHSA)

MultiHeadSelfAttention(
  (Linears): ModuleList(
    (0): Linear(in_features=12, out_features=12, bias=True)
    (1): Linear(in_features=12, out_features=12, bias=True)
    (2): Linear(in_features=12, out_features=12, bias=True)
    (3): Linear(in_features=12, out_features=12, bias=True)
  )
  (dropout): Dropout(p=0.0)
)


In [115]:
x = torch.randn((1,4,12))

In [116]:
MHSA(x,x,x)

tensor([[[[ 2.4557,  2.4557,  2.4557],
          [ 2.5077,  2.5077,  2.5077],
          [-6.6023, -6.6023, -6.6023],
          [ 2.5013,  2.5013,  2.5013]],

         [[ 2.4557,  2.4557,  2.4557],
          [ 2.5077,  2.5077,  2.5077],
          [-6.6023, -6.6023, -6.6023],
          [ 2.5013,  2.5013,  2.5013]],

         [[ 2.4557,  2.4557,  2.4557],
          [ 2.5077,  2.5077,  2.5077],
          [-6.6023, -6.6023, -6.6023],
          [ 2.5013,  2.5013,  2.5013]],

         [[ 2.4557,  2.4557,  2.4557],
          [ 2.5077,  2.5077,  2.5077],
          [-6.6023, -6.6023, -6.6023],
          [ 2.5013,  2.5013,  2.5013]]]], grad_fn=<UnsafeViewBackward>)


tensor([[[ 30.4684,  30.4684,  30.4684,  30.4684,  30.4684,  30.4684,  30.4684,
           30.4684,  30.4684,  30.4684,  30.4684,  30.4684],
         [ 31.0919,  31.0919,  31.0919,  31.0919,  31.0919,  31.0919,  31.0919,
           31.0919,  31.0919,  31.0919,  31.0919,  31.0919],
         [-78.2277, -78.2277, -78.2277, -78.2277, -78.2277, -78.2277, -78.2277,
          -78.2277, -78.2277, -78.2277, -78.2277, -78.2277],
         [ 31.0157,  31.0157,  31.0157,  31.0157,  31.0157,  31.0157,  31.0157,
           31.0157,  31.0157,  31.0157,  31.0157,  31.0157]]],
       grad_fn=<AddBackward0>)

In [117]:
MHA(x,x,x)

tensor([[[[ 2.4557,  2.4557,  2.4557],
          [ 2.5077,  2.5077,  2.5077],
          [-6.6023, -6.6023, -6.6023],
          [ 2.5013,  2.5013,  2.5013]],

         [[ 2.4557,  2.4557,  2.4557],
          [ 2.5077,  2.5077,  2.5077],
          [-6.6023, -6.6023, -6.6023],
          [ 2.5013,  2.5013,  2.5013]],

         [[ 2.4557,  2.4557,  2.4557],
          [ 2.5077,  2.5077,  2.5077],
          [-6.6023, -6.6023, -6.6023],
          [ 2.5013,  2.5013,  2.5013]],

         [[ 2.4557,  2.4557,  2.4557],
          [ 2.5077,  2.5077,  2.5077],
          [-6.6023, -6.6023, -6.6023],
          [ 2.5013,  2.5013,  2.5013]]]], grad_fn=<UnsafeViewBackward>)


tensor([[[ 30.4684,  30.4684,  30.4684,  30.4684,  30.4684,  30.4684,  30.4684,
           30.4684,  30.4684,  30.4684,  30.4684,  30.4684],
         [ 31.0919,  31.0919,  31.0919,  31.0919,  31.0919,  31.0919,  31.0919,
           31.0919,  31.0919,  31.0919,  31.0919,  31.0919],
         [-78.2277, -78.2277, -78.2277, -78.2277, -78.2277, -78.2277, -78.2277,
          -78.2277, -78.2277, -78.2277, -78.2277, -78.2277],
         [ 31.0157,  31.0157,  31.0157,  31.0157,  31.0157,  31.0157,  31.0157,
           31.0157,  31.0157,  31.0157,  31.0157,  31.0157]]],
       grad_fn=<AddBackward0>)

In [108]:
o1 = self_attention(x, x, x)

tensor([[[ 0.1588, -0.7460,  0.2738,  0.3562, -0.4570,  0.5218,  0.1263,
           0.4898,  1.0624, -0.3421,  2.0724,  0.5900],
         [-0.3360, -1.6072, -0.8374,  1.1898, -1.4670,  0.6100,  0.5050,
          -0.6781,  1.6863, -1.1524,  1.3355,  1.5912],
         [ 1.5464, -0.5643,  2.1567,  0.7049,  0.4197,  1.7633,  0.7826,
           0.7569,  2.5032,  0.6617,  0.0387, -1.2237],
         [-0.8238, -0.6416,  1.5324,  1.8258, -1.4826,  2.1162, -0.7381,
          -0.4709, -0.6954,  0.4675,  1.2878, -0.1611]]])


In [109]:
o2 = attention(x, x, x)

tensor([[[ 0.1588, -0.7460,  0.2738,  0.3562, -0.4570,  0.5218,  0.1263,
           0.4898,  1.0624, -0.3421,  2.0724,  0.5900],
         [-0.3360, -1.6072, -0.8374,  1.1898, -1.4670,  0.6100,  0.5050,
          -0.6781,  1.6863, -1.1524,  1.3355,  1.5912],
         [ 1.5464, -0.5643,  2.1567,  0.7049,  0.4197,  1.7633,  0.7826,
           0.7569,  2.5032,  0.6617,  0.0387, -1.2237],
         [-0.8238, -0.6416,  1.5324,  1.8258, -1.4826,  2.1162, -0.7381,
          -0.4709, -0.6954,  0.4675,  1.2878, -0.1611]]])
