In [35]:
import math
import torch
from torch import nn
import torch.nn.functional as F

In [39]:
class MultiHeadAttention(nn.Module):
    def __init__(self, q_dim, k_dim, v_dim, hidden_dim, num_head, dropout):
        super(MultiHeadAttention, self).__init__()
        self.num_head = num_head
        self.hidden_dim = hidden_dim
        self.W_q = nn.Linear(q_dim, hidden_dim)
        self.W_k = nn.Linear(k_dim, hidden_dim)
        self.W_v = nn.Linear(v_dim, hidden_dim)
        self.W_o = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, query, key, value, valid_len = None):
        head_dim = self.hidden_dim // self.num_head
        B, L, _ = query.shape
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        Q = Q.reshape(B, L, self.num_head, head_dim).permute(0,2,1,3)
        K = K.reshape(B, L, self.num_head, head_dim).permute(0,2,1,3)
        V = V.reshape(B, L, self.num_head, head_dim).permute(0,2,1,3)
        scale = Q @ K.transpose(2,3) / (head_dim**0.5)
        if valid_len != None:
            mask = torch.arange(L)[None, None, None, : ] >= valid_len[:, None, None, None]
            scale = scale.masked_fill(mask, -1e6)
        weight = F.softmax(scale, dim = -1) 
        weight = self.dropout(weight)
        out = weight @ V
        O = out.reshape(B, L, self.hidden_dim)
        return self.W_o(O)

In [40]:
hidden_dim = 100
num_head = 5
attention = MultiHeadAttention(hidden_dim, hidden_dim, hidden_dim, hidden_dim, num_head, 0.5)
attention.eval()
batchsize = 2
num_queries = 4
X = torch.ones((batchsize, num_queries, hidden_dim))
valid_len = torch.tensor([3,2])
attention(X, X, X, valid_len).shape


torch.Size([2, 4, 100])

In [8]:
class PositionalEncoding(nn.Module):
    def __init__(self, hidden_dim, dropout, max_len = 1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, hidden_dim))
        X = torch.arange(max_len).reshape(-1,1)/torch.pow(10000, torch.arange(0, hidden_dim, 2)/hidden_dim)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
    def forward(self, X):
        X = X + self.P[:, :X.shape[1],:]
        return self.dropout(X)

In [9]:
encoder_dim = 100
dropout = 0.5
tokens = 50
posEncoding = PositionalEncoding(encoder_dim, dropout)
X = posEncoding(torch.ones(1, tokens, encoder_dim))
