MLX backend for Burn — native Apple Silicon GPU acceleration for deep learning.
This crate provides a Burn backend using Apple's MLX framework, enabling high-performance machine learning on M1/M2/M3/M4 Macs.
- Native Apple Silicon: Direct GPU acceleration via Metal
- Unified Memory: Zero-copy data sharing between CPU and GPU
- Lazy Evaluation: Automatic operation fusion and optimization
- Full Burn Backend: FloatTensorOps, IntTensorOps, BoolTensorOps, ModuleOps, ActivationOps
- Training Support: Pooling operations with backward passes for autodiff
- macOS with Apple Silicon (M1/M2/M3/M4)
- Rust 1.75+
Add to your Cargo.toml:
[dependencies]
burn-mlx = "0.1"
burn = "0.16"use burn::tensor::Tensor;
use burn_mlx::{Mlx, MlxDevice};
// Create tensors on Apple Silicon GPU
let device = MlxDevice::Gpu;
let a: Tensor<Mlx, 2> = Tensor::ones([2, 3], &device);
let b: Tensor<Mlx, 2> = Tensor::ones([2, 3], &device);
let c = a + b;
println!("Result shape: {:?}", c.shape());use burn::backend::Autodiff;
use burn_mlx::Mlx;
type TrainBackend = Autodiff<Mlx>;
// Now use TrainBackend for training with automatic differentiationburn-mlx provides full support for pooling operations with both forward and backward passes, enabling their use in training workflows.
use burn::tensor::Tensor;
use burn::nn::pool::{AvgPool2d, AvgPool2dConfig};
use burn_mlx::{Mlx, MlxDevice};
let device = MlxDevice::Gpu;
// Create a 4D tensor: [batch, channels, height, width]
let input: Tensor<Mlx, 4> = Tensor::ones([1, 3, 32, 32], &device);
// Create avg pool layer with 2x2 kernel and stride 2
let config = AvgPool2dConfig::new([2, 2]).with_strides([2, 2]);
let pool = AvgPool2d::new(config);
let output = pool.forward(input);
// Output shape: [1, 3, 16, 16]use burn::tensor::Tensor;
use burn::nn::pool::{MaxPool2d, MaxPool2dConfig};
use burn_mlx::{Mlx, MlxDevice};
let device = MlxDevice::Gpu;
let input: Tensor<Mlx, 4> = Tensor::ones([1, 3, 32, 32], &device);
// Create max pool layer with 2x2 kernel and stride 2
let config = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]);
let pool = MaxPool2d::new(config);
let output = pool.forward(input);
// Output shape: [1, 3, 16, 16]use burn::tensor::Tensor;
use burn::nn::pool::{AvgPool1d, AvgPool1dConfig, MaxPool1d, MaxPool1dConfig};
use burn_mlx::{Mlx, MlxDevice};
let device = MlxDevice::Gpu;
// Create a 3D tensor: [batch, channels, length]
let input: Tensor<Mlx, 3> = Tensor::ones([1, 64, 128], &device);
// Average pooling
let avg_config = AvgPool1dConfig::new(4).with_stride(4);
let avg_pool = AvgPool1d::new(avg_config);
let avg_output = avg_pool.forward(input.clone());
// Output shape: [1, 64, 32]
// Max pooling
let max_config = MaxPool1dConfig::new(4).with_stride(4);
let max_pool = MaxPool1d::new(max_config);
let max_output = max_pool.forward(input);
// Output shape: [1, 64, 32]use burn::tensor::Tensor;
use burn::nn::pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig};
use burn_mlx::{Mlx, MlxDevice};
let device = MlxDevice::Gpu;
let input: Tensor<Mlx, 4> = Tensor::ones([1, 512, 14, 14], &device);
// Adaptive pool to fixed output size (common before FC layers)
let config = AdaptiveAvgPool2dConfig::new([1, 1]);
let pool = AdaptiveAvgPool2d::new(config);
let output = pool.forward(input);
// Output shape: [1, 512, 1, 1]use burn_mlx::{MlxTensor, MlxDevice};
let device = MlxDevice::Gpu;
// Create tensors
let a = MlxTensor::<f32>::ones(&[1024, 1024], device);
let b = MlxTensor::<f32>::ones(&[1024, 1024], device);
// Operations
let c = a.matmul(&b);
let d = c.relu();
let e = d.softmax();
// Evaluate lazy computation
e.eval().expect("evaluation failed");- Arithmetic: add, sub, mul, div, matmul
- Math: exp, log, sqrt, abs, neg, pow
- Reductions: sum, mean, max, min, argmax, argmin
- Shape: reshape, transpose, permute, expand, slice, flip, scatter
- ReLU, Sigmoid, Tanh, GELU, LeakyReLU
- Softmax, LogSoftmax, HardSigmoid
- Conv1d, Conv2d (with proper NCHW layout handling)
- Embedding lookup
- Pooling (full forward and backward support):
- AvgPool1d, AvgPool2d
- MaxPool1d, MaxPool2d
- MaxPool2d with indices
- AdaptiveAvgPool1d, AdaptiveAvgPool2d
The pooling operations are implemented using MLX's as_strided function combined with reduction operations:
-
Forward Pass: Uses
as_stridedto create sliding window views over the input, then appliesmean_axes(avg pool) ormax_axes(max pool) for reduction. -
Backward Pass:
- AvgPool: Distributes gradients evenly across each pooling window using
scatter_add - MaxPool: Uses saved indices from forward pass to scatter gradients to max positions
- AvgPool: Distributes gradients evenly across each pooling window using
-
Layout Handling: Automatically converts between Burn's NCHW format and MLX's native NHWC format.
On Apple M-series chips, burn-mlx leverages:
- Metal Performance Shaders for optimized GPU kernels
- Unified memory architecture for efficient data transfer
- Lazy evaluation for automatic operation fusion
Typical matmul performance (1024x1024):
- ~12ms per operation on M1/M2
- Scales well with larger matrices
- macOS only (Apple Silicon required)
- Conv3d and ConvTranspose operations are placeholders
- Quantization support is minimal
- Dilation in pooling operations is not yet supported
Apache-2.0