In [None]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        # embed size가 256, heads가 8이라면 8*32로 나누어질 수 있다 -> 어떻게 나누어지는가?
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size            # embed_size = 256
        self.heads = heads                      # heads = 8
        self.head_dim = embed_size // heads     # head_dim = 32
                                                # 당연히 head_dim은 integer이므로 assert로 확인한다!
            
        assert (self.head_dim * heads == embed_size), 'Embed size needs to be div by heads'
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys   = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

        """
        * 차원을 head_dim(32)로 나누는 이유는 결국 나중에 Multihead attention으로 전환하기 위해서이다.
        * query / value / key는 모두 같은 차원
        * query = attention의 수혜자 (단어 하나)
        * keys = attention의 대상 (전체 단어 / 어디에 집중할 것이여!)
        * values = softmax 값을 얹는 대상
        * fc_out = heads들을 모두 concat하여 fc_out에 집어넣는다!
        """
        
    
    def forward(self, values, keys, query, mask):
        N = query.shape[0]    # How many inputs we are going to send at the same time
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        # These Three will be correspond to source / target sentence length
        # 단순히 차원수이며, src / trg에 비례하는 이유는 heads 갯수에 따라 달라질 거기 때문에
        
        
        # Split embedding into self.heads pieces
        
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, key_len, self.heads, self.head_dim)

        """
        query는 values / keys와 다르게 복수형이 아니다 -> 즉, 단일한 값이라는 거다.
        
        """
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # n: batch_size
        # q: query_len
        # h: heads
        # d: heads dimension
        # k: key_len
        # einsum: bmm을 좀 더 쉽게 할 수 있따.
        
        """
        matrix multiplication for various dimensions
        queries shape: (N, query_len, heads, heads_dim)
        keys shape: (N, key_len, heads, head_dim)
        energy shape: (N, heads, query_len, key_len)
        --> query_len: target / src sentence , key_len: src sentence
        query len이 얼만큼 key_len에 집중할 것인가?
        """
        if mask is not Noe:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
            
        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        # now multiply with value
        
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out return value = (N, query_len, heads, head_dim)
        
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions
        
        out = self.fc_out(out)
        return out
    
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.Linear()
        