diff --git a/cpp/src/arrow/memory_pool-test.cc b/cpp/src/arrow/memory_pool-test.cc index c5e3ef295a7f6..8915708ea2319 100644 --- a/cpp/src/arrow/memory_pool-test.cc +++ b/cpp/src/arrow/memory_pool-test.cc @@ -91,4 +91,25 @@ TEST(LoggingMemoryPool, Logging) { ASSERT_EQ(200, pool->max_memory()); } + +TEST(ProxyMemoryPool, Logging) { + MemoryPool* pool = default_memory_pool(); + + ProxyMemoryPool pp(pool); + + uint8_t* data; + ASSERT_OK(pool->Allocate(100, &data)); + + uint8_t* data2; + ASSERT_OK(pp.Allocate(300, &data2)); + + ASSERT_EQ(400, pool->bytes_allocated()); + ASSERT_EQ(300, pp.bytes_allocated()); + + pool->Free(data, 100); + pp.Free(data2, 300); + + ASSERT_EQ(0, pool->bytes_allocated()); + ASSERT_EQ(0, pp.bytes_allocated()); +} } // namespace arrow diff --git a/cpp/src/arrow/memory_pool.cc b/cpp/src/arrow/memory_pool.cc index dedab7ea7291d..34bd600e83f44 100644 --- a/cpp/src/arrow/memory_pool.cc +++ b/cpp/src/arrow/memory_pool.cc @@ -201,4 +201,40 @@ int64_t LoggingMemoryPool::max_memory() const { std::cout << "max_memory: " << mem << std::endl; return mem; } + +ProxyMemoryPool::ProxyMemoryPool(MemoryPool* pool) : pool_(pool) {} + +Status ProxyMemoryPool::Allocate(int64_t size, uint8_t** out) { + RETURN_NOT_OK(pool_->Allocate(size, out)); + bytes_allocated_ += size; + { + std::lock_guard guard(lock_); + if (bytes_allocated_ > max_memory_) { + max_memory_ = bytes_allocated_.load(); + } + } + return Status::OK(); +} + +Status ProxyMemoryPool::Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) { + RETURN_NOT_OK(pool_->Reallocate(old_size, new_size, ptr)); + bytes_allocated_ += new_size - old_size; + { + std::lock_guard guard(lock_); + if (bytes_allocated_ > max_memory_) { + max_memory_ = bytes_allocated_.load(); + } + } + return Status::OK(); +} + +void ProxyMemoryPool::Free(uint8_t* buffer, int64_t size) { + pool_->Free(buffer, size); + bytes_allocated_ -= size; +} + +int64_t ProxyMemoryPool::bytes_allocated() const { return bytes_allocated_.load(); } + +int64_t ProxyMemoryPool::max_memory() const { return max_memory_.load(); } + } // namespace arrow diff --git a/cpp/src/arrow/memory_pool.h b/cpp/src/arrow/memory_pool.h index 348343b54e918..de588965d9eba 100644 --- a/cpp/src/arrow/memory_pool.h +++ b/cpp/src/arrow/memory_pool.h @@ -20,6 +20,7 @@ #include #include +#include #include "arrow/util/visibility.h" @@ -86,6 +87,30 @@ class ARROW_EXPORT LoggingMemoryPool : public MemoryPool { MemoryPool* pool_; }; +/// Derived class for memory allocation. +/// +/// Tracks the number of bytes and maximum memory allocated through its direct +/// calls. Actual allocation is delegated to MemoryPool class. +class ARROW_EXPORT ProxyMemoryPool : public MemoryPool { + public: + explicit ProxyMemoryPool(MemoryPool* pool); + + Status Allocate(int64_t size, uint8_t** out) override; + Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override; + + void Free(uint8_t* buffer, int64_t size) override; + + int64_t bytes_allocated() const override; + + int64_t max_memory() const override; + + private: + mutable std::mutex lock_; + MemoryPool* pool_; + std::atomic bytes_allocated_{0}; + std::atomic max_memory_{0}; +}; + ARROW_EXPORT MemoryPool* default_memory_pool(); #ifdef ARROW_NO_DEFAULT_MEMORY_POOL diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 09f907c8a8a93..e4caf446777b9 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -90,7 +90,7 @@ def parse_version(root): from pyarrow.lib import (Buffer, ResizableBuffer, foreign_buffer, py_buffer, compress, decompress, allocate_buffer) -from pyarrow.lib import (MemoryPool, total_allocated_bytes, +from pyarrow.lib import (MemoryPool, ProxyMemoryPool, total_allocated_bytes, set_memory_pool, default_memory_pool, log_memory_allocations) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 9a0b4687e124a..660c119bda569 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -192,6 +192,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CLoggingMemoryPool" arrow::LoggingMemoryPool"(CMemoryPool): CLoggingMemoryPool(CMemoryPool*) + cdef cppclass CProxyMemoryPool" arrow::ProxyMemoryPool"(CMemoryPool): + CProxyMemoryPool(CMemoryPool*) + cdef cppclass CBuffer" arrow::Buffer": CBuffer(const uint8_t* data, int64_t size) const uint8_t* data() diff --git a/python/pyarrow/memory.pxi b/python/pyarrow/memory.pxi index 3d2601f89c8f3..1a461ef20c726 100644 --- a/python/pyarrow/memory.pxi +++ b/python/pyarrow/memory.pxi @@ -44,6 +44,19 @@ cdef class LoggingMemoryPool(MemoryPool): self.init(self.logging_pool.get()) +cdef class ProxyMemoryPool(MemoryPool): + """ + Derived MemoryPool class that tracks the number of bytes and + maximum memory allocated through its direct calls. + """ + cdef: + unique_ptr[CProxyMemoryPool] proxy_pool + + def __cinit__(self, MemoryPool pool): + self.proxy_pool.reset(new CProxyMemoryPool(pool.pool)) + self.init(self.proxy_pool.get()) + + def default_memory_pool(): cdef: MemoryPool pool = MemoryPool()