Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@
#include "xla/python/pjrt_ifrt/pjrt_topology.h"
#include "xla/python/pjrt_ifrt/pjrt_tuple.h"

#include "triton/Dialect/Triton/IR/Dialect.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"

using namespace mlir;
using namespace llvm;
Expand Down Expand Up @@ -325,6 +325,40 @@ extern "C" PjRtDevice *ClientGetAddressableDevice(PjRtClient *client,
client->LookupAddressableDevice(PjRtLocalDeviceId(device_id)));
}

// To keep in sync with JLAllocatorStats in src/XLA.jl
struct JLAllocatorStats {
int64_t num_allocs;
int64_t bytes_in_use;
int64_t peak_bytes_in_use;
int64_t largest_alloc_size;
int64_t bytes_limit;
int64_t bytes_reserved;
int64_t peak_bytes_reserved;
int64_t bytes_reservable_limit;
int64_t largest_free_block_bytes;
int64_t pool_bytes;
int64_t peak_pool_bytes;
};

extern "C" void PjRtDeviceGetAllocatorStats(PjRtDevice *device,
JLAllocatorStats *jlstats) {
auto stats = MyValueOrThrow(device->GetAllocatorStats());
int64_t optnull = std::numeric_limits<int64_t>::min();

jlstats->num_allocs = stats.num_allocs;
jlstats->bytes_in_use = stats.bytes_in_use;
jlstats->peak_bytes_in_use = stats.peak_bytes_in_use;
jlstats->largest_alloc_size = stats.largest_alloc_size;
jlstats->bytes_limit = stats.bytes_limit.value_or(optnull);
jlstats->bytes_reserved = stats.bytes_reserved;
jlstats->peak_bytes_reserved = stats.peak_bytes_reserved;
jlstats->bytes_reservable_limit =
stats.bytes_reservable_limit.value_or(optnull);
jlstats->largest_free_block_bytes = stats.largest_free_block_bytes;
jlstats->pool_bytes = stats.pool_bytes.value_or(optnull);
jlstats->peak_pool_bytes = stats.peak_pool_bytes.value_or(optnull);
}

Copy link
Contributor

@glou-nes glou-nes Jan 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's possible to replace this in order to limit GetAllocatorStats calls:

extern "C" tsl::AllocatorStats getAllocatorStats(PjRtDevice *Device) {
  auto stats = Device->GetAllocatorStats();
  return stats.value(); //probably use a tsl::AllocatorStats* to deal with unsupported devices.
}

Optional are represented with a 8 bytes discriminant. So a Tuple can be used here for instance.

struct AllocatorStats
    num_allocs::Int64
    bytes_in_use::Int64
    peak_bytes_in_use::Int64
    largest_alloc_size::Int64
    bytes_limit::Tuple{Int64,Int64}
    bytes_reserved::Int64
    peak_bytes_reserved::Int64
    bytes_reservable_limit::Tuple{Int64,Int64}
    largest_free_block_bytes::Int64
    pool_bytes::Tuple{Int64,Int64}
    peak_pool_bytes::Tuple{Int64,Int64}
end

@ccall mlir_c.getAllocatorStats(device::XLA.Device)::getAllocatorStats

I cannot test this locally, I cannot build ReactantExtra with cuda. And cpu device doesn't use AllocatorStats at all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m okay either way, but if you want to go this route, have it return the result via a pointer in the first arg and have the function return void (to avoid ABI issues with struct returning functions)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should the Tuple{Int,Int} be interpreted for std::optional<int64_t> ?

Copy link
Contributor

@glou-nes glou-nes Jan 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first one is the value, the second is the discriminant:

get_value(t::Tuple{Int,Int}) = t[2] % 2 == 1 ? t[1] : nothing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm slightly concerned with directly doing this (relying on the ABI of std optional), could we instead make our own struct which explicitly contains the ints (which we unwrap in c++)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this kind of things, i use a std::tuple. you could use a std::tuple<bool,int> where the first value says if the value is valid or not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using typemin(Int64) as a sentinel like I did here should be ok and C friendly ?

Copy link
Collaborator

@mofeing mofeing Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glou-nes reading it more closely, don't rely on passing that tsl::AllocatorStats directly. you have no guarantee it will work.

you can instead:

  1. save it into a pointer (e.g. using a copy/move constructor with new) and pass the pointer
  2. create your own C-struct and pass that, without C++ objects in it

Optional are represented with a 8 bytes discriminant. So a Tuple can be used here for instance.

i had problems with this kind of tricks even with the most naive C++ classes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure! Thanks for the feedback, seems fair to only consider C ABI. My solution is hacky anyway. That kind of bug are a pain to debug, I got several with Ops.convolution. 2) is probably the way here.

Copy link
Collaborator Author

@Pangoraw Pangoraw Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the code to make a single call to GetAllocatorStats. But I could not try on a CUDA machine so far (only by using a manually constructed tsl::AllocatorStats).

extern "C" void ExecutableFree(xla::PjRtLoadedExecutable *exec) { delete exec; }

extern "C" PjRtDevice *BufferToDevice(PjRtBuffer *Buffer) {
Expand Down Expand Up @@ -443,7 +477,7 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
if (ReactantThrowError) {
llvm::errs() << lmod << "\n";
ReactantThrowError(err_str.c_str());
return wrap((mlir::ModuleOp)nullptr);
return wrap((mlir::ModuleOp) nullptr);
}
}
mlir::MLIRContext &context = *unwrap(cctx);
Expand Down Expand Up @@ -642,8 +676,8 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,

if (auto func = dyn_cast<FunctionOpInterface>(op.getOperation())) {
if (func.isExternal()) {
shouldRemove = true;
return success();
shouldRemove = true;
return success();
}
}

Expand Down Expand Up @@ -678,13 +712,14 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
}

bool shouldRemove = false;
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, shouldRemove))) {
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID,
shouldRemove))) {
assert(0 && "failed to update all uses");
}
if (shouldRemove)
op.erase();
op.erase();
else
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
}
prevMod.getBody()->getOperations().splice(
prevMod.getBody()->getOperations().end(),
Expand Down
1 change: 1 addition & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ cc_library(
"-Wl,-exported_symbol,_ClientProcessIndex",
"-Wl,-exported_symbol,_ClientGetDevice",
"-Wl,-exported_symbol,_ClientGetAddressableDevice",
"-Wl,-exported_symbol,_PjRtDeviceGetAllocatorStats",
"-Wl,-exported_symbol,_ExecutableFree",
"-Wl,-exported_symbol,_BufferToDevice",
"-Wl,-exported_symbol,_BufferToClient",
Expand Down
54 changes: 54 additions & 0 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,60 @@ function client(device::Device)
end
end

# To keep in sync with JLAllocatorStats in ReactantExtra/API.cpp
struct JLAllocatorStats
num_allocs::Int64
bytes_in_use::Int64
peak_bytes_in_use::Int64
largest_alloc_size::Int64
bytes_limit::Int64
bytes_reserved::Int64
peak_bytes_reserved::Int64
bytes_reservable_limit::Int64
largest_free_block_bytes::Int64
pool_bytes::Int64
peak_pool_bytes::Int64
end

struct AllocatorStats
num_allocs::Int64
bytes_in_use::Int64
peak_bytes_in_use::Int64
largest_alloc_size::Int64
bytes_limit::Union{Nothing,Int64}
bytes_reserved::Int64
peak_bytes_reserved::Int64
bytes_reservable_limit::Union{Nothing,Int64}
largest_free_block_bytes::Int64
pool_bytes::Union{Nothing,Int64}
peak_pool_bytes::Union{Nothing,Int64}
end

function allocatorstats(
device::Device=ClientGetDevice(default_backend[], default_device_idx[])
)
ref = Ref{JLAllocatorStats}()
@ccall MLIR.API.mlir_c.PjRtDeviceGetAllocatorStats(
device.device::Ptr{Cvoid}, ref::Ptr{Cvoid}
)::Cvoid
stats = ref[]

nullopt = typemin(Int64)
return AllocatorStats(
stats.num_allocs,
stats.bytes_in_use,
stats.peak_bytes_in_use,
stats.largest_alloc_size,
stats.bytes_limit == nullopt ? nothing : stats.bytes_limit,
stats.bytes_reserved,
stats.peak_bytes_reserved,
stats.bytes_reservable_limit == nullopt ? nothing : stats.bytes_reservable_limit,
stats.largest_free_block_bytes,
stats.pool_bytes == nullopt ? nothing : stats.pool_bytes,
stats.peak_pool_bytes == nullopt ? nothing : stats.peak_pool_bytes,
)
end

# https://github.com/openxla/xla/blob/4bfb5c82a427151d6fe5acad8ebe12cee403036a/xla/xla_data.proto#L29
@inline primitive_type(::Type{Bool}) = 1

Expand Down
Loading