# Transformer for Next Character Prediction

### Importing the libraries

In [None]:
# Import necessary libraries and modules for working with neural networks, numerical operations,
# optimization, data handling, and visualization.
import functools
import flax.linen as nn        # Flax library for neural network models
import jax                     # Library for high-performance numerical computing
import jax.numpy as jnp        # JAX version of NumPy for array operations
from matplotlib import pyplot as plt  # Library for plotting and visualization
import optax                   # Optimization library compatible with JAX
import tensorflow_datasets as tfds  # Library for accessing pre-built datasets

# After importing, check the platform on which JAX is running.
# This can help identify whether JAX is utilizing CPU, GPU, or TPU, which is critical
# for understanding performance characteristics and optimizations.
print("JAX running on", jax.devices()[0].platform.upper())

JAX running on GPU


### Parameter settings

In [None]:
# @markdown Random seed:
SEED = 42  # @param{type:"integer"}
# @markdown Learning rate passed to the optimizer:
LEARNING_RATE = 5e-3 # @param{type:"number"}
# @markdown Batch size:
BATCH_SIZE = 128  # @param{type:"integer"}
# @markdown Numer of training iterations:
N_ITERATIONS = 6_000  # @param{type:"integer"}
# @markdown Number of training iterations between two consecutive evaluations:
N_FREQ_EVAL = 2_000 # @param{type:"integer"}
# @markdown Batch size
BATCH_SIZE = 512  # @param{type:"integer"}
# @markdown Rate for dropout in the transformer model
DROPOUT_RATE = 0.2  # @param{type:"number"}
# @markdown Context window for the transformer model
BLOCK_SIZE = 64  # @param{type:"integer"}
# @markdown Number of layer for the transformer model
NUM_LAYERS = 6  # @param{type:"integer"}
# @markdown Size of the embedding for the transformer model
EMBED_SIZE = 256  # @param{type:"integer"}
# @markdown Number of heads for the transformer model
NUM_HEADS = 8  # @param{type:"integer"}
# @markdown Size of the heads for the transformer model
HEAD_SIZE = 32  # @param{type:"integer"}

### Loading the data

In [None]:
ds = tfds.load("tiny_shakespeare")

# combine train and test examples into a single string
text_train = ""
for example in ds["train"].concatenate(ds["test"]).as_numpy_iterator():
  text_train += example["text"].decode("utf-8")

# similarly, create a single string for validation
text_validation = ""
for example in ds["validation"].as_numpy_iterator():
  text_validation += example["text"].decode("utf-8")

### Checking the length of train and validation data

In [None]:
print(f"Length of text for training: {len(text_train):_} characters")
print(f"Length of text for validation: {len(text_validation):_} characters")

Length of text for training: 1_059_624 characters
Length of text for validation: 55_770 characters


### Sample train data

In [None]:
# small sample of the train set
print(text_train[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



### Vocabulary of the entire data

In [None]:
vocab = sorted(list(set(text_train)))
print("Vocabulary:, ", "".join(vocab))
print("Length of vocabulary: ", len(vocab))

Vocabulary:,  
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Length of vocabulary:  65


### Vocabulary mapping, encoding and decoding of the data

In [None]:
# Create mappings from characters to integers and vice versa using a given vocabulary.
# `vocab` is assumed to be an iterable of unique characters.
stoi = {ch: i for i, ch in enumerate(vocab)}  # Character to index mapping
itos = {i: ch for i, ch in enumerate(vocab)}  # Index to character mapping

# Define an encoder function that converts a string into a list of integers based on the stoi mapping.
encode = lambda s: [
    stoi[c] for c in s
]  # Encoder: takes a string, outputs a list of integers

# Define a decoder function that converts a list of integers back into a string using the itos mapping.
decode = lambda l: "".join(
    [itos[i] for i in l]
)  # Decoder: takes a list of integers, outputs a string

# Encode train and validation data. This converts all characters in the training and validation text
# into their corresponding integer indices based on the vocabulary mapping.
# `text_train` and `text_validation` should be strings or lists of characters.
train_data = jnp.array(encode(text_train))  # Convert encoded training data into a JAX array
eval_data = jnp.array(encode(text_validation))  # Convert encoded validation data into a JAX array

### Batching the data

In [None]:
# Vectorized mapping of JAX's dynamic slicing to batch multiple slice operations
# in a single call. This uses `vmap` to automatically handle batched inputs,
# applying `dynamic_slice` across batches.
dynamic_slice_vmap = jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))

@jax.jit
def get_batch(random_key, data):
    """
    Prepares a random batch of training data using JAX's high-performance operations.
    It randomly selects starting indices for sequences within the data, then extracts
    those sequences for training and the subsequent sequences as targets.

    Args:
        random_key (jax.random.PRNGKey): A random seed key used for sampling.
        data (array): The complete training dataset stored in a JAX array.

    Returns:
        x (array): Input sequences extracted from the data.
        y (array): Target sequences, each is the subsequent sequence to the corresponding input.
    """
    # Generate random starting indices for data slices.
    # `BATCH_SIZE` determines how many sequences to sample, and `BLOCK_SIZE` defines the length of each sequence.
    ix = jax.random.randint(
        random_key, shape=(BATCH_SIZE, 1), minval=0, maxval=len(data) - BLOCK_SIZE
    )

    # Fetch the input sequences based on generated indices.
    # `x` will be a batch of sequences starting from each index `ix`.
    x = dynamic_slice_vmap(data, ix, (BLOCK_SIZE,))

    # Fetch the target sequences starting from one position after each `ix` to capture the subsequent values.
    # `y` is essentially `x` shifted by one position in the dataset, used for predicting the next item in the sequence.
    y = dynamic_slice_vmap(data, ix + 1, (BLOCK_SIZE,))

    return x, y

### Transformer Model

In [None]:
class NanoLM(nn.Module):
  """NanoLM model."""
  vocab_size: int
  num_layers: int = 6
  num_heads: int = 8
  head_size: int = 32
  dropout_rate: float = 0.2
  embed_size: int = 256
  block_size: int = 64

  @nn.compact
  def __call__(self, x, training: bool):
    seq_len = x.shape[1]

    x = nn.Embed(self.vocab_size, self.embed_size)(x) + nn.Embed(
        self.block_size, self.embed_size
    )(jnp.arange(seq_len))
    for _ in range(self.num_layers):
      x_norm = nn.LayerNorm()(x)
      x = x + nn.MultiHeadDotProductAttention(
          num_heads=self.num_heads,
          qkv_features=self.head_size,
          out_features=self.head_size * self.num_heads,
          dropout_rate=self.dropout_rate,
      )(
          x_norm,
          x_norm,
          mask=jnp.tril(jnp.ones((x.shape[-2], x.shape[-2]))),
          deterministic=not training,
      )

      x = x + nn.Sequential([
          nn.Dense(4 * self.embed_size),
          nn.relu,
          nn.Dropout(self.dropout_rate, deterministic=not training),
          nn.Dense(self.embed_size),
      ])(nn.LayerNorm()(x))

    x = nn.LayerNorm()(x)
    return nn.Dense(self.vocab_size)(x)

  @functools.partial(jax.jit, static_argnames=("self", "length"))
  def generate(self, rng, params, length):
    def _scan_generate(carry, _):
      random_key, context = carry
      logits = self.apply(params, context, training=False)
      rng, rng_subkey = jax.random.split(random_key)
      new_token = jax.random.categorical(
          rng_subkey, logits[:, -1, :], axis=-1, shape=(1, 1)
      )
      context = jnp.concatenate([context[:, 1:], new_token], axis=1)
      return (rng, context), new_token

    _, new_tokens = jax.lax.scan(
        _scan_generate,
        (rng, jnp.zeros((1, self.block_size), dtype=jnp.int32)),
        (),
        length=length,
    )
    return new_tokens

### Loss function and eval step

In [None]:
# Create an instance of the NanoLM class with specific configuration parameters.
# This setup includes specifying the vocabulary size and various architectural details
# such as the number of layers, attention heads, and the dropout rate, among others.
model = NanoLM(
    vocab_size=len(vocab),       # The size of the vocabulary used in the model.
    num_layers=NUM_LAYERS,       # The number of transformer layers.
    num_heads=NUM_HEADS,         # The number of attention heads in each multi-head attention layer.
    head_size=HEAD_SIZE,         # The dimensionality of each attention head.
    dropout_rate=DROPOUT_RATE,   # The dropout rate used to prevent overfitting.
    embed_size=EMBED_SIZE,       # The size of the embedding layer.
    block_size=BLOCK_SIZE,       # The size of the input sequences (block size).
)

# Defines the loss function used during training. This function calculates
# the mean softmax cross-entropy loss between the logits (model outputs) and
# the true labels (y), which is a common choice for classification tasks.
def loss_fun(params, x, y, dropout_key):
  # Apply the model to input x with parameters 'params', enabling dropout
  # by passing a dropout RNG key. This is important for training to help with regularization.
  logits = model.apply(params, x, training=True, rngs={"dropout": dropout_key})
  # Calculate the mean cross-entropy loss using Optax's utility function,
  # which is suitable for handling logits and integer labels.
  return optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=y
  ).mean()

# Defines a function to evaluate the model on a validation or test set.
# This is similar to the loss function but with dropout disabled, as indicated
# by the 'training=False' argument, which is typical for model evaluation.
@jax.jit
def eval_step(params, x, y):
  # Apply the model to input x with parameters 'params', disabling dropout.
  logits = model.apply(params, x, training=False)
  # Calculate and return the mean softmax cross-entropy loss.
  return optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=y
  ).mean()

### Model Initilization and seeding for reproducibility

In [None]:
# Initialize a PRNGKey with a fixed seed for reproducibility.
key = jax.random.PRNGKey(SEED)

# Split the PRNGKey to derive a new subkey. This is a common practice in JAX to manage
# randomness in a controlled and reproducible manner. The original key is retained
# while `subkey` is used for subsequent operations that require randomness.
key, subkey = jax.random.split(key)

# Initialize the model parameters. The `init` method is used to set up the parameters
# of the model based on the shape and type of the input data, and it often requires
# a random key to initialize weights in a stochastic manner.
var_params = model.init(
    key,  # PRNGKey for random initialization of parameters.
    jnp.ones((BATCH_SIZE, BLOCK_SIZE), dtype=jnp.int32),  # Dummy input to define the shape and type of model inputs.
    training=False,  # Specifies that the model is not in training mode; typically affects certain behaviors like dropout.
)

### No of paramaters in the model

In [None]:
# Sum the sizes of all parameters in the model. In JAX, model parameters are often organized
# in a nested structure. `jax.tree_util.tree_leaves` is used to flatten this structure into
# a list of arrays (leaves), where each array represents a parameter (e.g., a weight matrix
# or a bias vector).
n_params = sum(p.size for p in jax.tree_util.tree_leaves(var_params))

# Print the total number of parameters in the model. The formatting option `:_` is used
# to separate thousands using underscores for better readability.
print(f"Total number of parameters: {n_params:_}")

Total number of parameters: 3_408_513


### Optimizer

In [None]:
# To run with SGD instead of adam, replace `adam` with `sgd`
opt = optax.adamw(learning_rate=LEARNING_RATE)

opt_state = opt.init(var_params)

### Training and Evaluation

In [None]:
%%time

# Initialize lists to store the training and evaluation losses
all_train_losses = []
all_eval_losses = []

# Define a function 'step' that performs a single optimization step.
# This function is decorated with @jax.jit to just-in-time compile it,
# significantly speeding up its execution by optimizing the computation graph.
@jax.jit
def step(key, params, opt_state):
    # Split the PRNG key for randomness operations in JAX.
    key, subkey = jax.random.split(key)

    # Generate a batch of data for training.
    batch = get_batch(key, train_data)

    # Calculate the loss and its gradients with respect to the parameters.
    # `jax.value_and_grad` computes both loss value and gradient in a single function call.
    loss, grad = jax.value_and_grad(loss_fun)(params, *batch, subkey)

    # Compute parameter updates using the optimizer's update function.
    updates, opt_state = opt.update(grad, opt_state, params)

    # Apply the updates to the parameters to create new parameters.
    params = optax.apply_updates(params, updates)

    # Return the new parameters, updated PRNG key, optimizer state, and the loss.
    return params, key, opt_state, loss

# Iterate over a fixed number of iterations defined by N_ITERATIONS
for i in range(N_ITERATIONS):
    # Perform an optimization step and update training parameters.
    var_params, key, opt_state, loss = step(key, var_params, opt_state)
    # Record the training loss
    all_train_losses.append(loss)

    # Periodically evaluate the model on a validation dataset every N_FREQ_EVAL iterations.
    if i % N_FREQ_EVAL == 0:
        # Split the PRNG key again for randomness in evaluation.
        key, subkey = jax.random.split(key)

        # Compute the evaluation loss.
        eval_loss = eval_step(var_params, *get_batch(subkey, eval_data))
        # Record the evaluation loss.
        all_eval_losses.append(eval_loss)

        # Print the current step, training loss, and evaluation loss.
        print(f"Step: {i}\t train loss: {loss}\t eval loss: {eval_loss}")

Step: 0	 train loss: 4.587119102478027	 eval loss: 6.042149543762207
Step: 2000	 train loss: 1.388554573059082	 eval loss: 1.4247114658355713
Step: 4000	 train loss: 1.2833781242370605	 eval loss: 1.3967227935791016
CPU times: user 4min 16s, sys: 3min 41s, total: 7min 57s
Wall time: 7min 52s


### Sample output

In [None]:
# Let's now generate some text
# Split the PRNG key to obtain a new subkey for generating text. This ensures that randomness in the generation process doesn't affect other random operations.
key, subkey = jax.random.split(key)

# Generate text using the model. We specify the length of text to generate (1000 characters, words, or tokens, depending on the model's configuration).
# The key and parameters ('var_params') are passed to the model's generate function.
# The result is reshaped if necessary, here assuming the output is multidimensional where the actual text is in specific dimensions.
text = model.generate(subkey, var_params, 1000)[:, 0, 0].tolist()

# Decode the generated integer tokens back into human-readable text.
# 'decode' is assumed to be a function that maps integer tokens back to strings (e.g., decoding indices to words or characters).
print(decode(text))

CAMILLO:
Yes, sir, bestrictle.

LEONTES:
For which I can scoffer.

Clown:
I shall be full of bosoms in this grief,
And throw the droping sords give thee this banish'd:
The murderer whereof he's slain are so
At gates; the dread ham, on catched of wrongs:
That I trew my words from heaven's father,
His hug-lips in so face, if his world,
Made it kindly know what he most
Covenance on their lips I for link not
To our great speech.

Messenger:
My gracious liege, and then I love thee sir;
Come about the conspiracy. Thus thus they are the
vice music ere thou shalt not abuse thy head?
Be thy departing, but at night winter peace,
Should be welcome; thy life, thy trembling peace,
Comes, give me down as the new-day.

Provost:
Pardon me, cousin!

DUKE VINCENTIO:
The princes shall be gone.

Provost:
I have rather since, the drawbring state
of eldest faces is but within such same long.

OXTON:
Arise, for that yet did follow all,
Since yet over--Pluck the action! O, not some other
the present. Where
Me

# Transformers for Translation

### Necessary Installations

In [None]:
!python -m spacy download fr_core_news_sm
!python -m spacy download en_core_web_sm

Collecting fr-core-news-sm==3.7.0
  Downloading https://github.com/explosion/spacy-models/releases/download/fr_core_news_sm-3.7.0/fr_core_news_sm-3.7.0-py3-none-any.whl (16.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.3/16.3 MB[0m [31m51.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: fr-core-news-sm
Successfully installed fr-core-news-sm-3.7.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('fr_core_news_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.
Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/

### Importing the libraries

In [None]:
import math
import torchtext
import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab
from torchtext.utils import download_from_url, extract_archive
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch import Tensor
from torch.nn import (TransformerEncoder, TransformerDecoder,TransformerEncoderLayer, TransformerDecoderLayer)
import io
import time

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Fre-Eng

#### Data loading

In [None]:
url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
train_urls = ('train.fr.gz', 'train.en.gz')
val_urls = ('val.fr.gz', 'val.en.gz')
test_urls = ('test_2016_flickr.fr.gz', 'test_2016_flickr.en.gz')

# Extracting and downloading data from URLs and obtaining file paths
train_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in train_urls]
val_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in val_urls]
test_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in test_urls]

# Tokenizers for French and English text
fr_tokenizer = get_tokenizer('spacy', language='fr_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

def build_vocab(filepath, tokenizer):
    """
    Builds a vocabulary from a given file using the provided tokenizer.

    Args:
    - filepath (str): The path to the file containing text data.
    - tokenizer: A tokenizer function capable of tokenizing strings.

    Returns:
    - vocab: A vocabulary object containing tokens and their frequencies, with special tokens added.
    """
    counter = Counter()  # Initialize a counter to store token frequencies
    with io.open(filepath, encoding="utf8") as f:
        for string_ in f:  # Iterate through each line in the file
            counter.update(tokenizer(string_))  # Tokenize the string and update token frequencies
    # Create a vocabulary object with tokens and their frequencies, including special tokens
    return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

# Building vocabularies for French and English using training data
fr_vocab = build_vocab(train_filepaths[0], fr_tokenizer)
en_vocab = build_vocab(train_filepaths[1], en_tokenizer)

# Setting default indices for unknown tokens in vocabularies
fr_vocab.set_default_index(fr_vocab['<unk>'])
en_vocab.set_default_index(en_vocab['<unk>'])

### Data Preprocessing

In [None]:
# Function to process data from given filepaths
def data_process(filepaths):
    """
    Processes data from the given filepaths for French and English text.

    Args:
    - filepaths (list): A list containing two file paths, the first for French and the second for English.

    Returns:
    - data (list): A list of tuples, each containing processed French and English tensors.
    """
    raw_fr_iter = iter(io.open(filepaths[0], encoding="utf8"))  # Open file for French data
    raw_en_iter = iter(io.open(filepaths[1], encoding="utf8"))  # Open file for English data
    data = []  # Initialize an empty list to store processed data
    # Loop through both French and English iterators simultaneously using zip
    for (raw_fr, raw_en) in zip(raw_fr_iter, raw_en_iter):
        # Tokenize French text, map tokens to indices in fr_vocab, and create a tensor
        fr_tensor_ = torch.tensor([fr_vocab[token] for token in fr_tokenizer(raw_fr.rstrip("n"))],
                                  dtype=torch.long)
        # Tokenize English text, map tokens to indices in en_vocab, and create a tensor
        en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en.rstrip("n"))],
                                  dtype=torch.long)
        # Append processed French and English tensors to the data list
        data.append((fr_tensor_, en_tensor_))
    return data  # Return processed data

# Process data for training, validation, and testing sets
train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

# Set device to GPU if available, otherwise to CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Constants for batch processing
BATCH_SIZE = 128  # Batch size for training
PAD_IDX = fr_vocab['<pad>'] # Index of the padding token in the French vocabulary
BOS_IDX = fr_vocab['<bos>'] # Index of the beginning-of-sequence token in the French vocabulary
EOS_IDX = fr_vocab['<eos>'] # Index of the end-of-sequence token in the French vocabulary

### Generate data in batches

In [None]:
# DataLoader
def generate_batch(data_batch):
    """
    Generates batches of data with padding for sequences.

    Args:
    - data_batch (list): A batch of data, where each element is a tuple containing processed French and English tensors.

    Returns:
    - fr_batch (Tensor): A tensor containing padded sequences of French tensors for the batch.
    - en_batch (Tensor): A tensor containing padded sequences of English tensors for the batch.
    """
    fr_batch, en_batch = [], []  # Initialize empty lists for French and English batches
    # Iterate through each item in the data batch
    for (fr_item, en_item) in data_batch:
        # Add beginning-of-sequence and end-of-sequence tokens to French tensors and concatenate them
        fr_batch.append(torch.cat([torch.tensor([BOS_IDX]), fr_item, torch.tensor([EOS_IDX])], dim=0))
        # Add beginning-of-sequence and end-of-sequence tokens to English tensors and concatenate them
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    # Pad sequences in the French batch with padding value PAD_IDX
    fr_batch = pad_sequence(fr_batch, padding_value=PAD_IDX)
    # Pad sequences in the English batch with padding value PAD_IDX
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    return fr_batch, en_batch  # Return padded French and English batches

# DataLoader setup for training, validation, and testing data
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)

### Transformer Class

In [None]:
# Transformer class for sequence-to-sequence model
class Seq2SeqTransformer(nn.Module):
    """
    Transformer-based sequence-to-sequence model for machine translation.

    Args:
    - num_encoder_layers (int): Number of layers in the encoder.
    - num_decoder_layers (int): Number of layers in the decoder.
    - emb_size (int): Embedding size for tokens.
    - src_vocab_size (int): Vocabulary size of the source language.
    - tgt_vocab_size (int): Vocabulary size of the target language.
    - dim_feedforward (int): Dimension of the feedforward network in Transformer layers.
    - dropout (float): Dropout probability.

    Methods:
    - forward(src, trg, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
      Forward pass of the model.
    - encode(src, src_mask):
      Encoder forward pass.
    - decode(tgt, memory, tgt_mask):
      Decoder forward pass.
    """

    def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
                 emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
                 dim_feedforward:int = 512, dropout:float = 0.1):
        super(Seq2SeqTransformer, self).__init__()

        # Transformer encoder
        encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        # Transformer decoder
        decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # Generator layer
        self.generator = nn.Linear(emb_size, tgt_vocab_size)

        # Token embeddings for source and target languages
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)

        # Positional encoding
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
                tgt_mask: Tensor, src_padding_mask: Tensor,
                tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
        """
        Forward pass of the transformer-based sequence-to-sequence model.

        Args:
        - src (Tensor): Input tensor of shape (src_seq_len, batch_size).
        - trg (Tensor): Target tensor of shape (trg_seq_len, batch_size).
        - src_mask (Tensor): Mask for source tokens of shape (src_seq_len, src_seq_len).
        - tgt_mask (Tensor): Mask for target tokens of shape (trg_seq_len, trg_seq_len).
        - src_padding_mask (Tensor): Mask for padding tokens in source of shape (batch_size, src_seq_len).
        - tgt_padding_mask (Tensor): Mask for padding tokens in target of shape (batch_size, trg_seq_len).
        - memory_key_padding_mask (Tensor): Mask for padding in memory key of shape (batch_size, src_seq_len).

        Returns:
        - Tensor: Output tensor of shape (trg_seq_len, batch_size, tgt_vocab_size).
        """
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
                                        tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        """
        Encoder forward pass.

        Args:
        - src (Tensor): Input tensor of shape (src_seq_len, batch_size).
        - src_mask (Tensor): Mask for source tokens of shape (src_seq_len, src_seq_len).

        Returns:
        - Tensor: Encoded tensor of shape (src_seq_len, batch_size, emb_size).
        """
        return self.transformer_encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        """
        Decoder forward pass.

        Args:
        - tgt (Tensor): Input tensor of shape (trg_seq_len, batch_size).
        - memory (Tensor): Memory tensor from encoder of shape (src_seq_len, batch_size, emb_size).
        - tgt_mask (Tensor): Mask for target tokens of shape (trg_seq_len, trg_seq_len).

        Returns:
        - Tensor: Decoded tensor of shape (trg_seq_len, batch_size, emb_size).
        """
        return self.transformer_decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

### Positional Embedding and Embedding

In [None]:
class PositionalEncoding(nn.Module):
    """
    Positional encoding module for adding positional information to token embeddings.

    Args:
    - emb_size (int): Embedding size.
    - dropout (float): Dropout probability.
    - maxlen (int): Maximum sequence length.

    Methods:
    - forward(token_embedding: Tensor):
      Forward pass of the positional encoding module.
    """
    def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        # Calculate positional encodings
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)
        self.dropout = nn.Dropout(dropout)
        # Register positional embeddings as a buffer
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        """
        Forward pass of the positional encoding module.

        Args:
        - token_embedding (Tensor): Token embeddings of shape (seq_len, batch_size, emb_size).

        Returns:
        - Tensor: Token embeddings with positional encodings added, of shape (seq_len, batch_size, emb_size).
        """
        # Add positional encodings to token embeddings
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:])


class TokenEmbedding(nn.Module):
    """
    Token embedding module for converting token indices into token embeddings.

    Args:
    - vocab_size (int): Vocabulary size.
    - emb_size (int): Embedding size.

    Methods:
    - forward(tokens: Tensor):
      Forward pass of the token embedding module.
    """
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        # Initialize an embedding layer for tokens
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        """
        Forward pass of the token embedding module.

        Args:
        - tokens (Tensor): Token indices of shape (seq_len, batch_size).

        Returns:
        - Tensor: Token embeddings of shape (seq_len, batch_size, emb_size).
        """
        # Convert token indices into token embeddings and scale them
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

### Masking

In [None]:
def generate_square_subsequent_mask(sz):
    """
    Generates a square subsequent mask for self-attention mechanisms in transformer models.

    Args:
    - sz (int): Size of the square mask.

    Returns:
    - Tensor: Square subsequent mask of shape (sz, sz).
    """
    # Create an upper triangular matrix of ones
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    # Convert the mask to float and replace zeros with negative infinity and ones with zero
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    """
    Creates masks for source and target sequences to be used in transformer models.

    Args:
    - src (Tensor): Source tensor of shape (src_seq_len, batch_size).
    - tgt (Tensor): Target tensor of shape (tgt_seq_len, batch_size).

    Returns:
    - src_mask (Tensor): Mask for source sequence of shape (src_seq_len, src_seq_len).
    - tgt_mask (Tensor): Mask for target sequence of shape (tgt_seq_len, tgt_seq_len).
    - src_padding_mask (Tensor): Padding mask for source sequence of shape (batch_size, src_seq_len).
    - tgt_padding_mask (Tensor): Padding mask for target sequence of shape (batch_size, tgt_seq_len).
    """
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    # Generate mask for the target sequence
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    # Create a padding mask for the source sequence
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

    # Create padding masks for both source and target sequences
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

### Important parmeters

In [None]:
# Define vocabulary sizes and model parameters
SRC_VOCAB_SIZE = len(fr_vocab)
TGT_VOCAB_SIZE = len(en_vocab)
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
NUM_EPOCHS = 10

# Determine the device (CPU or GPU) for computation
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Instantiate the Seq2SeqTransformer model
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,
                                 EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
                                 FFN_HID_DIM)

# Initialize model parameters using Xavier initialization
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

# Move the model to the selected device
transformer = transformer.to(DEVICE)

# Define the loss function (CrossEntropyLoss) ignoring padding tokens
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# Define the optimizer (Adam) with specific parameters
optimizer = torch.optim.Adam(
    transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)



### Train function

In [None]:
def train_epoch(model, train_iter, optimizer, epoch, save_path):
    """
    Trains the model for one epoch using the provided training iterator and optimizer.

    Args:
    - model (nn.Module): The model to be trained.
    - train_iter (DataLoader): The data loader iterator for training data.
    - optimizer (torch.optim.Optimizer): The optimizer used for training.
    - epoch (int): The current epoch number.
    - save_path (str): The directory path to save the trained model.

    Returns:
    - train_loss (float): The average loss of the model over the train set.
    """
    model.train()  # Set the model to training mode
    total_loss = 0  # Initialize total loss for the epoch
    total_correct = 0  # Initialize total number of correct predictions
    total_elements = 0  # Initialize total token count for accuracy calculation

    # Iterate over the training data iterator
    for idx, (src, tgt) in enumerate(train_iter):
        src = src.to(device)  # Move source data to the appropriate device
        tgt = tgt.to(device)  # Move target data to the appropriate device
        tgt_input = tgt[:-1, :]  # Get input to the decoder (exclude last token)

        # Generate masks for source and target sequences
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        # Forward pass: compute predicted outputs by passing inputs to the model
        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

        # Zero the gradients before backward pass
        optimizer.zero_grad()

        # Define target output (right-shifted)
        tgt_out = tgt[1:, :]

        # Compute the loss
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))

        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()

        # Perform a single optimization step (parameter update)
        optimizer.step()

        # Accumulate the total loss for the epoch
        total_loss += loss.item()

    # Calculate average loss and accuracy for the epoch
    train_loss = total_loss / len(train_iter)

    # Save the trained model after each epoch
    model_save_path = f'{save_path}/model_epoch_{epoch}.pth'
    torch.save(model.state_dict(), model_save_path)

    return train_loss

### Evaluate function

In [None]:
def evaluate(model, val_iter, device):
    """
    Evaluate the performance of a model on the validation dataset.

    This function calculates the average loss and accuracy of the model
    over the entire validation set. The model is set to evaluation mode to
    disable operations like dropout.

    Args:
    - model (torch.nn.Module): The model to be evaluated.
    - val_iter (iterable): An iterable over the validation dataset. Each iteration
                           produces a pair (src, tgt) representing source and target data.
    - device (torch.device): The device tensors should be sent to (e.g., 'cuda' or 'cpu').

    Returns:
    - val_loss (float): The average loss of the model over the validation set.
    """
    model.eval()
    total_loss = 0
    total_accuracy = 0
    total_elements = 0

    for idx, (src, tgt) in enumerate(val_iter):
        src = src.to(device)  # Move source data to the specified device
        tgt = tgt.to(device)  # Move target data to the specified device

        tgt_input = tgt[:-1, :]  # Prepare the input for the target sequence

        # Create masks for the source and target input
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        # Forward pass: compute predicted logits by passing src, tgt_input, and masks to the model
        logits = model(src, tgt_input, src_mask, tgt_mask,
                       src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]  # The actual targets, excluding the first element of tgt
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        total_loss += loss.item()  # Accumulate the loss

    val_loss = total_loss / len(val_iter)  # Calculate average loss

    return val_loss

### Train and Val loss

In [None]:
# Define the path where the model or other data might be saved or accessed.
path = '/content/drive/MyDrive/trf/'

# Loop over each epoch starting from 1 to NUM_EPOCHS
for epoch in range(1, NUM_EPOCHS+1):
    start_time = time.time()  # Record the start time of the epoch

    # Train the model for one epoch and retrieve training loss and accuracy
    train_loss = train_epoch(transformer, train_iter, optimizer, epoch, path)

    end_time = time.time()  # Record the end time of the training

    # Evaluate the model on the validation dataset and retrieve loss and accuracy
    val_loss = evaluate(transformer, valid_iter, device)

    # Print the results for the epoch including both training and validation metrics
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "
           f"Epoch time = {(end_time - start_time):.3f}s"))

Epoch: 1, Train loss: 0.810, Val loss: 1.410, Epoch time = 19.292s
Epoch: 2, Train loss: 0.761, Val loss: 1.421, Epoch time = 19.433s
Epoch: 3, Train loss: 0.715, Val loss: 1.397, Epoch time = 19.517s
Epoch: 4, Train loss: 0.674, Val loss: 1.412, Epoch time = 19.564s
Epoch: 5, Train loss: 0.635, Val loss: 1.404, Epoch time = 19.583s
Epoch: 6, Train loss: 0.598, Val loss: 1.412, Epoch time = 19.590s
Epoch: 7, Train loss: 0.563, Val loss: 1.415, Epoch time = 19.736s
Epoch: 8, Train loss: 0.530, Val loss: 1.415, Epoch time = 19.652s
Epoch: 9, Train loss: 0.499, Val loss: 1.431, Epoch time = 19.812s
Epoch: 10, Train loss: 0.468, Val loss: 1.442, Epoch time = 19.638s


### Greedy Decoder


In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    """
    Performs greedy decoding to generate a translation sequence given a source sequence.

    Args:
    - model (nn.Module): The trained model for sequence-to-sequence translation.
    - src (Tensor): Source sequence tensor of shape (src_seq_len, batch_size).
    - src_mask (Tensor): Source mask tensor to mask out padding tokens.
    - max_len (int): Maximum length of the generated translation sequence.
    - start_symbol (int): Index of the start symbol in the target vocabulary.

    Returns:
    - Tensor: Generated translation sequence tensor of shape (tgt_seq_len, batch_size).
    """
    src = src.to(device)  # Move source tensor to device (CPU or GPU)
    src_mask = src_mask.to(device)  # Move source mask tensor to device (CPU or GPU)
    memory = model.encode(src, src_mask)  # Encode the source sequence
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)  # Initialize target sequence with start symbol
    # Iterate over target sequence until maximum length or end-of-sequence token is reached
    for i in range(max_len-1):
        memory = memory.to(device)  # Move memory tensor to device (CPU or GPU)
        # Create mask for memory tensor to mask out padding tokens
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
        # Generate mask for target sequence to mask out subsequent tokens
        tgt_mask = (generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(device)
        # Decode the target sequence
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)  # Transpose output tensor
        # Generate probability distribution over target vocabulary for the next word
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)  # Get index of the word with maximum probability
        next_word = next_word.item()  # Convert index to Python integer
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)  # Append next word to target sequence
        if next_word == EOS_IDX:  # Check if end-of-sequence token is reached
            break
    return ys  # Return the generated translation sequence

### Translation function

In [None]:
def translate(model, src, src_vocab, tgt_vocab, src_tokenizer):
    """
    Translates a source sentence to the target language using the trained model.

    Args:
    - model (nn.Module): The trained model for sequence-to-sequence translation.
    - src (str): Source sentence to be translated.
    - src_vocab (Vocab): Source vocabulary object.
    - tgt_vocab (Vocab): Target vocabulary object.
    - src_tokenizer (Tokenizer): Tokenizer for the source language.

    Returns:
    - str: Translated sentence in the target language.
    """
    model.eval()  # Set the model to evaluation mode
    # Tokenize the source sentence and convert tokens to indices using the source vocabulary
    tokens = [BOS_IDX] + [src_vocab.get_stoi()[tok] for tok in src_tokenizer(src)] + [EOS_IDX]
    num_tokens = len(tokens)  # Get the number of tokens in the source sentence
    src = (torch.LongTensor(tokens).reshape(num_tokens, 1))  # Convert tokens to tensor
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)  # Create source mask tensor
    # Generate translation using greedy decoding
    tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    # Convert target tokens to words using the target vocabulary and join them into a sentence
    translation = " ".join([tgt_vocab.get_itos()[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")
    return translation  # Return the translated sentence

### Sample output

In [None]:
# Translate a source sentence from French to English
output = translate(transformer, "Un groupe de personnes se tient devant un igloo .", fr_vocab, en_vocab, fr_tokenizer)
# Print the translated output
print(output)

 A group of people stand in front of an igloo . 
 


In [None]:
# Translate a source sentence from French to English
output = translate(transformer, "Il chante dans la chorale .", fr_vocab, en_vocab, fr_tokenizer)
# Print the translated output
print(output)

 He is singing in the choir . 
 


# Eng-Fre

#### Data Loading

In [None]:
url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
train_urls = ('train.en.gz', 'train.fr.gz')
val_urls = ('val.en.gz', 'val.fr.gz')
test_urls = ('test_2016_flickr.en.gz', 'test_2016_flickr.fr.gz')

# Extracting and downloading data from URLs and obtaining file paths
train_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in train_urls]
val_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in val_urls]
test_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in test_urls]

# Tokenizers for French and English text
fr_tokenizer = get_tokenizer('spacy', language='fr_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

def build_vocab(filepath, tokenizer):
    """
    Builds a vocabulary from a given file using the provided tokenizer.

    Args:
    - filepath (str): The path to the file containing text data.
    - tokenizer: A tokenizer function capable of tokenizing strings.

    Returns:
    - vocab: A vocabulary object containing tokens and their frequencies, with special tokens added.
    """
    counter = Counter()  # Initialize a counter to store token frequencies
    with io.open(filepath, encoding="utf8") as f:
        for string_ in f:  # Iterate through each line in the file
            counter.update(tokenizer(string_))  # Tokenize the string and update token frequencies
    # Create a vocabulary object with tokens and their frequencies, including special tokens
    return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

# Building vocabularies for French and English using training data
fr_vocab = build_vocab(train_filepaths[1], fr_tokenizer)
en_vocab = build_vocab(train_filepaths[0], en_tokenizer)

# Setting default indices for unknown tokens in vocabularies
fr_vocab.set_default_index(fr_vocab['<unk>'])
en_vocab.set_default_index(en_vocab['<unk>'])



### Data Preprocessing

In [None]:
def data_process(filepaths):
    """
    Processes raw text data from filepaths into tensors for both French and English languages.

    Args:
    - filepaths (list): List of filepaths for the French and English datasets.

    Returns:
    - list: List of tuples containing processed tensor pairs (English, French).
    """
    raw_fr_iter = iter(io.open(filepaths[1], encoding="utf8"))  # Open French data file for reading
    raw_en_iter = iter(io.open(filepaths[0], encoding="utf8"))  # Open English data file for reading
    data = []  # Initialize empty list to store processed data
    # Iterate over both French and English iterators simultaneously using zip
    for (raw_fr, raw_en) in zip(raw_fr_iter, raw_en_iter):
        # Tokenize French text, map tokens to indices in fr_vocab, and create a tensor
        fr_tensor_ = torch.tensor([fr_vocab[token] for token in fr_tokenizer(raw_fr.rstrip("n"))],
                                  dtype=torch.long)
        # Tokenize English text, map tokens to indices in en_vocab, and create a tensor
        en_tensor_ = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en.rstrip("n"))],
                                  dtype=torch.long)
        # Append processed English and French tensors to the data list
        data.append((en_tensor_, fr_tensor_))
    return data  # Return processed data

# Process data for training, validation, and testing sets
train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

# Set device to GPU if available, otherwise to CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Constants for batch processing
BATCH_SIZE = 128  # Batch size for training
PAD_IDX = fr_vocab['<pad>'] # Index of the padding token in the French vocabulary
BOS_IDX = fr_vocab['<bos>'] # Index of the beginning-of-sequence token in the French vocabulary
EOS_IDX = fr_vocab['<eos>'] # Index of the end-of-sequence token in the French vocabulary

### Generate data in batches

In [None]:
# DataLoader
def generate_batch(data_batch):
    """
    Generates batches of data by padding sequences to the maximum length in the batch.

    Args:
    - data_batch (list): List of tuples containing paired sequences (English, French).

    Returns:
    - Tensor: Padded batch of English sequences.
    - Tensor: Padded batch of French sequences.
    """
    en_batch, fr_batch = [], []  # Initialize empty lists to store batches of English and French sequences
    # Iterate over each tuple (English, French) in the data batch
    for (en_item, fr_item) in data_batch:
        # Append the English sequence with start-of-sequence and end-of-sequence tokens and concatenate into a tensor
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
        # Append the French sequence with start-of-sequence and end-of-sequence tokens and concatenate into a tensor
        fr_batch.append(torch.cat([torch.tensor([BOS_IDX]), fr_item, torch.tensor([EOS_IDX])], dim=0))
    # Pad sequences in the English batch with the padding token to ensure uniform length
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    # Pad sequences in the French batch with the padding token to ensure uniform length
    fr_batch = pad_sequence(fr_batch, padding_value=PAD_IDX)
    return en_batch, fr_batch  # Return padded batches of English and French sequences


# Create data loaders for train, validation, and test datasets using the generated batch function
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)

### Transformer Class

In [None]:
# Transformer class for sequence-to-sequence model
class Seq2SeqTransformer(nn.Module):
    """
    Transformer-based sequence-to-sequence model for machine translation.

    Args:
    - num_encoder_layers (int): Number of layers in the encoder.
    - num_decoder_layers (int): Number of layers in the decoder.
    - emb_size (int): Embedding size for tokens.
    - src_vocab_size (int): Vocabulary size of the source language.
    - tgt_vocab_size (int): Vocabulary size of the target language.
    - dim_feedforward (int): Dimension of the feedforward network in Transformer layers.
    - dropout (float): Dropout probability.

    Methods:
    - forward(src, trg, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
      Forward pass of the model.
    - encode(src, src_mask):
      Encoder forward pass.
    - decode(tgt, memory, tgt_mask):
      Decoder forward pass.
    """

    def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
                 emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
                 dim_feedforward:int = 512, dropout:float = 0.1):
        super(Seq2SeqTransformer, self).__init__()

        # Transformer encoder
        encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        # Transformer decoder
        decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # Generator layer
        self.generator = nn.Linear(emb_size, tgt_vocab_size)

        # Token embeddings for source and target languages
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)

        # Positional encoding
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
                tgt_mask: Tensor, src_padding_mask: Tensor,
                tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
        """
        Forward pass of the transformer-based sequence-to-sequence model.

        Args:
        - src (Tensor): Input tensor of shape (src_seq_len, batch_size).
        - trg (Tensor): Target tensor of shape (trg_seq_len, batch_size).
        - src_mask (Tensor): Mask for source tokens of shape (src_seq_len, src_seq_len).
        - tgt_mask (Tensor): Mask for target tokens of shape (trg_seq_len, trg_seq_len).
        - src_padding_mask (Tensor): Mask for padding tokens in source of shape (batch_size, src_seq_len).
        - tgt_padding_mask (Tensor): Mask for padding tokens in target of shape (batch_size, trg_seq_len).
        - memory_key_padding_mask (Tensor): Mask for padding in memory key of shape (batch_size, src_seq_len).

        Returns:
        - Tensor: Output tensor of shape (trg_seq_len, batch_size, tgt_vocab_size).
        """
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
                                        tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        """
        Encoder forward pass.

        Args:
        - src (Tensor): Input tensor of shape (src_seq_len, batch_size).
        - src_mask (Tensor): Mask for source tokens of shape (src_seq_len, src_seq_len).

        Returns:
        - Tensor: Encoded tensor of shape (src_seq_len, batch_size, emb_size).
        """
        return self.transformer_encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        """
        Decoder forward pass.

        Args:
        - tgt (Tensor): Input tensor of shape (trg_seq_len, batch_size).
        - memory (Tensor): Memory tensor from encoder of shape (src_seq_len, batch_size, emb_size).
        - tgt_mask (Tensor): Mask for target tokens of shape (trg_seq_len, trg_seq_len).

        Returns:
        - Tensor: Decoded tensor of shape (trg_seq_len, batch_size, emb_size).
        """
        return self.transformer_decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

### Positional Embedding and Embedding

In [None]:
class PositionalEncoding(nn.Module):
    """
    Positional encoding module for adding positional information to token embeddings.

    Args:
    - emb_size (int): Embedding size.
    - dropout (float): Dropout probability.
    - maxlen (int): Maximum sequence length.

    Methods:
    - forward(token_embedding: Tensor):
      Forward pass of the positional encoding module.
    """
    def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        # Calculate positional encodings
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)
        self.dropout = nn.Dropout(dropout)
        # Register positional embeddings as a buffer
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        """
        Forward pass of the positional encoding module.

        Args:
        - token_embedding (Tensor): Token embeddings of shape (seq_len, batch_size, emb_size).

        Returns:
        - Tensor: Token embeddings with positional encodings added, of shape (seq_len, batch_size, emb_size).
        """
        # Add positional encodings to token embeddings
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0),:])


class TokenEmbedding(nn.Module):
    """
    Token embedding module for converting token indices into token embeddings.

    Args:
    - vocab_size (int): Vocabulary size.
    - emb_size (int): Embedding size.

    Methods:
    - forward(tokens: Tensor):
      Forward pass of the token embedding module.
    """
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        # Initialize an embedding layer for tokens
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        """
        Forward pass of the token embedding module.

        Args:
        - tokens (Tensor): Token indices of shape (seq_len, batch_size).

        Returns:
        - Tensor: Token embeddings of shape (seq_len, batch_size, emb_size).
        """
        # Convert token indices into token embeddings and scale them
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

### Masking

In [None]:
def generate_square_subsequent_mask(sz):
    """
    Generates a square subsequent mask for self-attention mechanisms in transformer models.

    Args:
    - sz (int): Size of the square mask.

    Returns:
    - Tensor: Square subsequent mask of shape (sz, sz).
    """
    # Create an upper triangular matrix of ones
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    # Convert the mask to float and replace zeros with negative infinity and ones with zero
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    """
    Creates masks for source and target sequences to be used in transformer models.

    Args:
    - src (Tensor): Source tensor of shape (src_seq_len, batch_size).
    - tgt (Tensor): Target tensor of shape (tgt_seq_len, batch_size).

    Returns:
    - src_mask (Tensor): Mask for source sequence of shape (src_seq_len, src_seq_len).
    - tgt_mask (Tensor): Mask for target sequence of shape (tgt_seq_len, tgt_seq_len).
    - src_padding_mask (Tensor): Padding mask for source sequence of shape (batch_size, src_seq_len).
    - tgt_padding_mask (Tensor): Padding mask for target sequence of shape (batch_size, tgt_seq_len).
    """
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    # Generate mask for the target sequence
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    # Create a padding mask for the source sequence
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

    # Create padding masks for both source and target sequences
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

### Important Parameters

In [None]:
# Define vocabulary sizes and model parameters
SRC_VOCAB_SIZE = len(en_vocab)
TGT_VOCAB_SIZE = len(fr_vocab)
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
NUM_EPOCHS = 10

# Determine the device (CPU or GPU) for computation
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Instantiate the Seq2SeqTransformer model
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS,
                                 EMB_SIZE, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE,
                                 FFN_HID_DIM)

# Initialize model parameters using Xavier initialization
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

# Move the model to the selected device
transformer = transformer.to(DEVICE)

# Define the loss function (CrossEntropyLoss) ignoring padding tokens
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# Define the optimizer (Adam) with specific parameters
optimizer = torch.optim.Adam(
    transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)



### Train function

In [None]:
def train_epoch(model, train_iter, optimizer, epoch, save_path):
    """
    Trains the model for one epoch using the provided training iterator and optimizer.

    Args:
    - model (nn.Module): The model to be trained.
    - train_iter (DataLoader): The data loader iterator for training data.
    - optimizer (torch.optim.Optimizer): The optimizer used for training.
    - epoch (int): The current epoch number.
    - save_path (str): The directory path to save the trained model.

    Returns:
    - train_loss (float): The average loss of the model over the train set.
    """
    model.train()  # Set the model to training mode
    total_loss = 0  # Initialize total loss for the epoch
    total_correct = 0  # Initialize total number of correct predictions
    total_elements = 0  # Initialize total token count for accuracy calculation

    # Iterate over the training data iterator
    for idx, (src, tgt) in enumerate(train_iter):
        src = src.to(device)  # Move source data to the appropriate device
        tgt = tgt.to(device)  # Move target data to the appropriate device
        tgt_input = tgt[:-1, :]  # Get input to the decoder (exclude last token)

        # Generate masks for source and target sequences
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        # Forward pass: compute predicted outputs by passing inputs to the model
        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

        # Zero the gradients before backward pass
        optimizer.zero_grad()

        # Define target output (right-shifted)
        tgt_out = tgt[1:, :]

        # Compute the loss
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))

        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()

        # Perform a single optimization step (parameter update)
        optimizer.step()

        # Accumulate the total loss for the epoch
        total_loss += loss.item()

    # Calculate average loss and accuracy for the epoch
    train_loss = total_loss / len(train_iter)

    # Save the trained model after each epoch
    model_save_path = f'{save_path}/model_epoch_{epoch}.pth'
    torch.save(model.state_dict(), model_save_path)

    return train_loss

### Evaluate function

In [None]:
def evaluate(model, val_iter, device):
    """
    Evaluate the performance of a model on the validation dataset.

    This function calculates the average loss and accuracy of the model
    over the entire validation set. The model is set to evaluation mode to
    disable operations like dropout.

    Args:
    - model (torch.nn.Module): The model to be evaluated.
    - val_iter (iterable): An iterable over the validation dataset. Each iteration
                           produces a pair (src, tgt) representing source and target data.
    - device (torch.device): The device tensors should be sent to (e.g., 'cuda' or 'cpu').

    Returns:
    - val_loss (float): The average loss of the model over the validation set.
    """
    model.eval()
    total_loss = 0
    total_accuracy = 0
    total_elements = 0

    for idx, (src, tgt) in enumerate(val_iter):
        src = src.to(device)  # Move source data to the specified device
        tgt = tgt.to(device)  # Move target data to the specified device

        tgt_input = tgt[:-1, :]  # Prepare the input for the target sequence

        # Create masks for the source and target input
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        # Forward pass: compute predicted logits by passing src, tgt_input, and masks to the model
        logits = model(src, tgt_input, src_mask, tgt_mask,
                       src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]  # The actual targets, excluding the first element of tgt
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        total_loss += loss.item()  # Accumulate the loss

    val_loss = total_loss / len(val_iter)  # Calculate average loss

    return val_loss

### Train and Val loss

In [None]:
# Define the path where the model or other data might be saved or accessed.
path = '/content/drive/MyDrive/trf/'

# Loop over each epoch starting from 1 to NUM_EPOCHS
for epoch in range(1, NUM_EPOCHS+1):
    start_time = time.time()  # Record the start time of the epoch

    # Train the model for one epoch and retrieve training loss and accuracy
    train_loss = train_epoch(transformer, train_iter, optimizer, epoch, path)

    end_time = time.time()  # Record the end time of the training

    # Evaluate the model on the validation dataset and retrieve loss and accuracy
    val_loss = evaluate(transformer, valid_iter, device)

    # Print the results for the epoch including both training and validation metrics
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "
           f"Epoch time = {(end_time - start_time):.3f}s"))



Epoch: 1, Train loss: 3.249, Val loss: 2.723, Epoch time = 20.590s
Epoch: 2, Train loss: 2.522, Val loss: 2.206, Epoch time = 20.501s
Epoch: 3, Train loss: 2.081, Val loss: 1.908, Epoch time = 20.369s
Epoch: 4, Train loss: 1.783, Val loss: 1.713, Epoch time = 20.305s
Epoch: 5, Train loss: 1.561, Val loss: 1.576, Epoch time = 20.377s
Epoch: 6, Train loss: 1.391, Val loss: 1.479, Epoch time = 20.530s
Epoch: 7, Train loss: 1.254, Val loss: 1.413, Epoch time = 20.515s
Epoch: 8, Train loss: 1.140, Val loss: 1.357, Epoch time = 20.435s
Epoch: 9, Train loss: 1.045, Val loss: 1.313, Epoch time = 20.418s
Epoch: 10, Train loss: 0.964, Val loss: 1.263, Epoch time = 20.337s


### Greedy Decoder

In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    """
    Perform greedy decoding on the output of a model.

    This function decodes a source input sentence into a target sentence using
    the greedy approach where the most probable next word is chosen at each step.

    Args:
    - model (torch.nn.Module): The transformer model to use for decoding.
    - src (torch.Tensor): The input tensor containing the source sentence indices.
    - src_mask (torch.Tensor): The mask tensor for the source input, preventing
                               the model from attending to padding positions.
    - max_len (int): The maximum length of the output sentence to generate.
    - start_symbol (int): The index of the start symbol used to initiate decoding.

    Returns:
    - torch.Tensor: The tensor containing the indices of the generated sentence.
    """

    # Move the input source sentence and its mask to the same device as the model.
    src = src.to(device)
    src_mask = src_mask.to(device)

    # Encode the source sentence and initialize the target with the start symbol.
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)

    # Iterate to generate each symbol until max_len is reached or EOS is predicted.
    for i in range(max_len-1):
        # Create a subsequent mask for the target to prevent attention to future positions.
        memory = memory.to(device)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                                    .type(torch.bool)).to(device)

        # Decode using the latest target input, the encoded memory, and the masks.
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])

        # Select the token with the highest probability as the next token.
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.item()

        # Append the predicted word to the target sequence.
        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)

        # If the end of sentence (EOS) token is predicted, stop decoding.
        if next_word == EOS_IDX:
          break

    return ys

### Translation function

In [None]:
def translate(model, src, src_vocab, tgt_vocab, src_tokenizer):
    """
    Translate a source sentence into the target language using a given model.

    This function tokenizes the input sentence, encodes it into indices, performs
    greedy decoding to generate the output sequence, and then converts the output
    indices back to text.

    Args:
    - model (torch.nn.Module): The trained translation model.
    - src (str): The source sentence to translate.
    - src_vocab (Vocab): The source vocabulary object that maps tokens to indices.
    - tgt_vocab (Vocab): The target vocabulary object that maps indices to tokens.
    - src_tokenizer (function): The tokenizer function for the source language.

    Returns:
    - str: The translated sentence in the target language.
    """

    model.eval()  # Set the model to evaluation mode to disable training-specific behaviors such as dropout.

    # Tokenize the input sentence, convert tokens to indices, and add boundary tokens.
    tokens = [BOS_IDX] + [src_vocab.get_stoi()[tok] for tok in src_tokenizer(src)] + [EOS_IDX]
    num_tokens = len(tokens)
    src = torch.LongTensor(tokens).reshape(num_tokens, 1)  # Reshape for model input.

    # Create a source mask (assuming full attention across the input).
    src_mask = torch.zeros(num_tokens, num_tokens).type(torch.bool)

    # Decode the source input to generate target tokens.
    tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()

    # Convert the output indices to tokens and join into a single string.
    # Remove special tokens <bos> and <eos> for final output.
    return " ".join([tgt_vocab.get_itos()[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")

### Sample output

In [None]:
# Translate an English sentence into French using a pretrained transformer model.
# 'transformer' is the model, 'en_vocab' is the English vocabulary, 'fr_vocab' is the French vocabulary,
# and 'en_tokenizer' is the tokenizer function for English.
output = translate(transformer, "A group of people talking.", en_vocab, fr_vocab, en_tokenizer)

# Print the translated sentence. The expected output should be the French translation
# of the English sentence "A group of people talking."
print(output)

 Un groupe de personnes parlant . 
 


In [None]:
output = translate(transformer, "He sings in the choir.", en_vocab, fr_vocab, en_tokenizer)
print(output)


 Il chante dans la chorale . 
 
