In [57]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, model_dim, num_heads):
        super(SelfAttention, self).__init__()
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads
        
        assert (self.head_dim * num_heads == model_dim), "Embed size needs to be div by heads!"
        
        # Scaled Dot-Product Attention은 Head로 나누어진다.
        # (N, L, model_dim) -> (N, L, num_heads, head_dim)
        self.embed2Q = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.embed2K = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.embed2V = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(self.model_dim, model_dim)
        
    def forward(self, Q, K, V, mask):
        # Q, K, V: (N, L, d_model)
        N = Q.shape[0]
        L = Q.shape[1]
        # 1. Multi-Head Split
        Q = Q.reshape((N, L, self.num_heads, self.head_dim))
        K = K.reshape((N, L, self.num_heads, self.head_dim))
        V = V.reshape((N, L, self.num_heads, self.head_dim))
        
        Q = self.embed2Q(Q)
        K = self.embed2K(K)
        V = self.embed2V(V)
        # 2. MatMul(Q, K)
        # (N, Q_L, h, d_h) * (N, K_L, h, d_h) -> (N, h, Q_L, K_L)
        energy = torch.einsum("nqhd,nkhd->nhqk", [Q, K])
        # 3. Scale
        energy /= self.head_dim ** (1/2)
        # 4. Mask (opt.)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        # 5. Softmax
        attention = torch.softmax(energy, dim=3)
        # 6. MatMul(attention, V), Concat
        # (N, h, Q_L, K_L) * (N, V_L, h, d_h) -> (N, query_L, h, d_h)
        # (N, query_L, h, d_h) -> (N, L, d_model)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, V]).reshape((N, L, self.model_dim))
        # 8. Linear
        out = self.fc_out(out)
        
        return out

In [None]:
class FFNSubLayer(nn.Module):
    def __init__(self, model_dim, ff_dim, dropout):
        super(FFNSubLayer, self).__init__()
        self.norm = nn.LayerNorm(model_dim)
        self.fc1 = nn.Linear(model_dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, model_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        sublayer_x = self.dropout(self.fc2(self.relu(self.fc1(x))))
        out = self.norm(x + sublayer_x)
        return out


class MultiHeadAttentionSubLayer(nn.Module):
    def __init__(self, model_dim, num_heads):
        super(MultiHeadAttentionSubLayer, self).__init__()
        self.norm = nn.LayerNorm(model_dim)
        self.attention = SelfAttention(model_dim, num_heads)
        
    def forward(self, Q, K, V, mask):
        attention = self.attention(Q, K, V, mask)
        

class Encoder(nn.Module):
