# Wave Transformer Example with Ember ML

This notebook demonstrates the structure and forward pass of a Wave Transformer model built using the Ember ML framework. Wave Transformers leverage attention mechanisms for processing sequential data, and this example showcases how Ember ML's backend-agnostic components can be assembled to create such architectures.

In [None]:
# Import necessary libraries
import numpy as np

# Import Ember ML components
from ember_ml.ops import set_backend
from ember_ml.nn import tensor
from ember_ml import ops
from ember_ml.wave.models.wave_transformer import (
    WaveMultiHeadAttention,
    WaveTransformerEncoderLayer,
    WaveTransformerEncoder,
    WaveTransformer,
    create_wave_transformer,
)

# Set a backend (choose 'numpy', 'torch', or 'mlx')
# You can change this to see how the code runs on different backends
set_backend('numpy')
print(f"Using backend: {ops.get_backend()}")

## 1. Generate Synthetic Sequence Data

We will generate synthetic sequence data suitable for a transformer model. This data will have a batch dimension, a sequence length dimension, and a feature dimension (embedding dimension).

In [None]:
def create_dummy_sequence_data(batch_size, seq_length, embed_dim):
    """Creates dummy sequence data."""
    return tensor.random_normal((batch_size, seq_length, embed_dim), dtype=tensor.float32)

# Define data parameters
batch_size = 32
seq_length = 50
embed_dim = 64 # Feature dimension / Embedding dimension

# Generate data
input_sequence = create_dummy_sequence_data(batch_size, seq_length, embed_dim)

print(f"Input sequence shape: {tensor.shape(input_sequence)}")

## 2. Define the Wave Transformer Model

We will define a Wave Transformer model, showcasing its key components: Multi-Head Attention and Transformer Encoder Layers. We can use the `create_wave_transformer` factory function for convenience.

In [None]:
# Define model parameters
num_heads = 8
ff_hidden_dim = 128 # Hidden dimension for the feed-forward network
num_layers = 2 # Number of encoder layers

# Create the Wave Transformer model using the factory function
model = create_wave_transformer(
    seq_length=seq_length,
    embed_dim=embed_dim,
    num_heads=num_heads,
    ff_hidden_dim=ff_hidden_dim,
    num_layers=num_layers,
)

print("Model Architecture:")
print(model)

## 3. Demonstrate Forward Pass

We will pass the synthetic data through the Wave Transformer model to demonstrate a forward pass and observe the output shape.

In [None]:
# Perform a forward pass
# Note: Transformer models often require attention masks, but for simplicity,
# we omit them in this basic forward pass demonstration.
output = model(input_sequence)

print(f"Output sequence shape: {tensor.shape(output)}")

# The output shape should be the same as the input shape for a standard transformer encoder

## 4. Notes on Training

Training a complex model like a Transformer typically involves computing gradients and updating model parameters using an optimizer. While Ember ML provides optimizers (`ember_ml.training`) and the ability to compute gradients (`ops.gradients`), a full automatic differentiation system like `GradientTape` is required for seamless end-to-end training of complex, layered models. Without it, training would involve manual gradient calculations for each operation, which is complex for a Transformer.

For supervised tasks with a Transformer, you would typically define a loss function (e.g., from `ember_ml.training`), compute the loss between model predictions and true labels, calculate gradients of the loss with respect to trainable parameters, and apply these gradients using an optimizer. The `ops.gradients` function can compute gradients for a given output with respect to specific inputs (like model parameters), which can be used to implement a manual training loop.

## Conclusion

This notebook provided a basic demonstration of the Wave Transformer model in Ember ML, focusing on its structure and forward pass. It showcased how Ember ML's backend-agnostic modules can be used to build advanced architectures. While full training requires a more complete automatic differentiation system, the core components for building and running the model are available and work across different backends.