https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb

https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext
from torchtext.legacy.datasets import Multi30k
from torchtext.legacy.data import Field, BucketIterator

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import spacy
import numpy as np

import random
import math
import time

In [2]:
SEED=12

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic=True

In [3]:
spacy_de = spacy.load('de_core_news_sm')
spacy_en = spacy.load("en_core_web_sm")

In [4]:
def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [5]:
SRC=Field(tokenize=tokenize_de, init_token='', eos_token='', lower=True, batch_first=True)
TRG=Field(tokenize=tokenize_en, init_token='', eos_token='', lower=True, batch_first=True)

In [6]:
train_data, valid_data, test_data = Multi30k.splits(exts=('.de', '.en'), fields=(SRC,TRG))

In [7]:
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)

In [8]:
device= torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

In [9]:
BATCH_SIZE=128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits((train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

In [10]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_length=100):
        super().__init__()
            
        self.device=device
        self.tok_embedding=nn.Embedding(input_dim, hid_dim)
        self.pos_embedding=nn.Embedding(max_length, hid_dim)
        
        self.layers=nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device) for _ in range(n_layers)])
        
        self.dropout=nn.Dropout(dropout)
        
        self.scale=torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
    
    def forward(self,src, src_mask):
        
        
        #src=[src len, batch_size]
        #src mask=[batch size, 1, 1, src len]
        
        batch_size=src.shape[0]
        src_len=src.shape[1]
        
        pos= torch.arange(0, src_len).unsqueeze(0).repeat(batch_size,1).to(self.device)
        src= self.dropout((self.tok_embedding(src)*self.scale) + self.pos_embedding(pos))
        
        for layer in self.layers:
            src=layer(src, src_mask)
            
        return src
        
        

In [11]:
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
        super().__init__()
        
        self.self_attn_layer_norm=nn.LayerNorm(hid_dim)
        self.ff_layer_norm=nn.LayerNorm(hid_dim)
        self.self_attention=MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward=PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout=nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        
        _src, _ =self.self_attention(src, src, src, src_mask)
        
        src=self.self_attn_layer_norm(src+self.dropout(_src))
        
        _src=self.positionwise_feedforward(src)
        
        src=self.ff_layer_norm(src+self.dropout(_src))
        
        return src

In [12]:
class MultiheadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()
        
        assert hid_dim % n_heads== 0
        
        self.hid_dim=hid_dim
        self.n_heads=n_heads
        self.head_dim=hid_dim//n_heads
        
        self.fc_q=nn.Linear(hid_dim, hid_dim)
        self.fc_k=nn.Linear(hid_dim, hid_dim)
        self.fc_v=nn.Linear(hid_dim, hid_dim)
        
        self.fc_o=nn.Linear(hid_dim, hid_dim)
        
        self.dropout=nn.Dropout(dropout)
        
        self.scale=torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
        
    def forward(self ,query, key, value, mask=None):
        
        batch_zie=query.shape[0]
        #query = [batch_size, qyery len, hid_dim]
        #key= [batch_size, key len, hid_dim]
        #value=[batch size, value len, hid _dim]
        
        Q=self.fc_q(query)
        K=self.fc_k(key)
        V=self.fc_v(value)
        
        #Q = [batch size, query len, hid _dim]
        #K= [batch size, key len, hid dim]
        #V= [batch size, value len , hid_dim]
        
        Q=Q.view(batch_size, -1, self.n_heads, self.hid_dim ).permute(0,2,1,3)
        K=K.view(batch_size, -1, self.n_heads, self.hid_dim).permute(0,2,1,3)
        V=V.view(batch_size, -1, self.n_heads, self.hid_dim).permute(0,2,1,3)
        
        #Q=[batch_size, n_heads, query len, hid_dim]
        #K=[batch_size, n_heads, Key len, hid_dim]
        #V=[batch_size, n_heads, value len, hid_dim]
        
        energy=torch.matmul(Q, K.permute(0,1,3,2)) / self.scale 
        #  batch와 n_heads는 아마도 작용 x ,[query_len, hid_dim]*[hid_dim, key_len]
        #energy= [batch size, n_heads, query len ,key len]
        
        if mask is not None:
            energy = energy.masked_fill(mask==0, -1e10)
            
        attention=torch.softmax(energy, dim=-1)
        
        #attention= [batch size, n_heads, query len , key len]
        
        x=torch.matmul(self.dropout(attention), V)
        
        #x=[batch size, n_heads, query len, head_dim]
        x=x.permute(0,2,1,3).contiguous() # 메모리주소가 axis=0으로 저장되게 contiguous()를 써서 적용
        #x=[batch_size, n heads, query len, head dim]
        x=x.view(batch_size, -1, self.hid_dim)
        #x=[batch size, query len, hid_dim]
        x=self.fc_o(x)
        #x=[batch size, query len, hid dim]
        return x, attention

In [None]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc_1=nn.Linear(hid_dim, pf_dim)
        self.fc_2=nn.Linear(pf_dim, hid_dim)
        
        self.dropout=nn.Dropout(dropout)
        
    def forward(self, x):
        
        x=self.