Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CondaPkg.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[deps]
python = "<=3.13,>=3.9,<4"
python = "<=3.12,>=3.9,<4"

[pip.deps]
jax = ">= 0.6"
jax = ">= 0.5"
tensorflow = ">= 2.17"
numpy = ">= 2"
numpy = ">= 1, >= 2"
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Expand All @@ -62,6 +63,7 @@ ReactantFloat8sExt = "Float8s"
ReactantKernelAbstractionsExt = "KernelAbstractions"
ReactantMPIExt = "MPI"
ReactantNNlibExt = ["NNlib", "Statistics"]
ReactantNPZExt = "NPZ"
ReactantOffsetArraysExt = "OffsetArrays"
ReactantOneHotArraysExt = "OneHotArrays"
ReactantPythonCallExt = "PythonCall"
Expand Down Expand Up @@ -96,6 +98,7 @@ Libdl = "1.10"
LinearAlgebra = "1.10"
MPI = "0.20"
NNlib = "0.9.26"
NPZ = "0.4"
OffsetArrays = "1"
OneHotArrays = "0.2.10"
OrderedCollections = "1"
Expand Down
19 changes: 19 additions & 0 deletions docs/src/api/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,22 @@ or [TensorFlow Hub](https://tensorflow.org/hub). Refer to the
```@docs
Reactant.Serialization.export_as_tf_saved_model
```

## Exporting to JAX via EnzymeAD

!!! note "Load NPZ"

This export functionality requires the `NPZ` package to be loaded.

This export functionality generates:

1. A `.mlir` file containing the StableHLO representation of your Julia function
2. Input `.npz` files containing the input arrays for the function
3. A Python script that wraps the function for use with `enzyme_ad.jax.hlo_call`

The generated Python script can be immediately used with JAX and EnzymeAD without any
additional Julia dependencies.

```@docs
Reactant.Serialization.EnzymeJAX.export_to_enzymejax
```
24 changes: 24 additions & 0 deletions ext/ReactantNPZExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module ReactantNPZExt

using NPZ: npzwrite
using Reactant.Serialization: Serialization, EnzymeJAX

Serialization.serialization_supported(::Val{:NPZ}) = true

# Helper function to save all input data to a single NPZ file
function EnzymeJAX.save_inputs_npz_impl(
output_path::String, inputs::Dict{String,<:Union{AbstractArray,Number}}
)
# Transpose arrays for Python/NumPy (row-major vs column-major)
transposed_inputs = Dict{String,Union{AbstractArray,Number}}()
for (name, arr) in inputs
transposed_inputs[name] =
arr isa Number ? arr : permutedims(arr, reverse(1:ndims(arr)))
end

# Save all inputs to a single NPZ file with compression
npzwrite(output_path, transposed_inputs)
return output_path
end

end # module
19 changes: 1 addition & 18 deletions ext/ReactantPythonCallExt/ReactantPythonCallExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module ReactantPythonCallExt
using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist
using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay
using Reactant.Ops: @opcall
using Reactant.Serialization: NUMPY_SIMPLE_TYPES

const jaxptr = Ref{Py}()
const jnpptr = Ref{Py}()
Expand All @@ -15,24 +16,6 @@ const npptr = Ref{Py}()

const SAVED_MODEL_EXPORT_SUPPORTED = Ref{Bool}(false)

const NUMPY_SIMPLE_TYPES = Dict(
Bool => :bool,
Int8 => :int8,
Int16 => :int16,
Int32 => :int32,
Int64 => :int64,
UInt8 => :uint8,
UInt16 => :uint16,
UInt32 => :uint32,
UInt64 => :uint64,
Float16 => :float16,
Float32 => :float32,
Float64 => :float64,
ComplexF16 => :complex16,
ComplexF32 => :complex32,
ComplexF64 => :complex64,
)

function __init__()
try
jaxptr[] = pyimport("jax")
Expand Down
13 changes: 9 additions & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,7 @@ function __get_compile_options_and_kwargs(;
)
end

function compile_mlir(f, args; client=nothing, kwargs...)
function compile_mlir(f, args; client=nothing, drop_unsupported_attributes=false, kwargs...)
client = client !== nothing ? client : XLA.default_backend()
backend = XLA.platform_name(client)

Expand Down Expand Up @@ -1441,6 +1441,11 @@ function compile_mlir(f, args; client=nothing, kwargs...)
mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas
)

if drop_unsupported_attributes
# Drop some of our attributes
run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes")
end

return mod, mlir_fn_res
end

Expand Down Expand Up @@ -3571,6 +3576,9 @@ function compile_xla(
mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas
)

# Drop some of our attributes
run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes")

# compile MLIR module to XLA executable
global_device_ids = collect(Int64, mlir_fn_res.global_device_ids)
mlir_fn_res.is_sharded && (device = nothing)
Expand All @@ -3584,9 +3592,6 @@ function compile_xla(
module_string = ""
end

# Drop some of our attributes
run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes")

if before_xla_optimizations
exec = nothing
hlo_modules = XLA.HloModule(mod)
Expand Down
4 changes: 2 additions & 2 deletions src/Sharding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ function HloSharding(sharding::NamedSharding, client::XLA.IFRT.Client, _, x)
data = XLA.IFRT.AsyncArray(client, x, ifrt_sharding)

# XXX: Can we auto-pad this case too? Will think about it later, for now use
# NamedSharidng
# NamedSharding
return data, ShardInfo(hlo_sharding, device_to_array_slices), nothing
end

Expand Down Expand Up @@ -997,7 +997,7 @@ function (sharding::HloSharding)(
data = XLA.IFRT.AsyncArray(client, x, ifrt_sharding)

# XXX: Can we auto-pad this case too? Will think about it later, for now use
# NamedSharidng
# NamedSharding
return data, ShardInfo(sharding, device_to_array_slices), nothing
end

Expand Down
Loading
Loading