eformer (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a comprehensive collection of tools for distributed computing, custom data structures, numerical optimization, and high-performance operations. Eformer aims to make it easier to build, scale, and optimize models efficiently while leveraging JAX's capabilities for high-performance computing.
The library is organized into several core modules:
aparser
: Advanced argument parsing utilities with dataclass integrationcallib
: Custom function calling and Triton kernel integrationcommon_types
: Shared type definitions and sharding constantsescale
: Distributed sharding and parallelism utilitiesexecutor
: Execution management and hardware-specific optimizationsjaximus
: Custom PyTree implementations and structured array utilitiesmpric
: Mixed precision training and dynamic scaling infrastructureops
: Optimized operations including Flash Attention and quantizationoptimizers
: Flexible optimizer configuration and factory patternspytree
: Enhanced tree manipulation and transformation utilities
Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling, enabling faster training and reduced memory footprint.
Tools for efficient sharding and distributed computation in JAX, allowing you to scale your models across multiple devices with various sharding strategies:
- Data Parallelism (
DP
) - Fully Sharded Data Parallel (
FSDP
) - Tensor Parallelism (
TP
) - Expert Parallelism (
EP
) - Sequence Parallelism (
SP
)
Enhanced utilities for creating custom PyTrees and ArrayValue
objects, updated from Equinox, providing flexible data structures for your models.
Custom function calling utilities with direct integration of Triton kernels in JAX, allowing you to optimize performance-critical operations.
A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp, making it easy to experiment with different optimization strategies.
- Flash Attention 2 implementation for GPUs/TPUs (via Triton and Pallas) for faster attention computations
- 8-bit and NF4 quantization for efficient model deployment
- Additional optimized operations under active development
For detailed API references and usage examples, see:
- Argument Parser (
aparser
) - Triton Integration (
callib
) - Sharding Utilities (
escale
) - Execution Management (
executor
) - Mixed Precision Infrastructure (
mpric
) - Custom Operations (
ops
)
You can install eformer
via pip:
pip install eformer
from eformer.mpric import PrecisionHandler
# Create a handler with float8 compute precision
handler = PrecisionHandler(
policy="p=f32,c=f8_e4m3,o=f32", # params in f32, compute in float8, output in f32
use_dynamic_scale=True
)
import jax
from eformer.jaximus import ArrayValue, implicit
from eformer.ops.quantization.quantization_functions import dequantize_row_q8_0, quantize_row_q8_0
class Array8B(ArrayValue):
scale: jax.Array
weight: jax.Array
def __init__(self, array: jax.Array):
self.weight, self.scale = quantize_row_q8_0(array)
def materialize(self):
return dequantize_row_q8_0(self.weight, self.scale)
array = jax.random.normal(jax.random.key(0), (256, 64), "f2")
qarray = Array8B(array)
We welcome contributions! Please read our Contributing Guidelines to get started.
This project is licensed under the Apache License 2.0. See the LICENSE file for details.