diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 20f4d6a83c..7dabbe829f 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -1,3 +1,5 @@ +const _memory_tracker = Dict{Symbol, Dict{UInt, Tuple{Int, Type, Tuple}}}() + function ConcreteRNumber{T}( data::T2; client::XLA.Client=XLA.default_backend[], @@ -66,9 +68,32 @@ function ConcreteRArray( device::Union{Nothing,XLA.Device}=nothing, ) where {T,N} device = device === nothing ? XLA.ClientGetDevice(client, idx) : device - return ConcreteRArray{T,N}( - XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, data, device), nothing), size(data) + + # Create array + arr = ConcreteRArray{T,N}( + XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, data, device), nothing), + size(data) + ) + + # Track memory allocation with additional info + client_sym = Symbol(client) + if !haskey(_memory_tracker, client_sym) + _memory_tracker[client_sym] = Dict{UInt, Tuple{Int, Type, Tuple}}() + end + _memory_tracker[client_sym][objectid(arr)] = ( + sizeof(T) * prod(size(data)), # size in bytes + T, # element type + size(data) # shape ) + + # Add finalizer + finalizer(arr) do x + if haskey(_memory_tracker, client_sym) + delete!(_memory_tracker[client_sym], objectid(x)) + end + end + + return arr end ConcreteRArray(x::AnyConcreteRArray) = ConcreteRArray{eltype(x),ndims(x)}(x) @@ -374,3 +399,68 @@ function Base.fill!(a::ConcreteRArray{T,N}, val) where {T,N} fn(a, val, idxs...) return a end + +""" + get_memory_allocated(; client=nothing) -> Int + +Returns the amount of device memory currently allocated for RArrays in bytes. +If a client is specified, only the memory usage on that client is returned. +""" +function get_memory_allocated(; client=nothing) + if isnothing(client) + return sum(sum(first(x) for x in values(dict); init=0) for dict in values(_memory_tracker); init=0) + else + client_sym = Symbol(client) + dict = get(_memory_tracker, client_sym, Dict{UInt, Tuple{Int, Type, Tuple}}()) + return sum(first(x) for x in values(dict); init=0) + end +end + +# Optional: Add a convenience function for GB units +get_memory_allocated_gb(; client=nothing) = get_memory_allocated(; client) / 1024^3 + +""" + get_largest_arrays(k::Int=5; client=nothing) -> Vector{NamedTuple} + +Return information about the k arrays occupying the most video memory, including their size, shape, and memory usage. If a client is specified, only the arrays on that client will be displayed. + +Return format: Vector{(shape=..., size_bytes=..., size_mb=...)} +""" +function get_largest_arrays(k::Int=5; client=nothing) + arrays_info = [] + + clients_to_check = if isnothing(client) + keys(_memory_tracker) + else + [Symbol(client)] + end + + for client_sym in clients_to_check + client_dict = get(_memory_tracker, client_sym, Dict{UInt, Tuple{Int, Type, Tuple}}()) + for (obj_id, (size_bytes, type, shape)) in client_dict + push!(arrays_info, ( + shape = shape, + type = type, + size_bytes = size_bytes, + size_mb = round(size_bytes / 1024^2, digits=2), + client = client_sym + )) + end + end + + sort!(arrays_info, by=x->x.size_bytes, rev=true) + return first(arrays_info, k) +end + +""" + print_largest_arrays(k::Int=5; client=nothing) + +Print the information of the k arrays occupying the most video memory. +""" +function print_largest_arrays(k::Int=5; client=nothing) + arrays = get_largest_arrays(k; client) + println("Top $k arrays by memory usage:") + for (i, arr) in enumerate(arrays) + println("$i. Type: $(arr.type), Shape: $(arr.shape), Size: $(arr.size_mb) MB, Client: $(arr.client)") + end +end diff --git a/src/Reactant.jl b/src/Reactant.jl index 6396bd65c8..0f724988d0 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -230,6 +230,7 @@ end using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace +export get_memory_allocated, get_memory_allocated_gb, get_largest_arrays, print_largest_arrays const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()