# Tensorflow Encoder와 Decoder 구현해보기
미흡한 사항
- 텍스트에 대해 vocab_size를 지정하지 않음
- 모델의 masking 처리를 제대로 해야

In [None]:
# 데이터만 keras에서 받아오기
import tensorflow as tf
url = "https://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip"
path = tf.keras.utils.get_file("spa-eng.zip", origin=url, cache_dir="datasets",
                               extract=True)
text = (Path(path).with_name("spa-eng") / "spa.txt").read_text()

In [3]:
text[:100]

'Go.\tVe.\nGo.\tVete.\nGo.\tVaya.\nGo.\tVáyase.\nHi.\tHola.\nRun!\t¡Corre!\nRun.\tCorred.\nWho?\t¿Quién?\nFire!\t¡Fueg'

In [4]:
import numpy as np

text = text.replace("¡", "").replace("¿", "")
pairs = [line.split("\t") for line in text.splitlines()]
np.random.seed(42)  # extra code – ensures reproducibility on CPU
np.random.shuffle(pairs)
sentences_en, sentences_es = zip(*pairs)  # separates the pairs into 2 lists

In [5]:
for i in range(3):
    print(sentences_en[i], "=>", sentences_es[i])

How boring! => Qué aburrimiento!
I love sports. => Adoro el deporte.
Would you like to swap jobs? => Te gustaría que intercambiemos los trabajos?


In [6]:
# 글자모음
print('영어문장 : ', ''.join(sorted(''.join(set(''.join(sentences_en))))))
print('스페인어문장 : ', ''.join(sorted(''.join(set(''.join(sentences_es))))))


영어문장 :   !"$%'+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz°áãèéêóöüč‘’₂€
스페인어문장 :   !"$%&'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz ¨ª«°º»ÁÉÍÓÚáåèéêíñóöúüčśс​—₂€


In [7]:
# puctuation을 제거 후 진행
import string
def delete_puctuation(text, punc=string.punctuation):
    return text.translate(str.maketrans("", "", punc))

In [8]:
punc_en = string.punctuation+'°áãèéêóöüč‘’₂€'
punc_es = string.punctuation+'¨ª«°º»åèêöüč​—₂€'

word_list_en = [delete_puctuation(doc,punc_en).strip().split() for doc in sentences_en]
word_list_es = [delete_puctuation(doc,punc_es).strip().split() for doc in sentences_es]

In [9]:
word_list_es_input = [['<SOS>'] + doc for doc in word_list_es]
word_list_es_output = [doc +['<EOS>'] for doc in word_list_es]

In [10]:
print(''.join(set(''.join(np.concatenate(word_list_en)))))
print(''.join(set(''.join(np.concatenate(word_list_es)))))

0qpkxz58ZNoAwDV6t92IyTOgjmcdnJuBbLF37XMPr1RhKWfGeasQ4HiCvYSlUE
Ú0qÍpzxék5ñNÓ8oíZAwDV6t92IyTÁOgcmjdnBuJbLóF3X7MPr1RhKfWÉGeasQс4iCHvYáSślUúE


In [11]:
# 모든 단어에 인덱스 부여 (??)
en_dict = dict(zip(np.unique(np.concatenate(word_list_en)), range(1, 1+len(np.unique(np.concatenate(word_list_en))))))
es_dict = dict(zip(np.unique(np.concatenate(word_list_es+[['<SOS>','<EOS>']])), range(1, 3+len(np.unique(np.concatenate(word_list_es))))))

In [12]:
en_mapping = lambda x: [en_dict[word] for word in x]
es_mapping = lambda x: [es_dict[word] for word in x]

print(word_list_en[:5])
print([en_mapping(x) for x in word_list_en[:5]])

[['How', 'boring'], ['I', 'love', 'sports'], ['Would', 'you', 'like', 'to', 'swap', 'jobs'], ['My', 'mother', 'did', 'nothing', 'but', 'weep'], ['Croatia', 'is', 'in', 'the', 'southeastern', 'part', 'of', 'Europe']]
[[1478, 4423], [1500, 9485, 13200], [3165, 15243, 9352, 14108, 13686, 9007], [1997, 9989, 6204, 10266, 4663, 14945], [861, 8954, 8650, 13949, 13097, 10672, 10358, 1097]]


In [13]:
def pad_sequence(word_list, max_len, lang='en'):
    try:
        if lang == 'en':
            x = np.asarray(en_mapping(word_list))
        elif lang == 'es':
            x = np.asarray(es_mapping(word_list))
    except:
        print(word_list)
    if len(x) < max_len:
        x = np.concatenate([np.zeros(max_len - len(x)), x])
    return x[len(x) - max_len:]

In [14]:
max_length = 50
vec_list_en = np.array(list(map(lambda x:pad_sequence(x,max_length,'en'),word_list_en)))
vec_list_es_input = np.array(list(map(lambda x:pad_sequence(x,max_length,'es'),word_list_es_input)))
vec_list_es_output = np.array(list(map(lambda x:pad_sequence(x,max_length,'es'),word_list_es_output)))


In [15]:
# decoder의 입력은 <start>로 시작
# decoder의 출력은 <end>로 끝
X_train_enc = vec_list_en[:10000]
X_test_enc = vec_list_en[10000:]
X_train_dec = vec_list_es_input[:10000]
X_test_dec = vec_list_es_input[10000:]
y_train_dec = vec_list_es_output[:10000]
y_test_dec = vec_list_es_output[10000:]

# Encoder

In [16]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.optim import Adam

In [17]:
# sine/cosine positional encoding
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, embed_size, dtype=torch.float32, **kwargs):
        super().__init__(**kwargs)
        p,i = torch.meshgrid(torch.arange(max_len), 2*torch.arange(embed_size//2))
        self.pos_emb = torch.zeros(1, max_len, embed_size)
        self.pos_emb[:, :, 0::2] = torch.sin(p / 10_000 ** (i / embed_size))
        self.pos_emb[:, :, 1::2] = torch.cos(p / 10_000 ** (i / embed_size))
        self.pos_emb = self.pos_emb.type(dtype)

    def forward(self, x):
        _, batch_max_length = x.size()
        return self.pos_emb[:,:batch_max_length, :]
        
class TokenAndPositionEmbedding(nn.Module):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = PositionalEncoding(maxlen, embed_dim)
        # 논문에서처럼 고정된 positional embedding을 사용하지 않고 학습 가능한 embedding을 사용할 경우
        # self.pos_emb = nn.Embedding(maxlen, embed_dim)
        
    def forward(self, x):
        maxlen = x.shape[-1]
        positions = torch.arange(start=0, end=maxlen, dtype=torch.long)
        # 다음 부분이 없으면 position은 model.to(device) 해도 gpu로 넘어가지 않음
        positions = positions.to(x.device)
        positions = positions.unsqueeze(0).expand(x.shape)
        x = self.token_emb(x) + self.pos_emb(positions).to(x.device)
        return x


In [18]:
class EncoderBlock(nn.Module):
    def __init__(self, embed_dim, heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
        )
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    def forward(self, x, src_mask=None, src_key_padding_mask=None):
        # MultiheadAttention의 forward : (query, key, value) -> (output, output_weights)
        # src shape: (seq_len, batch_size, d_model)
        # src_mask shape: (seq_len, seq_len)
        attn_output = self.attention(x, x, x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        attn_output = self.dropout1(attn_output)
        out1 = self.norm1(x + attn_output)
        ff_output = self.ff(out1)
        ff_output = self.dropout2(ff_output)
        out2 = self.norm2(out1 + ff_output)
        return out2

# Decoder
## masking

In [19]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attention1 = nn.MultiheadAttention(embed_dim, heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attention2 = nn.MultiheadAttention(embed_dim, heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
        )
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
    def forward(self, x, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # MultiheadAttention의 forward : (query, key, value) -> (output, output_weights)
        # tgt shape: (seq_len, batch_size, d_model)
        # memory shape: (memory_len, batch_size, d_model), from encoder
        # tgt_mask shape: (seq_len, seq_len)
        # memory_mask shape: (seq_len, memory_len)
        # tgt_key_padding_mask shape: (batch_size, seq_len)
        # memory_key_padding_mask shape: (batch_size, memory_len)
        attn_output1 = self.attention1(x, x, x, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
        attn_output1 = self.dropout1(attn_output1)
        out1 = self.norm1(x + attn_output1)
        attn_output2 = self.attention2(out1, memory, memory,attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask )[0]
        attn_output2 = self.dropout2(attn_output2)
        out2 = self.norm2(out1 + attn_output2)
        ff_output = self.ff(out2)
        ff_output = self.dropout3(ff_output)
        out3 = self.norm3(out2 + ff_output)
        return out3



In [31]:
class Transformer(nn.Module):
    def __init__(self, maxlen, vocab_size_enc, vocab_size_dec, embed_dim, num_heads, ff_dim, encoder_stack=6, decoder_stack=6, dropout=0.1,
                encoder_masking=True, decoder_masking=True, memory_masking=True):
        super().__init__()
        self.encoder_masking = encoder_masking
        self.decoder_masking = decoder_masking
        self.memory_masking = memory_masking
        self.embedding_layer_enc = TokenAndPositionEmbedding(maxlen, vocab_size_enc, embed_dim)
        self.embedding_layer_dec = TokenAndPositionEmbedding(maxlen, vocab_size_dec, embed_dim)
        self.transformer_encoder = nn.ModuleList([EncoderBlock(embed_dim,num_heads,ff_dim) for _ in range(encoder_stack)])
        self.transformer_decoder = nn.ModuleList([DecoderBlock(embed_dim,num_heads,ff_dim) for _ in range(decoder_stack)])
        self.fc = nn.Linear(embed_dim, vocab_size_dec)
        self.dropout = nn.Dropout(dropout)
        

    def forward(self, src, tgt):
        # src_mask = self.make_src_mask(src)
        tgt_mask = self.prediction_masking(tgt)
        src_emb = self.embedding_layer_enc(src)
        tgt_emb = self.embedding_layer_dec(tgt)
        # memory_mask = self.make_src_mask(src)
        if self.encoder_masking==False:
            src_mask = None
        if self.decoder_masking==False:
            tgt_mask = None
        if self.memory_masking==False:
            memory_mask = None
        for i in range(len(self.transformer_encoder)):
            # print(i)
            # print(src_emb.size())
            # print(src_mask.size())
            src_emb = self.transformer_encoder[i](src_emb, src_mask=src_mask)
        for i in range(len(self.transformer_decoder)):
            # print(tgt_mask.shape)
            tgt_emb = self.transformer_decoder[i](tgt_emb, src_emb, tgt_mask=tgt_mask, memory_mask=memory_mask)
        output = self.fc(tgt_emb)
        return output
    
    # def make_src_mask(self, src):
    #     src_mask = (src != 0).unsqueeze(-2)
    #     return src_mask
    # def make_tgt_mask(self, tgt):
    #     tgt_pad_mask = (tgt != 0).unsqueeze(-2)
    #     tgt_len = tgt.shape[1]
    #     tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
    #     tgt_mask = tgt_pad_mask & tgt_sub_mask
    #     return tgt_mask
    def prediction_masking(self, tgt):
        tgt_len = tgt.shape[0]
        tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
        return tgt_sub_mask

In [38]:
vocab_size_enc = 1+len(np.unique(np.concatenate(word_list_en)))
vocab_size_dec = 1+len(np.unique(np.concatenate(word_list_es)))

enc_temp = torch.Tensor(X_train_enc[:16]).long()
dec_temp = torch.Tensor(X_train_dec[:16]).long()
dec_y_temp = torch.Tensor(y_train_dec[:16]).long()
# src shape: (seq_len, batch_size, d_model)
# src_mask shape: (seq_len, seq_len)
transformer = Transformer(maxlen=50, vocab_size_enc=vocab_size_enc, vocab_size_dec=vocab_size_dec, 
                            embed_dim=128, num_heads=8, ff_dim=512, encoder_stack=6,decoder_stack=6, dropout=0.1,
                            encoder_masking=False, decoder_masking=True, memory_masking=False)
output = transformer(enc_temp, dec_temp)

output.shape, output.argmax(2).shape, dec_y_temp.shape

(torch.Size([16, 50, 29015]), torch.Size([16, 50]), torch.Size([16, 50]))

In [28]:
trainDS = TensorDataset(torch.Tensor(X_train_enc).long(), torch.Tensor(X_train_dec).long(), torch.Tensor(y_train_dec).long())
testDS = TensorDataset(torch.Tensor(X_test_enc).long(), torch.Tensor(X_test_dec).long(), torch.Tensor(y_test_dec).long())
trainDL = DataLoader(trainDS, batch_size=64, shuffle=True)
testDL = DataLoader(testDS, batch_size=64, shuffle=False)

In [29]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
# train
# 시간이 오래걸리는 관계로 생략
epochs = 10
transformer = transformer.to(device)

def acc_cal(y_true, y_pred):
    '''
    y_true shape : batch_size, maxlen(50)
    y_pred shape : batch_size, maxlen(50)
    '''
    result_TF = np.where(dec_y_temp!=0,output.argmax(2)==dec_y_temp, np.nan).flatten()
    result_TF = result_TF[~np.isnan(result_TF)]
    return result_TF.sum(), len(result_TF)


for epoch in range(epochs):
    train_all_correct = 0
    train_all_length = 0
    train_loss = 0
    transformer.train()
    for x_enc, x_dec, y_enc in trainDL:
        x_enc, x_dec, y_enc = x_enc.to(device), x_dec.to(device), y_enc.to(device)
        optimizer.zero_grad()
        y_pred = transformer(x_enc,x_dec)
        loss = criterion(y_pred.permute(0,2,1), y_enc)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*x_enc.size(0)
        y_pred_argmax = y_pred.argmax(2)
        train_correct, train_length = acc_cal(y_enc, y_pred_argmax)
        train_all_correct += train_correct
        train_all_length += train_length
    train_acc = train_all_correct/train_all_length
    train_loss /= len(trainDL.dataset)
    print(f'Epoch {epoch+1} of {epochs}')
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%')

    val_all_correct = 0
    val_all_length = 0
    val_loss = 0
    transformer.eval()
    with torch.no_grad():
        for x_enc, x_dec, y_enc in testDL:
            x_enc, x_dec, y_enc = x_enc.to(device), x_dec.to(device), y_enc.to(device)
            y_pred = transformer(x_enc,x_dec)
            loss = criterion(y_pred.permute(0,2,1), y_enc)
            val_loss += loss.item()*x_enc.size(0)
            y_pred_argmax = y_pred.argmax(2)
            val_correct, val_length = acc_cal(y_enc, y_pred_argmax)
            val_all_correct += val_correct
            val_all_length += val_length
    val_acc = val_all_correct/val_all_length
    val_loss /= len(testDL.dataset)
    print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%')
    print('---------------------------------')
