diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b5a724fc05a..55c31eb1698 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -3304,6 +3304,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/lower_functional_ops.h", "common_runtime/lower_while_op.h", "common_runtime/memory_planner.h", + "common_runtime/gpu_memory_planner.h", "common_runtime/memory_types.h", "common_runtime/metrics.h", "common_runtime/mkl_cpu_allocator.h", @@ -3325,9 +3326,11 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/session_factory.h", "common_runtime/simple_propagator_state.h", "common_runtime/single_threaded_cpu_device.h", + "common_runtime/size_class.h", "common_runtime/stats_publisher_interface.h", "common_runtime/step_stats_collector.h", "common_runtime/tensorpool_allocator.h", + "common_runtime/gpu_tensorpool_allocator.h", "common_runtime/threadpool_device.h", "common_runtime/process_state.h", "common_runtime/pool_allocator.h", @@ -3374,6 +3377,7 @@ tf_cuda_library( "common_runtime/lower_if_op.cc", "common_runtime/lower_while_op.cc", "common_runtime/memory_planner.cc", + "common_runtime/gpu_memory_planner.cc", "common_runtime/memory_types.cc", "common_runtime/metrics.cc", "common_runtime/mkl_cpu_allocator.cc", @@ -3404,6 +3408,7 @@ tf_cuda_library( "common_runtime/stats_publisher_interface.cc", "common_runtime/step_stats_collector.cc", "common_runtime/tensorpool_allocator.cc", + "common_runtime/gpu_tensorpool_allocator.cc", "common_runtime/threadpool_device.cc", "common_runtime/threadpool_device_factory.cc", "graph/gradients.cc", @@ -4670,6 +4675,25 @@ tf_cc_test_gpu( ], ) +tf_cc_test_gpu( + name = "gpu_tensorpool_allocator_test", + size = "medium", + srcs = ["common_runtime/gpu_tensorpool_allocator_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":core", + ":core_cpu", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":test", + ":test_main", + ":testlib", + ], +) + tf_cuda_cc_test( name = "gpu_device_unified_memory_test", size = "small", diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 939b12c1841..e4c3517103c 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include #include +#include +#include #include "absl/container/flat_hash_set.h" #include "tensorflow/core/common_runtime/collective_executor_mgr.h" @@ -32,6 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/memory_planner.h" +#include "tensorflow/core/common_runtime/gpu_memory_planner.h" #include "tensorflow/core/common_runtime/metrics.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/process_util.h" @@ -770,6 +773,21 @@ Status DirectSession::RunInternal( return Status::OK(); } +bool DirectSession::EnableTensorPoolTracking(ExecutorsAndKeys* executors_and_keys) { + static std::unordered_map has_training_graph; + if (has_training_graph.find(executors_and_keys) == has_training_graph.end()) { + for (const PerPartitionExecutorsAndLib& partition : + executors_and_keys->items) { + if (partition.graph->IsTrainingGraph()) { + has_training_graph[executors_and_keys] = true; + return true; + } + } + has_training_graph[executors_and_keys] = false; + } + return has_training_graph[executors_and_keys]; +} + Status DirectSession::Run(const RunOptions& run_options, const NamedTensorList& inputs, const std::vector& output_names, @@ -781,6 +799,7 @@ Status DirectSession::Run(const RunOptions& run_options, direct_session_runs->GetCell()->IncrementBy(1); ScopedMemoryCollector scoped_memory_collector; + std::unique_ptr scoped_memory_collector_gpu_ptr; // Extract the inputs names for this run of the session. std::vector input_tensor_names; @@ -804,6 +823,9 @@ Status DirectSession::Run(const RunOptions& run_options, { mutex_lock l(collective_graph_key_lock_); collective_graph_key_ = executors_and_keys->collective_graph_key; + if (EnableTensorPoolTracking(executors_and_keys)) { + scoped_memory_collector_gpu_ptr.reset(new ScopedMemoryCollectorGPU); + } } // Configure a call frame for the step, which we use to feed and diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index fcd08e9f38c..f95d56ffc0a 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -253,6 +253,9 @@ class DirectSession : public Session { RunMetadata* run_metadata, const thread::ThreadPoolOptions& threadpool_options); + // Returns whether enable tracking of tensorpool allocator + bool EnableTensorPoolTracking(ExecutorsAndKeys* executors_and_keys); + // Returns whether inter-op execution uses a global pool or the input // `run_options` requests being run on inter_op_thread_pool = 0 in case // multiple pools are configured. diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index cead8c1612a..1f66006d2ad 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -209,7 +209,7 @@ class EigenGpuStreamDevice : public ::Eigen::StreamInterface { LogMemory::RecordRawDeallocation(data->operation_, data->step_id_, data->address_, data->allocator_, false); } - data->allocator_->DeallocateRaw(data->address_); + data->allocator_->DeallocateRawAsync(data->address_); delete data; } diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc index d9e80c28232..1ae1ccec178 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_vmem_allocator.h" #include "tensorflow/core/common_runtime/pool_allocator.h" #include "tensorflow/core/common_runtime/shared_counter.h" +#include "tensorflow/core/common_runtime/gpu_tensorpool_allocator.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/tracking_allocator.h" @@ -62,6 +63,12 @@ bool useCudaMallocAsyncAllocator() { std::strcmp(debug_allocator_str, "cuda_malloc_async") == 0; } +bool useTensorPoolAllocator() { + const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR"); + return debug_allocator_str != nullptr && + std::strcmp(debug_allocator_str, "tensorpool") == 0; +} + } // namespace /*static*/ GPUProcessState* GPUProcessState::singleton(GPUProcessState* ps) { @@ -122,21 +129,35 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options, (options.per_process_gpu_memory_fraction() > 1.0 || options.experimental().use_unified_memory()), gpu_visitors_[bus_id], {}); - GPUBFCAllocator* gpu_bfc_allocator = - new GPUBFCAllocator(sub_allocator, total_bytes, options, + Allocator* gpu_allocator = nullptr; + GPUBFCAllocator* gpu_bfc_allocator = nullptr; + if (useTensorPoolAllocator()) { + gpu_allocator = + new GPUTensorPoolAllocator(sub_allocator, + strings::StrCat("GPU_", tf_gpu_id.value(), "_tensorpool"), + total_bytes); + } else { + gpu_bfc_allocator = + new GPUBFCAllocator(sub_allocator, total_bytes, options, strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc")); - Allocator* gpu_allocator = gpu_bfc_allocator; - // GPUVMemAllocator will allocate host memory as backup after running out of - // gpu device memory to avoid OOM failures - gpu_allocator = maybe_create_gpu_vmem_allocator(gpu_allocator, - bus_id, - platform_gpu_id, - tf_gpu_id.value(), - stream_exec); + gpu_allocator = gpu_bfc_allocator; + // GPUVMemAllocator will allocate host memory as backup after running out of + // gpu device memory to avoid OOM failures + gpu_allocator = maybe_create_gpu_vmem_allocator(gpu_allocator, + bus_id, + platform_gpu_id, + tf_gpu_id.value(), + stream_exec); + } + SharedCounter* timing_counter = nullptr; if (options.experimental().timestamped_allocator()) { - timing_counter = new SharedCounter; - gpu_bfc_allocator->SetTimingCounter(timing_counter); + if (useTensorPoolAllocator()) { + LOG(WARNING) << "TensorPoolAllocator " << "don't support timestamped_allocator"; + } else { + timing_counter = new SharedCounter; + gpu_bfc_allocator->SetTimingCounter(timing_counter); + } } // If true, checks for memory overwrites by writing @@ -197,7 +218,10 @@ SharedCounter* GPUProcessState::GPUAllocatorCounter(TfGpuId tf_gpu_id) { << " but only have " << gpu_allocators_.size(); return nullptr; } - + if (useTensorPoolAllocator()) { + LOG(WARNING) << "TensorPoolAllocator " << "don't support timestamped_allocator"; + return nullptr; + } AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()]; if (allocator_parts.counter.get() == nullptr) { SharedCounter* timing_counter = new SharedCounter; diff --git a/tensorflow/core/common_runtime/gpu_memory_planner.cc b/tensorflow/core/common_runtime/gpu_memory_planner.cc new file mode 100644 index 00000000000..3794ec27cc8 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu_memory_planner.cc @@ -0,0 +1,554 @@ +#include "tensorflow/core/common_runtime/gpu_tensorpool_allocator.h" +#include "tensorflow/core/common_runtime/gpu_memory_planner.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/env_var.h" +#include +#include +#include +#include + +namespace tensorflow { + +namespace { +constexpr int64 DEFAULT_START_STATISTIC_STEP = 10; +constexpr int64 DEFAULT_STOP_STATISTIC_STEP = 110; + +bool AllocTimeCompare(AllocStats* s1, AllocStats*s2) { + return s1->begin < s2->begin; +} +} + +GPUMemoryPlanner::GPUMemoryPlanner() : + is_stats_(false), + inited_(false), + allocator_(nullptr), + thread_pool_(nullptr), + counter_(0), + start_step_(DEFAULT_START_STATISTIC_STEP), + stop_step_(DEFAULT_STOP_STATISTIC_STEP) { + InitStepInfo(); + InitPolicy(); +} + +GPUMemoryPlanner::~GPUMemoryPlanner() { +} + +void GPUMemoryPlanner::InitPolicy() { + lifetime_stats_polices_.emplace_back( + new GPULifetimePolicy(_4KB, _4KB_OFFSET, _32KB)); + lifetime_stats_polices_.emplace_back( + new GPULifetimePolicy(_8KB, _8KB_OFFSET, _32KB)); + lifetime_stats_polices_.emplace_back( + new GPULifetimePolicy(_16KB, _16KB_OFFSET, _32KB)); + small_bins_.resize(kClassNum); + for (int i = 0; i < kClassNum; ++i) { + small_bins_[i] = new GPULifetimeBin(i, kSizeClass[i]); + } +} + +void GPUMemoryPlanner::InitStepInfo() { + Status s = ReadInt64FromEnvVar("START_STATISTIC_STEP", + DEFAULT_START_STATISTIC_STEP, + &start_step_); + s = ReadInt64FromEnvVar("STOP_STATISTIC_STEP", + DEFAULT_STOP_STATISTIC_STEP, + &stop_step_); +} + +// lifetime policy +GPULifetimePolicy* GPUMemoryPlanner::BestLifetimePolicy() { + GPULifetimePolicy* best_policy = nullptr; + auto total_mem = std::numeric_limits::max(); + for (auto policy : lifetime_stats_polices_) { + auto policy_mem = policy->TotalMem(); + if (policy_mem < total_mem) { + best_policy = policy; + total_mem = policy_mem; + } + } + return best_policy; +} + +std::vector& GPUMemoryPlanner::GetSmallBins() { + return small_bins_; +} + +GPULifetimeBin* GPUMemoryPlanner::GetSmallBin(size_t size) { + return small_bins_[kSmallSizeMap.GetClass(size)]; +} + +void GPUMemoryPlanner::Reset() { + counter_ = 0; + Cleanup(); +} + +void GPUMemoryPlanner::StartCollect() { + if (is_stats_.load()) { + BestFit(); + ResetStats(); + } + + auto current = counter_.fetch_add(1); + if (current == start_step_) { + is_stats_ = true; + } else if (current == stop_step_) { + is_stats_ = false; + CollectDone(); + } + if (allocator_ != nullptr) { + allocator_->BeginStep(); + } +} + +void GPUMemoryPlanner::BestFit() { + for (auto policy : lifetime_stats_polices_) { + policy->BestFit(); + } + for (auto bin : small_bins_) { + bin->SmallFit(); + } +} + +void GPUMemoryPlanner::ResetStats() { + for (auto policy : lifetime_stats_polices_) { + policy->ResetStats(); + } + for (auto bin : small_bins_) { + bin->ResetStats(); + } + std::lock_guard l(stats_lock_); + for (auto s : alloc_stats_) { + delete s; + } + alloc_stats_.clear(); +} + +void GPUMemoryPlanner::StopCollect() { + // Make sure counter_ load is atomic. +} + +void GPUMemoryPlanner::CollectDone() { + Schedule([this]() { + if (allocator_ != nullptr) { + allocator_->Init(); + } + Cleanup(); + inited_ = true; + }); +} + +void GPUMemoryPlanner::Cleanup() { + for (auto policy : lifetime_stats_polices_) { + policy->Cleanup(); + } + for (auto bin : small_bins_) { + bin->Cleanup(); + } + std::lock_guard l(stats_lock_); + for (auto it : ptr_stats_) { + delete it.second; + } + ptr_stats_.clear(); + for (auto s : alloc_stats_) { + delete s; + } + alloc_stats_.clear(); +} + +void GPUMemoryPlanner::SetAllocator(GPUTensorPoolAllocator* allocator) { + allocator_ = allocator; +} + +void GPUMemoryPlanner::SetThreadPool(thread::ThreadPool* thread_pool) { + if (thread_pool_ == nullptr) { + thread_pool_ = thread_pool; + } +} +void GPUMemoryPlanner::Schedule(std::function f) { + if (thread_pool_ == nullptr) { + f(); + } else { + thread_pool_->Schedule(std::move(f)); + } +} + +void GPUMemoryPlanner::TrackAllocate(size_t alignment, size_t num_bytes, void* ptr) { + if (!is_stats_.load()) { + return; + } + + timeval tmp; + gettimeofday(&tmp, nullptr); + + auto alloc_stats = new AllocStats; + alloc_stats->begin = Timeval2Double(tmp); + alloc_stats->size = num_bytes; + { + std::lock_guard l(stats_lock_); + ptr_stats_[ptr] = alloc_stats; + } + + if (SmallAlloc(num_bytes)) { + GetSmallBin(num_bytes)->TrackAllocate(alignment); + return; + } + + for (auto lifetime_policy : lifetime_stats_polices_) { + lifetime_policy->TrackAllocate(alignment, num_bytes); + } +} + +void GPUMemoryPlanner::TrackDeallocate(void* ptr) { + if (!is_stats_.load()) { + return; + } + timeval tmp; + gettimeofday(&tmp, nullptr); + + AllocStats* alloc_stats; + { + std::lock_guard l(stats_lock_); + auto iter = ptr_stats_.find(ptr); + if (iter == ptr_stats_.end()) { + return; + } + alloc_stats = iter->second; + ptr_stats_.erase(iter); + alloc_stats_.emplace_back(alloc_stats); + } + alloc_stats->end = Timeval2Double(tmp); + + if (SmallAlloc(alloc_stats->size)) { + GetSmallBin(alloc_stats->size)->TrackDeallocate(alloc_stats); + return; + } + + for (auto lifetime_policy : lifetime_stats_polices_) { + lifetime_policy->TrackDeallocate(alloc_stats); + } +} + +GPULifetimePolicy::GPULifetimePolicy(size_t interval, + size_t interval_offset, size_t start) : + interval_(interval), interval_offset_(interval_offset), start_(start), + large_bin_index_(Index(_32MB, interval_, interval_offset_) + 1) { + auto cur = start_ + interval_; + bins_.resize(large_bin_index_); + for (auto i = 0; i < large_bin_index_; ++i) { + bins_[i] = new GPULifetimeBin(i, cur); + cur += interval_; + } +} + +void GPULifetimePolicy::TrackAllocate(size_t alignment, size_t num_bytes) { + auto index = Index(num_bytes, interval_, interval_offset_); + if (index < 0) { + LOG(ERROR) << "GPUTensorPoolAllocator Invalid Index:" << index + << ", size:" << num_bytes; + return; + } + GetBin(index)->TrackAllocate(alignment); +} + +GPULifetimeBin* GPULifetimePolicy::GetBin(size_t index) { + if (index >= large_bin_index_) { + std::lock_guard l(large_bin_lock_); + auto bin = large_bins_.find(index); + if (bin == large_bins_.end()) { + auto chunk_size = start_ + interval_ * (index + 1); + bin = large_bins_.emplace(index, new GPULifetimeBin(index, chunk_size)).first; + } + return bin->second; + } else { + return bins_[index]; + } +} + +void GPULifetimePolicy::TrackDeallocate(AllocStats* alloc_stats) { + auto index = Index(alloc_stats->size, interval_, interval_offset_); + if (index < 0) { + LOG(ERROR) << "GPUTensorPoolAllocator Invalid Index:" << index + << ", size:" << alloc_stats->size; + return; + } + GetBin(index)->TrackDeallocate(alloc_stats); +} + +size_t GPULifetimePolicy::TotalMem() const { + size_t total_mem = 0; + for (auto bin : bins_) { + total_mem += bin->TotalMem(); + } + { + std::lock_guard l(large_bin_lock_); + for (auto large_bin : large_bins_) { + auto bin_info = large_bin.second; + total_mem += bin_info->TotalMem(); + } + } + return total_mem; +} + +void GPULifetimePolicy::Dump() const { + // LOG(INFO_DEV) << "GPULifetimePolicy, start:" << start_ + // << ", interval:" << interval_ + // << ", Detail:"; + for (auto& b : bins_) { + b->Dump(); + } + { + std::lock_guard l(large_bin_lock_); + for (auto& large_bin : large_bins_) { + auto bin_info = large_bin.second; + bin_info->Dump(); + } + } +} + +void GPULifetimePolicy::Cleanup() { + for (auto bin : bins_) { + bin->Cleanup(); + } + { + std::lock_guard l(large_bin_lock_); + for (auto bin : large_bins_) { + auto bin_info = bin.second; + bin_info->Cleanup(); + } + } +} + +GPULifetimeBin::GPULifetimeBin(size_t bin_index, size_t chunk_size) + : bin_index_(bin_index), + chunk_size_(chunk_size), + max_alignment_(Allocator::kAllocatorAlignment) { +} + +GPULifetimeBin::~GPULifetimeBin() { +} + +void GPULifetimeBin::TrackAllocate(size_t alignment) { + std::lock_guard l(stats_lock_); + max_alignment_ = std::max(max_alignment_, alignment); +} + +void GPULifetimeBin::TrackDeallocate(AllocStats* stats) { + // multiple thread enter + std::lock_guard l(stats_lock_); + stats_.emplace_back(stats); +} + +void GPULifetimePolicy::BestFit() { + std::lock_guard l(large_bin_lock_); + for (auto it = large_bins_.rbegin(); + it != large_bins_.rend(); ++it) { + auto bin_info = it->second; + bin_info->BestFit(this); + } + for (auto it = bins_.rbegin(); it != bins_.rend(); ++it) { + (*it)->BestFit(this); + } +} + +std::vector& GPULifetimePolicy::GetBins() { + return bins_; +} + +std::map& GPULifetimePolicy::GetLargeBins() { + return large_bins_; +} + +size_t GPULifetimePolicy::Alignment() const { + return interval_; +} + +size_t GPULifetimePolicy::AlignmentOffset() const { + return interval_offset_; +} + +size_t GPULifetimePolicy::Interval() { + return interval_; +} + +void GPULifetimePolicy::ResetStats() { + { + std::lock_guard l(large_bin_lock_); + for (auto it : large_bins_) { + auto bin_info = it.second; + bin_info->ResetStats(); + } + } + for (auto bin : bins_) { + bin->ResetStats(); + } +} + +void GPULifetimeBin::Cleanup() { + for (auto block : blocks_) { + delete block; + } + blocks_.clear(); + + for (auto vblock : virtual_blocks_) { + delete vblock; + } + virtual_blocks_.clear(); + + // stats_ pointer's memory would be clear by blocks. + // protect stats_ could be touched by other thread. + std::lock_guard l(stats_lock_); + stats_.clear(); +} + +void GPULifetimeBin::BestFit(GPULifetimePolicy* policy) { + std::lock_guard l(stats_lock_); + for (auto vb : virtual_blocks_) { + delete vb; + } + virtual_blocks_.clear(); + if (stats_.empty()) { + return; + } + // sort by alloc time + std::sort(stats_.begin(), stats_.end(), AllocTimeCompare); + for (auto s : stats_) { + auto block = FindBlock(s); + if (block != nullptr) { + block->Insert(s); + continue; + } + block = policy->FindBlock(s, bin_index_+1); + if (block != nullptr) { + block->Insert(s); + auto vblock = new VirtualGPUAllocBlock(block, chunk_size_); + virtual_blocks_.emplace_back(vblock); + continue; + } + block = new GPUAllocBlock(chunk_size_, bin_index_); + block->Insert(s); + blocks_.emplace_back(block); + } +} + +void GPULifetimeBin::SmallFit() { + std::lock_guard l(stats_lock_); + if (stats_.empty()) { + return; + } + for (auto s : stats_) { + auto block = FindBlock(s); + if (block != nullptr) { + block->Insert(s); + continue; + } + block = new GPUAllocBlock(chunk_size_, bin_index_); + block->Insert(s); + blocks_.emplace_back(block); + } +} + +void GPULifetimeBin::ResetStats() { + std::lock_guard l(stats_lock_); + for (auto b : blocks_) { + b->ResetStats(); + } + stats_.clear(); +} + +GPUAllocBlock* GPULifetimeBin::FindBlock(AllocStats* stats) { + for (auto block : blocks_) { + if (block->CanInsert(stats)) { + return block; + } + } + return nullptr; +} + +size_t GPULifetimeBin::BlockSize() const { + return blocks_.size(); +} + +size_t GPULifetimeBin::ChunkSize() const { + return chunk_size_; +} + +size_t GPULifetimeBin::Alignment() const { + return max_alignment_; +} + +GPUAllocBlock* GPULifetimePolicy::FindBlock( + AllocStats* stats, size_t bindex) { + for ( ; bindex < large_bin_index_; ++bindex) { + auto block = bins_[bindex]->FindBlock(stats); + if (block != nullptr) { + return block; + } + } + // no need to lock, BestFit already hold large_bin_lock_ firstly + for (auto it = large_bins_.lower_bound(bindex); + it != large_bins_.end(); ++it) { + auto block = (it->second)->FindBlock(stats); + if (block != nullptr) { + return block; + } + } + return nullptr; +} + +size_t GPULifetimeBin::TotalMem() const { + return blocks_.size() * RoundedBytes(chunk_size_, max_alignment_); +} + +void GPULifetimeBin::Dump() const { + size_t stats_size = 0; + { + std::lock_guard l(stats_lock_); + if (stats_.empty()) { + return; + } + stats_size = stats_.size(); + } + // LOG(INFO_DEV) << "Bin index:" << bin_index_ + // << ", chunk size:" << chunk_size_ + // << ", stats counter:" << stats_size + // << ", blocks counter:" << blocks_.size() + // << ", vblocks counter:" << virtual_blocks_.size() + // << ", realsize:" << blocks_.size() * chunk_size_; +} + +GPUAllocBlock::GPUAllocBlock(size_t size, size_t bin_index) + : size_(size), bin_index_(bin_index) { +} + +bool GPUAllocBlock::CanInsert(AllocStats* alloc_stats) { + for (auto s : stats_) { + if (s->IsOverlap(alloc_stats)) { + return false; + } + } + return true; +} + +void GPUAllocBlock::Insert(AllocStats* alloc_stats) { + // single thread enter + stats_.emplace_back(alloc_stats); +} + +void GPUAllocBlock::ResetStats() { + stats_.clear(); +} + +GPUMemoryPlannerFactory::GPUMemoryPlannerFactory() { + // Enable Memory Optimization by default + Status s = ReadBoolFromEnvVar("ENABLE_MEMORY_OPTIMIZATION", + true, + &enable_memory_opt_); + if (enable_memory_opt_) { + //LOG(INFO_DEV) << "Enable Memory Optimization!"; + memory_planner_ = new GPUMemoryPlanner(); + } else { + memory_planner_ = new NullableGPUMemoryPlanner(); + } +} + +} // tensorflow diff --git a/tensorflow/core/common_runtime/gpu_memory_planner.h b/tensorflow/core/common_runtime/gpu_memory_planner.h new file mode 100644 index 00000000000..93d2b3e5c49 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu_memory_planner.h @@ -0,0 +1,244 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_MEMORYPLANNER_GPU_H_ +#define TENSORFLOW_COMMON_RUNTIME_MEMORYPLANNER_GPU_H_ + +#include "tensorflow/core/lib/core/spin_lock.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/common_runtime/memory_planner.h" +#include "tensorflow/core/common_runtime/size_class.h" + +#include +#include +#include +#include + +namespace tensorflow { +namespace { + +const SizeMap kSmallSizeMap; + +inline size_t RoundedBytes(size_t bytes, size_t alignment) { + return alignment * ((bytes + alignment - 1) / alignment); +} +} + +class GPUAllocBlock { + public: + GPUAllocBlock(size_t size, size_t bin_index); + virtual ~GPUAllocBlock() {} + + void Insert(AllocStats* alloc_stats); + bool CanInsert(AllocStats* alloc_stats); + size_t BinIndex() const { return bin_index_; } + void ResetStats(); + + private: + std::vector stats_; // not owned + size_t size_; + size_t bin_index_; +}; + +class VirtualGPUAllocBlock { + public: + VirtualGPUAllocBlock(GPUAllocBlock* block, size_t s) : + internal_block_(block), size_(s) { + }; + + size_t BinIndex() const { + return internal_block_->BinIndex(); + } + + private: + GPUAllocBlock* internal_block_; + size_t size_; +}; + +class GPULifetimePolicy; +class GPULifetimeBin { + public: + GPULifetimeBin(size_t bin_index, size_t chunk_size); + virtual ~GPULifetimeBin(); + + void TrackAllocate(size_t alignment); + void TrackDeallocate(AllocStats* stats); + void BeginStep(); + size_t TotalMem() const; + void Dump() const; + void BestFit(GPULifetimePolicy* policy); + void SmallFit(); + void Cleanup(); + + GPUAllocBlock* FindBlock(AllocStats* stats); + + size_t BlockSize() const; + size_t ChunkSize() const; + size_t Alignment() const; + size_t BinIndex() const { return bin_index_; } + std::vector& VBlocks() { + return virtual_blocks_; + } + + void ResetStats(); + + private: + mutable spin_lock stats_lock_; + std::vector stats_; // not owned + std::vector blocks_; + std::vector virtual_blocks_; + size_t bin_index_; + size_t chunk_size_; + int64_t max_alignment_; +}; + +class GPULifetimePolicy { + public: + GPULifetimePolicy(size_t interval, size_t interval_offset, size_t start); + virtual ~GPULifetimePolicy() {}; + + void TrackAllocate(size_t alignment, size_t num_bytes); + void TrackDeallocate(AllocStats* stats); + size_t TotalMem() const; + + void Dump() const; + void Cleanup(); + + GPUAllocBlock* FindBlock(AllocStats* stats, size_t bin_index); + + void BestFit(); + size_t Interval(); + + std::vector& GetBins(); + std::map& GetLargeBins(); + + size_t Alignment() const; + size_t AlignmentOffset() const; + + void ResetStats(); + + private: + GPULifetimeBin* GetBin(size_t index); + + private: + std::vector bins_; + std::map large_bins_; + mutable spin_lock large_bin_lock_; + const size_t interval_; + const size_t interval_offset_; + const size_t start_; + const size_t large_bin_index_; +}; + +class GPUTensorPoolAllocator; +class GPUMemoryPlannerBase { + public: + virtual void SetAllocator(GPUTensorPoolAllocator* allocator) = 0; + virtual void SetThreadPool(thread::ThreadPool* thread_pool) = 0; + virtual void StartCollect() = 0; + virtual void StopCollect() = 0; + virtual void TrackAllocate(size_t alignment, size_t num_bytes, void* ptr) = 0; + virtual void TrackDeallocate(void* ptr) = 0; + virtual GPULifetimePolicy* BestLifetimePolicy() = 0; + virtual std::vector& GetSmallBins() = 0; + + virtual void Reset() = 0; +}; + +class NullableGPUMemoryPlanner : public GPUMemoryPlannerBase { + void SetAllocator(GPUTensorPoolAllocator* allocator) override {} + void SetThreadPool(thread::ThreadPool* thread_pool) override {} + void StartCollect() override {} + void StopCollect() override {} + void TrackAllocate(size_t alignment, size_t num_bytes, void* ptr) override {} + void TrackDeallocate(void* ptr) override {} + + GPULifetimePolicy* BestLifetimePolicy() override { + LOG(ERROR) << "Memory Optimization is disable, shouldn't be here"; + return nullptr; + } + + std::vector& GetSmallBins() override { + std::vector tmp; + LOG(ERROR) << "Memory Optimization is disable, shouldn't be here"; + return tmp; + } + + void Reset() override {} +}; + +class GPUMemoryPlanner : public GPUMemoryPlannerBase { + public: + GPUMemoryPlanner(); + virtual ~GPUMemoryPlanner(); + + void SetAllocator(GPUTensorPoolAllocator* allocator) override; + void SetThreadPool(thread::ThreadPool* thread_pool) override; + + void StartCollect() override; + void StopCollect() override; + void TrackAllocate(size_t alignment, size_t num_bytes, void* ptr) override; + void TrackDeallocate(void* ptr) override; + + GPULifetimePolicy* BestLifetimePolicy() override; + std::vector& GetSmallBins() override; + void Reset() override; + + private: + void Schedule(std::function f); + void InitPolicy(); + void InitStepInfo(); + void CollectDone(); + void Cleanup(); + void ResetStats(); + void BestFit(); + + GPULifetimeBin* GetSmallBin(size_t size); + + private: + // statistics + std::atomic_bool is_stats_; + std::vector lifetime_stats_polices_; + std::vector small_bins_; + + GPUTensorPoolAllocator* allocator_; + thread::ThreadPool* thread_pool_; + + mutable spin_lock stats_lock_; + std::unordered_map ptr_stats_; + std::vector alloc_stats_; + + // step information + std::atomic counter_; + int64 start_step_; + int64 stop_step_; + std::atomic_bool inited_; +}; + +class GPUMemoryPlannerFactory { + public: + static GPUMemoryPlannerBase* GetMemoryPlanner() { + static GPUMemoryPlannerFactory factory; + return factory.memory_planner_; + } + + private: + GPUMemoryPlannerFactory(); + + private: + bool enable_memory_opt_; + GPUMemoryPlannerBase* memory_planner_; +}; + +class GPUScopedMemoryCollector { + public: + GPUScopedMemoryCollector() { + GPUMemoryPlannerFactory::GetMemoryPlanner()->StartCollect(); + } + ~GPUScopedMemoryCollector() { + GPUMemoryPlannerFactory::GetMemoryPlanner()->StopCollect(); + } +}; + +} + +#endif // TENSORFLOW_COMMON_RUNTIME_MEMORYPLANNER_GPU_H_ diff --git a/tensorflow/core/common_runtime/gpu_tensorpool_allocator.cc b/tensorflow/core/common_runtime/gpu_tensorpool_allocator.cc new file mode 100644 index 00000000000..97b60a9c3ef --- /dev/null +++ b/tensorflow/core/common_runtime/gpu_tensorpool_allocator.cc @@ -0,0 +1,510 @@ +#include "tensorflow/core/common_runtime/gpu_memory_planner.h" +#include "tensorflow/core/common_runtime/gpu_tensorpool_allocator.h" +#include "tensorflow/core/framework/allocator_registry.h" +#include "tensorflow/core/platform/mem.h" +#include +#include + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + +namespace tensorflow { + +GPUTensorPoolAllocator::GPUTensorPoolAllocator( + SubAllocator* sub_allocator, string name, size_t total_memory) : + name_(name), + stats_(false), + inited_(false), + initing_(false), + sub_allocator_(sub_allocator), + mem_planner_(GPUMemoryPlannerFactory::GetMemoryPlanner()), + large_bin_index_(0), + null_bin_counter_(0), + hit_counter_(0), + missed_counter_(0), + big_mem_begin_(nullptr), + big_mem_end_(nullptr), + small_mem_begin_(nullptr), + small_mem_end_(nullptr) { + mem_planner_->SetAllocator(this); + alloc_stats_.bytes_limit = static_cast(total_memory); +} + +GPUTensorPoolAllocator::~GPUTensorPoolAllocator() { + if (big_mem_begin_ != nullptr) { + sub_allocator_->Free(big_mem_begin_, big_bytes_); + } + if (small_mem_begin_ != nullptr) { + sub_allocator_->Free(small_mem_begin_, small_bytes_); + } + for (auto bin : lifetime_bins_) { + if (bin != nullptr) { + delete bin; + } + } + for (auto it : large_lifetime_bins_) { + delete it.second; + } + for (auto bin : small_bins_) { + if (bin != nullptr) { + delete bin; + } + } +} + +void GPUTensorPoolAllocator::Init() { + bool tmp = false; + if (initing_.compare_exchange_strong(tmp, true)) { + auto lifetime_policy = mem_planner_->BestLifetimePolicy(); + + alignment_ = lifetime_policy->Alignment(); + alignment_offset_ = lifetime_policy->AlignmentOffset(); + + big_bytes_ = 0; + std::map bin_to_offset; + + auto policy_bins = lifetime_policy->GetBins(); + large_bin_index_ = policy_bins.size(); + lifetime_bins_.resize(large_bin_index_); + + size_t max_alignment = 0; + + for (auto it = policy_bins.begin(); it != policy_bins.end(); + ++it) { + if ((*it)->BlockSize() > 0) { + // add padding between two bins + big_bytes_ = RoundedBytes(big_bytes_, (*it)->Alignment()); + bin_to_offset[(*it)->BinIndex()] = big_bytes_; + big_bytes_ += (*it)->TotalMem(); + max_alignment = std::max(max_alignment, (*it)->Alignment()); + } + } + + auto policy_large_bins = lifetime_policy->GetLargeBins(); + for (auto it = policy_large_bins.begin(); + it != policy_large_bins.end(); ++it) { + auto bin_info = it->second; + if (bin_info->BlockSize() > 0) { + // add padding between two bins + big_bytes_ = RoundedBytes(big_bytes_, bin_info->Alignment()); + bin_to_offset[bin_info->BinIndex()] = big_bytes_; + big_bytes_ += bin_info->TotalMem(); + max_alignment = std::max(max_alignment, bin_info->Alignment()); + } + } + + big_mem_begin_ = sub_allocator_->Alloc(max_alignment, big_bytes_); + if (big_bytes_ > 0 && big_mem_begin_ == nullptr) { + LOG(FATAL) << "OOM!!! Try to alloc(" + << max_alignment << ", " << big_bytes_ << ")"; + } + if (big_bytes_ > 0) { + big_mem_end_ = big_mem_begin_ + big_bytes_; + } else { + big_mem_end_ = nullptr; + } + + // create bigger bin first + for (auto rit = policy_large_bins.rbegin(); + rit != policy_large_bins.rend(); ++rit) { + auto bin_info = rit->second; + Bin* bin = nullptr; + if (bin_info->BlockSize() > 0) { + auto offset = bin_to_offset[rit->first]; + bin = new Bin(bin_info->BlockSize(), bin_info->ChunkSize(), + bin_info->Alignment(), bin_info->VBlocks(), + this, big_mem_begin_ + offset); + offset_to_bin_[offset] = bin; + } else if (bin_info->VBlocks().size() > 0) { + bin = new Bin(bin_info->BlockSize(), bin_info->ChunkSize(), + bin_info->Alignment(), bin_info->VBlocks(), + this, nullptr); + } + if (bin != nullptr) { + large_lifetime_bins_.emplace(rit->first, bin); + } + } + + for (auto it = policy_bins.rbegin(); it != policy_bins.rend(); + ++it) { + Bin* bin = nullptr; + if ((*it)->BlockSize() > 0) { + auto offset = bin_to_offset[(*it)->BinIndex()]; + bin = new Bin((*it)->BlockSize(), (*it)->ChunkSize(), + (*it)->Alignment(), (*it)->VBlocks(), + this, big_mem_begin_ + offset); + offset_to_bin_[offset] = bin; + } else if ((*it)->VBlocks().size() > 0) { + bin = new Bin((*it)->BlockSize(), (*it)->ChunkSize(), + (*it)->Alignment(), (*it)->VBlocks(), + this, nullptr); + } + lifetime_bins_[(*it)->BinIndex()] = bin; + } + + auto small_bins = mem_planner_->GetSmallBins(); + small_bins_.resize(small_bins.size()); + small_bytes_ = 0; + max_alignment = 0; + bin_to_offset.clear(); + for (auto b : small_bins) { + if (b->BlockSize() > 0) { + small_bytes_ = RoundedBytes(small_bytes_, b->Alignment()); + bin_to_offset[b->BinIndex()] = small_bytes_; + small_bytes_ += b->TotalMem(); + max_alignment = std::max(max_alignment, b->Alignment()); + } + } + + small_mem_begin_ = sub_allocator_->Alloc(max_alignment, small_bytes_); + if (small_bytes_ > 0 && small_mem_begin_ == nullptr) { + LOG(FATAL) << "OOM!!! Try to alloc(" + << max_alignment << ", " << small_bytes_ << ")"; + } + if (small_bytes_ > 0) { + small_mem_end_ = small_mem_begin_ + small_bytes_; + } else { + small_mem_end_ = nullptr; + } + + for (auto b : small_bins) { + SmallBin* bin = nullptr; + if (b->BlockSize() > 0) { + auto offset = bin_to_offset[b->BinIndex()]; + bin = new SmallBin(b->BlockSize(), b->ChunkSize(), + b->Alignment(), small_mem_begin_ + offset); + offset_to_small_bin_[offset] = bin; + } + small_bins_[b->BinIndex()] = bin; + } + + inited_ = true; + } +} + +void GPUTensorPoolAllocator::BeginStep() { + if (inited_.load()) { + for (auto b : lifetime_bins_) { + if (b != nullptr) { + b->BeginStep(); + } + } + for (auto it : large_lifetime_bins_) { + it.second->BeginStep(); + } + } + std::lock_guard l(free_lock_); + for (auto ptr : async_free_list_) { + sub_allocator_->Free(ptr, 0); + } + async_free_list_.clear(); +} + +void* GPUTensorPoolAllocator::AllocateRaw(size_t alignment, + size_t num_bytes) { + if (!inited_.load()) { + auto ptr = sub_allocator_->Alloc(alignment, num_bytes); + mem_planner_->TrackAllocate(alignment, num_bytes, ptr); + return ptr; + } + + if (SmallAlloc(num_bytes)) { + return SmallAllocate(alignment, num_bytes); + } + if (unlikely(stats_)) { + return BigAllocateStatistic(alignment, num_bytes); + } else { + return BigAllocate(alignment, num_bytes); + } +} + +void GPUTensorPoolAllocator::DeallocateRaw(void* ptr) { + if (!inited_.load()) { + mem_planner_->TrackDeallocate(ptr); + sub_allocator_->Free(ptr, 0); + } else if (IsBigOwned(ptr)) { + BigDeallocate(ptr); + } else if (IsSmallOwned(ptr)) { + SmallDeallocate(ptr); + } else { + sub_allocator_->Free(ptr, 0); + } +} + +void GPUTensorPoolAllocator::DeallocateRawAsync(void* ptr) { + if (!inited_.load()) { + mem_planner_->TrackDeallocate(ptr); + { + std::lock_guard l(free_lock_); + async_free_list_.push_back(ptr); + } + } else if (IsBigOwned(ptr)) { + BigDeallocate(ptr); + } else if (IsSmallOwned(ptr)) { + SmallDeallocate(ptr); + } else { + std::lock_guard l(free_lock_); + async_free_list_.push_back(ptr); + } +} + +absl::optional GPUTensorPoolAllocator::GetStats() { + return alloc_stats_; +} + +GPUTensorPoolAllocator::Bin* GPUTensorPoolAllocator::GetBin( + size_t bin_index) { + if (unlikely(bin_index < 0)) { + return nullptr; + } + + if (unlikely(bin_index >= large_bin_index_)) { + auto it = large_lifetime_bins_.find(bin_index); + if (it == large_lifetime_bins_.end()) { + return nullptr; + } else { + return it->second; + } + } + return lifetime_bins_[bin_index]; +} + +GPUTensorPoolAllocator::SmallBin* GPUTensorPoolAllocator::GetSmallBin( + size_t size) { + auto id = kSmallSizeMap.GetClass(size); + if (unlikely(id >= small_bins_.size())) { + LOG(FATAL) << "logic error"; + return nullptr; + } + return small_bins_[id]; +} + +GPUTensorPoolAllocator::SmallBin::SmallBin(size_t len, + size_t chunk_size, size_t alignment, void* begin) { + auto rounded_bytes = RoundedBytes(chunk_size, alignment); + auto buffer_size = rounded_bytes * len; + begin_ = begin; + if (begin != nullptr) { + end_ = begin + buffer_size; + } else { + end_ = nullptr; + } + + for (auto i = 0; i < len; ++i) { + buffer_.emplace(begin + rounded_bytes *i); + } +} + +void* GPUTensorPoolAllocator::SmallBin::AllocateRaw() { + std::lock_guard l(lock_); + if (unlikely(buffer_.empty())) { + return nullptr; + } + auto ptr = buffer_.top(); + buffer_.pop(); + return ptr; +} + +void GPUTensorPoolAllocator::SmallBin::DeallocateRaw(void* p) { + if (unlikely(begin_ == nullptr || p < begin_ || p > end_)) { + LOG(WARNING) << "probabaly memory corruption!! begin_: " << begin_ + << " end_: " << end_ << " p: " << p; + } + std::lock_guard l(lock_); + buffer_.emplace(p); +} + +GPUTensorPoolAllocator::Bin::Bin(size_t len, + size_t chunk_size, size_t alignment, + std::vector& vblocks, + GPUTensorPoolAllocator* tp, void* begin) : + buffer_(len, chunk_size, alignment, begin), + virtual_buffer_(vblocks, tp) { +} + +void* GPUTensorPoolAllocator::Bin::Allocate() { + auto ptr = buffer_.Allocate(); + if (ptr != nullptr) { + return ptr; + } + return virtual_buffer_.Allocate(); +} + +void* GPUTensorPoolAllocator::Bin::AllocateRaw() { + return buffer_.Allocate(); +} + +void GPUTensorPoolAllocator::Bin::DeallocateRaw(void* p) { + buffer_.Deallocate(p); +} + +void GPUTensorPoolAllocator::Bin::BeginStep() { + return virtual_buffer_.BeginStep(); +} + +GPUTensorPoolAllocator::Buffer::Buffer(size_t len, size_t chunk_size, + size_t alignment, void* begin) { + auto rounded_bytes = RoundedBytes(chunk_size, alignment); + auto buffer_size = rounded_bytes * len; + begin_ = begin; + if (begin != nullptr) { + end_ = begin + buffer_size; + } else { + end_ = nullptr; + } + + for (auto i = 0; i < len; ++i) { + buffer_.emplace(begin + rounded_bytes *i); + } +} + +void* GPUTensorPoolAllocator::Buffer::Allocate() { + std::lock_guard l(lock_); + if (unlikely(buffer_.empty())) { + return nullptr; + } + auto ptr = buffer_.top(); + buffer_.pop(); + return ptr; +} + +void GPUTensorPoolAllocator::Buffer::Deallocate(void* p) { + if (unlikely(begin_ == nullptr || p < begin_ || p > end_)) { + LOG(WARNING) << "probabaly memory corruption!! begin_: " << begin_ + << " end_: " << end_ << " p: " << p; + } + std::lock_guard l(lock_); + buffer_.emplace(p); +} + +GPUTensorPoolAllocator::VirtualBuffer::VirtualBuffer( + std::vector& vblocks, + GPUTensorPoolAllocator* tp) { + for (auto vblock : vblocks) { + auto bin_index = vblock->BinIndex(); + auto internal_bin = tp->GetBin(bin_index); + if (internal_bin == nullptr) { + LOG(WARNING) << "logic error or not allocate correctly"; + } + internal_bins_.emplace_back(internal_bin); + } + curr_index_ = 0; +} + +void* GPUTensorPoolAllocator::VirtualBuffer::Allocate() { + if (unlikely(internal_bins_.empty())) { + return nullptr; + } + auto index = curr_index_.fetch_add(1) % internal_bins_.size(); + auto bin = internal_bins_[index]; + return bin->AllocateRaw(); +} + +void GPUTensorPoolAllocator::VirtualBuffer::BeginStep() { + curr_index_ = 0; +} + +void GPUTensorPoolAllocator::DumpStats() { + if (stats_) { + double hit_rate = (double)hit_counter_ / + (hit_counter_ + missed_counter_ + null_bin_counter_); + LOG(INFO) << "If you're TensorFlow user, " + << "please ignore following debugging statistic." + << "GPUTensorPoolAllocator Statistic:" + << " hit_counter[" << hit_counter_ + << "], missed_counter[" << missed_counter_ + << "], null_bin_counter[" << null_bin_counter_ + << "], hit_rate[" << hit_rate + << "]"; + + stats_ = false; + hit_counter_ = 0; + missed_counter_ = 0; + null_bin_counter_ = 0; + } else { + stats_ = true; + LOG(INFO) << "Start counting GPUTensorPoolAllocator"; + } +} + +bool GPUTensorPoolAllocator::IsBigOwned(void *ptr) { + return (ptr >= big_mem_begin_ && ptr <= big_mem_end_); +} + +bool GPUTensorPoolAllocator::IsSmallOwned(void *ptr) { + return (ptr >= small_mem_begin_ && ptr <= small_mem_end_); +} + +void* GPUTensorPoolAllocator::SmallAllocate(size_t alignment, size_t num_bytes) { + auto bin = GetSmallBin(num_bytes); + if (unlikely(bin == nullptr)) { + return sub_allocator_->Alloc(alignment, num_bytes); + } + auto ptr = bin->AllocateRaw(); + if (likely(ptr != nullptr)) { + return ptr; + } + return sub_allocator_->Alloc(alignment, num_bytes); +} + +void* GPUTensorPoolAllocator::BigAllocate(size_t alignment, + size_t num_bytes) { + auto id = Index(num_bytes, alignment_, alignment_offset_); + if (unlikely(id < 0)) { + return sub_allocator_->Alloc(alignment, num_bytes); + } + + auto b = GetBin(id); + if (unlikely(b == nullptr)) { + return sub_allocator_->Alloc(alignment, num_bytes); + } + + auto ptr = b->Allocate(); + if (likely(ptr != nullptr)) { + return ptr; + } + + return sub_allocator_->Alloc(alignment, num_bytes); +} + +// unlikely execute this path which do some atomic operations +void* GPUTensorPoolAllocator::BigAllocateStatistic(size_t alignment, + size_t num_bytes) { + auto id = Index(num_bytes, alignment_, alignment_offset_); + if (unlikely(id < 0)) { + return sub_allocator_->Alloc(alignment, num_bytes); + } + + auto b = GetBin(id); + if (unlikely(b == nullptr)) { + ++null_bin_counter_; + return sub_allocator_->Alloc(alignment, num_bytes); + } + + auto ptr = b->Allocate(); + if (likely(ptr != nullptr)) { + ++hit_counter_; + return ptr; + } + + ++missed_counter_; + return sub_allocator_->Alloc(alignment, num_bytes); +} + +void GPUTensorPoolAllocator::SmallDeallocate(void* ptr) { + size_t offset = reinterpret_cast(ptr) - + reinterpret_cast(small_mem_begin_); + auto it = offset_to_small_bin_.upper_bound(offset); + it = std::prev(it); + it->second->DeallocateRaw(ptr); +} + +void GPUTensorPoolAllocator::BigDeallocate(void* ptr) { + size_t offset = reinterpret_cast(ptr) - + reinterpret_cast(big_mem_begin_); + auto it = offset_to_bin_.upper_bound(offset); + it = std::prev(it); + it->second->DeallocateRaw(ptr); +} + +} // tensorflow diff --git a/tensorflow/core/common_runtime/gpu_tensorpool_allocator.h b/tensorflow/core/common_runtime/gpu_tensorpool_allocator.h new file mode 100644 index 00000000000..cd8574df230 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu_tensorpool_allocator.h @@ -0,0 +1,158 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_TENSORPOOL_ALLOCATOR_GPU_H_ +#define TENSORFLOW_COMMON_RUNTIME_TENSORPOOL_ALLOCATOR_GPU_H_ + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/lib/core/spin_lock.h" + +#include +#include +#include +#include + +namespace tensorflow { + +class GPUMemoryPlannerBase; +class VirtualGPUAllocBlock; + +class GPUTensorPoolAllocator : public Allocator { + public: + GPUTensorPoolAllocator(SubAllocator* sub_allocator, string name, + size_t total_memory); + ~GPUTensorPoolAllocator() override; + + GPUTensorPoolAllocator(const GPUTensorPoolAllocator&) = delete; + GPUTensorPoolAllocator& operator=(const GPUTensorPoolAllocator&) = delete; + + void Init(); + void BeginStep(); + + string Name() override { return name_; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + void DeallocateRawAsync(void* ptr) override; + + absl::optional GetStats() override; + + void DumpStats(); + + class Bin; + Bin* GetBin(size_t bin_index); + + class VirtualBuffer { + public: + VirtualBuffer(std::vector& vblocks, + GPUTensorPoolAllocator* tp); + virtual ~VirtualBuffer() {} + + void* Allocate(); + void BeginStep(); + + private: + std::vector internal_bins_; + std::atomic curr_index_; + }; + + class Buffer { + public: + Buffer(size_t len, size_t chunk_size, + size_t alignment, void* begin); + + void* Allocate(); + void Deallocate(void* p); + + private: + mutable spin_lock lock_; + std::stack buffer_; + void* begin_; + void* end_; + }; + + class Bin { + public: + Bin(size_t len, size_t chunk_size, size_t alignment, + std::vector& vblocks, + GPUTensorPoolAllocator* tp, void* begin); + virtual ~Bin(){} + + Bin(const Bin&) = delete; + Bin& operator=(const Bin&) = delete; + + void* Allocate(); + void* AllocateRaw(); + void DeallocateRaw(void* p); + + void BeginStep(); + + private: + Buffer buffer_; + VirtualBuffer virtual_buffer_; + }; + + class SmallBin { + public: + SmallBin(size_t len, size_t chunk_size, size_t alignment, void* begin); + virtual ~SmallBin(){} + + SmallBin(const SmallBin&) = delete; + Bin& operator=(const Bin&) = delete; + + void* AllocateRaw(); + void DeallocateRaw(void* p); + + private: + mutable spin_lock lock_; + std::stack buffer_; + void* begin_; + void* end_; + }; + + private: + bool IsBigOwned(void *ptr); + bool IsSmallOwned(void *ptr); + void* BigAllocate(size_t alignment, size_t num_bytes); + void* BigAllocateStatistic(size_t alignment, size_t num_bytes); + void BigDeallocate(void* ptr); + + SmallBin* GetSmallBin(size_t size); + void* SmallAllocate(size_t alignment, size_t num_bytes); + void SmallDeallocate(void* ptr); + + private: + mutable spin_lock free_lock_; + std::vector async_free_list_; + string name_; + AllocatorStats alloc_stats_; + + bool stats_; + std::atomic_bool inited_; + std::atomic_bool initing_; + + std::unique_ptr sub_allocator_; + GPUMemoryPlannerBase* mem_planner_; + + size_t large_bin_index_; + std::vector lifetime_bins_; + std::vector small_bins_; + std::map large_lifetime_bins_; + + size_t alignment_; + size_t alignment_offset_; + size_t big_bytes_; + void *big_mem_begin_; + void *big_mem_end_; + std::map offset_to_bin_; + + size_t small_bytes_; + void *small_mem_begin_; + void *small_mem_end_; + std::map offset_to_small_bin_; + + // Statistic + std::atomic null_bin_counter_; + std::atomic hit_counter_; + std::atomic missed_counter_; +}; + +} + +#endif // TENSORFLOW_COMMON_RUNTIME_TENSORPOOL_ALLOCATOR_GPU_H_ diff --git a/tensorflow/core/common_runtime/gpu_tensorpool_allocator_test.cc b/tensorflow/core/common_runtime/gpu_tensorpool_allocator_test.cc new file mode 100644 index 00000000000..29e72ffd8db --- /dev/null +++ b/tensorflow/core/common_runtime/gpu_tensorpool_allocator_test.cc @@ -0,0 +1,404 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) +#include +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h" +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h" +#include "tensorflow/core/common_runtime/gpu_memory_planner.h" +#include "tensorflow/core/common_runtime/gpu_tensorpool_allocator.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +TEST(GPUTensorPoolAllocatorTest, SmallAllocationWithAlignment32B) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + void* p = allocator.AllocateRaw(32, 100); + EXPECT_TRUE(p != nullptr); + allocator.DeallocateRaw(p); +} + +TEST(GPUTensorPoolAllocatorTest, SmallAllocationWithAlignment8B) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + void* p = allocator.AllocateRaw(8, 100); + EXPECT_TRUE(p != nullptr); + allocator.DeallocateRaw(p); +} + +TEST(GPUTensorPoolAllocatorTest, SmallAllocationWithAlignment16B) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + void* p = allocator.AllocateRaw(16, 100); + EXPECT_TRUE(p != nullptr); + allocator.DeallocateRaw(p); +} + +TEST(GPUTensorPoolAllocatorTest, AlignedSmallAllocationWithAlignment16B) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + void* p = allocator.AllocateRaw(16, 128); + EXPECT_TRUE(p != nullptr); + allocator.DeallocateRaw(p); +} + +TEST(GPUTensorPoolAllocatorTest, SmallAllocationWithoutAlignment) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + void* p = allocator.AllocateRaw(64, 100); + EXPECT_TRUE(p != nullptr); + allocator.DeallocateRaw(p); +} + +TEST(GPUTensorPoolAllocatorTest, BigAllocation100KBWithoutAlignment) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + void* p = allocator.AllocateRaw(64, 100000); + EXPECT_TRUE(p != nullptr); + allocator.DeallocateRaw(p); +} + +TEST(GPUTensorPoolAllocatorTest, BigAllocation100KBWithAlignment128B) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + void* p = allocator.AllocateRaw(64, 100000); + EXPECT_TRUE(p != nullptr); + allocator.DeallocateRaw(p); +} + +TEST(GPUTensorPoolAllocatorTest, BigAllocation1MBWithoutAlignment32B) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + void* p = allocator.AllocateRaw(64, 1024*1024); + EXPECT_TRUE(p != nullptr); + allocator.DeallocateRaw(p); +} + +TEST(GPUTensorPoolAllocatorTest, BigAllocation128KBWithoutAlignment16B) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + void* p = allocator.AllocateRaw(64, 128 * 1024); + EXPECT_TRUE(p != nullptr); + allocator.DeallocateRaw(p); +} + +TEST(TensorPoolAllcatorTest, MixedAllocationLoops) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + std::vector alignments = {8, 16, 32, 64, 128}; + std::vector sizes = {100, 512, 2048, 128*1024, 100 * 1024}; + std::vector vec; + for (int i = 0; i < 100; ++i) { + for (auto alignment : alignments) { + for (auto size : sizes) { + void* p = allocator.AllocateRaw(alignment, size); + EXPECT_TRUE(p != nullptr); + vec.emplace_back(p); + } + } + } + for (auto p : vec) { + allocator.DeallocateRaw(p); + } +} + +TEST(GPUTensorPoolAllocatorTest, MultipleThreadMixedAllocationLoops) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + auto func = [&allocator] { + std::vector alignments = {8, 16, 32, 64, 128}; + std::vector sizes = {100, 512, 2048, 128*1024, 100 * 1024}; + std::vector vec; + for (int i = 0; i < 100; ++i) { + for (auto alignment : alignments) { + for (auto size : sizes) { + void* p = allocator.AllocateRaw(alignment, size); + EXPECT_TRUE(p != nullptr); + vec.emplace_back(p); + } + } + } + for (auto p : vec) { + allocator.DeallocateRaw(p); + } + }; + + std::vector ths; + for (int i = 0; i < 3; ++i) { + auto th = new std::thread(func); + ths.emplace_back(th); + } + for (auto th : ths) { + th->join(); + } +} + +TEST(GPUTensorPoolAllocatorTest, MemoryPlannerBasic) { + GPUMemoryPlannerFactory::GetMemoryPlanner()->Reset(); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + for (int i = 0; i < 2000; ++i) { + GPUScopedMemoryCollector c; + std::vector alignments = {8, 16, 64, 128}; + std::vector sizes = {100, 128, 1024, 128*1024}; + std::vector vec; + for (int i = 0; i < 2; ++i) { + for (auto alignment : alignments) { + for (auto size : sizes) { + void* p = allocator.AllocateRaw(alignment, size); + EXPECT_TRUE(p != nullptr); + vec.emplace_back(p); + } + } + } + for (auto p : vec) { + allocator.DeallocateRaw(p); + } + } + std::vector alignments = {8, 16, 32, 64, 128}; + std::vector sizes = {100, 512, 2048, 128*1024, 100 * 1024}; + std::vector vec; + for (int i = 0; i < 100; ++i) { + for (auto alignment : alignments) { + for (auto size : sizes) { + void* p = allocator.AllocateRaw(alignment, size); + EXPECT_TRUE(p != nullptr); + vec.emplace_back(p); + } + } + } + for (auto p : vec) { + allocator.DeallocateRaw(p); + } + sleep(1); +} + +TEST(GPUTensorPoolAllocatorTest, MemoryPlannerSingletonTest) { + GPUMemoryPlannerFactory::GetMemoryPlanner()->Reset(); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + for (int i = 0; i < 2000; ++i) { + GPUScopedMemoryCollector c; + std::vector alignments = {8, 16, 64, 128}; + std::vector sizes = {100, 128, 1024, 128*1024}; + std::vector vec; + for (int i = 0; i < 2; ++i) { + for (auto alignment : alignments) { + for (auto size : sizes) { + void* p = allocator.AllocateRaw(alignment, size); + EXPECT_TRUE(p != nullptr); + vec.emplace_back(p); + } + } + } + for (auto p : vec) { + allocator.DeallocateRaw(p); + } + } + sleep(1); +} + +TEST(GPUTensorPoolAllocatorTest, MemoryPlannerMemoryConsumption) { + GPUMemoryPlannerFactory::GetMemoryPlanner()->Reset(); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + for (int i = 0; i < 2000; ++i) { + GPUScopedMemoryCollector c; + std::vector alignments = {64}; + std::vector sizes = {128*1024, 64*1024}; + std::vector vec; + for (int i = 0; i < 2; ++i) { + for (auto alignment : alignments) { + for (auto size : sizes) { + void* p = allocator.AllocateRaw(alignment, size); + EXPECT_TRUE(p != nullptr); + vec.emplace_back(p); + } + } + } + for (auto p : vec) { + allocator.DeallocateRaw(p); + } + } + sleep(1); +} + +TEST(GPUTensorPoolAllocatorTest, HugeMemoryAllocation) { + GPUMemoryPlannerFactory::GetMemoryPlanner()->Reset(); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + for (int i = 0; i < 2000; ++i) { + GPUScopedMemoryCollector c; + std::vector alignments = {64}; + std::vector sizes = {512*1024*1024, 64*1024*1024}; + std::vector vec; + for (int i = 0; i < 2; ++i) { + for (auto alignment : alignments) { + for (auto size : sizes) { + void* p = allocator.AllocateRaw(alignment, size); + EXPECT_TRUE(p != nullptr); + vec.emplace_back(p); + } + } + } + for (auto p : vec) { + allocator.DeallocateRaw(p); + } + } + sleep(1); +} + +static void BM_GPUTensorPoolAllocator_SmallAllocation(int iters) { + GPUMemoryPlannerFactory::GetMemoryPlanner()->Reset(); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + std::vector sizes = {64, 128, 144, 256, 400, 512}; + + int size_index = 0; + while (--iters > 0) { + GPUScopedMemoryCollector c; + int bytes = sizes[size_index++ % sizes.size()]; + void* p = allocator.AllocateRaw(1, bytes); + allocator.DeallocateRaw(p); + } +} + +BENCHMARK(BM_GPUTensorPoolAllocator_SmallAllocation); + +static void BM_GPUTensorPoolAllocator_BigAllocation(int iters) { + GPUMemoryPlannerFactory::GetMemoryPlanner()->Reset(); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUTensorPoolAllocator allocator(sub_allocator, "GPU_0_tensorpool", 1 << 30); + std::vector sizes = {64 * 1024, 128 * 1024, + 144 * 1024, 256 * 2048, 400 * 4096, 512 * 4096}; + + int size_index = 0; + while (--iters > 0) { + GPUScopedMemoryCollector c; + int bytes = sizes[size_index++ % sizes.size()]; + void* p = allocator.AllocateRaw(1, bytes); + allocator.DeallocateRaw(p); + } +} + +BENCHMARK(BM_GPUTensorPoolAllocator_BigAllocation); + +static void BM_BFCAllocatorGPU_SmallAllocation(int iters) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUOptions options; + options.set_allow_growth(true); + GPUBFCAllocator allocator(sub_allocator, 1LL << 30, options, "GPU_0_bfc"); + std::vector sizes = {64, 128, 144, 256, 400, 512}; + + int size_index = 0; + while (--iters > 0) { + int bytes = sizes[size_index++ % sizes.size()]; + void* p = allocator.AllocateRaw(1, bytes); + allocator.DeallocateRaw(p); + } +} + +BENCHMARK(BM_BFCAllocatorGPU_SmallAllocation); + +static void BM_BFCAllocatorGPU_BigAllocation(int iters) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUOptions options; + options.set_allow_growth(true); + GPUBFCAllocator allocator(sub_allocator, 1LL << 30, options, "GPU_0_bfc"); + std::vector sizes = {64 * 1024, 128 * 1024, + 144 * 1024, 256 * 2048, 400 * 4096, 512 * 4096}; + + int size_index = 0; + while (--iters > 0) { + int bytes = sizes[size_index++ % sizes.size()]; + void* p = allocator.AllocateRaw(1, bytes); + allocator.DeallocateRaw(p); + } +} + +BENCHMARK(BM_BFCAllocatorGPU_BigAllocation); + +} +} // namespace tensorflow + +#include +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/common_runtime/size_class.h b/tensorflow/core/common_runtime/size_class.h new file mode 100644 index 00000000000..96a2fdd406b --- /dev/null +++ b/tensorflow/core/common_runtime/size_class.h @@ -0,0 +1,143 @@ +#ifndef TENSORFLOW_COMMON_RUNTIME_SIZE_CLASS_H_ +#define TENSORFLOW_COMMON_RUNTIME_SIZE_CLASS_H_ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +constexpr int kClassNum = 67; +constexpr int kMaxClassSize = 32 * 1024; // 32 KB +// This size class is token from tcmalloc +const int kSizeClass[kClassNum] = { + 0, + 8, + 16, + 32, + 48, + 64, + 80, + 96, + 112, + 128, + 144, + 160, + 176, + 192, + 208, + 224, + 240, + 256, + 272, + 288, + 304, + 320, + 336, + 352, + 368, + 384, + 400, + 416, + 448, + 480, + 512, + 576, + 640, + 704, + 768, + 896, + 1024, + 1152, + 1280, + 1408, + 1536, + 1792, + 2048, + 2304, + 2688, + 2816, + 3200, + 3456, + 3584, + 4096, + 4736, + 5376, + 6144, + 6528, + 6784, + 7168, + 8192, + 9472, + 10240, + 12288, + 13568, + 14336, + 16384, + 20480, + 24576, + 28672, + 32768 }; + +class SizeMap { + public: + //------------------------------------------------------------------- + // Mapping from size to size_class and vice versa + //------------------------------------------------------------------- + + // Sizes <= 1024 have an alignment >= 8. So for such sizes we have an + // array indexed by ceil(size/8). Sizes > 1024 have an alignment >= 128. + // So for these larger sizes we have an array indexed by ceil(size/128). + // + // We flatten both logical arrays into one physical array and use + // arithmetic to compute an appropriate index. The constants used by + // ClassIndex() were selected to make the flattening work. + // + // Examples: + // Size Expression Index + // ------------------------------------------------------- + // 0 (0 + 7) / 8 0 + // 1 (1 + 7) / 8 1 + // ... + // 1024 (1024 + 7) / 8 128 + // 1025 (1025 + 127 + (120<<7)) / 128 129 + // ... + // 32768 (32768 + 127 + (120<<7)) / 128 376 + static constexpr int kMaxSmallSize = 1024; + static constexpr size_t kClassArraySize = + ((kMaxClassSize + 127 + (120 << 7)) >> 7) + 1; + + SizeMap() { + int next_size = 0; + for (int c = 1; c < kClassNum; c++) { + const int max_size_in_class = kSizeClass[c]; + + for (int s = next_size; s <= max_size_in_class; s += 8) { + class_array_[ClassIndex(s)] = c; + } + next_size = max_size_in_class + 8; + if (next_size > kMaxClassSize) { + break; + } + } + } + + inline size_t ClassIndex(size_t size) const { + if (size <= kMaxSmallSize) { + return (size + 7) >> 3; + } else if (size <= kMaxClassSize) { + return (size + 127 + (120 << 7)) >> 7; + } + LOG(ERROR) << "size " << size << " out of range"; + return 0; + } + + inline size_t GetClass(size_t size) const { + return class_array_[ClassIndex(size)]; + } + + private: + int class_array_[kClassArraySize]; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_SIZE_CLASS_H_ diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index bd29949c7d6..72dc0893f43 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -126,6 +126,12 @@ class Allocator { // REQUIRES: "ptr" was previously returned by a call to AllocateRaw virtual void DeallocateRaw(void* ptr) = 0; + // Used in cudaStreamAddCallback, which must not make any CUDA API calls + // Use this to avoid global sync of cuMemFree before CUDA 11.2 + virtual void DeallocateRawAsync(void* ptr) { + DeallocateRaw(ptr); + } + // Returns true if this allocator tracks the sizes of allocations. // RequestedSize and AllocatedSize must be overridden if // TracksAllocationSizes is overridden to return true. diff --git a/tensorflow/core/framework/tracking_allocator.cc b/tensorflow/core/framework/tracking_allocator.cc index a758ffbc878..0820aa48f40 100644 --- a/tensorflow/core/framework/tracking_allocator.cc +++ b/tensorflow/core/framework/tracking_allocator.cc @@ -109,6 +109,43 @@ void TrackingAllocator::DeallocateRaw(void* ptr) { } } +void TrackingAllocator::DeallocateRawAsync(void* ptr) { + // freeing a null ptr is a no-op + if (nullptr == ptr) { + return; + } + bool should_delete; + // fetch the following outside the lock in case the call to + // AllocatedSize is slow + bool tracks_allocation_sizes = allocator_->TracksAllocationSizes(); + size_t allocated_bytes = 0; + if (tracks_allocation_sizes) { + allocated_bytes = allocator_->AllocatedSize(ptr); + } else if (track_sizes_locally_) { + mutex_lock lock(mu_); + auto itr = in_use_.find(ptr); + if (itr != in_use_.end()) { + tracks_allocation_sizes = true; + allocated_bytes = (*itr).second.allocated_size; + in_use_.erase(itr); + } + } + Allocator* allocator = allocator_; + { + mutex_lock lock(mu_); + if (tracks_allocation_sizes) { + CHECK_GE(allocated_, allocated_bytes); + allocated_ -= allocated_bytes; + allocations_.emplace_back(-allocated_bytes, Env::Default()->NowMicros()); + } + should_delete = UnRef(); + } + allocator->DeallocateRawAsync(ptr); + if (should_delete) { + delete this; + } +} + bool TrackingAllocator::TracksAllocationSizes() const { return track_sizes_locally_ || allocator_->TracksAllocationSizes(); } diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h index 428bffd9e15..6ace0e9d648 100644 --- a/tensorflow/core/framework/tracking_allocator.h +++ b/tensorflow/core/framework/tracking_allocator.h @@ -61,6 +61,7 @@ class TrackingAllocator : public Allocator { void* AllocateRaw(size_t alignment, size_t num_bytes, const AllocationAttributes& allocation_attr) override; void DeallocateRaw(void* ptr) override; + void DeallocateRawAsync(void* ptr) override; bool TracksAllocationSizes() const override; size_t RequestedSize(const void* ptr) const override; size_t AllocatedSize(const void* ptr) const override; diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 634f2a93a8f..45699235702 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -838,6 +838,15 @@ Status Graph::AddWhileContext(StringPiece frame_name, return Status::OK(); } +bool Graph::IsTrainingGraph() const { + for (Node* node : op_nodes()) { + if (node->name().find("gradient") != std::string::npos) { + return true; + } + } + return false; +} + std::unordered_map Graph::BuildNodeNameIndex() const { std::unordered_map result; for (Node* n : nodes()) { diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index dad7d568e41..1581cb4f24d 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -723,6 +723,9 @@ class Graph { // Builds a node name to node pointer index for all nodes in the graph. std::unordered_map BuildNodeNameIndex() const; + // Return true if this graph contain gradients node + bool IsTrainingGraph() const; + // TODO(josh11b): uint64 hash() const; private: