# Naïve implementation of Transformer cycle Rev
[Link to paper](https://arxiv.org/abs/2104.06022)

In [None]:
!pip install -U spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

# Data sourcing and processing

In [2]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import Multi30k
from typing import Iterable, List
SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

token_transform = {}
vocab_transform = {}

token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
  language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

  for data_sample in data_iter:
    yield token_transform[language](data_sample[language_index[language]])

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
  vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                  min_freq=1,
                                                  specials=special_symbols,
                                                  special_first=True)
  
# Set UNK_IDX as the default index. This index is returned when the token is not found.
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)

100%|██████████| 1.21M/1.21M [00:01<00:00, 741kB/s] 


# Seq2Seq Network using Transformer

In [3]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class PositionalEncoding(nn.Module):
  def __init__(self, embed_size: int, dropout: float, maxlen: int = 5000):
    super(PositionalEncoding, self).__init__()
    den = torch.exp(- torch.arange(0, embed_size, 2) * math.log(10000) / embed_size)
    pos = torch.arange(0, maxlen).reshape(maxlen, 1)
    pos_embedding = torch.zeros((maxlen, embed_size))
    pos_embedding[:, 0::2] = torch.sin(pos * den)
    pos_embedding[:, 1::2] = torch.cos(pos * den)
    pos_embedding = pos_embedding.unsqueeze(-2)

    self.dropout = nn.Dropout(dropout)
    self.register_buffer('pos_embedding', pos_embedding)

  def forward(self, token_embedding: Tensor):
    return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

In [5]:
class TokenEmbedding(nn.Module):
  def __init__(self, vocab_size: int, embed_size: int):
    super(TokenEmbedding, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.embed_size = embed_size

  def forward(self, tokens: Tensor):
    return self.embedding(tokens.long()) * math.sqrt(self.embed_size)

In [6]:
class Seq2SeqTransformerShared(nn.Module):
  def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
               embed_size: int, nhead: int, src_vocab_size: int,
               tgt_vocab_size: int, dim_ffn: int = 512, dropout: float = 0.1):
    super(Seq2SeqTransformerShared, self).__init__()

    # encoder
    self.encoder = nn.TransformerEncoder(
        nn.TransformerEncoderLayer(
            d_model=embed_size,
            nhead=nhead,
            dim_feedforward=dim_ffn,
            dropout=dropout),
        num_layers=num_encoder_layers)
    self.encoder.layers.extend(self.encoder.layers[::-1])

    # decoder
    self.decoder = nn.TransformerDecoder(
        nn.TransformerDecoderLayer(
            d_model=embed_size,
            nhead=nhead,
            dim_feedforward=dim_ffn,
            dropout=dropout),
        num_layers=num_decoder_layers)
    self.decoder.layers.extend(self.decoder.layers[::-1])

    self.generator = nn.Linear(embed_size ,tgt_vocab_size)
    self.src_tok_embed = TokenEmbedding(src_vocab_size, embed_size)
    self.tgt_tok_embed = TokenEmbedding(tgt_vocab_size, embed_size)
    self.pos_encode = PositionalEncoding(embed_size, dropout=dropout)

  def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor, tgt_mask: Tensor,
              src_padding_mask: Tensor, tgt_padding_mask: Tensor,
              memory_key_padding_mask: Tensor):
    src_embed = self.pos_encode(self.src_tok_embed(src))
    tgt_embed = self.pos_encode(self.tgt_tok_embed(tgt))

    memory = self.encoder(src_embed, mask=src_mask, src_key_padding_mask=src_padding_mask)
    outs = self.decoder(tgt_embed, memory, tgt_mask=tgt_mask, memory_mask=None,
                          tgt_key_padding_mask=tgt_padding_mask,
                          memory_key_padding_mask=memory_key_padding_mask)
    return self.generator(outs)

  def encode(self, src: Tensor, src_mask: Tensor):
    return self.encoder(self.pos_encode(self.src_tok_embed(src)),
                                    src_mask)
    
  def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
    return self.decoder(self.pos_encode(self.tgt_tok_embed(tgt)),
                                    memory, tgt_mask)

In [7]:
def generate_square_subsequent_mask(sz):
  mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
  mask = mask.float().masked_fill(mask==0, float('-inf')).masked_fill(mask==1, float(0.0))
  return mask

def create_mask(src, tgt):
  src_seq_len = src.shape[0]
  tgt_seq_len = tgt.shape[0]

  tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
  src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

  src_padding_mask = (src == PAD_IDX).transpose(0, 1)
  tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
  return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [8]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMBED_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformerShared(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMBED_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
for p in transformer.parameters():
  if p.dim() > 1:
    nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001,
                             betas=(0.9, 0.98), eps=1e-9)

# Collation

In [9]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
  def func(txt_input):
    for transform in transforms:
      txt_input = transform(txt_input)
    return txt_input
  return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
  return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(token_ids),
                    torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  text_transform[ln] = sequential_transforms(token_transform[ln], # Tokenization
                                             vocab_transform[ln], # Numericalization
                                             tensor_transform) # Add BOS/EOS and create tensor

# function to collate data samples into batch tesors
def collate_fn(batch):
  src_batch, tgt_batch = [], []
  for src_sample, tgt_sample in batch:
    src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
    tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))
  
  src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
  tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
  return src_batch, tgt_batch

In [10]:
from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
  model.train()
  losses = 0
  train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
  train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

  for src, tgt in train_dataloader:
    src = src.to(DEVICE)
    tgt = tgt.to(DEVICE)

    tgt_input = tgt[:-1, :]
    src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

    logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask,
                   src_padding_mask)
    optimizer.zero_grad()
    tgt_out = tgt[1:, :]
    loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
    loss.backward()

    optimizer.step()
    losses += loss.item()

  return losses / len(train_dataloader)

def evaluate(model):
  model.eval()
  losses = 0

  val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
  val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

  for src, tgt in val_dataloader:
    src = src.to(DEVICE)
    tgt = tgt.to(DEVICE)

    tgt_input = tgt[:-1, :]
    src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

    logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask,
                   src_padding_mask)
    tgt_out = tgt[1:, :]
    loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
    losses += loss.item()

  return losses / len(val_dataloader)

In [None]:
from timeit import default_timer as timer
NUM_EPOCHS = 18

for epoch in range(1, NUM_EPOCHS+1):
  start_time = timer()
  train_loss = train_epoch(transformer, optimizer)
  end_time = timer()
  val_loss = evaluate(transformer)
  print((f'Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}. '
        f'Epoch time: {(end_time-start_time):.3f}s'))

100%|██████████| 46.3k/46.3k [00:00<00:00, 228kB/s]


Epoch: 1, Train loss: 5.623, Val loss: 4.454. Epoch time: 129.605s
Epoch: 2, Train loss: 4.091, Val loss: 3.739. Epoch time: 129.445s
Epoch: 3, Train loss: 3.608, Val loss: 3.448. Epoch time: 129.284s
Epoch: 4, Train loss: 3.310, Val loss: 3.226. Epoch time: 129.573s
Epoch: 5, Train loss: 3.065, Val loss: 3.030. Epoch time: 129.563s
Epoch: 6, Train loss: 2.823, Val loss: 2.848. Epoch time: 129.502s
Epoch: 7, Train loss: 2.621, Val loss: 2.699. Epoch time: 129.720s
Epoch: 8, Train loss: 2.436, Val loss: 2.577. Epoch time: 129.465s
Epoch: 9, Train loss: 2.275, Val loss: 2.447. Epoch time: 129.301s
Epoch: 10, Train loss: 2.124, Val loss: 2.352. Epoch time: 129.329s
Epoch: 11, Train loss: 1.987, Val loss: 2.283. Epoch time: 129.352s
Epoch: 12, Train loss: 1.862, Val loss: 2.219. Epoch time: 129.361s
Epoch: 13, Train loss: 1.756, Val loss: 2.167. Epoch time: 129.486s
Epoch: 14, Train loss: 1.652, Val loss: 2.115. Epoch time: 129.384s
Epoch: 15, Train loss: 1.556, Val loss: 2.080. Epoch time

In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
  src = src.to(DEVICE)
  src_mask = src_mask.to(DEVICE)

  memory = model.encode(src, src_mask)
  ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
  for i in range(max_len-1):
    memory = memory.to(DEVICE)
    tgt_mask = (generate_square_subsequent_mask(ys.size(0))).type(torch.bool).to(DEVICE)
    out = model.decode(ys, memory, tgt_mask)
    out = out.transpose(0, 1)
    prob = model.generator(out[:, -1])
    _, next_word = torch.max(prob, dim=1)
    next_word = next_word.item()

    ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
    if next_word == EOS_IDX:
      break
  return ys

def translate(model: torch.nn.Module, src_sentence: str):
  model.eval()
  src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
  num_tokens = src.shape[0]
  src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
  tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens+5,
                             start_symbol=BOS_IDX).flatten()
  return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>","").replace("<eos>", "")

In [None]:
print(translate(transformer, "Maurer bauen eine Ward ."))

 Workers are pouring wood . 
