In [1]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Preparing the Data

In [2]:
from torchtext.data import Field

SRC = Field(tokenize="spacy",
            init_token="<sos>",
            eos_token="<eos>",
            lower=True, 
            batch_first=True)

TRG = Field(tokenize="spacy",
            init_token="<sos>",
            eos_token="<eos>",
            lower=True,
            batch_first=True)

In [3]:
from torchtext.datasets import Multi30k

train_data, valid_data, test_data = Multi30k.splits(exts=(".de", ".en"), 
                                                   fields=(SRC, TRG))

SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)

In [4]:
from torchtext.data import BucketIterator

BATCH_SIZE = 128

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    datasets=(train_data, valid_data, test_data), 
    batch_size=BATCH_SIZE, 
device=device)

## Building the Model

In [6]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, 
                input_dim, 
                hid_dim, 
                n_layers, 
                n_heads, 
                pf_dim, 
                dropout, 
                max_length=100):
        
        super().__init__()
        
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.scale = torch.sqrt(torch.tensor([hid_dim], dtype=torch.float)).to(device)
        
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                 n_heads, 
                                                 pf_dim, 
                                                 dropout)
                                    for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        
        batch_size = src.shape[0]
        src_len = src.shape[1]
        
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device)
        
        src = self.dropout(self.tok_embedding(src) * self.scale + self.pos_embedding(pos))
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        return src

In [7]:
class EncoderLayer(nn.Module):
    def __init__(self, 
                hid_dim, 
                n_heads, 
                pf_dim, 
                dropout):
        
        super().__init__()
        
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.positionwise_feedforward = PositionwiseFeedforward(hid_dim, pf_dim, dropout)
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hid_dim)
        
    def forward(self, src, src_mask):
        
        _src, _ = self.self_attention(src, src, src, src_mask)
        
        src = self.layer_norm(src + self.dropout(_src))
        
        _src = self.positionwise_feedforward(src)
        
        src = self.layer_norm(src + self.dropout(_src))
        
        return src

In [None]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout):
        
        super().__init__()
        
        assert hid_dim % h_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.tensor([self.head_dim], dtype=torch.float)).to(device)
        
    def forward(self, query, key, value, mask=None):
        
        batch_size = query.shape[0]
        
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
            
        attention = torch.softmax(energy, dim=-1)
        
        x = torch.matmul(self.dropout(attention), V)
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        x = x.view(batch_size, -1, self.hid_dim)
        
        x = self.fc_o(x)
        
        return x, attention