In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#### Multi head attention

In [None]:
class MultiheadAttentionLayer(nn.Module):
    def __init__(self, input_dim, hid_dim, n_heads, dropout, device):
        super().__init__()
        assert hid_dim % n_heads == 0
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim//n_heads
        
        self.fc_q = nn.Linear(input_dim, hid_dim)
        self.fc_k = nn.Linear(input_dim, hid_dim)
        self.fc_v = nn.Linear(input_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
    
    def forward(self, query, key, value, temp=1.0, mask = None):
        # query = [batch_size, query_len, input_dim]
        # key = [batch_size, key_len, input_dim]
        # value = [batch_size, value_len, input_dim]
        batch_size = query.shape[0]
        
        Q = self.fc_q(query) # [bs, query len, hid dim]
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0,2,1,3) # [bs, n heads, query len, head dim]
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0,2,1,3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0,2,1,3)
        
        energy = torch.matmul(Q, K.permute(0,1,3,2)) / self.scale # [bs, n heads, query len, key len]
        if mask is not None:
            energy = energy.masked_fill(mask==0, -1e10)
        
        attention = torch.softmax(energy/temp, dim=-1) # [bs, n heads, query len, key len]
        x = torch.matmul(self.dropout(attention), V) # [bs, query len, n heads, head dim]
        x = x.permute(0,2,1,3).contiguous()
        x = x.view(batch_size, -1, self.hid_dim) # [bs, query len, hid dim]
        x = self.fc_o(x)
        return x, attention

In [None]:
mha = MultiheadAttentionLayer(1, 32, 2, 0, 'cpu')
x = torch.rand((10, 4, 1))

mha(x,x,x,1)[0].shape, mha(x,x,x,1)[1].shape