In [1]:
import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

## Preparing the Data

In [2]:
from torchtext.data import Field

SRC = Field(tokenize="spacy",
            init_token="<sos>",
            eos_token="<eos>",
            lower=True, 
            batch_first=True)

TRG = Field(tokenize="spacy",
            init_token="<sos>",
            eos_token="<eos>",
            lower=True,
            batch_first=True)

In [3]:
from torchtext.datasets import Multi30k

train_data, valid_data, test_data = Multi30k.splits(exts=(".de", ".en"), 
                                                   fields=(SRC, TRG))

SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)

In [4]:
from torchtext.data import BucketIterator

BATCH_SIZE = 256

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    datasets=(train_data, valid_data, test_data), 
    batch_size=BATCH_SIZE, 
    device=device)

## Building the Model

### Encoder

In [5]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, 
                input_dim, 
                hid_dim, 
                n_layers, 
                n_heads, 
                pf_dim, 
                dropout, 
                max_length=100):
        
        super().__init__()
        
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.scale = torch.sqrt(torch.tensor([hid_dim], dtype=torch.float)).to(device)
        
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        
        self.layers = nn.ModuleList([EncoderLayer(hid_dim, 
                                                 n_heads, 
                                                 pf_dim, 
                                                 dropout)
                                    for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src, src_mask):
        # (in)  src: [batch_size, src_len]
        # (in)  src_mask: [batch_size, 1, 1, src_len]
        # (out) src: [batch_size, src_len, hid_dim]
        
        batch_size = src.shape[0]
        src_len = src.shape[1]
        
        # (out) pos: [batch_size, src_len]
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(device)
        
        # (in)  src
        # (in)  pos
        # (out) src: [batch_size, src_len, hid_dim]
        src = self.dropout(self.tok_embedding(src) * self.scale 
                           + self.pos_embedding(pos))
        
        for layer in self.layers:
            src = layer(src, src_mask)
            
        return src

### Encoder Layer

In [6]:
class EncoderLayer(nn.Module):
    def __init__(self, 
                hid_dim, 
                n_heads, 
                pf_dim, 
                dropout):
        
        super().__init__()
        
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hid_dim)
        
    def forward(self, src, src_mask):
        # (in)  src: [batch_size, src_len, hid_dim]
        # (in)  src_mask: [batch_size, 1, 1, src_len]
        # (out) src: [batch_size, src_len, hid_dim]
        
        _src, _ = self.self_attention(src, src, src, src_mask)
        
        src = self.layer_norm(src + self.dropout(_src))
        
        _src = self.positionwise_feedforward(src)
        
        src = self.layer_norm(src + self.dropout(_src))
        
        return src

### Multi Head Attention Layer

In [7]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout):
        
        super().__init__()
        
        assert hid_dim % n_heads == 0
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.tensor([self.head_dim], dtype=torch.float)).to(device)
        
    def forward(self, query, key, value, mask=None):
        # (in)  qeury: [batch_size, query_len, hid_dim]
        # (in)  key: [batch_size, query_len, hid_dim]
        # (in)  value: [batch_size, query_len, hid_dim]
        
        batch_size = query.shape[0]
        
        # (in)  query
        # (in)  key
        # (in)  value
        # (out) Q: [batch_size, seq_len, hid_dim]
        # (out) K: [batch_size, seq_len, hid_dim]
        # (out) V: [batch_size, seq_len, hid_dim]
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        # (out) Q: [batch_size, n_heads, query_len, head_dim]
        # (out) K: [batch_size, n_heads, key_len, head_dim]
        # (out) V: [batch_size, n_heads, value_len, head_dim]
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # (in)  Q
        # (in)  K
        # (out) energy: [batch_size, n_heads, query_len, key_len]
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        if mask is not None:
            energy = energy.masked_fill(mask==0, -1e10)
        
        # (in)  energy
        # (out) attention: [batch_size, n_heads, query_len, key_len]
        attention = torch.softmax(energy, dim=-1)
        
        # (in)  attention
        # (out) x: [batch_size, n_heads, query_len, head_dim]
        x = torch.matmul(self.dropout(attention), V)
        
        # (in)  x 
        # (out) x: [batch_size, query_len, n_heads, head_dim]
        x = x.permute(0, 2, 1, 3).contiguous()
        
        # (in)  x
        # (out) x: [batch_size, query_len, hid_dim]
        x = x.view(batch_size, -1, self.hid_dim)
        
        # (in)  x
        # (out) x: [batch_size, query_len, hid_dim]
        x = self.fc_o(x)
        
        return x, attention

### Position-wise Feedforward Layer

In [8]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        
        super().__init__()
        
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        x = self.dropout(torch.relu(self.fc_1(x)))
        
        x = self.fc_2(x)
        
        return x

### Decoder

In [9]:
class Decoder(nn.Module):
    def __init__(self, 
                output_dim, 
                hid_dim, 
                n_layers, 
                n_heads, 
                pf_dim, 
                dropout, 
                max_length=100):
        
        super().__init__()
        
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.scale = torch.sqrt(torch.tensor([hid_dim], dtype=torch.float)).to(device)
        
        self.layers = nn.ModuleList([DecoderLayer(hid_dim, 
                                                 n_heads, 
                                                 pf_dim, 
                                                 dropout)
                                    for _ in range(n_layers)])
        
        self.dropout = nn.Dropout(dropout)
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
                
    def forward(self, trg, enc_src, trg_mask, src_mask):
        # (in)  trg: [batch_size, trg_len]
        # (in)  enc_src: [batch_size, src_len. hid_dim]
        # (in) trg_mask: [batch_size, 1, trg_len, trg_len]
        # (in) src_mask: [batch_size, 1, 1, src_len]
        # (out) output: [batch_size, trg_len, trg_vocab_size]
        # (out) attention: [batch_size, n_heads, trg_len, src_len]
        
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(device)
        
        # (in)  trg
        # (in)  pos
        # (out) trg: [batch_size, trg_len, hid_dim]
        trg = self.dropout(self.tok_embedding(trg) * self.scale + self.pos_embedding(pos))
        
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
            
        # In PyTorch, the softmax operation is contained within our loss function, 
        # so we do not explicitly need to use a softmax layer here.
        output = self.fc_out(trg)
        
        return output, attention

### Decoder Layer

In [10]:
class DecoderLayer(nn.Module):
    def __init__(self, 
                hid_dim, 
                n_heads, 
                pf_dim, 
                dropout):
        
        super().__init__()
        
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hid_dim)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        # (in)  trg: [batch_size, trg_len, hid_dim]
        # (in)  enc_src: [batch_size, src_len, hid_dim]
        # (in)  trg_mask: [batch_size, 1, trg_len, trg_len]
        # (in)  src_mask: [batch_size, 1, 1, src_len]
        # (out) trg: [batch_size, trg_len, hid_dim]
        # (out) attention: [batch_size, n_heads, trg_len, src_len]
        
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        
        trg = self.layer_norm(trg + self.dropout(_trg))
        
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        
        trg = self.layer_norm(trg + self.dropout(_trg))
        
        _trg = self.positionwise_feedforward(trg)
        
        trg = self.layer_norm(trg + self.dropout(_trg))
        
        return trg, attention

### Seq2Seq

In [11]:
class Seq2Seq(nn.Module):
    def __init__(self, 
                encoder, 
                decoder, 
                src_pad_idx, 
                trg_pad_idx):
        
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        
    def forward(self, src, trg):
        # (in)  src: [batch_size, src_len]
        # (in)  trg: [batch_size, trg_len]
        # (out) output: [batch_size, trg_len, trg_vocab_size]
        # (out) attention: [batch_size, dec_heads, trg_len, src_len]
        
        # (in)  src
        # (out) src_mask: [batch_size, 1, 1, src_len]
        src_mask = self.make_src_mask(src)
        # (in)  trg
        # (out) trg_mask: [batch_size, 1, trg_len, trg_len]
        trg_mask = self.make_trg_mask(trg)
        
        # (in)  src
        # (in)  src_mask
        # (out) enc_src: [batch_size, src_len, hid_dim]
        enc_src = self.encoder(src, src_mask)
        
        # (in)  trg
        # (in)  enc_src
        # (in)  trg_mask
        # (in)  src_mask
        # (out) output: [batch_size, trg_len, trg_vocab_size]
        # (out) attention: [batch_size, dec_heads, trg_len, src_len]
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        return output, attention
        
    def make_src_mask(self, src):
        # (in) src: [batch_size, src_len]
        # (out) src_mask: [batch_size, 1, 1, src_len]
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        
        return src_mask
    
    def make_trg_mask(self, trg):
        # (in) trg: [batch_size, trg_len]
        # (out) trg_mask: [batch_size, 1, trg_len, trg_len]
        
        # (in) trg
        # (out) trg_pad_mask: [batch_size, 1, trg_len, 1]
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)
        #?trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        
        trg_len = trg.shape[1]
        # (out) trg_sub_mask: [trg_len, trg_len]
        trg_sub_mask = torch.tril(
            torch.ones((trg_len, trg_len), device=device)).bool()
        
        # (in)  trg_pad_mask
        # (in)  trg_sub_mask
        # (out) trg_mask: [batch_size, 1, trg_len, trg_len]
        trg_mask = trg_pad_mask & trg_sub_mask
        
        return trg_mask

## Training the Seq2Seq Model

In [12]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

In [13]:
enc = Encoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT)

dec = Decoder(OUTPUT_DIM, 
              HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT)

SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX).to(device)

In [14]:
def initialize_weights(m):
    if hasattr(m, "weight") and m.weight.dim() > 1:  # dim <= 1: layernorm
        nn.init.xavier_uniform_(m.weight.data)
        
model.apply(initialize_weights)

Seq2Seq(
  (encoder): Encoder(
    (tok_embedding): Embedding(7873, 256)
    (pos_embedding): Embedding(100, 256)
    (layers): ModuleList(
      (0): EncoderLayer(
        (self_attention): MultiHeadAttentionLayer(
          (fc_q): Linear(in_features=256, out_features=256, bias=True)
          (fc_k): Linear(in_features=256, out_features=256, bias=True)
          (fc_v): Linear(in_features=256, out_features=256, bias=True)
          (fc_o): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (positionwise_feedforward): PositionwiseFeedforwardLayer(
          (fc_1): Linear(in_features=256, out_features=512, bias=True)
          (fc_2): Linear(in_features=512, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      )
      (1): EncoderLayer(
        (sel

In [15]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"The model has {count_parameters(model):,} trainable parameters")

The model has 9,038,853 trainable parameters


### Optimizer

In [16]:
# Note that the learning rate needs to be lower than the default used by Adam or 
# else learning is unstable.
LEARNING_RATE = 0.0005

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

### Criterion

In [17]:
criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

### Training and Validating

In [18]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for batch in iterator:
        
        optimizer.zero_grad()
        
        src = batch.src
        trg = batch.trg
        
        # As we want our model to predict the <eos> token but not have it be an input 
        # into our model we simply slice the <eos> token off the end of the sequence. 
        output, _ = model(src, trg[:,:-1])
            
        output = output.contiguous().view(-1, output.shape[-1])
        # We then calculate our loss using the original trg tensor with the <sos> token 
        # sliced off the front, leaving the <eos> token:
        trg = trg[:,1:].contiguous().view(-1)
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [19]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
        
        for batch in iterator:
            
            src = batch.src
            trg = batch.trg
            
            output, _ = model(src, trg[:, :-1])
            
            output = output.contiguous().view(-1, output.shape[-1])
            trg = trg[:, 1:].contiguous().view(-1)
            loss = criterion(output, trg)
            
            epoch_loss += loss.item()
            
    return epoch_loss / len(iterator)

In [20]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time % 60)
    
    return elapsed_mins, elapsed_secs

In [21]:
import time
import math

N_EPOCHS = 20
CLIP = 1

best_valid_loss = float("inf")

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
#     if valid_loss < best_valid_loss:
#         best_valid_loss = valid_loss
#         torch.save(model.state_dict(), 'transformer.pt')
        
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

tensor([   2,  164,  237,    6,   26,  208,   42,   23, 2206,   66,   31,   11,
          25,  138,    5,    3,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1])
tensor([[ 0.4260,  0.3596,  0.0826,  ...,  0.0540, -0.6501, -0.0638],
        [ 0.4659,  0.0971,  0.1377,  ...,  0.3734, -0.3828,  0.0092],
        [ 0.7424,  0.0877, -0.2175,  ...,  0.5225, -0.3839,  0.0826],
        ...,
        [ 0.4996,  0.0581, -0.0723,  ...,  0.2517, -0.0008,  0.0458],
        [ 0.5585,  0.0750,  0.0469,  ...,  0.3147, -0.5993,  0.0691],
        [ 0.6011, -0.0545, -0.1059,  ...,  0.0806, -0.3568, -0.1677]],
       grad_fn=<SelectBackward>)
tensor([   2,    4,   39, 1139,  557,  697,   10, 2000,  548,  245,   28,  489,
          27,  511,    5,    3,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1])
tensor([[ 0.3045,  0.5706,  0.5888,  ..., -0.4267, -0.4822,  0.1649],
        [ 0.2779,  0.3809,  0.5398,  ..., -0.3839, -0.3013,  0.0699],
 

KeyboardInterrupt: 