In [43]:
import math
import torchtext
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from collections import Counter
from torchtext.vocab import Vocab
from torch.nn import TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer
import io
import time
import pandas as pd
import numpy as np
import pickle
import sentencepiece as spm
from datasets import load_dataset

torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = load_dataset("Sampuran01/english-japanese")

In [44]:
# Load model directly
# Extract the training data
training_data = dataset['train']

# Separate English and Japanese translations, get first 10000 rows
en = [entry['translation']['en'] for entry in training_data][:10000]
ja = [entry['translation']['ja'] for entry in training_data][:10000]

In [45]:
#Tokenizer
en_tokenizer = spm.SentencePieceProcessor(model_file='./tokenizer/spm.en.nopretok.model')
ja_tokenizer = spm.SentencePieceProcessor(model_file='./tokenizer/spm.ja.nopretok.model')

In [46]:
#Test the tokenizer
print(en[0])
en_tokenizer.encode(en[0], out_type=str)

Yeah, Vincent Hanna.


['▁Yeah', ',', '▁Vin', 'cent', '▁Han', 'na', '.']

In [47]:
#Test the tokenizer
print(ja[0])
ja_tokenizer.encode(ja[0], out_type=str)

- ラウール - ラウールに ヴィンセント・ハンナだ


['▁-',
 '▁',
 'ラ',
 'ウール',
 '▁-',
 '▁',
 'ラ',
 'ウール',
 'に',
 '▁',
 'ヴィ',
 'ン',
 'セント',
 '・',
 'ハン',
 'ナ',
 'だ']

In [48]:
# Build Vocab object from Torchtext
def build_vocab(sentences, tokenizer):
  counter = Counter()
  for sentence in sentences:
    counter.update(tokenizer.encode(sentence, out_type=str))
  return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
ja_vocab = build_vocab(ja, ja_tokenizer)
en_vocab = build_vocab(en, en_tokenizer)

In [49]:
# Build tensors
def data_process(ja, en):
  data = []
  for (raw_ja, raw_en) in zip(ja, en):
    ja_tensor_ = torch.tensor([ja_vocab[token] for token in ja_tokenizer.encode(raw_ja.rstrip("\n"), out_type=str)],
                            dtype=torch.long)
    en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer.encode(raw_en.rstrip("\n"), out_type=str)],
                            dtype=torch.long)
    data.append((ja_tensor_, en_tensor_))
  return data
train_data = data_process(ja, en)

In [50]:
# Create the DataLoader object to be iterated during training
BATCH_SIZE = 32
PAD_IDX = ja_vocab['<pad>']
BOS_IDX = ja_vocab['<bos>']
EOS_IDX = ja_vocab['<eos>']
def generate_batch(data_batch):
  ja_batch, en_batch = [], []
  for (ja_item, en_item) in data_batch:
    ja_batch.append(torch.cat([torch.tensor([BOS_IDX]), ja_item, torch.tensor([EOS_IDX])], dim=0))
    en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
  ja_batch = pad_sequence(ja_batch, padding_value=PAD_IDX)
  en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
  return ja_batch, en_batch
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)

In [51]:
# Text tokens are represented by using token embeddings. Positional encoding is added to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding +
                            self.pos_embedding[:token_embedding.size(0),:])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [52]:
# Sequence-to-sequence Transformer
NHEAD = 8

class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
                 emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
                 dim_feedforward:int = 512, dropout:float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
                tgt_mask: Tensor, src_padding_mask: Tensor,
                tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
                                        tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer_encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer_decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [53]:
# subsequent word mask to stop a target word from attending to its subsequent words.
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# masking source and target padding tokens
def create_mask(src, tgt):
  src_seq_len = src.shape[0]
  tgt_seq_len = tgt.shape[0]

  tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
  src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

  src_padding_mask = (src == PAD_IDX).transpose(0, 1)
  tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
  return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [54]:
SRC_VOCAB_SIZE = len(ja_vocab)
TGT_VOCAB_SIZE = len(en_vocab)
EMB_SIZE = 512
FFN_HID_DIM = 512
BATCH_SIZE = 32
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
NUM_EPOCHS = 16
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,
                                 EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
                                 FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(device)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(
    transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)
def train_epoch(model, train_iter, optimizer):
  model.train()
  losses = 0
  for idx, (src, tgt) in enumerate(train_iter):
      src = src.to(device)
      tgt = tgt.to(device)

      tgt_input = tgt[:-1, :]

      src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

      logits = model(src, tgt_input, src_mask, tgt_mask,
                                src_padding_mask, tgt_padding_mask, src_padding_mask)

      optimizer.zero_grad()

      tgt_out = tgt[1:,:]
      loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
      loss.backward()

      optimizer.step()
      losses += loss.item()
  return losses / len(train_iter)


def evaluate(model, val_iter):
  model.eval()
  losses = 0
  for idx, (src, tgt) in (enumerate(val_iter)):
    src = src.to(device)
    tgt = tgt.to(device)

    tgt_input = tgt[:-1, :]

    src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

    logits = model(src, tgt_input, src_mask, tgt_mask,
                              src_padding_mask, tgt_padding_mask, src_padding_mask)
    tgt_out = tgt[1:,:]
    loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
    losses += loss.item()
  return losses / len(val_iter)



In [64]:
# load saved models
checkpoint = torch.load('model_checkpoint.tar')

# Restore the model state
transformer.load_state_dict(checkpoint['model_state_dict'])

# Restore the optimizer state
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

NUM_EPOCHS = 16

# Start training 
for epoch in range(1, NUM_EPOCHS+1):
  start_time = time.time()
  train_loss = train_epoch(transformer, train_iter, optimizer)
  end_time = time.time()
  print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, "
          f"Epoch time = {(end_time - start_time):.3f}s"))

  checkpoint = torch.load('model_checkpoint.tar')


Epoch: 1, Train loss: 4.080, Epoch time = 418.502s
Epoch: 2, Train loss: 3.915, Epoch time = 401.932s
Epoch: 3, Train loss: 3.756, Epoch time = 460.250s
Epoch: 4, Train loss: 3.603, Epoch time = 451.041s
Epoch: 5, Train loss: 3.452, Epoch time = 374.487s
Epoch: 6, Train loss: 3.303, Epoch time = 377.283s
Epoch: 7, Train loss: 3.145, Epoch time = 369.159s
Epoch: 8, Train loss: 3.001, Epoch time = 378.927s
Epoch: 9, Train loss: 2.854, Epoch time = 366.346s
Epoch: 10, Train loss: 2.705, Epoch time = 364.322s
Epoch: 11, Train loss: 2.560, Epoch time = 365.923s
Epoch: 12, Train loss: 2.423, Epoch time = 368.808s
Epoch: 13, Train loss: 2.280, Epoch time = 363.471s
Epoch: 14, Train loss: 2.144, Epoch time = 364.591s
Epoch: 15, Train loss: 2.004, Epoch time = 366.960s
Epoch: 16, Train loss: 1.871, Epoch time = 363.768s


In [65]:
import pickle
# open a file, where you want to store the data
file = open('en_vocab.pkl', 'wb')
# dump information to that file
pickle.dump(en_vocab, file)
file.close()
file = open('ja_vocab.pkl', 'wb')
pickle.dump(ja_vocab, file)
file.close()

In [66]:
# save model + checkpoint to resume training later
torch.save({
  'epoch': NUM_EPOCHS,
  'model_state_dict': transformer.state_dict(),
  'optimizer_state_dict': optimizer.state_dict(),
  'loss': train_loss,
  }, 'model_checkpoint.tar')

In [67]:
# Translate jap to en using model
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(device)
    src_mask = src_mask.to(device)
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    for i in range(max_len-1):
        memory = memory.to(device)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                                    .type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.item()
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
          break
    return ys
def translate(model, src, src_vocab, tgt_vocab, src_tokenizer):
    model.eval()
    tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX]
    num_tokens = len(tokens)
    src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join([tgt_vocab.itos[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")

In [69]:
print(translate(transformer, ja[2], ja_vocab, en_vocab, ja_tokenizer))
print(en[2])

 ▁It ' s ▁you ! 
It works!


In [70]:
print(translate(transformer, ja[5], ja_vocab, en_vocab, ja_tokenizer))
print(en[5])

 ▁Do ▁not ▁read ▁this ▁kanji ? 
Do not encrypt the data channel
