diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f7ada88b22..c5dee2cade 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -90,8 +90,8 @@ #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/pjrt_tuple.h" -#include "triton/Dialect/Triton/IR/Dialect.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" using namespace mlir; using namespace llvm; @@ -325,6 +325,40 @@ extern "C" PjRtDevice *ClientGetAddressableDevice(PjRtClient *client, client->LookupAddressableDevice(PjRtLocalDeviceId(device_id))); } +// To keep in sync with JLAllocatorStats in src/XLA.jl +struct JLAllocatorStats { + int64_t num_allocs; + int64_t bytes_in_use; + int64_t peak_bytes_in_use; + int64_t largest_alloc_size; + int64_t bytes_limit; + int64_t bytes_reserved; + int64_t peak_bytes_reserved; + int64_t bytes_reservable_limit; + int64_t largest_free_block_bytes; + int64_t pool_bytes; + int64_t peak_pool_bytes; +}; + +extern "C" void PjRtDeviceGetAllocatorStats(PjRtDevice *device, + JLAllocatorStats *jlstats) { + auto stats = MyValueOrThrow(device->GetAllocatorStats()); + int64_t optnull = std::numeric_limits::min(); + + jlstats->num_allocs = stats.num_allocs; + jlstats->bytes_in_use = stats.bytes_in_use; + jlstats->peak_bytes_in_use = stats.peak_bytes_in_use; + jlstats->largest_alloc_size = stats.largest_alloc_size; + jlstats->bytes_limit = stats.bytes_limit.value_or(optnull); + jlstats->bytes_reserved = stats.bytes_reserved; + jlstats->peak_bytes_reserved = stats.peak_bytes_reserved; + jlstats->bytes_reservable_limit = + stats.bytes_reservable_limit.value_or(optnull); + jlstats->largest_free_block_bytes = stats.largest_free_block_bytes; + jlstats->pool_bytes = stats.pool_bytes.value_or(optnull); + jlstats->peak_pool_bytes = stats.peak_pool_bytes.value_or(optnull); +} + extern "C" void ExecutableFree(xla::PjRtLoadedExecutable *exec) { delete exec; } extern "C" PjRtDevice *BufferToDevice(PjRtBuffer *Buffer) { @@ -443,7 +477,7 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) { if (ReactantThrowError) { llvm::errs() << lmod << "\n"; ReactantThrowError(err_str.c_str()); - return wrap((mlir::ModuleOp)nullptr); + return wrap((mlir::ModuleOp) nullptr); } } mlir::MLIRContext &context = *unwrap(cctx); @@ -642,8 +676,8 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op, if (auto func = dyn_cast(op.getOperation())) { if (func.isExternal()) { - shouldRemove = true; - return success(); + shouldRemove = true; + return success(); } } @@ -678,13 +712,14 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, } bool shouldRemove = false; - if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, shouldRemove))) { + if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, + shouldRemove))) { assert(0 && "failed to update all uses"); } if (shouldRemove) - op.erase(); + op.erase(); else - SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); + SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); } prevMod.getBody()->getOperations().splice( prevMod.getBody()->getOperations().end(), diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index a86f5a3c9d..44424a50fc 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -411,6 +411,7 @@ cc_library( "-Wl,-exported_symbol,_ClientProcessIndex", "-Wl,-exported_symbol,_ClientGetDevice", "-Wl,-exported_symbol,_ClientGetAddressableDevice", +"-Wl,-exported_symbol,_PjRtDeviceGetAllocatorStats", "-Wl,-exported_symbol,_ExecutableFree", "-Wl,-exported_symbol,_BufferToDevice", "-Wl,-exported_symbol,_BufferToClient", diff --git a/src/XLA.jl b/src/XLA.jl index c7a526f2bc..29a77cfa10 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -231,6 +231,60 @@ function client(device::Device) end end +# To keep in sync with JLAllocatorStats in ReactantExtra/API.cpp +struct JLAllocatorStats + num_allocs::Int64 + bytes_in_use::Int64 + peak_bytes_in_use::Int64 + largest_alloc_size::Int64 + bytes_limit::Int64 + bytes_reserved::Int64 + peak_bytes_reserved::Int64 + bytes_reservable_limit::Int64 + largest_free_block_bytes::Int64 + pool_bytes::Int64 + peak_pool_bytes::Int64 +end + +struct AllocatorStats + num_allocs::Int64 + bytes_in_use::Int64 + peak_bytes_in_use::Int64 + largest_alloc_size::Int64 + bytes_limit::Union{Nothing,Int64} + bytes_reserved::Int64 + peak_bytes_reserved::Int64 + bytes_reservable_limit::Union{Nothing,Int64} + largest_free_block_bytes::Int64 + pool_bytes::Union{Nothing,Int64} + peak_pool_bytes::Union{Nothing,Int64} +end + +function allocatorstats( + device::Device=ClientGetDevice(default_backend[], default_device_idx[]) +) + ref = Ref{JLAllocatorStats}() + @ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats( + device.device::Ptr{Cvoid}, ref::Ptr{Cvoid} + )::Cvoid + stats = ref[] + + nullopt = typemin(Int64) + return AllocatorStats( + stats.num_allocs, + stats.bytes_in_use, + stats.peak_bytes_in_use, + stats.largest_alloc_size, + stats.bytes_limit == nullopt ? nothing : stats.bytes_limit, + stats.bytes_reserved, + stats.peak_bytes_reserved, + stats.bytes_reservable_limit == nullopt ? nothing : stats.bytes_reservable_limit, + stats.largest_free_block_bytes, + stats.pool_bytes == nullopt ? nothing : stats.pool_bytes, + stats.peak_pool_bytes == nullopt ? nothing : stats.peak_pool_bytes, + ) +end + # https://github.com/openxla/xla/blob/4bfb5c82a427151d6fe5acad8ebe12cee403036a/xla/xla_data.proto#L29 @inline primitive_type(::Type{Bool}) = 1