From 1a12fd531df2c4969671cb99f722624fafa71865 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Sun, 12 Jan 2025 12:17:43 +0100 Subject: [PATCH 1/6] allocator --- deps/ReactantExtra/API.cpp | 83 ++++++++++++++++++++++++++++++++++---- src/XLA.jl | 45 +++++++++++++++++++++ 2 files changed, 121 insertions(+), 7 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f7ada88b22..5f183454ac 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -90,8 +90,8 @@ #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/pjrt_tuple.h" -#include "triton/Dialect/Triton/IR/Dialect.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" using namespace mlir; using namespace llvm; @@ -325,6 +325,74 @@ extern "C" PjRtDevice *ClientGetAddressableDevice(PjRtClient *client, client->LookupAddressableDevice(PjRtLocalDeviceId(device_id))); } +extern "C" int64_t PjrtDeviceGetNumAllocs(PjrtDevice *device) { + auto stats = device->GetAllocatorStats(); + if (!stats.ok()) + return std::numeric_limits::min(); + return stats->num_allocs; +} +extern "C" int64_t PjrtDeviceGetBytesInUse(PjrtDevice *device) { + auto stats = device->GetAllocatorStats(); + if (!stats.ok()) + return std::numeric_limits::min(); + return stats->bytes_in_use; +} +extern "C" int64_t PjrtDeviceGetPeakBytesInUse(PjrtDevice *device) { + auto stats = device->GetAllocatorStats(); + if (!stats.ok()) + return std::numeric_limits::min(); + return stats->peak_bytes_in_use; +} +extern "C" int64_t PjrtDeviceGetLargestAllocSize(PjrtDevice *device) { + auto stats = device->GetAllocatorStats(); + if (!stats.ok()) + return std::numeric_limits::min(); + return stats->largest_alloc_size; +} +extern "C" int64_t PjrtDeviceGetBytesLimit(PjrtDevice *device) { + auto stats = device->GetAllocatorStats(); + if (!stats.ok()) + return std::numeric_limits::min(); + return stats->bytes_limit.value_or(std::numeric_limits::min()); +} +extern "C" int64_t PjrtDeviceGetBytesReserved(PjrtDevice *device) { + auto stats = device->GetAllocatorStats(); + if (!stats.ok()) + return std::numeric_limits::min(); + return stats->bytes_reserved; +} +extern "C" int64_t PjrtDeviceGetPeakBytesReserved(PjrtDevice *device) { + auto stats = device->GetAllocatorStats(); + if (!stats.ok()) + return std::numeric_limits::min(); + return stats->peak_bytes_reserved; +} +extern "C" int64_t PjrtDeviceGetBytesReservableLimit(PjrtDevice *device) { + auto stats = device->GetAllocatorStats(); + if (!stats.ok()) + return std::numeric_limits::min(); + return stats->bytes_reservable_limit.value_or( + std::numeric_limits::min()); +} +extern "C" int64_t PjrtDeviceGetLargestFreeBlockBytes(PjrtDevice *device) { + auto stats = device->GetAllocatorStats(); + if (!stats.ok()) + return std::numeric_limits::min(); + return stats->largest_free_block_bytes; +} +extern "C" int64_t PjrtDeviceGetPoolBytes(PjrtDevice *device) { + auto stats = device->GetAllocatorStats(); + if (!stats.ok()) + return std::numeric_limits::min(); + return stats->pool_bytes.value_or(std::numeric_limits::min()); +} +extern "C" int64_t PjrtDeviceGetPeakPoolBytes(PjrtDevice *device) { + auto stats = device->GetAllocatorStats(); + if (!stats.ok()) + return std::numeric_limits::min(); + return stats->peak_pool_bytes.value_or(std::numeric_limits::min()); +} + extern "C" void ExecutableFree(xla::PjRtLoadedExecutable *exec) { delete exec; } extern "C" PjRtDevice *BufferToDevice(PjRtBuffer *Buffer) { @@ -443,7 +511,7 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) { if (ReactantThrowError) { llvm::errs() << lmod << "\n"; ReactantThrowError(err_str.c_str()); - return wrap((mlir::ModuleOp)nullptr); + return wrap((mlir::ModuleOp) nullptr); } } mlir::MLIRContext &context = *unwrap(cctx); @@ -642,8 +710,8 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op, if (auto func = dyn_cast(op.getOperation())) { if (func.isExternal()) { - shouldRemove = true; - return success(); + shouldRemove = true; + return success(); } } @@ -678,13 +746,14 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, } bool shouldRemove = false; - if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, shouldRemove))) { + if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, + shouldRemove))) { assert(0 && "failed to update all uses"); } if (shouldRemove) - op.erase(); + op.erase(); else - SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); + SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); } prevMod.getBody()->getOperations().splice( prevMod.getBody()->getOperations().end(), diff --git a/src/XLA.jl b/src/XLA.jl index c7a526f2bc..b44fcb6490 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -231,6 +231,51 @@ function client(device::Device) end end +struct AllocatorStats + num_allocs::Int64 + bytes_in_use::Int64 + peak_bytes_in_use::Int64 + largest_alloc_size::Int64 + bytes_limit::Union{Nothing,Int64} + bytes_reserved::Int64 + peak_bytes_reserved::Int64 + bytes_reservable_limit::Union{Nothing,Int64} + largest_free_block_bytes::Int64 + pool_bytes::Union{Nothing,Int64} + peak_pool_bytes::Union{Nothing,Int64} +end + +function allocatorstats(device::Device) + num_allocs = @ccall MLIR.API.mlir_c.PjrtDeviceGetNumAllocs(device.device::Ptr{Cvoid})::Int64 + if num_allocs == typemin(Int64) + return nothing + end + + bytes_in_use = @ccall MLIR.API.mlir_c.PjrtDeviceGetBytesInUse(device.device::Ptr{Cvoid})::Int64 + peak_bytes_in_use = @ccall MLIR.API.mlir_c.PjrtDeviceGetPeakBytesInUse(device.device::Ptr{Cvoid})::Int64 + largest_alloc_size = @ccall MLIR.API.mlir_c.PjrtDeviceGetLargestAllocSize(device.device::Ptr{Cvoid})::Int64 + bytes_limit = @ccall MLIR.API.mlir_c.PjrtDeviceGetBytesLimit(device.device::Ptr{Cvoid})::Int64 + bytes_reserved = @ccall MLIR.API.mlir_c.PjrtDeviceGetBytesReserved(device.device::Ptr{Cvoid})::Int64 + peak_bytes_reserved = @ccall MLIR.API.mlir_c.PjrtDeviceGetPeakBytesReserved(device.device::Ptr{Cvoid})::Int64 + bytes_reservable_limit = @ccall MLIR.API.mlir_c.PjrtDeviceGetBytesReservableLimit(device.device::Ptr{Cvoid})::Int64 + largest_free_block_bytes = @ccall MLIR.API.mlir_c.PjrtDeviceGetLargestFreeBlockBytes(device.device::Ptr{Cvoid})::Int64 + pool_bytes = @ccall MLIR.API.mlir_c.PjrtDeviceGetPoolBytes(device.device::Ptr{Cvoid})::Int64 + peak_pool_bytes = @ccall MLIR.API.mlir_c.PjrtDeviceGetPeakPoolBytes(device.device::Ptr{Cvoid})::Int64 + + AllocatorStats( + num_allocs, + peak_bytes_in_use, + largest_alloc_size, + bytes_limit == typemin(Int64) ? Nothing : bytes_limit, + bytes_reserved, + peak_bytes_reserved, + bytes_reservable_limit == typemin(Int64) ? Nothing : bytes_reservable_limit, + largest_free_block_bytes, + pool_bytes == typemin(Int64) ? Nothing : pool_bytes, + peak_pool_bytes == typemin(Int64) ? Nothing : peak_pool_bytes, + ) +end + # https://github.com/openxla/xla/blob/4bfb5c82a427151d6fe5acad8ebe12cee403036a/xla/xla_data.proto#L29 @inline primitive_type(::Type{Bool}) = 1 From 5abd857a8725fea593da133638feff30e9de464b Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Sun, 12 Jan 2025 17:08:03 +0100 Subject: [PATCH 2/6] allocator2 --- deps/ReactantExtra/API.cpp | 22 +++++++++++----------- deps/ReactantExtra/BUILD | 11 +++++++++++ src/XLA.jl | 22 +++++++++++----------- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 5f183454ac..c11e0a27e9 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -325,68 +325,68 @@ extern "C" PjRtDevice *ClientGetAddressableDevice(PjRtClient *client, client->LookupAddressableDevice(PjRtLocalDeviceId(device_id))); } -extern "C" int64_t PjrtDeviceGetNumAllocs(PjrtDevice *device) { +extern "C" int64_t PjRtDeviceGetNumAllocs(PjRtDevice *device) { auto stats = device->GetAllocatorStats(); if (!stats.ok()) return std::numeric_limits::min(); return stats->num_allocs; } -extern "C" int64_t PjrtDeviceGetBytesInUse(PjrtDevice *device) { +extern "C" int64_t PjRtDeviceGetBytesInUse(PjRtDevice *device) { auto stats = device->GetAllocatorStats(); if (!stats.ok()) return std::numeric_limits::min(); return stats->bytes_in_use; } -extern "C" int64_t PjrtDeviceGetPeakBytesInUse(PjrtDevice *device) { +extern "C" int64_t PjRtDeviceGetPeakBytesInUse(PjRtDevice *device) { auto stats = device->GetAllocatorStats(); if (!stats.ok()) return std::numeric_limits::min(); return stats->peak_bytes_in_use; } -extern "C" int64_t PjrtDeviceGetLargestAllocSize(PjrtDevice *device) { +extern "C" int64_t PjRtDeviceGetLargestAllocSize(PjRtDevice *device) { auto stats = device->GetAllocatorStats(); if (!stats.ok()) return std::numeric_limits::min(); return stats->largest_alloc_size; } -extern "C" int64_t PjrtDeviceGetBytesLimit(PjrtDevice *device) { +extern "C" int64_t PjRtDeviceGetBytesLimit(PjRtDevice *device) { auto stats = device->GetAllocatorStats(); if (!stats.ok()) return std::numeric_limits::min(); return stats->bytes_limit.value_or(std::numeric_limits::min()); } -extern "C" int64_t PjrtDeviceGetBytesReserved(PjrtDevice *device) { +extern "C" int64_t PjRtDeviceGetBytesReserved(PjRtDevice *device) { auto stats = device->GetAllocatorStats(); if (!stats.ok()) return std::numeric_limits::min(); return stats->bytes_reserved; } -extern "C" int64_t PjrtDeviceGetPeakBytesReserved(PjrtDevice *device) { +extern "C" int64_t PjRtDeviceGetPeakBytesReserved(PjRtDevice *device) { auto stats = device->GetAllocatorStats(); if (!stats.ok()) return std::numeric_limits::min(); return stats->peak_bytes_reserved; } -extern "C" int64_t PjrtDeviceGetBytesReservableLimit(PjrtDevice *device) { +extern "C" int64_t PjRtDeviceGetBytesReservableLimit(PjRtDevice *device) { auto stats = device->GetAllocatorStats(); if (!stats.ok()) return std::numeric_limits::min(); return stats->bytes_reservable_limit.value_or( std::numeric_limits::min()); } -extern "C" int64_t PjrtDeviceGetLargestFreeBlockBytes(PjrtDevice *device) { +extern "C" int64_t PjRtDeviceGetLargestFreeBlockBytes(PjRtDevice *device) { auto stats = device->GetAllocatorStats(); if (!stats.ok()) return std::numeric_limits::min(); return stats->largest_free_block_bytes; } -extern "C" int64_t PjrtDeviceGetPoolBytes(PjrtDevice *device) { +extern "C" int64_t PjRtDeviceGetPoolBytes(PjRtDevice *device) { auto stats = device->GetAllocatorStats(); if (!stats.ok()) return std::numeric_limits::min(); return stats->pool_bytes.value_or(std::numeric_limits::min()); } -extern "C" int64_t PjrtDeviceGetPeakPoolBytes(PjrtDevice *device) { +extern "C" int64_t PjRtDeviceGetPeakPoolBytes(PjRtDevice *device) { auto stats = device->GetAllocatorStats(); if (!stats.ok()) return std::numeric_limits::min(); diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index a86f5a3c9d..782afa4ad9 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -411,6 +411,17 @@ cc_library( "-Wl,-exported_symbol,_ClientProcessIndex", "-Wl,-exported_symbol,_ClientGetDevice", "-Wl,-exported_symbol,_ClientGetAddressableDevice", +"-Wl,-exported_symbol,_PjRtDeviceGetNumAllocs", +"-Wl,-exported_symbol,_PjRtDeviceGetBytesInUse", +"-Wl,-exported_symbol,_PjRtDeviceGetPeakBytesInUse", +"-Wl,-exported_symbol,_PjRtDeviceGetLargestAllocSize", +"-Wl,-exported_symbol,_PjRtDeviceGetBytesLimit", +"-Wl,-exported_symbol,_PjRtDeviceGetBytesReserved", +"-Wl,-exported_symbol,_PjRtDeviceGetPeakBytesReserved", +"-Wl,-exported_symbol,_PjRtDeviceGetBytesReservableLimit", +"-Wl,-exported_symbol,_PjRtDeviceGetLargestFreeBlockBytes", +"-Wl,-exported_symbol,_PjRtDeviceGetPoolBytes", +"-Wl,-exported_symbol,_PjRtDeviceGetPeakPoolBytes", "-Wl,-exported_symbol,_ExecutableFree", "-Wl,-exported_symbol,_BufferToDevice", "-Wl,-exported_symbol,_BufferToClient", diff --git a/src/XLA.jl b/src/XLA.jl index b44fcb6490..d29e7bbfab 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -246,21 +246,21 @@ struct AllocatorStats end function allocatorstats(device::Device) - num_allocs = @ccall MLIR.API.mlir_c.PjrtDeviceGetNumAllocs(device.device::Ptr{Cvoid})::Int64 + num_allocs = @ccall MLIR.API.mlir_c.PjRtDeviceGetNumAllocs(device.device::Ptr{Cvoid})::Int64 if num_allocs == typemin(Int64) return nothing end - bytes_in_use = @ccall MLIR.API.mlir_c.PjrtDeviceGetBytesInUse(device.device::Ptr{Cvoid})::Int64 - peak_bytes_in_use = @ccall MLIR.API.mlir_c.PjrtDeviceGetPeakBytesInUse(device.device::Ptr{Cvoid})::Int64 - largest_alloc_size = @ccall MLIR.API.mlir_c.PjrtDeviceGetLargestAllocSize(device.device::Ptr{Cvoid})::Int64 - bytes_limit = @ccall MLIR.API.mlir_c.PjrtDeviceGetBytesLimit(device.device::Ptr{Cvoid})::Int64 - bytes_reserved = @ccall MLIR.API.mlir_c.PjrtDeviceGetBytesReserved(device.device::Ptr{Cvoid})::Int64 - peak_bytes_reserved = @ccall MLIR.API.mlir_c.PjrtDeviceGetPeakBytesReserved(device.device::Ptr{Cvoid})::Int64 - bytes_reservable_limit = @ccall MLIR.API.mlir_c.PjrtDeviceGetBytesReservableLimit(device.device::Ptr{Cvoid})::Int64 - largest_free_block_bytes = @ccall MLIR.API.mlir_c.PjrtDeviceGetLargestFreeBlockBytes(device.device::Ptr{Cvoid})::Int64 - pool_bytes = @ccall MLIR.API.mlir_c.PjrtDeviceGetPoolBytes(device.device::Ptr{Cvoid})::Int64 - peak_pool_bytes = @ccall MLIR.API.mlir_c.PjrtDeviceGetPeakPoolBytes(device.device::Ptr{Cvoid})::Int64 + bytes_in_use = @ccall MLIR.API.mlir_c.PjRtDeviceGetBytesInUse(device.device::Ptr{Cvoid})::Int64 + peak_bytes_in_use = @ccall MLIR.API.mlir_c.PjRtDeviceGetPeakBytesInUse(device.device::Ptr{Cvoid})::Int64 + largest_alloc_size = @ccall MLIR.API.mlir_c.PjRtDeviceGetLargestAllocSize(device.device::Ptr{Cvoid})::Int64 + bytes_limit = @ccall MLIR.API.mlir_c.PjRtDeviceGetBytesLimit(device.device::Ptr{Cvoid})::Int64 + bytes_reserved = @ccall MLIR.API.mlir_c.PjRtDeviceGetBytesReserved(device.device::Ptr{Cvoid})::Int64 + peak_bytes_reserved = @ccall MLIR.API.mlir_c.PjRtDeviceGetPeakBytesReserved(device.device::Ptr{Cvoid})::Int64 + bytes_reservable_limit = @ccall MLIR.API.mlir_c.PjRtDeviceGetBytesReservableLimit(device.device::Ptr{Cvoid})::Int64 + largest_free_block_bytes = @ccall MLIR.API.mlir_c.PjRtDeviceGetLargestFreeBlockBytes(device.device::Ptr{Cvoid})::Int64 + pool_bytes = @ccall MLIR.API.mlir_c.PjRtDeviceGetPoolBytes(device.device::Ptr{Cvoid})::Int64 + peak_pool_bytes = @ccall MLIR.API.mlir_c.PjRtDeviceGetPeakPoolBytes(device.device::Ptr{Cvoid})::Int64 AllocatorStats( num_allocs, From 4fd970e05a3139361710b31580214508bdf77735 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Sun, 12 Jan 2025 20:03:00 +0100 Subject: [PATCH 3/6] throw error when unsupported --- deps/ReactantExtra/API.cpp | 66 +++++++++++++------------------------- 1 file changed, 22 insertions(+), 44 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index c11e0a27e9..f7d130ded8 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -326,71 +326,49 @@ extern "C" PjRtDevice *ClientGetAddressableDevice(PjRtClient *client, } extern "C" int64_t PjRtDeviceGetNumAllocs(PjRtDevice *device) { - auto stats = device->GetAllocatorStats(); - if (!stats.ok()) - return std::numeric_limits::min(); - return stats->num_allocs; + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + return stats.num_allocs; } extern "C" int64_t PjRtDeviceGetBytesInUse(PjRtDevice *device) { - auto stats = device->GetAllocatorStats(); - if (!stats.ok()) - return std::numeric_limits::min(); - return stats->bytes_in_use; + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + return stats.bytes_in_use; } extern "C" int64_t PjRtDeviceGetPeakBytesInUse(PjRtDevice *device) { - auto stats = device->GetAllocatorStats(); - if (!stats.ok()) - return std::numeric_limits::min(); - return stats->peak_bytes_in_use; + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + return stats.peak_bytes_in_use; } extern "C" int64_t PjRtDeviceGetLargestAllocSize(PjRtDevice *device) { - auto stats = device->GetAllocatorStats(); - if (!stats.ok()) - return std::numeric_limits::min(); - return stats->largest_alloc_size; + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + return stats.largest_alloc_size; } extern "C" int64_t PjRtDeviceGetBytesLimit(PjRtDevice *device) { - auto stats = device->GetAllocatorStats(); - if (!stats.ok()) - return std::numeric_limits::min(); - return stats->bytes_limit.value_or(std::numeric_limits::min()); + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + return stats.bytes_limit.value_or(std::numeric_limits::min()); } extern "C" int64_t PjRtDeviceGetBytesReserved(PjRtDevice *device) { - auto stats = device->GetAllocatorStats(); - if (!stats.ok()) - return std::numeric_limits::min(); - return stats->bytes_reserved; + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + return stats.bytes_reserved; } extern "C" int64_t PjRtDeviceGetPeakBytesReserved(PjRtDevice *device) { - auto stats = device->GetAllocatorStats(); - if (!stats.ok()) - return std::numeric_limits::min(); - return stats->peak_bytes_reserved; + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + return stats.peak_bytes_reserved; } extern "C" int64_t PjRtDeviceGetBytesReservableLimit(PjRtDevice *device) { - auto stats = device->GetAllocatorStats(); - if (!stats.ok()) - return std::numeric_limits::min(); - return stats->bytes_reservable_limit.value_or( + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + return stats.bytes_reservable_limit.value_or( std::numeric_limits::min()); } extern "C" int64_t PjRtDeviceGetLargestFreeBlockBytes(PjRtDevice *device) { - auto stats = device->GetAllocatorStats(); - if (!stats.ok()) - return std::numeric_limits::min(); - return stats->largest_free_block_bytes; + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + return stats.largest_free_block_bytes; } extern "C" int64_t PjRtDeviceGetPoolBytes(PjRtDevice *device) { - auto stats = device->GetAllocatorStats(); - if (!stats.ok()) - return std::numeric_limits::min(); - return stats->pool_bytes.value_or(std::numeric_limits::min()); + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + return stats.pool_bytes.value_or(std::numeric_limits::min()); } extern "C" int64_t PjRtDeviceGetPeakPoolBytes(PjRtDevice *device) { - auto stats = device->GetAllocatorStats(); - if (!stats.ok()) - return std::numeric_limits::min(); - return stats->peak_pool_bytes.value_or(std::numeric_limits::min()); + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + return stats.peak_pool_bytes.value_or(std::numeric_limits::min()); } extern "C" void ExecutableFree(xla::PjRtLoadedExecutable *exec) { delete exec; } From 592d2792dc8c4cafca1ec5eb24ff9b8663022365 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 13 Jan 2025 17:22:45 +0100 Subject: [PATCH 4/6] single GetAllocatorStats call --- deps/ReactantExtra/API.cpp | 74 ++++++++++++++++---------------------- deps/ReactantExtra/BUILD | 12 +------ src/XLA.jl | 56 +++++++++++++++-------------- 3 files changed, 62 insertions(+), 80 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f7d130ded8..c5dee2cade 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -325,50 +325,38 @@ extern "C" PjRtDevice *ClientGetAddressableDevice(PjRtClient *client, client->LookupAddressableDevice(PjRtLocalDeviceId(device_id))); } -extern "C" int64_t PjRtDeviceGetNumAllocs(PjRtDevice *device) { +// To keep in sync with JLAllocatorStats in src/XLA.jl +struct JLAllocatorStats { + int64_t num_allocs; + int64_t bytes_in_use; + int64_t peak_bytes_in_use; + int64_t largest_alloc_size; + int64_t bytes_limit; + int64_t bytes_reserved; + int64_t peak_bytes_reserved; + int64_t bytes_reservable_limit; + int64_t largest_free_block_bytes; + int64_t pool_bytes; + int64_t peak_pool_bytes; +}; + +extern "C" void PjRtDeviceGetAllocatorStats(PjRtDevice *device, + JLAllocatorStats *jlstats) { auto stats = MyValueOrThrow(device->GetAllocatorStats()); - return stats.num_allocs; -} -extern "C" int64_t PjRtDeviceGetBytesInUse(PjRtDevice *device) { - auto stats = MyValueOrThrow(device->GetAllocatorStats()); - return stats.bytes_in_use; -} -extern "C" int64_t PjRtDeviceGetPeakBytesInUse(PjRtDevice *device) { - auto stats = MyValueOrThrow(device->GetAllocatorStats()); - return stats.peak_bytes_in_use; -} -extern "C" int64_t PjRtDeviceGetLargestAllocSize(PjRtDevice *device) { - auto stats = MyValueOrThrow(device->GetAllocatorStats()); - return stats.largest_alloc_size; -} -extern "C" int64_t PjRtDeviceGetBytesLimit(PjRtDevice *device) { - auto stats = MyValueOrThrow(device->GetAllocatorStats()); - return stats.bytes_limit.value_or(std::numeric_limits::min()); -} -extern "C" int64_t PjRtDeviceGetBytesReserved(PjRtDevice *device) { - auto stats = MyValueOrThrow(device->GetAllocatorStats()); - return stats.bytes_reserved; -} -extern "C" int64_t PjRtDeviceGetPeakBytesReserved(PjRtDevice *device) { - auto stats = MyValueOrThrow(device->GetAllocatorStats()); - return stats.peak_bytes_reserved; -} -extern "C" int64_t PjRtDeviceGetBytesReservableLimit(PjRtDevice *device) { - auto stats = MyValueOrThrow(device->GetAllocatorStats()); - return stats.bytes_reservable_limit.value_or( - std::numeric_limits::min()); -} -extern "C" int64_t PjRtDeviceGetLargestFreeBlockBytes(PjRtDevice *device) { - auto stats = MyValueOrThrow(device->GetAllocatorStats()); - return stats.largest_free_block_bytes; -} -extern "C" int64_t PjRtDeviceGetPoolBytes(PjRtDevice *device) { - auto stats = MyValueOrThrow(device->GetAllocatorStats()); - return stats.pool_bytes.value_or(std::numeric_limits::min()); -} -extern "C" int64_t PjRtDeviceGetPeakPoolBytes(PjRtDevice *device) { - auto stats = MyValueOrThrow(device->GetAllocatorStats()); - return stats.peak_pool_bytes.value_or(std::numeric_limits::min()); + int64_t optnull = std::numeric_limits::min(); + + jlstats->num_allocs = stats.num_allocs; + jlstats->bytes_in_use = stats.bytes_in_use; + jlstats->peak_bytes_in_use = stats.peak_bytes_in_use; + jlstats->largest_alloc_size = stats.largest_alloc_size; + jlstats->bytes_limit = stats.bytes_limit.value_or(optnull); + jlstats->bytes_reserved = stats.bytes_reserved; + jlstats->peak_bytes_reserved = stats.peak_bytes_reserved; + jlstats->bytes_reservable_limit = + stats.bytes_reservable_limit.value_or(optnull); + jlstats->largest_free_block_bytes = stats.largest_free_block_bytes; + jlstats->pool_bytes = stats.pool_bytes.value_or(optnull); + jlstats->peak_pool_bytes = stats.peak_pool_bytes.value_or(optnull); } extern "C" void ExecutableFree(xla::PjRtLoadedExecutable *exec) { delete exec; } diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 782afa4ad9..44424a50fc 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -411,17 +411,7 @@ cc_library( "-Wl,-exported_symbol,_ClientProcessIndex", "-Wl,-exported_symbol,_ClientGetDevice", "-Wl,-exported_symbol,_ClientGetAddressableDevice", -"-Wl,-exported_symbol,_PjRtDeviceGetNumAllocs", -"-Wl,-exported_symbol,_PjRtDeviceGetBytesInUse", -"-Wl,-exported_symbol,_PjRtDeviceGetPeakBytesInUse", -"-Wl,-exported_symbol,_PjRtDeviceGetLargestAllocSize", -"-Wl,-exported_symbol,_PjRtDeviceGetBytesLimit", -"-Wl,-exported_symbol,_PjRtDeviceGetBytesReserved", -"-Wl,-exported_symbol,_PjRtDeviceGetPeakBytesReserved", -"-Wl,-exported_symbol,_PjRtDeviceGetBytesReservableLimit", -"-Wl,-exported_symbol,_PjRtDeviceGetLargestFreeBlockBytes", -"-Wl,-exported_symbol,_PjRtDeviceGetPoolBytes", -"-Wl,-exported_symbol,_PjRtDeviceGetPeakPoolBytes", +"-Wl,-exported_symbol,_PjRtDeviceGetAllocatorStats", "-Wl,-exported_symbol,_ExecutableFree", "-Wl,-exported_symbol,_BufferToDevice", "-Wl,-exported_symbol,_BufferToClient", diff --git a/src/XLA.jl b/src/XLA.jl index d29e7bbfab..37a3b3d368 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -231,6 +231,21 @@ function client(device::Device) end end +# To keep in sync with JLAllocatorStats in ReactantExtra/API.cpp +struct JLAllocatorStats + num_allocs::Int64 + bytes_in_use::Int64 + peak_bytes_in_use::Int64 + largest_alloc_size::Int64 + bytes_limit::Int64 + bytes_reserved::Int64 + peak_bytes_reserved::Int64 + bytes_reservable_limit::Int64 + largest_free_block_bytes::Int64 + pool_bytes::Int64 + peak_pool_bytes::Int64 +end + struct AllocatorStats num_allocs::Int64 bytes_in_use::Int64 @@ -245,34 +260,23 @@ struct AllocatorStats peak_pool_bytes::Union{Nothing,Int64} end -function allocatorstats(device::Device) - num_allocs = @ccall MLIR.API.mlir_c.PjRtDeviceGetNumAllocs(device.device::Ptr{Cvoid})::Int64 - if num_allocs == typemin(Int64) - return nothing - end - - bytes_in_use = @ccall MLIR.API.mlir_c.PjRtDeviceGetBytesInUse(device.device::Ptr{Cvoid})::Int64 - peak_bytes_in_use = @ccall MLIR.API.mlir_c.PjRtDeviceGetPeakBytesInUse(device.device::Ptr{Cvoid})::Int64 - largest_alloc_size = @ccall MLIR.API.mlir_c.PjRtDeviceGetLargestAllocSize(device.device::Ptr{Cvoid})::Int64 - bytes_limit = @ccall MLIR.API.mlir_c.PjRtDeviceGetBytesLimit(device.device::Ptr{Cvoid})::Int64 - bytes_reserved = @ccall MLIR.API.mlir_c.PjRtDeviceGetBytesReserved(device.device::Ptr{Cvoid})::Int64 - peak_bytes_reserved = @ccall MLIR.API.mlir_c.PjRtDeviceGetPeakBytesReserved(device.device::Ptr{Cvoid})::Int64 - bytes_reservable_limit = @ccall MLIR.API.mlir_c.PjRtDeviceGetBytesReservableLimit(device.device::Ptr{Cvoid})::Int64 - largest_free_block_bytes = @ccall MLIR.API.mlir_c.PjRtDeviceGetLargestFreeBlockBytes(device.device::Ptr{Cvoid})::Int64 - pool_bytes = @ccall MLIR.API.mlir_c.PjRtDeviceGetPoolBytes(device.device::Ptr{Cvoid})::Int64 - peak_pool_bytes = @ccall MLIR.API.mlir_c.PjRtDeviceGetPeakPoolBytes(device.device::Ptr{Cvoid})::Int64 +function allocatorstats(device::Device=ClientGetDevice(default_backend[], default_device_idx[])) + ref = Ref{JLAllocatorStats}() + @ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats(device.device::Ptr{Cvoid}, ref::Ptr{Cvoid})::Cvoid + stats = ref[] + nullopt = typemin(Int64) AllocatorStats( - num_allocs, - peak_bytes_in_use, - largest_alloc_size, - bytes_limit == typemin(Int64) ? Nothing : bytes_limit, - bytes_reserved, - peak_bytes_reserved, - bytes_reservable_limit == typemin(Int64) ? Nothing : bytes_reservable_limit, - largest_free_block_bytes, - pool_bytes == typemin(Int64) ? Nothing : pool_bytes, - peak_pool_bytes == typemin(Int64) ? Nothing : peak_pool_bytes, + stats.num_allocs, + stats.peak_bytes_in_use, + stats.largest_alloc_size, + stats.bytes_limit == nullopt ? nothing : stats.bytes_limit, + stats.bytes_reserved, + stats.peak_bytes_reserved, + stats.bytes_reservable_limit == nullopt ? nothing : stats.bytes_reservable_limit, + stats.largest_free_block_bytes, + stats.pool_bytes == nullopt ? nothing : stats.pool_bytes, + stats.peak_pool_bytes == nullopt ? nothing : stats.peak_pool_bytes, ) end From c8349f5fbcee78e7439c0768b28a865ba74c53a9 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 13 Jan 2025 17:25:31 +0100 Subject: [PATCH 5/6] format --- src/XLA.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/XLA.jl b/src/XLA.jl index 37a3b3d368..0f61feb3ee 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -260,13 +260,17 @@ struct AllocatorStats peak_pool_bytes::Union{Nothing,Int64} end -function allocatorstats(device::Device=ClientGetDevice(default_backend[], default_device_idx[])) +function allocatorstats( + device::Device=ClientGetDevice(default_backend[], default_device_idx[]) +) ref = Ref{JLAllocatorStats}() - @ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats(device.device::Ptr{Cvoid}, ref::Ptr{Cvoid})::Cvoid + @ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats( + device.device::Ptr{Cvoid}, ref::Ptr{Cvoid} + )::Cvoid stats = ref[] nullopt = typemin(Int64) - AllocatorStats( + return AllocatorStats( stats.num_allocs, stats.peak_bytes_in_use, stats.largest_alloc_size, From 65d0d119f29398b077631ea95fb1b174a17fdb8b Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 13 Jan 2025 17:36:37 +0100 Subject: [PATCH 6/6] fixup --- src/XLA.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/XLA.jl b/src/XLA.jl index 0f61feb3ee..29a77cfa10 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -272,6 +272,7 @@ function allocatorstats( nullopt = typemin(Int64) return AllocatorStats( stats.num_allocs, + stats.bytes_in_use, stats.peak_bytes_in_use, stats.largest_alloc_size, stats.bytes_limit == nullopt ? nothing : stats.bytes_limit,