This implementation of RNN for machine translation was adapted from the following:
*   GitHub Repository: [nus-cs4248x](https://github.com/chrisvdweth/nus-cs4248x/blob/master/3-neural-nlp/Section%203.2%20-%20RNN%20Machine%20Translation.ipynb)
*  Author: [chrisvdweth](https://github.com/chrisvdweth)

In [None]:
%pip install transformers[sentencepiece] datasets

In [None]:
%pip install torchtext

In [None]:
%pip install seaborn

In [None]:
%pip install wandb onnx -Uq

In [None]:
from tqdm import tqdm
import torch
import torchtext
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchtext.vocab import vocab

# The following src files are retrieved from https://github.com/chrisvdweth/nus-cs4248x/tree/master/3-neural-nlp/src
from src.rnn import Encoder, Decoder, RnnAttentionSeq2Seq
from src.sampler import BaseDataset, EqualLengthsBatchSampler
from src.utils import Dict2Class, get_line_count, plot_attention_weights

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [None]:
from datasets import load_dataset, load_metric
raw_datasets = load_dataset("wi_locness", 'wi')

from transformers import AutoTokenizer
model_checkpoint = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


add speacial tokens to vocab

In [None]:
vocab = tokenizer.get_vocab()
special_tokens = {'bos_token' : "<s>", 'cls_token' : "<cls>", 'sep_token' : "<sep>"}
tokenizer.add_special_tokens(special_tokens)
vocab = tokenizer.get_vocab()


In [None]:
def preprocess_function(examples):
    inputs = examples['text']
    model_inputs = tokenizer(
        inputs,
        max_length=512,
        truncation=True,
        return_offsets_mapping=True
    )

    labels_out = []
    offset_mapping = model_inputs.pop("offset_mapping")
    for i in range(len(model_inputs["input_ids"])):
        example_idx = i

        start_idx = offset_mapping[i][0][0]
        end_idx = offset_mapping[i][-2][1]  # last token is <eos>, so we care about second last tok offset

        edits = examples["edits"][example_idx]

        corrected_text = inputs[example_idx][start_idx:end_idx]

        for start, end, correction in reversed(
            list(zip(edits["start"], edits["end"], edits["text"]))
        ):
            if start < start_idx or end > end_idx:
                continue
            start_offset = start - start_idx  # >= 0
            end_offset = end - start_idx
            if correction == None:
                correction = tokenizer.unk_token
            corrected_text = (
                corrected_text[:start_offset] + correction + corrected_text[end_offset:]
            )

        labels_out.append(corrected_text)

    labels_out = tokenizer(labels_out, max_length=512, truncation=True)
    model_inputs["labels"] = labels_out["input_ids"]

    return model_inputs

In [None]:
tokenized_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=raw_datasets['train'].column_names
)

# Train-Test split of 90%-10%
dataset_dict = tokenized_datasets["train"].train_test_split(test_size=0.1, seed=0)
tokenized_datasets["train"] = dataset_dict["train"]
tokenized_datasets["test"] = dataset_dict["test"]

X_train = tokenized_datasets["train"]["input_ids"]
Y_train = tokenized_datasets["train"]["labels"]

X_test = tokenized_datasets["test"]["input_ids"]
Y_test = tokenized_datasets["test"]["labels"]

Convert Sequence pairs into list of input and target tensors

In [None]:
len_train = len(X_train)
X_train = [ torch.LongTensor(vec) for vec in X_train]
Y_train = [ torch.LongTensor(vec) for vec in Y_train]

len_validation = len(X_test)
X_validation = [ torch.LongTensor(vec) for vec in X_test]
Y_validation = [ torch.LongTensor(vec) for vec in Y_test]

train_samples = None
validation_samples = None

In [None]:
batch_size = 512

dataset_train = BaseDataset(X_train, Y_train)
sampler_train = EqualLengthsBatchSampler(batch_size, X_train, Y_train)
loader_train = DataLoader(dataset_train, batch_sampler=sampler_train, shuffle=False, drop_last=False)

dataset_test = BaseDataset(X_validation, Y_validation)
sampler_test = EqualLengthsBatchSampler(1, X_validation, Y_validation)
loader_test = DataLoader(dataset_test, batch_sampler=sampler_test, shuffle=False, drop_last=False)

Create Model

In [None]:
params = {
    "device": device,                            # as the decoder also generates sentence it mus be able to move the data to the correct device
    "vocab_size_encoder": len(vocab),        # the size of the source vocabulary determines the input size of the encoder embedding
    "vocab_size_decoder": len(vocab),        # the size of the target vocabulary determines the input size of the decoder embedding
    "embed_size": 300,                           # size of the word embeddings (here the same for encoder and decoder; but not mandatory)
    "rnn_cell": "LSTM",                          # in practice GRU or LSTM will always outperform RNN
    "rnn_hidden_size": 512,                      # size of the hidden state
    "rnn_num_layers": 2,                         # 1 or 2 layers are most common; more rarely sees any benefit
    "rnn_dropout": 0.2,                          # only relevant if rnn_num_layers > 1
    "rnn_encoder_bidirectional": True,           # The encoder can be bidirectional; the decoder can not
    "linear_hidden_sizes": [1024, 2048],         # list of sizes of subsequent hidden layers; can be [] (empty); only relevant for the decoder
    "linear_dropout": 0.2,                       # if hidden linear layers are used, we can also include Dropout; only relevant for the decoder
    "attention": "DOT",                          # Specify if attention should be used; only "DOT" supported; None if no attention
    "teacher_forcing_prob": 0.5,                 # Probability of using Teacher Forcing during training by the decoder
    "special_token_unk": vocab['<unk>'],     # Index of special token <UNK>
    "special_token_sos": vocab['<s>'],     # Index of special token <SOS>
    "special_token_eos": vocab['</s>'],     # Index of special token <EOS>
    "clip": 1.0                                  # Clipping value to limit gradients to prevent exploding gradients
}

# wandb.init(project='gec-baseline-lstm-rnn', config=params)

params = Dict2Class(params)


model = RnnAttentionSeq2Seq(params, nn.CrossEntropyLoss()).to(device)
encoder_optimizer = optim.Adam(model.encoder.parameters(), lr=0.0005)
decoder_optimizer = optim.Adam(model.decoder.parameters(), lr=0.0005)

In [None]:
def train_batch(model, encoder_optimizer, decoder_optimizer, X, Y):
    batch_size, num_steps = X.shape

    loss = model(X, Y)

    # Backpropagation
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), model.encoder.params.clip)
    torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), model.decoder.params.clip)
    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / (num_steps)

def train(model, loader, encoder_optimizer, decoder_optimizer, num_epochs, verbose=False):
    # wandb.watch(model, log="all", log_freq=10)
    # Set model to "train" mode
    model.train()

    print("Total Training Time (total number of epochs: {})".format(num_epochs))
    for epoch in range(1, num_epochs+1):

        # Initialize epoch loss (cummulative loss fo all batchs)
        epoch_loss = 0.0

        with tqdm(total=len(loader)) as progress_bar:

            for X_batch, Y_batch in loader:
                batch_size, seq_len = X_batch.shape[0], X_batch.shape[1]

                # Add EOS token to all sequences in that batch
                eos = torch.LongTensor([model.encoder.params.special_token_eos]*batch_size)
                X_batch = torch.cat((X_batch, eos.reshape(-1, 1)), axis=1)
                Y_batch = torch.cat((Y_batch, eos.reshape(-1, 1)), axis=1)

                # Move the batch to the correct device
                X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)

                # Train batch and get batch loss
                batch_loss = train_batch(model, encoder_optimizer, decoder_optimizer, X_batch, Y_batch)

                # Update epoch loss given als batch loss
                epoch_loss += batch_loss

                # Update progress bar
                progress_bar.update(batch_size)

        if verbose is True:
            print("Loss:\t{:.3f} (epoch {})".format(epoch_loss, epoch))
            epoch_loss_value = round(epoch_loss, 5)
            # wandb.log({"epoch": epoch, "loss": epoch_loss_value})

In [None]:
num_epochs = 50

train(model, loader_train, encoder_optimizer, decoder_optimizer, num_epochs, verbose=True)

total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")

In [None]:
# Save model
action = "save"
#action = "load"
#action = "none"

if action == "save":
    torch.save(model.state_dict(), 'wi-rnn-new.pt')
elif action == 'load':
    model = RnnAttentionSeq2Seq(params, nn.CrossEntropyLoss()).to(device)
    model.load_state_dict(torch.load('wi-rnn-new.pt'))
else:
    pass

Testing the model

In [None]:
def translate(model, inputs, max_len=512):
    # Encode input sequence/sentence
    encoder_outputs, encoder_hidden = model.encoder(inputs)
    # Translate input but generating/predicting the output sequence/sentence
    decoded_indices, attention_weights = model.decoder.generate(encoder_hidden, encoder_outputs, max_len=max_len)
    # Return the translation + the attention weights
    return decoded_indices, attention_weights

In [None]:
model.eval()

In [None]:
for idx, (inputs, targets) in enumerate(loader_test):
    # The input is the first sequence
    inputs = inputs[0:1].to(device)
    # Decode input sequence of indices to sequences of word/tokens
    src_labels = tokenizer.decode(inputs[0].cpu().numpy().tolist())

    # Translate input sequence into predicted target sequence
    decoded_indices, attention_weights = translate(model, inputs)

    # Decode target sequence of indices to sequences of word/tokens
    tgt_labels = tokenizer.decode(decoded_indices)


    print((src_labels))
    print()
    print((tgt_labels))