# A Comprehensive Guide to PyTorch's nn.Transformer() Module: Morse Code Translation Example

## Overview and Learning Objectives

Welcome to an in-depth exploration of the Transformer architecture using PyTorch, with a fascinating practical application: Morse Code Translation! This tutorial will guide you through building a neural machine translation model using the revolutionary Transformer architecture introduced in the seminal paper ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762) by Vaswani et al.
 
 ### Key Learning Objectives:
 
 - Understand the core components of the Transformer architecture.
 - Implement a Transformer model for sequence-to-sequence translation.
 - Learn how to preprocess and prepare data for sequence translation.
 - Gain practical experience with PyTorch's `nn.Transformer()` module.
 
 ### Background: Transformers and Their Revolution
 
 The Transformer architecture, introduced in 2017, fundamentally changed how we approach sequence-to-sequence tasks. Unlike previous recurrent neural network (RNN) architectures, Transformers rely entirely on attention mechanisms, enabling:
 - More parallel computation.
 - Effective handling of long-range dependencies.
 
 ### Why Morse Code Translation?
 Morse code provides an intuitive, constrained domain for demonstrating sequence translation:
 - Binary nature (dots and dashes) simplifies the task.
 - Fixed mapping between characters.
 - Historical significance as a communication method.
 - A practical example for showcasing sequence-to-sequence learning principles.
 %% [markdown]
 ## Prerequisites and Dependencies
 
 Before diving into the code, ensure the following libraries are installed:
 - `PyTorch`: Core deep learning framework.
 - `NumPy`: For numerical computations.
 - `Matplotlib`: For plotting metrics.
 - `torchview` (optional): For model visualization.
 - `graphviz` (optional): Used by `torchview` for graphical representations.
 
 **Note**: This implementation is designed for educational purposes and demonstrates core Transformer principles. Real-world applications require advanced techniques and larger datasets.

**Note**: You may freely use or reproduce our work but please cite it.techniques and larger datasets.


## Imports and Environment Setup

In the first code cell, we'll import the necessary libraries and set up a consistent random seed for reproducibility. Seeding helps ensure that our random operations produce the same results across different runs.

### Key Considerations:
- `torch.manual_seed()` ensures reproducible random number generation
- `torch.backends.cudnn.deterministic` enforces deterministic operations
- We use NumPy's random seed to maintain consistency across libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import matplotlib.pyplot as plt
try:
    from torchview import draw_graph
    import graphviz
except ImportError:
    import pip
    pip.main(['install', 'torchview', 'graphviz'])
    from torchview import draw_graph
    import graphviz
from torchinfo import summary

Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.


In [None]:

def seed_everything(seed=512):
    """
    Seed everything.
    """
    # random.seed(seed)
    # os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


seed_everything()


## Morse Code Dictionary and Preprocessing

Next, we define our Morse code dictionary, which maps characters to their binary representations. This is a critical step in preparing our translation task.

### Morse Code Encoding:
- Each character is represented by a unique binary sequence
- '0' typically represents a short signal (dot)
- '1' typically represents a long signal (dash)

In [None]:

# Define Morse Code Dictionary
MORSE_CODE_DICT = {
    'A': '01', 'B': '1000', 'C': '1010', 'D': '100', 'E': '0',
    'F': '0010', 'G': '110', 'H': '0000', 'I': '00', 'J': '0111',
    'K': '101', 'L': '0100', 'M': '11', 'N': '10', 'O': '111',
    'P': '0110', 'Q': '1101', 'R': '010', 'S': '000', 'T': '1',
    'U': '001', 'V': '0001', 'W': '011', 'X': '1001', 'Y': '1011',
    'Z': '1100',
}

### Special Tokens:
We introduce special tokens to help our model understand sequence boundaries:
- `SOS_TOKEN`: Start of Sequence
- `EOS_TOKEN`: End of Sequence
- `PADDING_TOKEN`: Used to make sequences uniform length
- `SPACE_CHAR`: Used to separate characters from each other in morse code

In [None]:
# Special tokens
SOS_TOKEN = '<SOS>'
EOS_TOKEN = '<EOS>'
PADDING_TOKEN = '<PAD>'
SPACE_CHAR = ' '



## Dataset Generation Function

The `generate_dataset()` function is crucial in creating synthetic training data for our Morse code translation task. Let's break down its key components:

### Data Generation Strategy:
- Randomly select characters from our Morse code dictionary
- Convert characters to their Morse code representations
- Add space tokens between Morse code characters for clarity
- Control sequence length through `max_len` parameter

### Key Design Considerations:
- Generates diverse training samples
- Allows control over dataset complexity
- Simulates real-world variability in input sequences

In [None]:

# Reverse dictionary for decoding
REVERSE_MORSE_CODE_DICT = {v: k for k, v in MORSE_CODE_DICT.items()}

# Prepare the dataset
def generate_dataset(num_samples=1000, max_len=10):
    """Generates random Morse code sequences with corresponding text translations."""
    dataset = []
    for _ in range(num_samples):
        seq_length = random.randint(1, max_len)
        text_seq = [random.choice(list(MORSE_CODE_DICT.keys())) for _ in range(seq_length)]
        morse_seq = [MORSE_CODE_DICT[char] for char in text_seq]
        morse_seq = f'{SPACE_CHAR}'.join(morse_seq)  # Add a space token between characters
        dataset.append((morse_seq, text_seq))
    return dataset

## Data Preparation and Vocabulary Creation

The `prepare_data()` function transforms our raw Morse code and text sequences into tensor representations suitable for neural network training.

### Preprocessing Steps:
1. Convert sequences to vocabulary indices
2. Add special tokens (SOS, EOS)
3. Pad sequences to uniform length
4. Convert to PyTorch tensors

In [None]:

# Convert Morse and text to indexed tensors
def prepare_data(dataset, morse_vocab, text_vocab, max_morse_len, max_text_len):
    """Converts sequences into padded tensors with indexes."""
    morse_sequences, text_sequences = [], []
    for morse_seq, text_seq in dataset:
        morse_tensor = [morse_vocab[SOS_TOKEN]] + [morse_vocab[ch] for ch in morse_seq] + [morse_vocab[EOS_TOKEN]]
        text_tensor = [text_vocab[SOS_TOKEN]] + [text_vocab[ch] for ch in text_seq] + [text_vocab[EOS_TOKEN]]
        
        morse_tensor += [morse_vocab[PADDING_TOKEN]] * (max_morse_len - len(morse_tensor))
        text_tensor += [text_vocab[PADDING_TOKEN]] * (max_text_len - len(text_tensor))
        
        morse_sequences.append(morse_tensor)
        text_sequences.append(text_tensor)
    return torch.tensor(morse_sequences, dtype=torch.long), torch.tensor(text_sequences, dtype=torch.long)

### Vocabulary Mapping:
We create two key vocabularies:
- `morse_vocab`: Maps Morse code tokens to unique indices
- `text_vocab`: Maps text characters to unique indices

In [None]:
# Define vocabulary
morse_vocab = {ch: idx for idx, ch in enumerate([PADDING_TOKEN, SOS_TOKEN, EOS_TOKEN, SPACE_CHAR, '0', '1'])}
text_vocab = {ch: idx for idx, ch in enumerate([PADDING_TOKEN, SOS_TOKEN, EOS_TOKEN] + list(MORSE_CODE_DICT.keys()))}
reverse_text_vocab = {idx: ch for ch, idx in text_vocab.items()}

### Splitting Data:
- Divide dataset into training and validation sets
- Use 80% of data for training, 20% for validation
- Ensures model can generalize beyond training data

In [None]:

# Dataset parameters
NUM_SAMPLES = 80000
TRAIN_RATIO = 0.8
NUM_TRAIN = int(NUM_SAMPLES * TRAIN_RATIO)
MAX_MORSE_LEN = 100
MAX_TEXT_LEN = 12

data = generate_dataset(NUM_SAMPLES)
morse_tensor, text_tensor = prepare_data(data, morse_vocab, text_vocab, MAX_MORSE_LEN, MAX_TEXT_LEN)

train_morse, val_morse = morse_tensor[:NUM_TRAIN], morse_tensor[NUM_TRAIN:]
train_text, val_text = text_tensor[:NUM_TRAIN], text_tensor[NUM_TRAIN:]


## Positional Encoding: Adding Sequence Context

The `PositionalEncoding` class is critical in helping Transformers understand the order of tokens. Unlike RNNs, Transformers process entire sequences simultaneously, so we need to inject position information.

### Positional Encoding Techniques:
- Use sine and cosine functions to create unique position embeddings
- Embeddings have different frequencies, allowing unique representations
- Add these encodings to input embeddings

### Why Positional Encoding?
- Provides sequence order information
- Allows model to distinguish between tokens based on their position
- Enables parallel processing while maintaining sequence context

In [None]:
# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1), :].to(x.device)


## Transformer Model Architecture

The `TransformerModel` class represents our complete neural translation model. It encapsulates the entire Transformer architecture, including embedding layers, positional encoding, and the core Transformer module.

### Model Components:
1. **Input Embedding**: Convert Morse code tokens to dense vector representations
2. **Target Embedding**: Convert text tokens to dense vector representations
3. **Positional Encoding**: Add sequence position information
4. **Transformer Module**: Core sequence-to-sequence translation mechanism
5. **Output Layer**: Project transformer outputs to vocabulary space

### Key Design Principles:
- Separate embeddings for input (Morse) and target (text) sequences
- Scale embeddings by `sqrt(d_model)` to control variance
- Use multiple attention heads for capturing different representation aspects

### Attention Mechanism
- The `generate_square_subsequent_mask()` method creates a causal mask
- Prevents decoder from attending to future tokens during training
- Crucial for autoregressive generation

In [None]:
# Transformer Model
class TransformerModel(nn.Module):
    def __init__(self, input_vocab_size, target_vocab_size, d_model=512, nhead=8, num_layers=3, dim_feedforward=2048):
        super(TransformerModel, self).__init__()
        self.d_model = d_model

        # Embedding layers
        self.input_embedding = nn.Embedding(input_vocab_size, d_model)
        self.target_embedding = nn.Embedding(target_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)

        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead, num_encoder_layers=num_layers, 
            num_decoder_layers=num_layers, dim_feedforward=dim_feedforward
        )

        # Output layer
        self.fc_out = nn.Linear(d_model, target_vocab_size)

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
        return mask

    def forward(self, src, tgt):
        src_mask = None
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(src.device)
        
        src_emb = self.positional_encoding(self.input_embedding(src) * np.sqrt(self.d_model))
        tgt_emb = self.positional_encoding(self.target_embedding(tgt) * np.sqrt(self.d_model))

        output = self.transformer(src_emb.transpose(0, 1), tgt_emb.transpose(0, 1),
                                  src_mask=src_mask, tgt_mask=tgt_mask)
        output = self.fc_out(output.transpose(0, 1))
        return output

In [None]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss

    def __call__(
        self, current_valid_loss,
        epoch, model, optimizer
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch}\n")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),

                }, 'bstmdl.pth')



## Training and Evaluation Infrastructure

We've designed comprehensive training and evaluation functions:

### `train()` Function:
- Manages model training loop
- Tracks training and validation losses
- Implements best model checkpoint saving
- Provides real-time training progress monitoring

In [None]:
def train(model, train_data, val_data, optimizer, criterion, num_epochs=10, batch_size=32):

    savebest = SaveBestModel()
    
    train_morse, train_text = train_data
    val_morse, val_text = val_data
    
    # Create DataLoader for training data
    train_dataset = torch.utils.data.TensorDataset(train_morse, train_text)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    Losses_Train = torch.empty(num_epochs)
    Losses_Valid = torch.empty(num_epochs)
    Accs_Train = torch.empty(num_epochs)
    Accs_Valid = torch.empty(num_epochs)
    
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (src, tgt) in enumerate(train_loader):
            src, tgt = src.to(device), tgt.to(device)
            tgt_input = tgt[:, :-1]  # Input to the decoder (everything except the last token)
            tgt_output = tgt[:, 1:]  # Ground truth (everything except the first token)

            optimizer.zero_grad()
            output = model(src, tgt_input)

            # Calculate loss
            loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Calculate accuracy
            pred = output.argmax(2)  # Predicted tokens
            correct += (pred == tgt_output).sum().item()
            total += tgt_output.numel()

            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], "
                      f"Loss: {loss.item():.4f}")

        val_loss, val_acc = evaluate(model, val_data, criterion, batch_size)

        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}, "
              f"Accuracy: {correct / total:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")
        savebest(total_loss / len(train_loader), epoch + 1, model, optimizer)
        
        Losses_Train[epoch] = total_loss / len(train_loader)
        Losses_Valid[epoch] = val_loss
        Accs_Train[epoch] = correct / total  
        Accs_Valid[epoch] = val_acc
    
    history = {"Loss_Train"     : Losses_Train, 
               "Loss_Valid"     : Losses_Valid, 
               "Accuracy_Train" : Accs_Train,  
               "Accuracy_Valid" : Accs_Valid,}
    return history

### `evaluate()` Function:
- Performs model evaluation on validation dataset
- Computes loss and accuracy metrics
- Uses no-gradient context for efficiency

In [None]:

def evaluate(model, val_data, criterion, batch_size=32):
    model.eval()
    val_morse, val_text = val_data
    
    # Create DataLoader for validation data
    val_dataset = torch.utils.data.TensorDataset(val_morse, val_text)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)
    
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for src, tgt in val_loader:
            src, tgt = src.to(device), tgt.to(device)
            tgt_input = tgt[:, :-1]  # Input to the decoder
            tgt_output = tgt[:, 1:]  # Ground truth

            output = model(src, tgt_input)

            # Calculate loss
            loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
            total_loss += loss.item()

            # Calculate accuracy
            pred = output.argmax(2)
            correct += (pred == tgt_output).sum().item()
            total += tgt_output.numel()

    return total_loss / len(val_loader), correct / total

### Optimization Strategies:
- Adam optimizer for adaptive learning rates
- Cross-entropy loss with padding token ignore
- Batch processing for computational efficiency

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



# Instantiate the model
model = TransformerModel(
    input_vocab_size=len(morse_vocab),
    target_vocab_size=len(text_vocab),
    d_model=256,  # Reduced for faster training
    nhead=16,
    num_layers=3,
    dim_feedforward=256  # Reduced for faster training
).to(device)

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=text_vocab[PADDING_TOKEN])

# Train the model

In [None]:
batch_size = 256  # Adjust based on your GPU memory
train_data = (train_morse, train_text)
val_data = (val_morse, val_text)
num_epochs = 500

print(train_morse[1], train_text[1])
print(data[1])



## Visualization and Model Inspection

We've included additional tools for model understanding:
- `torchview` for computational graph visualization
- `torchinfo` for detailed model architecture summary
- Matplotlib plots for training and validation metrics

In [None]:

graphviz.set_jupyter_format('png')

model_graph = draw_graph(model, input_size=[(2, MAX_MORSE_LEN), (2, MAX_TEXT_LEN - 1)], expand_nested=True, device=device, dtypes=[torch.long, torch.long])
model_graph.visual_graph

In [None]:
summary(model, input_size=[(2, MAX_MORSE_LEN), (2, MAX_TEXT_LEN - 1)] ,device=device, dtypes=[torch.long, torch.long])

In [None]:
print(sum(p.numel() for p in model.parameters()))

In [None]:
history = train(model, train_data, val_data, optimizer, criterion, num_epochs, batch_size)

In [None]:

plt.plot(history["Loss_Train"])
plt.plot(history["Loss_Valid"])
plt.show()

In [None]:
plt.plot(history["Accuracy_Train"])
plt.plot(history["Accuracy_Valid"])
plt.show()

In [None]:
# Evaluation
val_loss, val_acc = evaluate(model, val_data, criterion)
print(f"Final Validation Loss: {val_loss:.4f}, Final Validation Accuracy: {val_acc:.4f}")

## Inference and Translation

The `translate_sequence()` function demonstrates **how to use our trained model** for actual Morse code to text translation:

### Translation Process:
1. Preprocess input Morse sequence
2. Generate text tokens sequentially
3. Stop at End-of-Sequence (EOS) token
4. Convert token indices back to characters

In [None]:
def translate_sequence(binary_sequence, model, morse_vocab, text_vocab, reverse_text_vocab, max_morse_len):
    """
    Translates a binary Morse sequence (with spaces) into text using the trained model.
    
    Args:
        binary_sequence (str): Input Morse sequence in binary (e.g., "01 10 111").
        model (nn.Module): Trained Transformer model.
        morse_vocab (dict): Vocabulary mapping for Morse code.
        text_vocab (dict): Vocabulary mapping for text.
        reverse_text_vocab (dict): Reverse mapping from text indices to characters.
        max_morse_len (int): Maximum length of Morse sequences (for padding).
    
    Returns:
        str: Translated text sequence.
    """
    # Preprocess the input binary sequence
    
    morse_tensor = [morse_vocab[SOS_TOKEN]] + [morse_vocab[ch] for ch in binary_sequence if ch in morse_vocab] + [morse_vocab[EOS_TOKEN]]
    morse_tensor += [morse_vocab[PADDING_TOKEN]] * (max_morse_len - len(morse_tensor))
    morse_tensor = torch.tensor(morse_tensor, dtype=torch.long).unsqueeze(0).to(device)
    
    # Prepare the target input (start with <SOS>)
    tgt_input = torch.tensor([text_vocab[SOS_TOKEN]], dtype=torch.long).unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        for _ in range(MAX_TEXT_LEN):  # Generate until max text length
            output = model(morse_tensor, tgt_input)
            next_token = output.argmax(2)[:, -1:]  # Get the next token
            tgt_input = torch.cat((tgt_input, next_token), dim=1)  # Append the next token

            if next_token.item() == text_vocab[EOS_TOKEN]:  # Stop if <EOS> is generated
                break

    # Convert output indices to text
    translated_text = ''.join(reverse_text_vocab[idx.item()] for idx in tgt_input[0, 1:-1])  # Exclude <SOS> and <EOS>
    return translated_text

# Example usage
binary_sequence = "00 10 1000 000 11 01 101"  # Example Morse input (binary with spaces)
translated_text = translate_sequence(binary_sequence, model, morse_vocab, text_vocab, reverse_text_vocab, MAX_MORSE_LEN)
print(f"Input Morse Sequence: {binary_sequence}")
print(f"Translated Text: {translated_text}")

### Learning Outcome
By the end of this tutorial, you'll have:
- Understood Transformer architecture fundamentals
- Implemented a sequence-to-sequence translation model
- Learned practical PyTorch implementation techniques
- Explored an interesting domain-specific translation task

## References
- ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762) - Original Transformer paper
- [Detailed PyTorch Transformer Guide](https://towardsdatascience.com/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1) - The current implementation is inspired by this work. 

