# Transformer Model Inference

This notebook demonstrates how to load a trained transformer model and use it for translation.
We'll load the model, tokenizers, and perform inference on examples from the test set as well as custom inputs.

In [None]:
from pathlib import Path
import torch
import torch.nn as nn
import time
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, HTML

from config import get_config, latest_weights_file_path
from train import get_model, get_ds, run_validation
from translate import translate

## Load Model and Datasets

First, we'll load the necessary components:
- Configure the device (CPU/GPU)
- Load configuration
- Get datasets and tokenizers
- Build the model
- Load pretrained weights

In [None]:
# Set up device with proper handling for different hardware
device = torch.device("cuda" if torch.cuda.is_available() else 
                     "mps" if torch.backends.mps.is_available() else 
                     "cpu")
print(f"Using device: {device}")

# Load configuration
config = get_config()
print(f"Source language: {config['lang_src']}")
print(f"Target language: {config['lang_tgt']}")
print(f"Dataset: {config['datasource']}")

try:
    # Load datasets and tokenizers
    print("Loading datasets and tokenizers...")
    start_time = time.time()
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    print(f"Datasets loaded in {time.time() - start_time:.2f} seconds")
    
    # Track vocabulary sizes for reference
    print(f"Source vocabulary size: {tokenizer_src.get_vocab_size():,}")
    print(f"Target vocabulary size: {tokenizer_tgt.get_vocab_size():,}")
    
    # Build model
    print("Building model...")
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    
    # Count model parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {total_params:,} parameters ({trainable_params:,} trainable)")
    
    # Load the pretrained weights with proper error handling
    model_filename = latest_weights_file_path(config)
    if model_filename and Path(model_filename).exists():
        print(f"Loading weights from {model_filename}")
        # Load weights to the correct device directly
        state = torch.load(model_filename, map_location=device)
        model.load_state_dict(state['model_state_dict'])
        print(f"Loaded weights from epoch {state.get('epoch', 'unknown')}")
    else:
        print(f"WARNING: No weights file found at {model_filename}")
except Exception as e:
    print(f"Error loading model: {e}")

## Run Validation on Test Examples

Let's evaluate our model on examples from the validation set. This helps verify that the model works correctly.

In [None]:
# Helper function to create HTML output for better visualization
def highlight_translation(source, target, predicted):
    return f"""
    <div style="padding: 10px; margin-bottom: 10px; border: 1px solid #ddd; border-radius: 5px;">
        <div style="font-weight: bold; color: #555;">Source:</div>
        <div style="padding: 5px 10px; margin-bottom: 5px;">{source}</div>
        <div style="font-weight: bold; color: #555;">Target:</div>
        <div style="padding: 5px 10px; margin-bottom: 5px;">{target}</div>
        <div style="font-weight: bold; color: #555;">Predicted:</div>
        <div style="padding: 5px 10px; background-color: #f5f5f5;">{predicted}</div>
    </div>
    """

# Function to capture validation output and display in a formatted way
validation_results = []
def capture_validation(msg):
    validation_results.append(msg)
    print(msg)

# Set model to evaluation mode for better inference performance
model.eval()

# Benchmark validation time
start_time = time.time()

# Run validation with more examples (adjust as needed)
num_examples = 10  # Increase for more thorough testing
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, 
               config['seq_len'], device, capture_validation, 0, None, 
               num_examples=num_examples)

print(f"\nValidation completed in {time.time() - start_time:.2f} seconds")

# Parse and visualize results (if available)
current_group = {}
all_examples = []

for line in validation_results:
    if line.startswith('SOURCE: '):
        current_group['source'] = line[len('SOURCE: '):].strip()
    elif line.startswith('TARGET: '):
        current_group['target'] = line[len('TARGET: '):].strip()
    elif line.startswith('PREDICTED: '):
        current_group['predicted'] = line[len('PREDICTED: '):].strip()
        if len(current_group) == 3:  # We have all components
            all_examples.append(dict(current_group))
            current_group = {}

# Display results in a more readable format
for example in all_examples:
    display(HTML(highlight_translation(
        example['source'], 
        example['target'], 
        example['predicted']
    )))

## Translate Custom Text

Now let's try translating custom text inputs. We'll use the updated `translate` function with timing.

In [None]:
# Enhanced custom translation function with timing
def translate_with_timing(text, temperature=1.0, show_details=True):
    """Translate text and measure performance metrics"""
    start_time = time.time()
    
    # Translate the text
    result = translate(text, temperature=temperature, show_progress=False)
    
    # Calculate metrics
    elapsed = time.time() - start_time
    
    if show_details:
        # Enhanced visualization
        source_text = text
        if isinstance(text, int) or text.isdigit():
            source_text = f"Example #{text} from dataset"
            
        display(HTML(f"""
        <div style="padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 10px 0;">
            <div style="display: flex; justify-content: space-between;">
                <h3 style="margin-top: 0;">Translation Result</h3>
                <div style="color: #888;">Completed in {elapsed:.3f}s</div>
            </div>
            <div style="margin-bottom: 10px;">
                <div style="font-weight: bold; color: #555;">Source ({config['lang_src']}):</div>
                <div style="padding: 8px; background: #f8f8f8; border-radius: 4px;">{source_text}</div>
            </div>
            <div>
                <div style="font-weight: bold; color: #555;">Translation ({config['lang_tgt']}):</div>
                <div style="padding: 8px; background: #f0f7ff; border-radius: 4px; font-size: 1.1em;">{result}</div>
            </div>
        </div>
        """))
    
    return result, elapsed

# Example 1: Translate a custom sentence with default settings
example1, time1 = translate_with_timing("Why do I need to translate this?")

In [None]:
# Example 2: Translate using a dataset example with index
example2, time2 = translate_with_timing(34)

## Translation with Different Parameters

Let's explore how different parameters affect translation results and performance.

In [None]:
# Test different temperature values
test_sentence = "The book is on the table near the window."
temperatures = [0.7, 1.0, 1.3]  # Conservative to diverse

results = []
timings = []

for temp in temperatures:
    print(f"\nTranslating with temperature = {temp}")
    result, elapsed = translate_with_timing(test_sentence, temperature=temp, show_details=False)
    results.append(result)
    timings.append(elapsed)

# Display comparison table
display(HTML(f"""
<div style="margin: 20px 0;">
    <h3>Temperature Comparison</h3>
    <p>Source: "{test_sentence}"</p>
    <table style="width: 100%; border-collapse: collapse; border: 1px solid #ddd;">
        <tr style="background-color: #f5f5f5;">
            <th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Temperature</th>
            <th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Translation</th>
            <th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Time (s)</th>
        </tr>
        {''.join(f'''
        <tr>
            <td style="padding: 8px; border: 1px solid #ddd;">{temp}</td>
            <td style="padding: 8px; border: 1px solid #ddd;">{results[i]}</td>
            <td style="padding: 8px; border: 1px solid #ddd;">{timings[i]:.3f}</td>
        </tr>
        ''' for i, temp in enumerate(temperatures))}
    </table>
</div>
"""))

# Plot timing comparison
plt.figure(figsize=(10, 5))
plt.bar(temperatures, timings, color='skyblue')
plt.xlabel('Temperature')
plt.ylabel('Time (seconds)')
plt.title('Translation Time vs Temperature')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

## Batch Translation Performance

Let's test batch translation performance with multiple examples.

In [None]:
# Batch translation test
test_sentences = [
    "Hello, how are you today?",
    "I would like to buy a train ticket to Rome.",
    "The weather is beautiful this morning.",
    "Can you recommend a good restaurant nearby?",
    "What time does the museum open tomorrow?"
]

# Measure total and per-sentence time
batch_start = time.time()
all_results = []
individual_times = []

for sentence in test_sentences:
    start = time.time()
    translation = translate(sentence, show_progress=False)
    end = time.time()
    all_results.append(translation)
    individual_times.append(end - start)
    
batch_total = time.time() - batch_start

# Performance metrics
avg_time = sum(individual_times) / len(individual_times)
throughput = len(test_sentences) / batch_total

# Display results
print(f"Batch translation completed in {batch_total:.2f} seconds")
print(f"Average time per sentence: {avg_time:.2f} seconds")
print(f"Throughput: {throughput:.2f} sentences per second")

# Show all translations in a formatted table
display(HTML(f"""
<div style="margin: 20px 0;">
    <h3>Batch Translation Results</h3>
    <table style="width: 100%; border-collapse: collapse; border: 1px solid #ddd;">
        <tr style="background-color: #f5f5f5;">
            <th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Source</th>
            <th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Translation</th>
            <th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Time (s)</th>
        </tr>
        {''.join(f'''
        <tr>
            <td style="padding: 8px; border: 1px solid #ddd;">{test_sentences[i]}</td>
            <td style="padding: 8px; border: 1px solid #ddd;">{all_results[i]}</td>
            <td style="padding: 8px; border: 1px solid #ddd;">{individual_times[i]:.3f}</td>
        </tr>
        ''' for i in range(len(test_sentences)))}
    </table>
</div>
"""))

# Plot individual translation times
plt.figure(figsize=(12, 6))
plt.bar(range(len(test_sentences)), individual_times, color='lightgreen')
plt.axhline(y=avg_time, color='red', linestyle='--', label=f'Average: {avg_time:.2f}s')
plt.xticks(range(len(test_sentences)), [s[:20] + '...' if len(s) > 20 else s for s in test_sentences], rotation=45)
plt.xlabel('Sentence')
plt.ylabel('Time (seconds)')
plt.title('Translation Time by Sentence')
plt.legend()
plt.tight_layout()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

## Performance Optimization Notes

Here are some optimization techniques that could further improve translation performance:

1. **Batch Processing**: Process multiple sentences at once to better utilize GPU parallelism
2. **Mixed Precision**: Use FP16 for faster computation on compatible GPUs
3. **Caching**: Cache encoder outputs for repeated translations of the same source
4. **Model Quantization**: Use quantized models (INT8) for faster inference
5. **Beam Search**: Implement beam search for better translation quality
6. **ONNX Export**: Export model to ONNX for runtime optimization
7. **JIT Compilation**: Use TorchScript for compiled model execution

These optimizations would require modifications to the core model code.