This project is a complete implementation of the Transformer architecture in Rust, as detailed in the paper "Attention Is All You Need" by Vaswani et al. (2017). It uses the Candle deep learning framework for building and training the model.
This project is currently under active development. The core architectural components are in place, with future work planned for the following:
- Full dataset loading and preprocessing pipelines.
- Evaluation metrics, such as BLEU score calculation.
- Implementation of the Transformer Big configuration.
- Support for exporting model weights to standard formats (e.g., ONNX).
The primary goal is to create a faithful reproduction of the original Transformer model, adhering to the architecture, components, and training specifications described in the paper. This implementation focuses on the Transformer Base configuration.
The codebase is organized into modular Rust files with clear separation of concerns:
src/main.rs: The main binary entry point for running the training process and inference.src/lib.rs: The root of therustformerlibrary crate.src/transformer.rs: Contains the core building blocks of the Transformer, including theEncoder,Decoder,EncoderLayer, andDecoderLayerstructs.src/attention.rs: Implements theScaledDotProductAttentionandMultiHeadAttentionmechanisms.src/config.rs: Defines the model configuration and hyperparameters, such asd_model,d_ff,n_heads, and dropout rates.src/model_args.rs: Defines theModelArgsstruct for holding the model's architectural parameters.src/data.rs: Handles data loading, tokenization, and batching for training.src/train.rs: Implements the main training loop, including the optimizer setup and learning rate scheduling.src/optimizer.rs: Defines the Adam optimizer with the specific hyperparameters (β₁, β₂, ε) from the paper.src/generation.rs: Contains logic for generating sequences during inference.src/metrics.rs: (Future) Intended for evaluation metrics like BLEU score.src/monitoring.rs: (Future) Intended for logging and monitoring during training.
This implementation follows the Transformer Base model as specified in the paper.
- Encoder: A stack of N=6 identical layers, each with a multi-head self-attention mechanism and a position-wise feed-forward network.
- Decoder: A stack of N=6 identical layers, featuring masked multi-head self-attention, encoder-decoder attention, and a position-wise feed-forward network.
- Multi-Head Attention: Splits queries, keys, and values into 8 parallel attention heads.
- Position-wise Feed-Forward Network: A two-layer fully connected network with a ReLU activation.
- Positional Encoding: Uses sine and cosine functions to inject position information into the input embeddings.
| Parameter | Value |
|---|---|
| Layers (N) | 6 |
| d_model | 512 |
| d_ff | 2048 |
| Heads (h) | 8 |
| d_k / d_v | 64 |
| Dropout | 0.1 |
-
Build the project:
cargo build --release
-
Run the training:
cargo run --release
(Note: Dataset and command-line arguments for training configuration will be specified in
main.rs)