Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 92 additions & 2 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const _memory_tracker = Dict{Symbol, Dict{UInt, Tuple{Int, Type, Tuple}}}()

function ConcreteRNumber{T}(
data::T2;
client::XLA.Client=XLA.default_backend[],
Expand Down Expand Up @@ -66,9 +68,32 @@ function ConcreteRArray(
device::Union{Nothing,XLA.Device}=nothing,
) where {T,N}
device = device === nothing ? XLA.ClientGetDevice(client, idx) : device
return ConcreteRArray{T,N}(
XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, data, device), nothing), size(data)

# Create array
arr = ConcreteRArray{T,N}(
XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, data, device), nothing),
size(data)
)

# Track memory allocation with additional info
client_sym = Symbol(client)
if !haskey(_memory_tracker, client_sym)
_memory_tracker[client_sym] = Dict{UInt, Tuple{Int, Type, Tuple}}()
end
_memory_tracker[client_sym][objectid(arr)] = (
sizeof(T) * prod(size(data)), # size in bytes
T, # element type
size(data) # shape
)

# Add finalizer
finalizer(arr) do x
if haskey(_memory_tracker, client_sym)
delete!(_memory_tracker[client_sym], objectid(x))
end
end

return arr
end

ConcreteRArray(x::AnyConcreteRArray) = ConcreteRArray{eltype(x),ndims(x)}(x)
Expand Down Expand Up @@ -374,3 +399,68 @@ function Base.fill!(a::ConcreteRArray{T,N}, val) where {T,N}
fn(a, val, idxs...)
return a
end

"""
get_memory_allocated(; client=nothing) -> Int

Returns the amount of device memory currently allocated for RArrays in bytes.
If a client is specified, only the memory usage on that client is returned.
"""
function get_memory_allocated(; client=nothing)
if isnothing(client)
return sum(sum(first(x) for x in values(dict); init=0) for dict in values(_memory_tracker); init=0)
else
client_sym = Symbol(client)
dict = get(_memory_tracker, client_sym, Dict{UInt, Tuple{Int, Type, Tuple}}())
return sum(first(x) for x in values(dict); init=0)
end
end

# Optional: Add a convenience function for GB units
get_memory_allocated_gb(; client=nothing) = get_memory_allocated(; client) / 1024^3

"""
get_largest_arrays(k::Int=5; client=nothing) -> Vector{NamedTuple}

Return information about the k arrays occupying the most video memory, including their size, shape, and memory usage. If a client is specified, only the arrays on that client will be displayed.

Return format: Vector{(shape=..., size_bytes=..., size_mb=...)}
"""
function get_largest_arrays(k::Int=5; client=nothing)
arrays_info = []

clients_to_check = if isnothing(client)
keys(_memory_tracker)
else
[Symbol(client)]
end

for client_sym in clients_to_check
client_dict = get(_memory_tracker, client_sym, Dict{UInt, Tuple{Int, Type, Tuple}}())
for (obj_id, (size_bytes, type, shape)) in client_dict
push!(arrays_info, (
shape = shape,
type = type,
size_bytes = size_bytes,
size_mb = round(size_bytes / 1024^2, digits=2),
client = client_sym
))
end
end

sort!(arrays_info, by=x->x.size_bytes, rev=true)
return first(arrays_info, k)
end

"""
print_largest_arrays(k::Int=5; client=nothing)

Print the information of the k arrays occupying the most video memory.
"""
function print_largest_arrays(k::Int=5; client=nothing)
arrays = get_largest_arrays(k; client)
println("Top $k arrays by memory usage:")
for (i, arr) in enumerate(arrays)
println("$i. Type: $(arr.type), Shape: $(arr.shape), Size: $(arr.size_mb) MB, Client: $(arr.client)")
end
end
1 change: 1 addition & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ end

using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace
export get_memory_allocated, get_memory_allocated_gb, get_largest_arrays, print_largest_arrays

const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()

Expand Down
Loading