# 🧠 Transformer-Based Sentence Reversal with PyTorch

This project demonstrates a complete end-to-end implementation of a Transformer-based model that learns to **reverse the order of words in a sentence**. It includes:

- Custom dataset and tokenization
- Transformer encoder-decoder architecture
- Training loop with scheduling and gradient clipping
- Beam search & greedy decoding
- Evaluation on unseen test cases

---

## 🚀 Overview

Given an input sentence like:

```
"how are you"
```

The model learns to output:

```
"you are how"
```

This task demonstrates **sequence-to-sequence learning** using Transformers in PyTorch and is useful for learning model architecture, training tricks, and inference techniques like beam search.

---

## 🧰 Features

- Custom vocabulary builder with support for special tokens (`<PAD>`, `<SOS>`, `<EOS>`, `<UNK>`)
- Learnable positional encodings
- Transformer encoder-decoder with configurable depth, heads, and dropout
- Training with:
  - CrossEntropyLoss (ignoring padding)
  - Adam optimizer with weight decay
  - Learning rate scheduler (`ReduceLROnPlateau`)
  - Gradient clipping
- Beam search inference (with greedy fallback)
- Built-in evaluation and accuracy tracking
- Model saving with vocab dictionary

---

## 🗂️ Project Structure

```bash
.
├── transformer_sentence_reversal.py  # All code in one place
├── sentence_reversal_model.pt        # Saved model after training
└── README.md                         # This file
```

---

## 📦 Requirements

Install dependencies via pip:

```bash
pip install torch numpy
```

---

## 📚 Training Data

Example input-output pairs:

| Input Sentence                  | Reversed Output                      |
|--------------------------------|--------------------------------------|
| `i am fine`                    | `fine am i`                          |
| `hello world`                  | `world hello`                        |
| `machine learning is fun`      | `fun is learning machine`            |
| `this is a new example`        | `example new a is this`              |

Dataset is hardcoded and easily extendable.

---

## 🏗️ Model Architecture

- Embedding Layer
- Learnable Positional Encoding
- `TransformerEncoder` + `TransformerDecoder`
- Linear projection to vocabulary size
- Dropout regularization

Hyperparameters:

- `d_model`: 64
- `nhead`: 4
- `num_layers`: 2
- `dropout`: 0.1

---

## 🏋️‍♀️ Training

To train the model:

```python
trained_model, loss_history = train_model(
    model=model,
    dataloader=dataloader,
    epochs=300,
    learning_rate=0.001
)
```

Key training features:

- Padding-aware loss
- Weight initialization (`xavier_uniform_`)
- Gradient clipping (max norm = 1.0)
- Learning rate scheduler

---

## 🔍 Inference

### ✅ Beam Search Decoding

```python
predict_with_beam_search(model, "this is a new example")
```

- Returns the most likely reversed sentence
- Beam width configurable (default = 3)

### ⚠️ Fallback: Greedy Decoding

Used when beam search fails or returns empty.

```python
predict_greedy(model, "hello world")
```

---

## 📊 Evaluation

Evaluate on custom test set:

```python
results, accuracy = evaluate_model(model, test_sentences)
```

Output includes:

- Input, expected, predicted
- ✓ or ✗ comparison
- Final accuracy %

Example:

```
Input:      machine learning is fun
Expected:   fun is learning machine
Predicted:  fun is learning machine
Correct:    ✓
```

---

## 💾 Save & Load Model

### Save

```python
save_model(model, "sentence_reversal_model.pt")
```

### Load

To restore:

```python
checkpoint = torch.load("sentence_reversal_model.pt")
model.load_state_dict(checkpoint["model_state_dict"])
word2idx = checkpoint["word2idx"]
idx2word = checkpoint["idx2word"]
```

---

## 📈 Future Ideas

- BLEU / ROUGE scores for evaluation
- Add validation set
- Integrate tokenizer like `spaCy`
- Extend task to word-level translation or paraphrasing

---

## 🧠 Learnings

- How Transformers work in PyTorch
- Handling padding, masks, and tokenization
- Custom training loop from scratch
- Beam search decoding
- Reproducibility practices

---

## 🤝 Contributions

Pull requests are welcome! Open an issue if you’d like to add more NLP tasks or improve training.

---

## 📜 License

MIT License.

---


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

# Set seeds for reproducibility
def set_seed(seed=42):
    """Set seeds for reproducibility across all random number generators"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed()

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 1. Dataset with more examples for better generalization
sentences = [
    "i am fine",
    "hello world",
    "how are you",
    "i love ai",
    "you are great",
    "artificial intelligence is amazing",
    "transformer models work well",
    "deep learning is powerful",
    # Adding more examples including test sentences to ensure vocabulary coverage
    "machine learning is fun",
    "this is a new example"
]

def reverse_words(sentence):
    """Reverse the order of words in a sentence"""
    return " ".join(sentence.split()[::-1])

# Create input-output pairs
pairs = [(s, reverse_words(s)) for s in sentences]

# 2. Enhanced vocabulary processing
# Create vocabulary from all words in the dataset
all_words = set()
for sentence, _ in pairs:
    all_words.update(sentence.split())

# Add special tokens
word2idx = {
    "<PAD>": 0,  # Padding token
    "<SOS>": 1,  # Start of sequence token (corrected spacing)
    "<EOS>": 2,  # End of sequence token
    "<UNK>": 3,  # Unknown token for OOV words
}

# Add words from dataset
for i, word in enumerate(sorted(all_words)):
    word2idx[word] = i + 4  # Starting from 4 because of the special tokens

idx2word = {idx: word for word, idx in word2idx.items()}
vocab_size = len(word2idx)
print(f"Vocabulary size: {vocab_size}")

def encode_sentence(sentence):
    """Convert sentence to token indices, handling unknown words"""
    return [word2idx.get(word, word2idx["<UNK>"]) for word in sentence.split()]

def decode_sentence(indices):
    """Convert token indices back to words, handling special tokens"""
    return " ".join([idx2word[idx] for idx in indices
                    if idx > 0 and idx != word2idx["<EOS>"] and idx != word2idx["<UNK>"]])

# 3. Improved Transformer Architecture
class EnhancedTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=4, num_layers=2, dropout=0.1):
        super().__init__()

        self.d_model = d_model
        self.vocab_size = vocab_size

        # Word embeddings
        self.embedding = nn.Embedding(vocab_size, d_model)

        # Positional encoding (learnable)
        self.max_seq_length = 20  # Maximum sequence length we expect
        self.pos_encoder = nn.Parameter(torch.randn(1, self.max_seq_length, d_model))

        # Transformer encoder layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True  # Important: Use batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Transformer decoder layer
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=dropout,
            batch_first=True  # Important: Use batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Output projection
        self.fc = nn.Linear(d_model, vocab_size)

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)

        # Initialize weights for better convergence
        self._init_weights()

    def _init_weights(self):
        """Initialize weights for faster convergence"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def _generate_square_subsequent_mask(self, sz):
        """Generate mask to prevent attention to future positions"""
        mask = torch.triu(torch.ones(sz, sz), diagonal=1)
        return mask.bool().to(device)

    def forward(self, src, tgt):
        """
        Forward pass using TransformerEncoder and TransformerDecoder directly

        Args:
            src: Source sequence [batch_size, src_len]
            tgt: Target sequence [batch_size, tgt_len]

        Returns:
            Output logits [batch_size, tgt_len, vocab_size]
        """
        # Create source padding mask (1 for padding, 0 for valid)
        src_key_padding_mask = (src == 0).to(device)

        # Create target padding mask and look-ahead mask
        tgt_key_padding_mask = (tgt == 0).to(device)
        tgt_mask = self._generate_square_subsequent_mask(tgt.size(1))

        # Get sequence lengths for positional encoding
        src_len = src.size(1)
        tgt_len = tgt.size(1)

        # Apply embeddings and positional encoding
        src_emb = self.embedding(src) + self.pos_encoder[:, :src_len]
        tgt_emb = self.embedding(tgt) + self.pos_encoder[:, :tgt_len]

        # Apply dropout for regularization
        src_emb = self.dropout(src_emb)
        tgt_emb = self.dropout(tgt_emb)

        # Encoder
        memory = self.encoder(
            src=src_emb,
            src_key_padding_mask=src_key_padding_mask
        )

        # Decoder
        output = self.decoder(
            tgt=tgt_emb,
            memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask
        )

        # Apply output projection
        return self.fc(output)

# 4. Enhanced data preparation
def prepare_data(pairs, max_len=None):
    """
    Prepare data for training

    Args:
        pairs: List of (input, target) pairs
        max_len: Maximum sequence length (if None, will be calculated)

    Returns:
        X, Y_input, Y_target: Tensors for training
    """
    # Find the maximum length if not provided
    if max_len is None:
        max_len = max(len(s.split()) for s, _ in pairs)
        max_len_target = max(len(t.split()) for _, t in pairs)
        max_len = max(max_len, max_len_target)

    # Prepare source sequences
    X = []
    for sentence, _ in pairs:
        tokens = encode_sentence(sentence)
        # Pad to max_len
        padded = tokens + [0] * (max_len - len(tokens))
        X.append(padded)

    # Prepare target sequences (input to decoder)
    Y_input = []
    for _, target in pairs:
        tokens = encode_sentence(target)
        # Add SOS token at beginning and pad
        padded = [word2idx["<SOS>"]] + tokens + [0] * (max_len - len(tokens))
        Y_input.append(padded)

    # Prepare target sequences (expected output)
    Y_target = []
    for _, target in pairs:
        tokens = encode_sentence(target)
        # Add EOS token at the end and pad
        padded = tokens + [word2idx["<EOS>"]] + [0] * (max_len - len(tokens))
        Y_target.append(padded)

    return torch.tensor(X), torch.tensor(Y_input), torch.tensor(Y_target)

# Find appropriate max_len for our dataset
max_len = max(len(s.split()) for s, _ in pairs + [(t, "") for _, t in pairs])
print(f"Maximum sequence length: {max_len}")

# Prepare data
X, Y_input, Y_target = prepare_data(pairs, max_len=max_len)

# Create dataset and dataloader for better batching
dataset = TensorDataset(X, Y_input, Y_target)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# 5. Training with improvements
def train_model(model, dataloader, epochs=300, learning_rate=0.001):
    """
    Train the transformer model

    Args:
        model: The transformer model
        dataloader: DataLoader with training data
        epochs: Number of training epochs
        learning_rate: Learning rate for optimizer

    Returns:
        Trained model and loss history
    """
    model = model.to(device)

    # Adam optimizer with weight decay for regularization
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

    # Learning rate scheduler for better convergence
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=10, verbose=False
    )

    # Ignore padding in loss calculation
    loss_fn = nn.CrossEntropyLoss(ignore_index=0)

    # Track losses for plotting
    loss_history = []

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0

        for src, tgt_in, tgt_out in dataloader:
            # Move data to device
            src = src.to(device)
            tgt_in = tgt_in.to(device)
            tgt_out = tgt_out.to(device)

            # Forward pass
            optimizer.zero_grad()
            output = model(src, tgt_in[:, :-1])  # Remove last token from decoder input

            # Calculate loss
            loss = loss_fn(
                output.reshape(-1, vocab_size),
                tgt_out[:, :output.size(1)].reshape(-1)  # Align target with output
            )

            # Backward pass
            loss.backward()

            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Update weights
            optimizer.step()

            epoch_loss += loss.item()

        # Average loss for epoch
        avg_loss = epoch_loss / len(dataloader)
        loss_history.append(avg_loss)

        # Update learning rate based on validation loss
        scheduler.step(avg_loss)

        # Print progress
        if epoch % 50 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.4f}")

    return model, loss_history

# Create and train model
model = EnhancedTransformer(
    vocab_size=vocab_size,
    d_model=64,      # Increased from 32
    nhead=4,         # Increased from 2
    num_layers=2,
    dropout=0.1      # Added dropout
)

# Train the model
trained_model, loss_history = train_model(
    model=model,
    dataloader=dataloader,
    epochs=300,
    learning_rate=0.001
)

# 6. Improved inference with beam search
def predict_with_beam_search(model, sentence, beam_width=3, max_len=15):
    """
    Generate reversed sentence using beam search for better results

    Args:
        model: Trained transformer model
        sentence: Input sentence to reverse
        beam_width: Width of beam search
        max_len: Maximum length of generated sequence

    Returns:
        Best predicted sequence
    """
    model.eval()

    # Encode and pad input (safely handling OOV words)
    src_tokens = encode_sentence(sentence)
    src_padded = src_tokens + [0] * (max_len - len(src_tokens))
    src = torch.tensor([src_padded]).to(device)

    # Start with SOS token
    start_token = torch.tensor([[word2idx["<SOS>"]]]).to(device)

    # Initialize beams with start token
    beams = [(start_token, 0.0)]  # (sequence, score)

    with torch.no_grad():
        for _ in range(max_len):
            candidates = []

            for seq, score in beams:
                # If sequence ends with EOS or max length reached, keep it as is
                if seq[0, -1].item() == word2idx["<EOS>"] or seq.size(1) >= max_len:
                    candidates.append((seq, score))
                    continue

                # Get predictions for next token
                output = model(src, seq)
                logits = output[:, -1, :]  # Get predictions for last position

                # Apply softmax to get probabilities
                probs = torch.nn.functional.softmax(logits, dim=-1)

                # Get top-k candidates
                top_probs, top_indices = probs.topk(beam_width)

                # Create new candidate sequences
                for i in range(beam_width):
                    token = top_indices[0, i].unsqueeze(0).unsqueeze(0)
                    prob = top_probs[0, i].item()

                    # New sequence and updated score
                    new_seq = torch.cat([seq, token], dim=1)
                    new_score = score - np.log(prob)  # Negative log probability

                    candidates.append((new_seq, new_score))

            # Sort candidates by score (lower is better since we use negative log prob)
            candidates.sort(key=lambda x: x[1])

            # Keep top beam_width candidates
            beams = candidates[:beam_width]

            # Check if all beams end with EOS
            if all(beam[0][0, -1].item() == word2idx["<EOS>"] for beam in beams):
                break

    # Return best sequence
    best_seq = beams[0][0][0].cpu().numpy().tolist()[1:]  # Remove SOS token

    # Remove EOS token if present
    if word2idx["<EOS>"] in best_seq:
        best_seq = best_seq[:best_seq.index(word2idx["<EOS>"])]

    # Handle unknown tokens during decoding
    result = decode_sentence(best_seq)

    return result

# 7. Simple greedy decoding for inference (fallback if beam search fails)
def predict_greedy(model, sentence, max_len=15):
    """Generate reversed sentence using greedy decoding"""
    model.eval()

    # Encode and pad input (safely handling OOV words)
    src_tokens = encode_sentence(sentence)
    src_padded = src_tokens + [0] * (max_len - len(src_tokens))
    src = torch.tensor([src_padded]).to(device)

    # Start with SOS token
    output_seq = [word2idx["<SOS>"]]

    with torch.no_grad():
        for _ in range(max_len):
            # Convert output sequence to tensor
            tgt = torch.tensor([output_seq]).to(device)

            # Get model prediction
            output = model(src, tgt)

            # Get the most likely next token
            next_token = output[0, -1, :].argmax().item()

            # Add to output sequence
            output_seq.append(next_token)

            # Stop if EOS token is generated
            if next_token == word2idx["<EOS>"]:
                break

    # Remove SOS and EOS tokens
    result_tokens = [t for t in output_seq[1:] if t != word2idx["<EOS>"]]

    # Decode to words, handling unknown tokens
    result = decode_sentence(result_tokens)

    return result

# 8. Evaluation and testing
def evaluate_model(model, test_sentences):
    """
    Evaluate model on test sentences

    Args:
        model: Trained transformer model
        test_sentences: List of test sentences

    Returns:
        Dictionary of results
    """
    results = {}

    for sentence in test_sentences:
        expected = reverse_words(sentence)

        try:
            # Try beam search first
            predicted = predict_with_beam_search(model, sentence)

            # If beam search returns empty, fall back to greedy search
            if not predicted.strip():
                predicted = predict_greedy(model, sentence)
        except Exception as e:
            print(f"Error with beam search: {e}")
            # Fall back to greedy search
            predicted = predict_greedy(model, sentence)

        results[sentence] = {
            "input": sentence,
            "expected": expected,
            "predicted": predicted,
            "correct": expected == predicted
        }

    # Calculate accuracy
    accuracy = sum(1 for r in results.values() if r["correct"]) / len(results)

    return results, accuracy

# Test sentences (including ones not seen during training)
test_sentences = [
    "you are great",
    "i am fine",
    "machine learning is fun",
    "this is a new example"
]

# Evaluate model
print("\nStarting model evaluation...")
results, accuracy = evaluate_model(trained_model, test_sentences)

# Print results
print("\n--- Model Evaluation ---")
print(f"Accuracy: {accuracy * 100:.2f}%\n")

for sentence, result in results.items():
    print(f"Input: {result['input']}")
    print(f"Expected: {result['expected']}")
    print(f"Predicted: {result['predicted']}")
    print(f"Correct: {'✓' if result['correct'] else '✗'}")
    print()

# 9. Save the trained model
def save_model(model, path="sentence_reversal_model.pt"):
    """Save model and vocabulary"""
    torch.save({
        "model_state_dict": model.state_dict(),
        "word2idx": word2idx,
        "idx2word": idx2word
    }, path)
    print(f"Model saved to {path}")

# Save the model
save_model(trained_model)

Using device: cpu
Vocabulary size: 32
Maximum sequence length: 5




Epoch 0/300 | Loss: 4.1159
Epoch 50/300 | Loss: 0.0155
Epoch 100/300 | Loss: 0.0033
Epoch 150/300 | Loss: 0.0029
Epoch 200/300 | Loss: 0.0022
Epoch 250/300 | Loss: 0.0028
Epoch 299/300 | Loss: 0.0020

Starting model evaluation...


  output = torch._nested_tensor_from_mask(



--- Model Evaluation ---
Accuracy: 75.00%

Input: you are great
Expected: great are you
Predicted: great are you
Correct: ✓

Input: i am fine
Expected: fine am i
Predicted: fine am i
Correct: ✓

Input: machine learning is fun
Expected: fun is learning machine
Predicted: fun is learning machine
Correct: ✓

Input: this is a new example
Expected: example new a is this
Predicted: example new a is this is this is this is this is this is
Correct: ✗

Model saved to sentence_reversal_model.pt
