# 221. Transformers from scratch 

- positional encoding 약식으로 처리 

- 나머지는 모두 구현. 222. supplementary 와 참조하여 볼 것

In [20]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embed size는 heads의 배수로 지정"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys   = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(self.heads * self.head_dim, self.embed_size)

    def forward(self, v, k, q, mask):
        N = q.shape[0] #no. of training examples
        # source or target sentence length 계산
        v_len, k_len, q_len = v.shape[1], k.shape[1], q.shape[1]
        # self.heads 갯수로 v,k,q embedding 분리
        v = v.reshape(N, v_len, self.heads, self.head_dim)
        k = k.reshape(N, k_len, self.heads, self.head_dim)
        q = q.reshape(N, q_len, self.heads, self.head_dim)
        
        values = self.values(v)
        keys   = self.keys(k)
        queries = self.queries(q)
        
        #einstein notation을 사용하여 QK^T 계산
        energy = torch.einsum("nqhd,nkhd->nhqk", [q, k]) 
        # q shape: (N, q_len, heads, self.head_dim) 이고,
        # k shape: (N, k_len, heads, self.head_dim) 이므로
        # energy shape은 (N, heads, q_len, k_len) 이 되도록 내적
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, v]).reshape(
            N, q_len, self.heads * self.head_dim   # concatenate multi-head
        )
        # attention shape: (N, heads, q_len, k_len)
        # value shape: (N, v_len, heads, head_dim)
        # out shape: (N, q_len, heads, head_dim) and flatten last two dimensions

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    # muli_head attention + feed forward
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, v, k, q, mask):
        attention = self.attention(v, k, q, mask)

        x = self.dropout(self.norm1(attention + q))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embed_size, num_layers, heads, device,
                 forward_expansion, dropout, max_length):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, 
                                dropout=dropout, 
                                forward_expansion=forward_expansion)
                for _ in range(num_layers)
            ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)
        
        return out

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, v, k, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(v, k, query, src_mask)
        return out

class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, embed_size, num_layers, heads,
                 forward_expansion, dropout, device, max_length):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList([
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                    for _ in range(num_layers)
        ])

        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)
        
        out = self.fc_out(x)
        return out

class Transformer(nn.Module):
    def __init__(self, src_vocab, trg_vocab, src_pad_idx, trg_pad_idx, 
                 embed_size=256, num_layers=6, forward_expansion=4, heads=8,
                 dropout=0, device=device, max_length=100):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size, embed_size, num_layers, heads, 
                               device, forward_expansion, dropout, max_length)
        self.decoder = Decoder(trg_vocab_size, embed_size, num_layers, heads,
                                forward_expansion, dropout, device, max_length)
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, src_len)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src  = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
    device
)
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(
    device
)

src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(
    device
)
out = model(x, trg[:, :-1])
print(out.shape)

cpu
torch.Size([2, 7, 10])


In [32]:
model.encoder.position_embedding, model.encoder.position_embedding.weight

(Embedding(100, 256),
 Parameter containing:
 tensor([[-0.6555,  1.0932, -0.9335,  ...,  0.3213, -0.3497, -1.2016],
         [ 0.3508,  1.8792, -0.1046,  ...,  0.0624,  0.8227, -0.4507],
         [-1.6179, -0.1419,  1.2900,  ..., -0.0896, -0.3227,  0.3197],
         ...,
         [ 0.8739, -0.2360, -1.9173,  ..., -2.3945, -0.7202, -0.3905],
         [-0.7188, -0.1021,  0.5598,  ..., -0.3301,  2.1073,  0.9309],
         [-0.4889, -0.8939, -1.7693,  ..., -0.4780, -0.2769,  0.1592]],
        requires_grad=True))