#Transformer

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

##Input Layer

In [None]:
batch_size = 1
seq = 4
word_vector = 10

In [None]:
input = torch.Tensor(batch_size, seq, word_vector) #학습 불가능

In [None]:
input

tensor([[[ 1.6816e-44,  0.0000e+00, -1.0133e+18,  3.2200e-41,  1.5891e-42,
           3.2200e-41, -1.0134e+18,  3.2200e-41,  1.6816e-44,  3.2200e-41],
         [-1.0133e+18,  3.2200e-41,  1.5905e-42,  3.2200e-41, -1.0134e+18,
           3.2200e-41,  1.6816e-44,  3.2200e-41, -1.0134e+18,  3.2200e-41],
         [ 1.5961e-42,  3.2200e-41, -1.0134e+18,  3.2200e-41,  1.6816e-44,
           3.2200e-41, -1.0134e+18,  3.2200e-41,  3.3631e-44,  3.2200e-41],
         [-8.6093e-34,  4.5877e-41,  1.4013e-45,  3.2200e-41,  1.5414e-44,
           1.7830e+19,  1.4013e-45,  1.1210e-44,  1.4013e-45,  1.8217e-44]]])

In [None]:
input_sentence = "i am a dream"
tokenized_input = ["i","am","a","dream"]
input_idx = [1, 34, 7, 45]

In [None]:
input_tensor = torch.tensor(input_idx)

In [None]:
word_emb = nn.Embedding(100,10) #학습 가능

In [None]:
token_emb = word_emb(input_tensor)

##positional encoding

In [None]:
#position_enc = nn.Embedding(MAX_TOKEN, WORD_DIM)

In [None]:
position_idx = [0,1,2,3]
position_tensor = torch.tensor(position_idx)

In [None]:
#position_emb = position_enc(position_tensor)

In [None]:
#final_emb = token_emb+position_emb

In [None]:
#final_emb

In [None]:
class Config():
    max_position_embeddings = 20
    dim_token_emb = 10
    num_dict = 100

In [None]:
config = Config()

In [None]:
class Embedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.num_dict,config.dim_token_emb)
        self.position_embbeddings = nn.Embedding(config.max_position_embeddings, config.dim_token_emb)

    def forward(self, input):
        word_emb = self.token_embeddings(input)
        position_len = input.size()[0]
        position_idx = torch.arange(position_len, dtype = torch.long).unsqueeze(0)
        position_emb = self.position_embbeddings(position_idx)

        embedding = word_emb + position_emb
        return embedding

In [None]:
emb_layer = Embedding(config)

In [None]:
emb = emb_layer(input_tensor)

In [None]:
emb.shape

torch.Size([1, 4, 10])

##Multi-Head Attention

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.query = nn.Linear(config.emb_dim, config.out_dim)
        self.key = nn.Linear(config.emb_dim, config.out_dim)
        self.value = nn.Linear(config.emb_dim, config.out_dim)

    def forward(self, input):
        q = self.query(input)
        k = self.key(input)
        v = self.value(input)

        att_score = F.softmax(torch.bmm(q, k.transpose(1,2)) / math.sqrt(q.size(1)), -1)
        self_att_rst = torch.bmm(att_score, v)
        return self_att_rst

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.s_att = SelfAttention(config)

    def forward(self, input):
        rst_list = []
        for _ in range(config.num_head):
            rst_list.append(self.s_att(input))
        output = torch.concat(rst_list, -1)
        return output

In [None]:
emb_dim = emb.size(2)

In [None]:
num_head = 2

In [None]:
query1 = nn.Linear(emb_dim, emb_dim//num_head)
key1 = nn.Linear(emb_dim, emb_dim//num_head)
value1 = nn.Linear(emb_dim, emb_dim//num_head)

query2 = nn.Linear(emb_dim, emb_dim//num_head)
key2 = nn.Linear(emb_dim, emb_dim//num_head)
value2 = nn.Linear(emb_dim, emb_dim//num_head)

In [None]:
q1 = query1(emb)
k1 = key1(emb)
v1 = value1(emb)

q2 = query2(emb)
k2 = key2(emb)
v2 = value2(emb)

In [None]:
att_score1 = F.softmax(torch.bmm(q1, k1.transpose(1,2)) / math.sqrt(q1.size(1)), -1)
self_att_rst1 = torch.bmm(att_score1, v1)

In [None]:
att_score2 = F.softmax(torch.bmm(q2, k2.transpose(1,2)) / math.sqrt(q2.size(1)), -1)
self_att_rst2 = torch.bmm(att_score2, v2)

In [None]:
self_att_rst1

tensor([[[-0.5290,  0.1782, -0.3014,  0.3114,  0.3776],
         [-0.5913,  0.1063, -0.2973,  0.2388,  0.3592],
         [-0.6100,  0.2929, -0.3398,  0.4293,  0.5349],
         [ 0.0471, -0.1012, -0.0521,  0.0746, -0.2128]]],
       grad_fn=<BmmBackward0>)

In [None]:
self_att_rst2

tensor([[[-0.1992,  0.5415,  0.2673, -0.2996, -0.2124],
         [-0.1681,  0.5240,  0.2687, -0.2511, -0.3574],
         [-0.4314,  0.7083,  0.2692, -0.3083, -0.5256],
         [-0.1382,  0.4809,  0.3331, -0.1889, -0.5401]]],
       grad_fn=<BmmBackward0>)

In [None]:
concat_rst = torch.cat((self_att_rst1, self_att_rst2), -1)

In [None]:
concat_rst.size()

torch.Size([1, 4, 10])

In [None]:
class FF(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.lin1 = nn.Linear(config.emb_dim, config.hid_dim)
        self.act =  nn.ELU()
        self.lin2 = nn.Linear(config.hid_dim, config.emb_dim)

    def forward(self, input):
        input = self.lin1(input)
        input = self.act(input)
        input = self.lin2(input)
        return input

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.norm1 = nn.BatchNorm1d(config.hidden_size)
        self.norm2 = nn.BatchNorm1d(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.ff = FF(config)

    def forward(self, input):
        concat_rst = self.attention(input)
        nor_rst1 = self.norm1(concat_rst + input)
        f_rst = self.ff(nor_rst1)
        nor_rst2 = self.norm2(f_rst + nor_rst1)
        return nor_rst2

In [None]:
add_rst = concat_rst + emb

In [None]:
normal = nn.BatchNorm1d(4)

In [None]:
add_nor_rst = normal(add_rst)

In [None]:
fnn_dim = add_nor_rst.size(-1)

In [None]:
hid_dim = 50

In [None]:
lin1 = nn.Linear(fnn_dim, hid_dim)
act =  nn.ELU()
lin2 = nn.Linear(hid_dim, fnn_dim)

In [None]:
fdd_rst = lin2(act(lin1(add_nor_rst)))

In [None]:
f_add_rst = fdd_rst + add_nor_rst

In [None]:
f_add_nor_rst = normal(f_add_rst)

In [None]:
f_add_nor_rst

tensor([[[-1.2386,  0.5033, -1.4657,  0.3807,  0.3138,  0.1966,  2.0507,
           0.7117, -0.8151, -0.6374],
         [-2.0080,  0.5758, -0.3328, -0.8314, -0.4042,  0.2466,  2.0832,
           0.3920, -0.0581,  0.3370],
         [-0.9015,  0.5577,  0.2818, -0.0715,  2.5628, -0.4882, -0.2587,
          -1.1025,  0.2271, -0.8070],
         [ 0.4371,  0.5421,  0.0398, -1.6878, -0.9826, -1.0605,  2.0865,
           0.4209, -0.0043,  0.2086]]], grad_fn=<NativeBatchNormBackward0>)

#모듈화

In [None]:
class Config():
    max_position_embeddings = 20
    dim_token_emb = 10
    num_dict = 100
    emb_dim = 10
    num_head = 2
    out_dim = emb_dim//num_head
    hidden_size = 4
    hid_dim = 50

In [None]:
class Embedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.num_dict,config.dim_token_emb)
        self.position_embbeddings = nn.Embedding(config.max_position_embeddings, config.dim_token_emb)

    def forward(self, input):
        word_emb = self.token_embeddings(input)
        position_len = input.size()[0]
        position_idx = torch.arange(position_len, dtype = torch.long).unsqueeze(0)
        position_emb = self.position_embbeddings(position_idx)

        embedding = word_emb + position_emb
        return embedding

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.query = nn.Linear(config.emb_dim, config.out_dim)
        self.key = nn.Linear(config.emb_dim, config.out_dim)
        self.value = nn.Linear(config.emb_dim, config.out_dim)

    def forward(self, input):
        q = self.query(input)
        k = self.key(input)
        v = self.value(input)

        att_score = F.softmax(torch.bmm(q, k.transpose(1,2)) / math.sqrt(q.size(1)), -1)
        self_att_rst = torch.bmm(att_score, v)
        return self_att_rst

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.s_att = SelfAttention(config)

    def forward(self, input):
        rst_list = []
        for _ in range(config.num_head):
            rst_list.append(self.s_att(input))
        output = torch.concat(rst_list, -1)
        return output

In [None]:
class FF(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.lin1 = nn.Linear(config.emb_dim, config.hid_dim)
        self.act =  nn.ELU()
        self.lin2 = nn.Linear(config.hid_dim, config.emb_dim)

    def forward(self, input):
        input = self.lin1(input)
        input = self.act(input)
        input = self.lin2(input)
        return input

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.norm1 = nn.BatchNorm1d(config.hidden_size)
        self.norm2 = nn.BatchNorm1d(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.ff = FF(config)

    def forward(self, input):
        concat_rst = self.attention(input)
        nor_rst1 = self.norm1(concat_rst + input)
        f_rst = self.ff(nor_rst1)
        nor_rst2 = self.norm2(f_rst + nor_rst1)
        return nor_rst2

In [None]:
class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.emb = Embedding(config)
        self.trans = TransformerEncoder(config)

    def forward(self, input):
        input = self.emb(input)
        input = self.trans(input)
        return input

#Decoder 스크래치

In [None]:
config =Config()

In [None]:
encoder = Encoder(config)

In [None]:
enc_rst = encoder(input_tensor)

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(dim_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    weights = F.softmax(scores, dim=-1)
    return weights.bmm(value)

In [None]:
emb_layer = Embedding(config)

In [None]:
emb = emb_layer(input_tensor)

In [None]:
emb_dim=emb.size(2)

In [None]:
query = nn.Linear(emb_dim, emb_dim)
key = nn.Linear(emb_dim, emb_dim)
value = nn.Linear(emb_dim, emb_dim)

In [None]:
q = query(emb)
k = key(emb)
v = value(emb)

In [None]:
q.size(-2)

4

In [None]:
import numpy as np

In [None]:
tri = np.tri(4, 4, 0)

In [None]:
tri = torch.tensor(tri)

In [None]:
f_rst= scaled_dot_product_attention(q,k,v, mask = tri)

In [None]:
f_rst.size()

torch.Size([1, 4, 10])

In [None]:
f_rst

tensor([[[ 0.2788, -0.6527,  0.9656, -0.7796, -0.1122,  1.4655, -0.5020,
          -0.1309, -0.8396,  0.0361],
         [ 0.0063,  0.4250,  0.7432,  0.2747,  0.1872, -0.3472, -0.7626,
          -0.6438,  0.2820,  0.6354],
         [-0.0512,  0.4165,  0.7532,  0.5267,  0.2384, -0.7517, -0.6756,
          -0.6845,  0.2947,  0.7588],
         [ 0.1865, -0.3995,  0.7454,  0.1813,  0.0841, -0.2564, -0.0536,
          -0.2857, -0.5558,  0.4451]]], grad_fn=<BmmBackward0>)

In [None]:
att_score = F.softmax(torch.bmm(f_rst, enc_rst.transpose(1,2)) / math.sqrt(f_rst.size(1)), -1)
s_rst = torch.bmm(att_score, enc_rst)

In [None]:
print(s_rst.shape)

torch.Size([1, 4, 10])


In [None]:
l_dim = s_rst.size(-1)

In [None]:
hid_dim = 50

In [None]:
lin1 = nn.Linear(l_dim, hid_dim)
act =  nn.ELU()
lin2 = nn.Linear(hid_dim, l_dim)

In [None]:
l_r = lin2(act(lin1(s_rst)))

In [None]:
normal = nn.BatchNorm1d(4)

In [None]:
l_r.shape

torch.Size([1, 4, 10])

##모듈화

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / math.sqrt(dim_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    weights = F.softmax(scores, dim=-1)
    return weights.bmm(value)

In [None]:
class DSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.query = nn.Linear(config.emb_dim, config.out_dim)
        self.key = nn.Linear(config.emb_dim, config.out_dim)
        self.value = nn.Linear(config.emb_dim, config.out_dim)

    def forward(self, input):
        q = self.query(input)
        k = self.key(input)
        v = self.value(input)

        self_att_rst = scaled_dot_product_attention(q,k,v)
        return self_att_rst

In [None]:
class DMultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.s_att = DSelfAttention(config)

    def forward(self, input):
        rst_list = []
        for _ in range(config.num_head):
            rst_list.append(self.s_att(input))
        output = torch.concat(rst_list, -1)

        return output

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.norm1 = nn.BatchNorm1d(config.hidden_size)
        self.norm2 = nn.BatchNorm1d(config.hidden_size)
        self.norm3 = nn.BatchNorm1d(config.hidden_size)
        self.attention = DMultiHeadAttention(config)
        self.ff = FF(config)

    def forward(self, input, E_output):
        concat_rst = self.attention(input)
        nor_rst1 = self.norm1(concat_rst + input)

        att_score = F.softmax(torch.bmm(nor_rst1, E_output.transpose(1,2)) / math.sqrt(nor_rst1.size(1)), -1)
        s_rst = torch.bmm(att_score, E_output)

        nor_rst2 = self.norm2(nor_rst1 + s_rst)

        f_rst = self.ff(nor_rst2)

        nor_rst3 = self.norm3(f_rst + nor_rst1)
        return nor_rst3

In [None]:
class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.emb = Embedding(config)
        self.trans = TransformerDecoder(config)

    def forward(self, D_input, E_output):
        input = self.emb(D_input)
        input = self.trans(input, E_output)
        return input

In [None]:
decoder = Decoder(config)

In [None]:
decoder(input_tensor, enc_rst)

tensor([[[ 0.1298, -0.2773,  1.2603,  0.2071,  0.1946, -2.6092,  0.0442,
           1.1747,  0.0782, -0.2024],
         [ 1.4810,  1.1300, -0.2403, -1.3343, -1.7008,  0.7645,  0.3807,
           0.4151,  0.0502, -0.9460],
         [-0.4524,  0.4533,  0.9527, -0.2169,  0.6101, -2.1351,  1.5977,
           0.4272, -0.2986, -0.9380],
         [ 2.0670,  0.8960, -0.3170, -0.8146, -1.5359,  1.1211, -0.3260,
          -0.2212, -0.3585, -0.5109]]], grad_fn=<NativeBatchNormBackward0>)