diff --git a/README.md b/README.md index c124f2c..fcb8be2 100644 --- a/README.md +++ b/README.md @@ -26,11 +26,14 @@ previously possible and, ultimately, generate more accurate results. TFLMS is built into the `tensorflow-gpu` conda package so it is installed by default when you install the GPU enabled TensorFlow from WML CE. -The support is currently available in the [WML CE conda channel](https://public.dhe.ibm.com/ibmdl/export/pub/software/server/ibm-ai/conda/#/). + +The support is currently available for TensorFlow 2.2.0 in the [WML CE early access conda channel](https://public.dhe.ibm.com/ibmdl/export/pub/software/server/ibm-ai/conda-early-access/). + +The support is currently available for TensorFlow 2.1.0 in the [WML CE conda channel](https://public.dhe.ibm.com/ibmdl/export/pub/software/server/ibm-ai/conda/#/). + For more information on this channel, how to add channels, and install frameworks see [this WML CE install documentation](https://www.ibm.com/support/knowledgecenter/SS5SF7_1.7.0/navigation/wmlce_install.htm). - # How to enable TFLMS The TFLMS functionality is disabled by default in TensorFlow and needs to be @@ -153,33 +156,6 @@ process have socket affinity with the GPU which allows the fastest connection paths between system memory and GPU memory, which reduces the training or inferencing time. -# Memory defragmentation -When using very large tensors or during the course of a very long training -operation, the model's memory allocation and usage pattern may lead to -fragmented GPU memory and out of memory errors. When this occurs there is -enough free memory in the GPU for the next allocation, but it is in -non-contiguous blocks. In these cases, the process will fail and output a -message like this: - -``` -Enough free memory to satisfy the allocation request exists but it is fragmented. -Enabling Large Model Support defragmentation may avoid this failure. -``` - -TFLMS is capable of defragmenting sections of GPU memory to gather a -contiguous block large enough for the request. This feature waits for current -GPU computation to finish and then relocates active tensors to coalesce -contiguous free memory blocks. - -Even with the GPU computation cleared, the moving of active tensors carries -a risk of introducing NaN errors or other instability into the model. Despite -this risk it has performed well in multi-week training runs with very large -tensors and defragmentation called frequently. - -Due to the possible risk of instability the Large Model Support defragmentation -is disabled by default and can be enabled along with LMS with the `tf.config.experimental.set_lms_defrag_enabled(True)` API or the -`config.gpu_options.experimental.lms_defrag_enabled=True` ConfigProto setting. - # Model memory usage analysis with allocator statistics TFLMS adds several APIs to obtain GPU memory allocator statistics such as the number of allocations, the peak memory usage, the amount diff --git a/examples/AllocatorStats.md b/examples/AllocatorStats.md index adfb97a..6530f60 100644 --- a/examples/AllocatorStats.md +++ b/examples/AllocatorStats.md @@ -68,6 +68,25 @@ Returns the limit of reservable memory. **Parameter:** `gpu_id`: The zero indexed GPU ID for which to retrieve the statistic. +```python +tf.experimental.get_gpu_host_bytes_in_use(numa_node) +``` +Returns the current number of bytes in use in the GPU host (CPU memory) allocator. + +_Since: 2.2.0_ + +**Parameter:** `numa_node`: The ID of the NUMA node for the allocator. + +```python +tf.experimental.get_gpu_host_peak_bytes_in_use(numa_node) +``` +Returns the peak number of bytes in use in the GPU host (CPU memory) allocator. + +_Since: 2.2.0_ + +**Parameter:** `numa_node`: The ID of the NUMA node for the allocator. + + ## Large Model Support Specific Statistics The Large Model Support specific statistics provide information about Large Model Support's memory management. The statics use the following terms: @@ -80,9 +99,6 @@ Inactive tensors are those tensors which are not currently being used by an executing operation or a soon-to-be executing operation. * reclaim bytes - Reclaimed bytes are the bytes of inactive tensors which have been moved from GPU memory to the system (host) memory. -* defragmentation - A method of producing contiguous memory blocks by moving -active bytes to allow free memory blocks between the active bytes to coalesce -into larger contiguous blocks. ```python @@ -114,41 +130,39 @@ Returns the number of reclaimed bytes. **Parameter:** `gpu_id`: The zero indexed GPU ID for which to retrieve the statistic. + ```python -tf.experimental.get_num_single_reclaims(gpu_id) +tf.experimental.get_current_bytes_reclaimed(gpu_id) ``` -Large Model Support will reclaim the bytes of single tensors when possible. -This returns the number of times single tensors' bytes were reclaimed. +Returns the current number of reclaimed bytes. + +_Since: 2.2.0_ **Parameter:** `gpu_id`: The zero indexed GPU ID for which to retrieve the statistic. + ```python -tf.experimental.get_num_full_reclaims(gpu_id) +tf.experimental.get_peak_bytes_reclaimed(gpu_id) ``` -When no single tensor reclamation is able to free enough GPU memory for the -allocation request, all tensors are reclaimed. This returns the number -of times all tensors were reclaimed. +Returns the peak number of reclaimed bytes. -**Parameter:** `gpu_id`: The zero indexed GPU ID for which to retrieve the statistic. +_Since: 2.2.0_ +**Parameter:** `gpu_id`: The zero indexed GPU ID for which to retrieve the statistic. ```python -tf.experimental.get_num_defragmentations(gpu_id) +tf.experimental.get_num_single_reclaims(gpu_id) ``` -GPU memory may become fragmented such that there are no contiguous blocks which -can fulfill an allocation request, even after reclaiming all inactive -tensors. In this case, active tensors may be moved to allow free blocks to be -coalesced to produce a contiguous memory block large enough to fulfill the -allocation request. The defragmentation function of Large Model Support is -disabled by default. This API returns the number of times defragmentation was -performed. +Large Model Support will reclaim the bytes of single tensors when possible. +This returns the number of times single tensors' bytes were reclaimed. **Parameter:** `gpu_id`: The zero indexed GPU ID for which to retrieve the statistic. - ```python -tf.experimental.get_bytes_defragged(gpu_id) +tf.experimental.get_num_full_reclaims(gpu_id) ``` -The number of bytes moved during GPU memory defragmentation. +When no single tensor reclamation is able to free enough GPU memory for the +allocation request, all tensors are reclaimed. This returns the number +of times all tensors were reclaimed. **Parameter:** `gpu_id`: The zero indexed GPU ID for which to retrieve the statistic. diff --git a/examples/ManyModel.py b/examples/ManyModel.py index 6236d47..3a0857e 100644 --- a/examples/ManyModel.py +++ b/examples/ManyModel.py @@ -134,8 +134,6 @@ def get_callbacks(args): def run_model(args): if args.lms: tf.config.experimental.set_lms_enabled(True) - if args.lms_defrag: - tf.config.experimental.set_lms_defrag_enabled(True) image_dim = args.image_size opt = tf.keras.optimizers.RMSprop() @@ -209,14 +207,6 @@ def main(): help='Disable LMS (Default)') parser.set_defaults(lms=False) - defrag_group = parser.add_mutually_exclusive_group(required=False) - defrag_group.add_argument('--lms_defrag', dest='lms_defrag', - action='store_true', - help='Enable LMS defragmentation') - defrag_group.add_argument('--no-lms_defrag', dest='lms_defrag', - action='store_false', - help='Disable LMS defragmentation (Default)') - parser.set_defaults(lms_defrag=False) lms_stats = parser.add_mutually_exclusive_group(required=False) lms_stats.add_argument('--lms_stats', dest='lms_stats', action='store_true', help='Log LMS per-step stats to a file named ' diff --git a/examples/callbacks.py b/examples/callbacks.py index 99baa54..7f2afcc 100644 --- a/examples/callbacks.py +++ b/examples/callbacks.py @@ -27,7 +27,7 @@ nvtx.nvtxMarkA.restype = None STATS_KEYS = ['time', 'allocs', 'reclaim_ones', - 'reclaim_alls', 'defrags', 'gib_reclaimed', 'gib_defragged'] + 'reclaim_alls', 'gib_reclaimed'] class CudaProfileCallback(Callback): def __init__(self, profile_epoch, profile_batch_start, profile_batch_end): @@ -66,9 +66,7 @@ def _get_stats(self): stats['allocs'] = tf.experimental.get_num_allocs(self._gpu_id) stats['reclaim_ones'] = tf.experimental.get_num_single_reclaims(self._gpu_id) stats['reclaim_alls'] = tf.experimental.get_num_full_reclaims(self._gpu_id) - stats['defrags'] = tf.experimental.get_num_defragmentations(self._gpu_id) stats['gib_reclaimed'] = tf.experimental.get_bytes_reclaimed(self._gpu_id) / 1073741824.0 - stats['gib_defragged'] = tf.experimental.get_bytes_defragged(self._gpu_id) / 1073741824.0 return stats def step_begin(self): @@ -114,9 +112,7 @@ def write_step_stats(logfile, step_type, epoch, step_num, step_stats): row.append(step_stats['allocs']) row.append(step_stats['reclaim_ones']) row.append(step_stats['reclaim_alls']) - row.append(step_stats['defrags']) row.append(step_stats['gib_reclaimed']) - row.append(step_stats['gib_defragged']) with open(logfile, 'a+', newline='') as csvfile: statswriter = csv.writer(csvfile) statswriter.writerow(row) @@ -127,8 +123,7 @@ def write_step_log_header(logfile): statswriter = csv.writer(csvfile) statswriter.writerow(['step type', 'epoch', 'step', 'duration', 'allocs', 'reclaimOnes', - 'reclaimAlls', 'defrags', - 'GiB reclaimed', 'GiB defragged']) + 'reclaimAlls', 'GiB reclaimed']) class LMSStatsLogger(Callback): diff --git a/patches/tensorflow_v2.2.0_large_model_support.patch b/patches/tensorflow_v2.2.0_large_model_support.patch new file mode 100644 index 0000000..f618363 --- /dev/null +++ b/patches/tensorflow_v2.2.0_large_model_support.patch @@ -0,0 +1,2980 @@ +From 26031c9b25506eee02b0e3c3ef99f411b16b34dc Mon Sep 17 00:00:00 2001 +From: Samuel Matzek +Date: Mon, 2 Nov 2020 10:29:29 -0600 +Subject: [PATCH] TensorFlow Large Model Support for TensorFlow 2.2.0 + +This commit delivers TensorFlow Large Model Support +for TensorFlow at version 2.2.0. + +See: https://github.com/IBM/tensorflow-large-model-support + +Co-authored-by: Matthew Brandyberry +Co-authored-by: Andres Lugo-Reyes +--- + tensorflow/c/eager/c_api.cc | 26 ++ + tensorflow/c/eager/c_api.h | 4 + + tensorflow/c/tf_tensor.cc | 1 + + tensorflow/c/tf_tensor_internal.h | 5 +- + tensorflow/compiler/jit/xla_launch_util.h | 6 +- + tensorflow/core/BUILD | 9 +- + tensorflow/core/common_runtime/bfc_allocator.cc | 284 ++++++++++++++- + tensorflow/core/common_runtime/bfc_allocator.h | 70 +++- + tensorflow/core/common_runtime/executor.cc | 14 +- + .../core/common_runtime/gpu/gpu_bfc_allocator.cc | 104 +++++- + .../core/common_runtime/gpu/gpu_bfc_allocator.h | 25 ++ + .../core/common_runtime/gpu/gpu_debug_allocator.cc | 8 + + .../core/common_runtime/gpu/gpu_debug_allocator.h | 2 + + tensorflow/core/common_runtime/gpu/gpu_device.cc | 3 + + .../core/common_runtime/gpu/gpu_event_mgr_test.cc | 3 +- + .../core/common_runtime/gpu/gpu_mem_allocator.h | 2 + + .../core/common_runtime/gpu/gpu_process_state.cc | 7 +- + tensorflow/core/framework/allocator.cc | 27 +- + tensorflow/core/framework/allocator.h | 110 +++++- + tensorflow/core/framework/op_kernel.cc | 21 ++ + tensorflow/core/framework/op_kernel.h | 10 + + tensorflow/core/framework/tensor.cc | 291 ++++++++++++++- + tensorflow/core/framework/tensor.h | 60 ++- + tensorflow/core/platform/default/mutex.cc | 22 ++ + tensorflow/core/platform/mutex.h | 72 ++++ + tensorflow/core/protobuf/config.proto | 4 + + tensorflow/lite/delegates/flex/buffer_map.cc | 5 +- + tensorflow/python/BUILD | 33 ++ + tensorflow/python/__init__.py | 1 + + tensorflow/python/eager/context.py | 28 +- + tensorflow/python/eager/pywrap_tensor.cc | 21 ++ + tensorflow/python/framework/bfc_allocator_stats.py | 89 +++++ + .../framework/bfc_allocator_stats_wrapper.cc | 404 +++++++++++++++++++++ + tensorflow/python/framework/config.py | 15 + + tensorflow/python/keras/engine/network.py | 10 +- + .../golden/v1/tensorflow.config.experimental.pbtxt | 8 + + .../api/golden/v1/tensorflow.experimental.pbtxt | 73 ++++ + .../golden/v2/tensorflow.config.experimental.pbtxt | 8 + + .../api/golden/v2/tensorflow.experimental.pbtxt | 73 ++++ + 39 files changed, 1902 insertions(+), 56 deletions(-) + create mode 100644 tensorflow/python/framework/bfc_allocator_stats.py + create mode 100644 tensorflow/python/framework/bfc_allocator_stats_wrapper.cc + +diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc +index 65f37f3021f..141ca3c1555 100644 +--- a/tensorflow/c/eager/c_api.cc ++++ b/tensorflow/c/eager/c_api.cc +@@ -1,4 +1,5 @@ + /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. ++ * Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -1554,6 +1555,31 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, + return nullptr; + } + ++void TFE_TensorHandle_SetGraphId(TFE_TensorHandle* h, int64_t id) { ++ if (h->handle == nullptr) return; ++ tensorflow::TensorHandle* handle = ++ tensorflow::down_cast(h->handle.get()) ++ ->Handle(); ++ const tensorflow::Tensor* t = nullptr; ++ tensorflow::Status s = handle->Tensor(&t); ++ if (!s.ok()) return; ++ t->SetGraphId(id); ++} ++ ++bool TFE_TensorHandle_GraphId(TFE_TensorHandle* h, int64_t* id) { ++ if (h->handle == nullptr) return false; ++ tensorflow::TensorHandle* handle = ++ tensorflow::down_cast(h->handle.get()) ++ ->Handle(); ++ const tensorflow::Tensor* t = nullptr; ++ tensorflow::Status s = handle->Tensor(&t); ++ if (!s.ok()) return false; ++ tensorflow::int64 graph_id; ++ if (!t->GraphId(&graph_id)) return false; ++ *id = graph_id; ++ return true; ++} ++ + void TFE_ContextAddFunctionDef(TFE_Context* ctx, + const char* serialized_function_def, size_t size, + TF_Status* status) { +diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h +index 070b3a9bb60..d26ed0b8790 100644 +--- a/tensorflow/c/eager/c_api.h ++++ b/tensorflow/c/eager/c_api.h +@@ -1,4 +1,5 @@ + /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. ++ * Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -193,6 +194,9 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( + TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, + TF_Status* status); + ++TF_CAPI_EXPORT extern void TFE_TensorHandle_SetGraphId(TFE_TensorHandle* h, int64_t id); ++TF_CAPI_EXPORT extern bool TFE_TensorHandle_GraphId(TFE_TensorHandle* h, int64_t* id); ++ + // Debugging/Profiling information for TFE_TensorHandle + // + // TFE_TensorDebugInfo contains information useful for debugging and +diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc +index 4e75beceb3e..227d8834ade 100644 +--- a/tensorflow/c/tf_tensor.cc ++++ b/tensorflow/c/tf_tensor.cc +@@ -31,6 +31,7 @@ limitations under the License. + using tensorflow::Status; + using tensorflow::Tensor; + using tensorflow::TensorBuffer; ++using tensorflow::SimpleTensorBufferBase; + using tensorflow::errors::FailedPrecondition; + using tensorflow::errors::InvalidArgument; + +diff --git a/tensorflow/c/tf_tensor_internal.h b/tensorflow/c/tf_tensor_internal.h +index 08a55f26a83..76c8386089e 100644 +--- a/tensorflow/c/tf_tensor_internal.h ++++ b/tensorflow/c/tf_tensor_internal.h +@@ -34,12 +34,12 @@ typedef struct TF_Tensor { + std::unique_ptr tensor; + } TF_Tensor; + +-class TF_ManagedBuffer : public tensorflow::TensorBuffer { ++class TF_ManagedBuffer : public tensorflow::SimpleTensorBufferBase { + public: + TF_ManagedBuffer(void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg, bool owns_memory) +- : TensorBuffer(data), ++ : SimpleTensorBufferBase(data), + len_(len), + deallocator_(deallocator), + deallocator_arg_(deallocator_arg), +@@ -50,7 +50,6 @@ class TF_ManagedBuffer : public tensorflow::TensorBuffer { + } + + size_t size() const override { return len_; } +- TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription( + tensorflow::AllocationDescription* proto) const override { + tensorflow::int64 rb = size(); +diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h +index 511e0f1451a..b0928628b72 100644 +--- a/tensorflow/compiler/jit/xla_launch_util.h ++++ b/tensorflow/compiler/jit/xla_launch_util.h +@@ -171,11 +171,11 @@ class XlaComputationLaunchContext { + + // A simple TensorBuffer implementation that allows us to create Tensors that + // take ownership of pre-allocated memory. +-class XlaTensorBuffer : public TensorBuffer { ++class XlaTensorBuffer : public SimpleTensorBufferBase { + public: + XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size, + Allocator* allocator) +- : TensorBuffer(const_cast(ptr)), ++ : SimpleTensorBufferBase(const_cast(ptr)), + expected_size_(expected_size), + actual_size_(actual_size), + allocator_(allocator) {} +@@ -188,8 +188,6 @@ class XlaTensorBuffer : public TensorBuffer { + + size_t size() const override { return expected_size_; } + +- TensorBuffer* root_buffer() override { return this; } +- + void FillAllocationDescription(AllocationDescription* proto) const override { + proto->set_allocated_bytes(actual_size_); + } +diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD +index b02eb89ebfc..33bb412048f 100644 +--- a/tensorflow/core/BUILD ++++ b/tensorflow/core/BUILD +@@ -2740,6 +2740,8 @@ cc_library( + features = ["parse_headers"], + visibility = ["//visibility:public"], + deps = [ ++ ":framework", ++ ":framework_internal", + ":lib", + ":lib_internal", + ":protos_all_cc", +@@ -2927,11 +2929,16 @@ tf_cuda_library( + srcs = [ + "common_runtime/gpu/gpu_bfc_allocator.cc", + ], +- hdrs = ["common_runtime/gpu/gpu_bfc_allocator.h"], ++ hdrs = [ ++ "common_runtime/gpu/gpu_bfc_allocator.h", ++ "common_runtime/gpu/gpu_process_state.h", ++ "common_runtime/process_state.h", ++ ], + features = ["parse_headers"], + visibility = ["//visibility:public"], + deps = [ + ":bfc_allocator", ++ ":gpu_lib", + ":gpu_mem_allocator", + ":lib", + ":lib_internal", +diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc +index 1100ba9684c..b620803b81a 100644 +--- a/tensorflow/core/common_runtime/bfc_allocator.cc ++++ b/tensorflow/core/common_runtime/bfc_allocator.cc +@@ -1,4 +1,5 @@ + /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. ++Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -19,6 +20,7 @@ limitations under the License. + + #include "absl/strings/string_view.h" + #include "tensorflow/core/common_runtime/allocator_retry.h" ++#include "tensorflow/core/framework/tensor.h" + #include "tensorflow/core/lib/core/bits.h" + #include "tensorflow/core/lib/strings/numbers.h" + #include "tensorflow/core/lib/strings/str_util.h" +@@ -35,6 +37,8 @@ limitations under the License. + + namespace tensorflow { + ++const string BFCAllocator::kGPUHostAllocatorName = "gpu_host_bfc"; ++const string BFCAllocator::kGPUHostMemLimitEnvVar = "TF_GPU_HOST_MEM_LIMIT_IN_MB"; + BFCAllocator::BFCAllocator(SubAllocator* sub_allocator, size_t total_memory, + bool allow_growth, const string& name, + bool garbage_collection) +@@ -359,6 +363,245 @@ void BFCAllocator::DeallocateRegions( + } + } + ++void BFCAllocator::RemoveReclaimedBytes(size_t size) { ++ mutex_lock l(lock_); ++ stats_.cur_bytes_reclaimed -= size; ++} ++ ++bool BFCAllocator::ReclaimListAdd(void* ptr, IntrusiveListHook* hook) { ++ LMSTensorBuffer* buf = hook->elem(); ++ size_t size = buf->size(); ++ ++ mutex_lock l(lock_); ++ stats_.bytes_inactive += size; ++ reclaim_list_.append(hook); ++ ++ VLOG(2) << "-> INACTIVE " << (void*)buf << " (" << size << ")"; ++ ++ bool pageout_predicted = PredictReclaim(buf); ++ ++ if (reclaim_waiter_) ++ reclaim_cv_.notify_all(); ++ ++ return pageout_predicted; ++} ++ ++bool BFCAllocator::ReclaimListRemove(void* ptr, IntrusiveListHook* hook) { ++ mutex_lock l(lock_); ++ return ReclaimListRemoveInternal(ptr, hook, false); ++} ++ ++void BFCAllocator::ReclaimListNotify() { ++ mutex_lock l(lock_); ++ if (reclaim_waiter_) ++ reclaim_cv_.notify_all(); ++} ++ ++bool BFCAllocator::ReclaimListRemoveInternal(void* ptr, IntrusiveListHook* hook, bool reclaimed) { ++ bool removed = hook->remove(); ++ CHECK(!reclaimed || removed); // reclaimed tensors must be on the list ++ if (removed) { ++ LMSTensorBuffer* buf = hook->elem(); ++ size_t size = buf->size(); ++ stats_.bytes_inactive -= size; ++ stats_.peak_bytes_active = std::max(stats_.peak_bytes_active, stats_.bytes_active()); ++ if (!reclaimed) { ++ // Activate chunk ++ VLOG(2) << "ACTIVE <- " << (void*)buf << " (" << size << ")" ++ <<" [active: " << stats_.bytes_active() << ", inactive: " << stats_.bytes_inactive << "]"; ++ } else { ++ // Free chunk ++ stats_.bytes_reclaimed += size; ++ stats_.cur_bytes_reclaimed += size; ++ stats_.peak_bytes_reclaimed = std::max(stats_.peak_bytes_reclaimed, stats_.cur_bytes_reclaimed); ++ ++ DeallocateRawInternal(ptr); ++ } ++ ++ RecordReclaim(buf, reclaimed); ++ ++ if (reclaim_waiter_) ++ reclaim_cv_.notify_all(); ++ } ++ ++ return removed; ++} ++ ++bool BFCAllocator::PredictReclaim(LMSTensorBuffer* buf) { ++ int64_t id; ++ int64 graph_id; ++ if (!buf->Id(&id) && !buf->GraphId(&graph_id)) ++ return false; ++ bool first_time = !buf->Id(&id) && buf->GraphId(&graph_id); ++ ++ if (first_time) { ++ VLOG(2) << "LMS: " << (void*)buf << " page-out prediction" ++ << " graph_id=" << (void*)graph_id ++ << " size=" << buf->size(); ++ ++ id = static_cast(graph_id); ++ buf->SetId(id); ++ } ++ LMSReclaimHistory& hist = reclaim_history_[id]; ++ if (first_time) ++ hist.reset(); ++ return hist.predict(); ++} ++ ++void BFCAllocator::RecordReclaim(const LMSTensorBuffer* buf, bool reclaimed) { ++ int64_t id; ++ bool has_id = buf->Id(&id); ++ if ( has_id ) { ++ LMSReclaimHistory& hist = reclaim_history_[id]; ++ if (VLOG_IS_ON(2)) { ++ bool no_hist = (reclaimed && (hist.predictions_remaining() <= 0)); ++ if (no_hist || (hist.predict() != reclaimed)) { ++ VLOG(2) << "LMS: " << (void*)buf << " page-out prediction " ++ << (reclaimed ? (no_hist ? "" : "miss ") : "wrong") ++ << " id=" << (void*)id ++ << " size=" << buf->size(); ++ } ++ else if (reclaimed) { ++ VLOG(2) << "LMS: " << (void*)buf << " page-out prediction hit" ++ << " id=" << (void*)id ++ << " size=" << buf->size(); ++ } ++ } ++ hist.record(reclaimed); ++ } ++} ++ ++BFCAllocator::ReclaimStatus BFCAllocator::TryReclaim(IntrusiveListHook* hook) { ++ LMSTensorBuffer* buf = hook->elem(); ++ void* ptr = buf->TryPageout(); ++ if (ptr == nullptr) { ++ // Pageout attempt was not successful. Wait on reclaim list notification and retry. ++ return ReclaimStatus::kRetry; ++ } ++ ReclaimListRemoveInternal(ptr, hook, true); ++ return ReclaimStatus::kSuccess; ++} ++ ++BFCAllocator::ReclaimStatus BFCAllocator::ReclaimOne(size_t requested_bytes) { ++ IntrusiveListHook* chosen_hook = nullptr; ++ auto hook = reclaim_list_.head(); ++ auto end = reclaim_list_.terminator(); ++ do { ++ LMSTensorBuffer* buf = hook->elem(); ++ size_t size = RoundedBytes(buf->size()); ++ ++ if (size < requested_bytes) { ++ size_t available = size; ++ ++ // Add in sizes of free neighbors ++ void* ptr = buf->GetDevicePtr(); ++ BFCAllocator::ChunkHandle h = this->region_manager_.get_handle(ptr); ++ DCHECK(h != kInvalidChunkHandle); ++ Chunk* chunk = ChunkFromHandle(h); ++ const std::array neighbors = {chunk->prev, chunk->next}; ++ for (ChunkHandle neighbor_handle : neighbors) { ++ if (neighbor_handle != kInvalidChunkHandle) { ++ Chunk* neighbor = ChunkFromHandle(h); ++ if (!neighbor->in_use()) ++ available += RoundedBytes(neighbor->size); ++ } ++ } ++ if (available < requested_bytes) { ++ hook = hook->next(); ++ continue; ++ } ++ } ++ chosen_hook = hook; ++ ++ } while ((hook != end) && (chosen_hook == nullptr)); ++ ++ if (chosen_hook == nullptr) ++ return ReclaimStatus::kUnavailable; ++ ++ return TryReclaim(chosen_hook); ++} ++ ++BFCAllocator::ReclaimStatus BFCAllocator::ReclaimFragments(size_t rounded_bytes) { ++ // TODO(mtbrandy): Attempt to reclaim smaller tensors that, when ++ // coalesced, will satisfy the request. ++ // Dumb and slow (but effective) placeholder implementation. ++ return ReclaimAll(); ++} ++ ++BFCAllocator::ReclaimStatus BFCAllocator::ReclaimAll() { ++ stats_.num_full_reclaims++; ++ ReclaimStatus status = ReclaimStatus::kUnavailable; ++ while (!reclaim_list_.empty() && ++ (status = TryReclaim(reclaim_list_.head())) == ReclaimStatus::kSuccess); ++ return status; ++} ++ ++void* BFCAllocator::ReclaimChunkPtr(BinNum bin_num, size_t rounded_bytes, ++ size_t num_bytes, uint64 freed_before, ++ mutex_lock& lock) { ++ while (!reclaim_list_.empty()) { ++ void* ptr; ++ ++ // Reclaim a single suitable inactive allocation ++ auto status = ReclaimOne(rounded_bytes); ++ if (status == ReclaimStatus::kSuccess) { ++ stats_.num_single_reclaims++; ++ ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before); ++ if (ptr != nullptr) { ++ return ptr; ++ } ++ VLOG(2) << "ReclaimOne: ineffective (" << rounded_bytes << ")"; ++ continue; ++ } ++ // ReclaimFragments is currently a dummy impl which calls ReclaimAll. ++ // Commenting out the ReclaimFragments call until it gets a unique ++ // implementation to avoid a double call to ReclaimAll. ++ ++ // if (status == ReclaimStatus::kUnavailable) { ++ // // Reclaim and coalesce fragments of suitable inactive allocations ++ // status = ReclaimFragments(rounded_bytes); ++ // if (status == ReclaimStatus::kSuccess) { ++ // ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before); ++ // if (ptr != nullptr) { ++ // return ptr; ++ // } ++ // VLOG(2) << "ReclaimFragments: ineffective (" << rounded_bytes << ")"; ++ // continue; ++ // } ++ // } ++ ++ if (status == ReclaimStatus::kUnavailable) { ++ // Reclaim everything to give DeallocateFreeRegions the best chance of success. ++ status = ReclaimAll(); ++ if (status == ReclaimStatus::kSuccess) { ++ ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before); ++ if (ptr != nullptr) { ++ return ptr; ++ } ++ continue; ++ } ++ } ++ ++ if (status == ReclaimStatus::kUnavailable) { ++ continue; ++ } ++ ++ CHECK(status == ReclaimStatus::kRetry); ++ VLOG(2) << "ReclaimChunkPtr: wait (" << rounded_bytes << ")"; ++ reclaim_waiter_++; ++ reclaim_cv_.wait(lock); ++ reclaim_waiter_--; ++ VLOG(2) << "ReclaimChunkPtr: notified (" << rounded_bytes << ")"; ++ ++ // Retry FindChunkPtr since the allocation map may have changed. ++ ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before); ++ if (ptr != nullptr) { ++ return ptr; ++ } ++ } // end while reclaim list not empty ++ return nullptr; ++} ++ + void* BFCAllocator::AllocateRawInternal(size_t unused_alignment, + size_t num_bytes, + bool dump_log_on_failure, +@@ -409,6 +652,14 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment, + } + } + ++ // Try to swap out eligible tensor(s) ++ if (lms_enabled_) { ++ ptr = ReclaimChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before, l); ++ if (ptr != nullptr) { ++ return ptr; ++ } ++ } ++ + // Reaching this point means that no chunks can satisfy the request. Also, + // the unallocated bytes cannot satisfy the request. Before giving up, let's + // try deallocating free regions so that suballocator can combine them with +@@ -437,6 +688,18 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment, + << "\nCurrent allocation summary follows."; + DumpMemoryLog(rounded_bytes); + LOG(WARNING) << RenderOccupancy(); ++ ++ if (kGPUHostAllocatorName.compare(Name()) == 0) { ++ // The GPUHostAllocator has exhausted memory. ++ LOG(WARNING) << "The GPU host allocator ran out of memory. The " ++ << "allocator is limited to " ++ << strings::HumanReadableNumBytes(memory_limit_) ++ << " of memory. This can be increased by setting the " ++ << kGPUHostMemLimitEnvVar << " environment variable."; ++ } ++ else if (!lms_enabled_) { ++ LOG(WARNING) << "Enabling Large Model Support may avoid this failure."; ++ } + } + return nullptr; + } +@@ -511,6 +774,7 @@ void* BFCAllocator::FindChunkPtr(BinNum bin_num, size_t rounded_bytes, + stats_.bytes_in_use += chunk->size; + stats_.peak_bytes_in_use = + std::max(stats_.peak_bytes_in_use, stats_.bytes_in_use); ++ stats_.peak_bytes_active = std::max(stats_.peak_bytes_active, stats_.bytes_active()); + stats_.largest_alloc_size = + std::max(stats_.largest_alloc_size, chunk->size); + +@@ -585,16 +849,17 @@ void BFCAllocator::SplitChunk(BFCAllocator::ChunkHandle h, size_t num_bytes) { + void BFCAllocator::DeallocateRaw(void* ptr) { + VLOG(1) << "DeallocateRaw " << Name() << " " + << (ptr ? RequestedSize(ptr) : 0); +- DeallocateRawInternal(ptr); +- retry_helper_.NotifyDealloc(); +-} +- +-void BFCAllocator::DeallocateRawInternal(void* ptr) { + if (ptr == nullptr) { + VLOG(2) << "tried to deallocate nullptr"; + return; ++ } else { ++ mutex_lock l(lock_); ++ DeallocateRawInternal(ptr); + } +- mutex_lock l(lock_); ++ retry_helper_.NotifyDealloc(); ++} ++ ++void BFCAllocator::DeallocateRawInternal(void* ptr) { + + // Find the chunk from the ptr. + BFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr); +@@ -1136,6 +1401,13 @@ void BFCAllocator::ClearStats() { + stats_.num_allocs = 0; + stats_.peak_bytes_in_use = stats_.bytes_in_use; + stats_.largest_alloc_size = 0; ++ stats_.peak_bytes_active = stats_.bytes_active(); ++ stats_.bytes_reclaimed = 0; ++ stats_.num_single_reclaims = 0; ++ stats_.num_full_reclaims = 0; ++ stats_.cur_bytes_reclaimed = 0; ++ stats_.peak_bytes_reclaimed = 0; ++ + } + + std::array +diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h +index c39652692b7..1aa1c205950 100644 +--- a/tensorflow/core/common_runtime/bfc_allocator.h ++++ b/tensorflow/core/common_runtime/bfc_allocator.h +@@ -1,4 +1,5 @@ + /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. ++Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -46,7 +47,7 @@ class MemoryDump; + // coalescing. One assumption we make is that the process using this + // allocator owns pretty much all of the memory, and that nearly + // all requests to allocate memory go through this interface. +-class BFCAllocator : public Allocator { ++class BFCAllocator : public LMSAllocator { + public: + // Takes ownership of sub_allocator. + BFCAllocator(SubAllocator* sub_allocator, size_t total_memory, +@@ -85,6 +86,21 @@ class BFCAllocator : public Allocator { + + MemoryDump RecordMemoryMap(); + ++ void SetLMSConfig(bool enabled) { ++ lms_enabled_ = enabled; ++ } ++ LMSAllocator* AsLMSAllocator() final { ++ return (lms_enabled_) ? this : nullptr; ++ } ++ bool ReclaimListAdd(void* ptr, IntrusiveListHook* hook) override; ++ bool ReclaimListRemove(void* ptr, IntrusiveListHook* hook) override; ++ void ReclaimListNotify() override; ++ void RemoveReclaimedBytes(size_t size) override; ++ ++ static const int64 kDefaultGPUHostMemLimitInMB = 1LL << 16; /*64GB max by default*/ ++ static const string kGPUHostAllocatorName; ++ static const string kGPUHostMemLimitEnvVar; ++ + private: + struct Bin; + +@@ -545,6 +561,58 @@ class BFCAllocator : public Allocator { + int64 size_history_[MEM_DEBUG_SIZE_HISTORY_SIZE]; + #endif + ++ // Large Model Support ++ class LMSReclaimHistory { ++ public: ++ void record(bool reclaimed) { ++ data_ <<= 1; ++ if (reclaimed) ++ data_ |= 1; ++ n_++; ++ } ++ void reset() { ++ prev_ = data_; ++ prev_n_ = n_; ++ data_ = 0; ++ n_ = 0; ++ } ++ int predictions_remaining() const { return prev_n_ - n_; } ++ bool predict() { ++ int remain = predictions_remaining(); ++ return (remain > 0) && ((prev_ >> (remain - 1)) & 1) && ((prev_ >> remain) == data_); ++ } ++ uint32_t data() const { return data_; } ++ uint32_t prev() const { return prev_; } ++ ++ private: ++ uint32_t data_ = 0; ++ uint32_t prev_; ++ uint16_t n_ = 0; ++ uint16_t prev_n_ = 0; ++ }; ++ ++ bool lms_enabled_ = false; ++ IntrusiveList reclaim_list_ TF_GUARDED_BY(lock_); ++ std::unordered_map reclaim_history_ TF_GUARDED_BY(lock_); ++ condition_variable reclaim_cv_; ++ int reclaim_waiter_ = 0; ++ bool ReclaimListRemoveInternal(void* ptr, IntrusiveListHook* hook, bool reclaimed) ++ TF_EXCLUSIVE_LOCKS_REQUIRED(lock_); ++ ++ enum class ReclaimStatus { ++ kSuccess, ++ kUnavailable, ++ kRetry, ++ }; ++ ReclaimStatus TryReclaim(IntrusiveListHook* hook) TF_EXCLUSIVE_LOCKS_REQUIRED(lock_); ++ ReclaimStatus ReclaimOne(size_t rounded_bytes) TF_EXCLUSIVE_LOCKS_REQUIRED(lock_); ++ ReclaimStatus ReclaimFragments(size_t rounded_bytes) TF_EXCLUSIVE_LOCKS_REQUIRED(lock_); ++ ReclaimStatus ReclaimAll() TF_EXCLUSIVE_LOCKS_REQUIRED(lock_); ++ void* ReclaimChunkPtr(BinNum bin_num, size_t rounded_bytes, size_t num_bytes, uint64 freed_before, ++ mutex_lock& lock) TF_EXCLUSIVE_LOCKS_REQUIRED(lock_); ++ bool PredictReclaim(LMSTensorBuffer* buf); ++ void RecordReclaim(const LMSTensorBuffer* buf, bool reclaimed); ++ + friend class GPUBFCAllocatorPrivateMethodsTest; + TF_DISALLOW_COPY_AND_ASSIGN(BFCAllocator); + }; +diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc +index d1896fd6710..c2980f243b8 100644 +--- a/tensorflow/core/common_runtime/executor.cc ++++ b/tensorflow/core/common_runtime/executor.cc +@@ -1,4 +1,5 @@ + /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. ++Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -229,6 +230,11 @@ struct NodeItem { + // 0... for forward from that input. + const int* forward_from() const { return forward_from_base(); } + ++ // Return a unique Id in the graph for the given output. ++ int64 output_graphId(int i) const { ++ return reinterpret_cast(&output_attr_base()[i]); ++ } ++ + string DebugString() const { + string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node_id); + if (is_source) { +@@ -2125,8 +2131,12 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, + // we are in the tensor buffer. + DataType dtype = val.dtype_safe(); + if (dtype == item.output_type(i)) { +- if (stats && val.tensor->IsInitialized()) { +- nodestats::SetOutput(stats, i, val.tensor); ++ Tensor* t = val.tensor; ++ if (t->IsInitialized()) { ++ if (stats) { ++ nodestats::SetOutput(stats, i, t); ++ } ++ t->SetGraphId(item.output_graphId(i)); + } + if (val.is_ref()) { + out->has_value = true; +diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc +index aeb5d33f3ca..41ef3975d72 100644 +--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc ++++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc +@@ -1,4 +1,5 @@ + /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. ++Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -14,7 +15,9 @@ limitations under the License. + ==============================================================================*/ + + #include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h" ++#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" + ++#include "tensorflow/core/framework/tensor.h" + #include "tensorflow/core/lib/strings/strcat.h" + + namespace tensorflow { +@@ -83,6 +86,105 @@ GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator, + const string& name) + : BFCAllocator(sub_allocator, total_memory, + GPUBFCAllocator::GetAllowGrowthValue(gpu_options), name, +- GPUBFCAllocator::GetGarbageCollectionValue()) {} ++ GPUBFCAllocator::GetGarbageCollectionValue()), ++ stream_exec_(sub_allocator->stream_executor()) { ++ if (gpu_options.experimental().lms_enabled()) { ++ SetLMSConfig(true); ++ H2D_stream_ = new se::Stream(stream_exec_); ++ H2D_stream_->Init(); ++ D2H_stream_ = new se::Stream(stream_exec_); ++ D2H_stream_->Init(); ++ event_mgr_ = EventMgrFactory::Singleton()->GetEventMgr(stream_exec_, gpu_options); ++ } ++} ++ ++void GPUBFCAllocator::SetStreams(se::Stream* compute) { ++ compute_stream_ = compute; ++} ++ ++void* GPUBFCAllocator::Pagein(const LMSTensorBuffer *buf) { ++ size_t nbytes = buf->size(); ++ void *host_ptr = buf->GetHostPtr(); ++ void *device_ptr = AllocateRaw(Allocator::kAllocatorAlignment, nbytes); ++ ++ VLOG(2) << "PAGEIN <- " << (void*)buf << " (" << nbytes << ")"; ++ se::DeviceMemoryBase dst(device_ptr, nbytes); ++ auto result = stream_exec_->SynchronousMemcpyH2D(host_ptr, nbytes, &dst); ++ CHECK(result.ok()); ++ return device_ptr; ++} ++ ++void* GPUBFCAllocator::PageinAsync(const LMSTensorBuffer *buf, ++ const std::function& done) { ++ size_t nbytes = buf->size(); ++ void *host_ptr = buf->GetHostPtr(); ++ void *device_ptr = buf->GetDevicePtr(); ++ ++ if (device_ptr == nullptr) { ++ device_ptr = AllocateRaw(Allocator::kAllocatorAlignment, nbytes); ++ } ++ ++ VLOG(2) << "PAGEIN <- " << (void*)buf << " (" << nbytes << ") ASYNC"; ++ se::DeviceMemoryBase dst(device_ptr, nbytes); ++ ++ // Wait for the compute stream to make sure the device buffer is truly available. ++ H2D_stream_->ThenWaitFor(compute_stream_); ++ ++ H2D_stream_->ThenMemcpy(&dst, host_ptr, nbytes); ++ event_mgr_->ThenExecute(H2D_stream_, ++ [this, done]() { ++ CHECK(this->H2D_stream_->ok()); ++ done(); ++ }); ++ return device_ptr; ++} ++ ++void* GPUBFCAllocator::Pageout(const LMSTensorBuffer *buf) { ++ size_t nbytes = buf->size(); ++ void *device_ptr = buf->GetDevicePtr(); ++ void *host_ptr = buf->GetHostPtr(); ++ if (host_ptr == nullptr) { ++ host_ptr = host_allocator()->AllocateRaw(Allocator::kAllocatorAlignment, nbytes); ++ } ++ ++ VLOG(2) << "-> PAGEOUT " << (void*)buf << " (" << nbytes << ")"; ++ const se::DeviceMemoryBase src(device_ptr, nbytes); ++ auto result = stream_exec_->SynchronousMemcpyD2H(src, nbytes, host_ptr); ++ CHECK(result.ok()); ++ return host_ptr; ++} ++ ++void* GPUBFCAllocator::PageoutAsync(const LMSTensorBuffer *buf, ++ const std::function& done) { ++ size_t nbytes = buf->size(); ++ void *device_ptr = buf->GetDevicePtr(); ++ void *host_ptr = buf->GetHostPtr(); ++ if (host_ptr == nullptr) { ++ host_ptr = host_allocator()->AllocateRaw(Allocator::kAllocatorAlignment, nbytes); ++ } ++ ++ VLOG(2) << "-> PAGEOUT " << (void*)buf << " (" << nbytes << ") ASYNC"; ++ const se::DeviceMemoryBase src(device_ptr, nbytes); ++ ++ // Wait for the compute stream to make sure the data is available. ++ D2H_stream_->ThenWaitFor(compute_stream_); ++ ++ D2H_stream_->ThenMemcpy(host_ptr, src, nbytes); ++ event_mgr_->ThenExecute(D2H_stream_, ++ [this, done]() { ++ CHECK(this->D2H_stream_->ok()); ++ done(); ++ }); ++ return host_ptr; ++} ++ ++void GPUBFCAllocator::HostMemoryDeallocate(void *host_ptr) { ++ host_allocator()->DeallocateRaw(host_ptr); ++} ++ ++void GPUBFCAllocator::EnsureHostAllocator() { ++ std::call_once(host_allocator_init_, ++ [&] { host_allocator_ = GPUProcessState::singleton()->GetGpuHostAllocator(0); }); ++} + + } // namespace tensorflow +diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h +index 02b1a7418d8..aebb96b4c07 100644 +--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h ++++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h +@@ -1,4 +1,5 @@ + /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. ++Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -22,7 +23,9 @@ limitations under the License. + #include + + #include "tensorflow/core/common_runtime/bfc_allocator.h" ++#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" + #include "tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h" ++#include "tensorflow/core/platform/stream_executor.h" + #include "tensorflow/core/platform/thread_annotations.h" + #include "tensorflow/core/platform/types.h" + #include "tensorflow/core/protobuf/config.pb.h" +@@ -39,6 +42,13 @@ class GPUBFCAllocator : public BFCAllocator { + const GPUOptions& gpu_options, const string& name); + ~GPUBFCAllocator() override {} + ++ void SetStreams(se::Stream* compute) override; ++ void* Pagein(const LMSTensorBuffer *buf) override; ++ void* PageinAsync(const LMSTensorBuffer *buf, const std::function& done) override; ++ void* Pageout(const LMSTensorBuffer *buf) override; ++ void* PageoutAsync(const LMSTensorBuffer *buf, const std::function& done) override; ++ void HostMemoryDeallocate(void *host_ptr) override; ++ + TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator); + + #ifdef TENSORFLOW_MEM_DEBUG +@@ -48,6 +58,21 @@ class GPUBFCAllocator : public BFCAllocator { + private: + static bool GetAllowGrowthValue(const GPUOptions& gpu_options); + static bool GetGarbageCollectionValue(); ++ ++ // Large Model Support ++ se::StreamExecutor* stream_exec_; // not owned, non-null ++ se::Stream* H2D_stream_ = nullptr; ++ se::Stream* D2H_stream_ = nullptr; ++ se::Stream* compute_stream_ = nullptr; ++ EventMgr* event_mgr_ = nullptr; ++ Allocator* host_allocator_ = nullptr; ++ ++ void EnsureHostAllocator(); ++ std::once_flag host_allocator_init_; ++ inline Allocator* host_allocator() { ++ if (host_allocator_ == nullptr) EnsureHostAllocator(); ++ return host_allocator_; ++ } + }; + + } // namespace tensorflow +diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc +index a27294fc5ee..47ad4ccbfd6 100644 +--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc ++++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc +@@ -138,6 +138,10 @@ absl::optional GPUDebugAllocator::GetStats() { + + void GPUDebugAllocator::ClearStats() { base_allocator_->ClearStats(); } + ++LMSAllocator* GPUDebugAllocator::AsLMSAllocator() { ++ return base_allocator_->AsLMSAllocator(); ++} ++ + bool GPUDebugAllocator::CheckHeader(void* ptr) { + return CheckMask(stream_exec_, static_cast(ptr) - MASK_BYTES, + before_mask); +@@ -214,4 +218,8 @@ absl::optional GPUNanResetAllocator::GetStats() { + + void GPUNanResetAllocator::ClearStats() { base_allocator_->ClearStats(); } + ++LMSAllocator* GPUNanResetAllocator::AsLMSAllocator() { ++ return base_allocator_->AsLMSAllocator(); ++} ++ + } // namespace tensorflow +diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h +index 09adc45e6d6..c8ac8309417 100644 +--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h ++++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h +@@ -45,6 +45,7 @@ class GPUDebugAllocator : public Allocator { + int64 AllocationId(const void* ptr) const override; + absl::optional GetStats() override; + void ClearStats() override; ++ LMSAllocator* AsLMSAllocator() override; + + // For testing. + bool CheckHeader(void* ptr); +@@ -73,6 +74,7 @@ class GPUNanResetAllocator : public Allocator { + size_t AllocatedSize(const void* ptr) const override; + absl::optional GetStats() override; + void ClearStats() override; ++ LMSAllocator* AsLMSAllocator() override; + + private: + Allocator* base_allocator_ = nullptr; // owned +diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc +index d72a99f3ca7..cd49de9c1fe 100644 +--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc ++++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc +@@ -403,6 +403,9 @@ Status BaseGPUDevice::Init(const SessionOptions& options) { + timestamped_allocator_ ? gpu_allocator_ : nullptr, em_)); + } + ++ LMSAllocator* lms_allocator = gpu_allocator_->AsLMSAllocator(); ++ if (lms_allocator) lms_allocator->SetStreams(stream_->compute); ++ + gpu_device_info_ = new GpuDeviceInfo; + gpu_device_info_->stream = stream_->compute; + gpu_device_info_->default_context = device_context_; +diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc +index 680aec1ab29..44e112fd7ff 100644 +--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc ++++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc +@@ -96,7 +96,7 @@ static std::atomic_int_fast64_t live_tensor_bytes(0); + class TestTensorBuffer : public TensorBuffer { + public: + explicit TestTensorBuffer(size_t bytes) +- : TensorBuffer(nullptr), bytes_(bytes) { ++ : bytes_(bytes) { + live_tensor_bytes += bytes_; + } + ~TestTensorBuffer() override { live_tensor_bytes -= bytes_; } +@@ -104,6 +104,7 @@ class TestTensorBuffer : public TensorBuffer { + size_t size() const override { return bytes_; } + + // Not used in this test ++ void* data() const override { return nullptr; } + TensorBuffer* root_buffer() override { return nullptr; } + void FillAllocationDescription(AllocationDescription* arg) const override {} + +diff --git a/tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h +index e14f2d9377a..4124f4cf5d1 100644 +--- a/tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h ++++ b/tensorflow/core/common_runtime/gpu/gpu_mem_allocator.h +@@ -65,6 +65,8 @@ class GPUMemAllocator : public SubAllocator { + } + } + ++ se::StreamExecutor* stream_executor() { return stream_exec_; } ++ + private: + se::StreamExecutor* stream_exec_; // not owned, non-null + const PlatformGpuId gpu_id_; +diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc +index 3141bb6d10b..72402730f7f 100644 +--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc ++++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc +@@ -243,8 +243,8 @@ Allocator* GPUProcessState::GetGpuHostAllocator(int numa_node) { + gpu_host_free_visitors_[numa_node]); + // TODO(zheng-xq): evaluate whether 64GB by default is the best choice. + int64 gpu_host_mem_limit_in_mb = -1; +- Status status = ReadInt64FromEnvVar("TF_GPU_HOST_MEM_LIMIT_IN_MB", +- 1LL << 16 /*64GB max by default*/, ++ Status status = ReadInt64FromEnvVar(BFCAllocator::kGPUHostMemLimitEnvVar, ++ BFCAllocator::kDefaultGPUHostMemLimitInMB, + &gpu_host_mem_limit_in_mb); + if (!status.ok()) { + LOG(ERROR) << "GetGpuHostAllocator: " << status.error_message(); +@@ -253,7 +253,8 @@ Allocator* GPUProcessState::GetGpuHostAllocator(int numa_node) { + + Allocator* allocator = + new BFCAllocator(sub_allocator, gpu_host_mem_limit, +- true /*allow_growth*/, "gpu_host_bfc" /*name*/); ++ true /*allow_growth*/, ++ BFCAllocator::kGPUHostAllocatorName /*name*/); + + if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) { + // Wrap the allocator to track allocation ids for better logging +diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc +index 6757a9b593e..d32da99d557 100644 +--- a/tensorflow/core/framework/allocator.cc ++++ b/tensorflow/core/framework/allocator.cc +@@ -1,4 +1,5 @@ + /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. ++Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -34,16 +35,30 @@ thread_local uint64 pending_step_id = 0; + + string AllocatorStats::DebugString() const { + return strings::Printf( +- "Limit: %20lld\n" +- "InUse: %20lld\n" +- "MaxInUse: %20lld\n" +- "NumAllocs: %20lld\n" +- "MaxAllocSize: %20lld\n", ++ "Limit: %20lld\n" ++ "InUse: %20lld\n" ++ "MaxInUse: %20lld\n" ++ "NumAllocs: %20lld\n" ++ "MaxAllocSize: %20lld\n" ++ "BytesInactive: %20lld\n" ++ "BytesActive: %20lld\n" ++ "PeakBytesActive: %20lld\n" ++ "TotalBytesReclaimed: %20lld\n" ++ "CurBytesReclaimed: %20lld\n" ++ "NumSingleReclaims: %20lld\n" ++ "NumFullReclaims: %20lld\n", + static_cast(this->bytes_limit ? *this->bytes_limit : 0), + static_cast(this->bytes_in_use), + static_cast(this->peak_bytes_in_use), + static_cast(this->num_allocs), +- static_cast(this->largest_alloc_size)); ++ static_cast(this->largest_alloc_size), ++ static_cast(this->bytes_inactive), ++ static_cast(this->bytes_active()), ++ static_cast(this->peak_bytes_active), ++ static_cast(this->bytes_reclaimed), ++ static_cast(this->cur_bytes_reclaimed), ++ static_cast(this->num_single_reclaims), ++ static_cast(this->num_full_reclaims)); + } + + constexpr size_t Allocator::kAllocatorAlignment; +diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h +index 2e239a4d6de..892d9ea88f6 100644 +--- a/tensorflow/core/framework/allocator.h ++++ b/tensorflow/core/framework/allocator.h +@@ -1,4 +1,5 @@ + /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. ++Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -30,6 +31,10 @@ limitations under the License. + #include "tensorflow/core/platform/numa.h" + #include "tensorflow/core/platform/types.h" + ++namespace stream_executor { ++class Stream; ++} // namespace stream_executor ++ + namespace tensorflow { + + // Attributes for a single allocation call. Different calls to the same +@@ -106,17 +111,36 @@ struct AllocatorStats { + // if such a limit is known. + absl::optional bytes_reservable_limit; + ++ // Stats for LMS ++ int64 bytes_inactive; // Number of inactive bytes (available for reclaim) ++ int64 bytes_active() const { return bytes_in_use - bytes_inactive; } ++ int64 cur_bytes_reclaimed; // Current number of reclaimed bytes ++ int64 peak_bytes_reclaimed; // The peak number of reclaimed bytes ++ int64 peak_bytes_active; // The peak active bytes ++ int64 bytes_reclaimed; // Cumulative number of bytes transferred (D2H) ++ int64 num_single_reclaims; // Number of single tensor reclaimations performed ++ int64 num_full_reclaims; // Number of calls to reclaim all inactive bytes ++ + AllocatorStats() + : num_allocs(0), + bytes_in_use(0), + peak_bytes_in_use(0), + largest_alloc_size(0), + bytes_reserved(0), +- peak_bytes_reserved(0) {} ++ peak_bytes_reserved(0), ++ bytes_inactive(0), ++ peak_bytes_active(0), ++ bytes_reclaimed(0), ++ num_single_reclaims(0), ++ num_full_reclaims(0), ++ cur_bytes_reclaimed(0), ++ peak_bytes_reclaimed(0) {} + + string DebugString() const; + }; + ++class LMSAllocator; ++ + // Allocator is an abstract interface for allocating and deallocating + // device memory. + class Allocator { +@@ -227,6 +251,26 @@ class Allocator { + virtual void ClearStats() {} + + virtual void SetSafeFrontier(uint64 count) {} ++ ++ virtual LMSAllocator* AsLMSAllocator() { return nullptr; } ++}; ++ ++template ++class IntrusiveListHook; ++class LMSTensorBuffer; ++ ++class LMSAllocator : public Allocator { ++ public: ++ virtual void SetStreams(stream_executor::Stream* compute) {} ++ virtual bool ReclaimListAdd(void* ptr, IntrusiveListHook* hook) { return false; } ++ virtual bool ReclaimListRemove(void* ptr, IntrusiveListHook* hook) { return false; } ++ virtual void ReclaimListNotify() {} ++ virtual void* Pagein(const LMSTensorBuffer* buf) { return nullptr; } ++ virtual void* PageinAsync(const LMSTensorBuffer* buf, const std::function& done) { return nullptr; } ++ virtual void* Pageout(const LMSTensorBuffer* buf) { return nullptr; } ++ virtual void* PageoutAsync(const LMSTensorBuffer* buf, const std::function& done) { return nullptr; } ++ virtual void HostMemoryDeallocate(void* host_ptr) {} ++ virtual void RemoveReclaimedBytes(size_t size) {}; + }; + + // An implementation of Allocator that delegates all calls to another Allocator. +@@ -393,6 +437,70 @@ class SubAllocator { + const std::vector free_visitors_; + }; + ++// IntrusiveList and IntrusiveListHook are used to manage the set of ++// inactive tensors for LMS implementations. ++// ++// Element objects embed the IntrustiveListHook, which provides the ++// following properties: ++// 1. Insertion and removal operations are O(1) and require no ++// memory allocation or deletion. ++// 2. Element destruction is valid and can be performed safely ++// regardless of list membership. ++template ++class IntrusiveListHook { ++ public: ++ IntrusiveListHook(T *elem) : elem_(elem) { ++ next_ = prev_ = this; ++ } ++ ~IntrusiveListHook() { ++ remove(); ++ } ++ ++ bool attached() const { return next_ != this; } ++ bool detached() const { return next_ == this; } ++ ++ void insertbefore(IntrusiveListHook* x) { ++ CHECK(!x->attached()); ++ x->prev_ = prev_; ++ x->next_ = this; ++ prev_->next_ = x; ++ prev_ = x; ++ } ++ ++ bool remove() { ++ if (!attached()) return false; ++ ++ prev_->next_ = next_; ++ next_->prev_ = prev_; ++ next_ = prev_ = this; ++ return true; ++ } ++ IntrusiveListHook* next() const { return next_; } ++ IntrusiveListHook* prev() const { return prev_; } ++ T* elem() const { return elem_; } ++ ++ private: ++ IntrusiveListHook* next_; ++ IntrusiveListHook* prev_; ++ T* elem_; ++}; ++ ++template ++class IntrusiveList { ++ public: ++ IntrusiveList() : anchor_(nullptr) {} ++ ~IntrusiveList() {} ++ bool empty() const { return anchor_.detached(); } ++ void append(IntrusiveListHook* x) { anchor_.insertbefore(x); } ++ void prepend(IntrusiveListHook* x) { anchor_.next()->insertbefore(x); } ++ IntrusiveListHook* head() const { return anchor_.next(); } ++ IntrusiveListHook* tail() const { return anchor_.prev(); } ++ const IntrusiveListHook* terminator() const { return &anchor_; } ++ ++ private: ++ IntrusiveListHook anchor_; ++}; ++ + } // namespace tensorflow + + #endif // TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_H_ +diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc +index 28476695993..a2791f235db 100644 +--- a/tensorflow/core/framework/op_kernel.cc ++++ b/tensorflow/core/framework/op_kernel.cc +@@ -1,4 +1,5 @@ + /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. ++Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -345,9 +346,28 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs) + SetStatus(s); + } + } ++ ++ if (VLOG_IS_ON(2)) { ++ if (op_kernel().AsAsync() != nullptr) ++ LOG(INFO) << "OpKernelContext \"" << op_kernel().name() << "\" (async)"; ++ else ++ LOG(INFO) << "OpKernelContext \"" << op_kernel().name() << "\""; ++ } ++ for (const TensorValue& value : *params_->inputs) { ++ if (value.tensor != nullptr) { ++ pin_tensor(value.tensor); ++ } ++ } + } + + OpKernelContext::~OpKernelContext() { ++ // TODO(mtbrandy): consider skipping unpin for any tensors that are ++ // about to be destroyed to avoid add/remove reclaim_list overhead. ++ for (auto buf : pinned_tensors_) { ++ buf->lms_unpin(); ++ buf->Unref(); ++ } ++ + for (TensorValue& value : outputs_) { + if (!value.is_ref()) { + delete value.tensor; +@@ -781,6 +801,7 @@ Status OpKernelContext::allocate_tensor( + LogMemory::RecordTensorAllocation(params_->op_kernel->name(), + params_->step_id, new_tensor); + } ++ pin_tensor(&new_tensor); + record_tensor_reference(new_tensor); + *out_tensor = std::move(new_tensor); + return Status::OK(); +diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h +index 3bdd0fee6cc..4749924c8d4 100644 +--- a/tensorflow/core/framework/op_kernel.h ++++ b/tensorflow/core/framework/op_kernel.h +@@ -1349,6 +1349,16 @@ class OpKernelContext { + // TODO(ayushd): change to absl::flat_hash_set. + std::unique_ptr> allocated_scope_ids_; + ++ // Large Model Support ++ gtl::InlinedVector pinned_tensors_; ++ void pin_tensor(Tensor* tensor) { ++ TensorBuffer *buf = tensor->buf_; ++ if (buf != nullptr && buf->lms_pin()) { ++ buf->Ref(); ++ pinned_tensors_.push_back(buf); ++ } ++ } ++ + // The following data members are only used when allocation tracking is + // enabled, memory consumption is being recorded, or tensor access is being + // recorded. +diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc +index a7cc9f59b69..f2b91ffafbb 100644 +--- a/tensorflow/core/framework/tensor.cc ++++ b/tensorflow/core/framework/tensor.cc +@@ -1,4 +1,5 @@ + /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. ++Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -29,6 +30,8 @@ limitations under the License. + + #include "tensorflow/core/framework/tensor.h" + ++#include ++ + #include "absl/strings/escaping.h" + #include "tensorflow/core/framework/allocation_description.pb.h" + #include "tensorflow/core/framework/log_memory.h" +@@ -50,6 +53,7 @@ limitations under the License. + #include "tensorflow/core/lib/strings/strcat.h" + #include "tensorflow/core/platform/logging.h" + #include "tensorflow/core/platform/macros.h" ++#include "tensorflow/core/platform/mutex.h" + #include "tensorflow/core/platform/protobuf.h" + #include "tensorflow/core/platform/tensor_coding.h" + #include "tensorflow/core/platform/types.h" +@@ -78,11 +82,64 @@ bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const { + + namespace { + ++// Large Model Support ++class BufferBase; ++class LMSTensorBufferImpl : public LMSTensorBuffer { ++ public: ++ LMSTensorBufferImpl(BufferBase *buf, LMSAllocator* alloc) : ++ buf_(buf), alloc_(alloc), pincount_(0), list_hook_(this) {} ++ ~LMSTensorBufferImpl(); ++ ++ void ensure_data(); ++ void pin(); ++ void unpin(); ++ void* TryPageout() override; ++ size_t size() const override; ++ void* GetHostPtr() const override; ++ void* GetDevicePtr() const override; ++ ++private: ++ void ensure_data_internal(); ++ void transition_wait(recursive_mutex_lock& l); ++ void transition_complete(); ++ ++ enum class State : uint16_t { ++ kInit, ++ kActive, ++ kInactive, ++ kSynced, ++ kReclaimed, ++ }; ++ enum class Transition : uint16_t { ++ kNone, ++ kPagingOut, ++ kPagingIn, ++ }; ++ BufferBase* const buf_; ++ LMSAllocator* const alloc_; ++ recursive_mutex lock_; ++ void* host_data_ TF_GUARDED_BY(lock_) = nullptr; ++ int pincount_ TF_GUARDED_BY(lock_); ++ State state_ TF_GUARDED_BY(lock_) = State::kInit; ++ Transition transition_ TF_GUARDED_BY(lock_) = Transition::kNone; ++ recursive_condition_variable transition_cv_; ++ int transition_waiter_ = 0; ++ ++ // Guarded by allocator mutex ++ IntrusiveListHook list_hook_; ++}; ++ + // An un-templated base class for Buffer. + class BufferBase : public TensorBuffer { + public: + explicit BufferBase(Allocator* alloc, void* data_ptr) +- : TensorBuffer(data_ptr), alloc_(alloc) {} ++ : alloc_(alloc), ++ data_(data_ptr) { ++ LMSAllocator* lms_alloc = alloc->AsLMSAllocator(); ++ if (lms_alloc) { ++ lms_.reset(new LMSTensorBufferImpl(this, lms_alloc)); ++ } ++ } + + TensorBuffer* root_buffer() override { return this; } + +@@ -114,13 +171,47 @@ class BufferBase : public TensorBuffer { + } + } + ++ void* data() const override { ++ if (lms_enabled()) lms_->ensure_data(); ++ return data_; ++ } ++ ++ bool lms_pin() override { ++ if (!lms_enabled()) return false; ++ lms_->pin(); ++ return true; ++ } ++ ++ void lms_unpin() override { ++ DCHECK(lms_enabled()); ++ lms_->unpin(); ++ } ++ ++ void SetGraphId(int64 id) const override { ++ if (lms_enabled()) lms_->SetGraphId(id); ++ } ++ ++ bool GraphId(int64* id) const override { ++ if (!lms_enabled()) return false; ++ return lms_->GraphId(id); ++ } ++ ++ bool has_data() const override { ++ return lms_enabled() || data_ != nullptr; ++ } ++ + protected: + void RecordDeallocation() { +- LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()), ++ LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data_), + alloc_->Name()); + } + + Allocator* const alloc_; ++ void* data_; ++ std::unique_ptr lms_; ++ ++ friend class LMSTensorBufferImpl; // For access to data_ ++ bool lms_enabled() const { return lms_.get() != nullptr; } + }; + + // Typed ref-counted buffer: T[n]. +@@ -480,11 +571,18 @@ Buffer::Buffer(Allocator* a, int64 n, + + template + Buffer::~Buffer() { +- if (data()) { ++ if (lms_enabled()) { ++ // We don't need/want to perform page-in during destruction (there ++ // is no Dtor on the host for device memory), thus we tear down the ++ // LMS state here. ++ lms_.reset(nullptr); ++ } ++ ++ if (data_) { + if (MemoryLoggingEnabled()) { + RecordDeallocation(); + } +- TypedAllocator::Deallocate(alloc_, static_cast(data()), elem_); ++ TypedAllocator::Deallocate(alloc_, static_cast(data_), elem_); + } + } + +@@ -651,7 +749,7 @@ Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf) + } + + bool Tensor::IsInitialized() const { +- return (buf_ != nullptr && buf_->data() != nullptr) || ++ return (buf_ != nullptr && buf_->has_data()) || + shape_.num_elements() == 0; + } + +@@ -714,6 +812,14 @@ Status Tensor::BitcastFrom(const Tensor& other, DataType dtype, + return Status::OK(); + } + ++void Tensor::SetGraphId(int64 id) const { ++ if (buf_ != nullptr) buf_->SetGraphId(id); ++} ++ ++bool Tensor::GraphId(int64* id) const { ++ return (buf_ != nullptr) && buf_->GraphId(id); ++} ++ + // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer. + // For the latter case, we have to make sure that the refcount is + // one both for the SubBuffer _and_ the underlying TensorBuffer. +@@ -775,7 +881,7 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape) + if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) { + CASES(type, buf_ = new Buffer(a, shape.num_elements())); + } +- if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) { ++ if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->has_data()) { + LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID, + *this); + } +@@ -789,8 +895,8 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape, + if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) { + CASES(type, buf_ = new Buffer(a, shape.num_elements(), allocation_attr)); + } +- if (MemoryLoggingEnabled() && !allocation_attr.allocation_will_be_logged && +- buf_ != nullptr && buf_->data() != nullptr) { ++ if (MemoryLoggingEnabled() && !allocation_attr.allocation_will_be_logged && buf_ != nullptr && ++ buf_->has_data()) { + LogMemory::RecordTensorAllocation("Unknown (with attributes)", + LogMemory::UNKNOWN_STEP_ID, *this); + } +@@ -832,8 +938,8 @@ class SubBuffer : public TensorBuffer { + public: + // This buffer is an alias to buf[delta, delta + n). + SubBuffer(TensorBuffer* buf, int64 delta, int64 n) +- : TensorBuffer(buf->base() + delta), +- root_(buf->root_buffer()), ++ : root_(buf->root_buffer()), ++ delta_(delta), + elem_(n) { + // Sanity check. The caller should ensure the sub buffer is valid. + CHECK_LE(root_->base(), this->base()); +@@ -845,6 +951,7 @@ class SubBuffer : public TensorBuffer { + root_->Ref(); + } + ++ void* data() const override { return root_->base() + delta_; } + size_t size() const override { return sizeof(T) * elem_; } + TensorBuffer* root_buffer() override { return root_; } + bool GetAllocatedBytes(size_t* out_bytes) const override { +@@ -853,9 +960,15 @@ class SubBuffer : public TensorBuffer { + void FillAllocationDescription(AllocationDescription* proto) const override { + root_->FillAllocationDescription(proto); + } ++ bool has_data() const override { return root_->has_data(); } ++ bool lms_pin() override { return root_->lms_pin(); } ++ void lms_unpin() override { root_->lms_unpin(); } ++ void SetGraphId(int64 id) const override { root_->SetGraphId(id); } ++ bool GraphId(int64* id) const override { return root_->GraphId(id); } + + private: + TensorBuffer* root_; ++ int64 delta_; + int64 elem_; + + ~SubBuffer() override { root_->Unref(); } +@@ -941,7 +1054,7 @@ bool Tensor::FromProto(Allocator* a, const TensorProto& proto) { + buf_ = p; + // TODO(misard) add tracking of which kernels and steps are calling + // FromProto. +- if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) { ++ if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->has_data()) { + LogMemory::RecordTensorAllocation("Unknown (from Proto)", + LogMemory::UNKNOWN_STEP_ID, *this); + } +@@ -1268,7 +1381,7 @@ string Tensor::DeviceSafeDebugString() const { + void Tensor::FillDescription(TensorDescription* description) const { + description->set_dtype(dtype()); + shape().AsProto(description->mutable_shape()); +- if (buf_ != nullptr && buf_->data() != nullptr) { ++ if (buf_ != nullptr && buf_->has_data()) { + buf_->FillAllocationDescription( + description->mutable_allocation_description()); + } +@@ -1300,4 +1413,158 @@ gtl::InlinedVector Tensor::ComputeFlatOuterDims( + return out_dims; + } + ++LMSTensorBufferImpl::~LMSTensorBufferImpl() { ++ DCHECK(transition_ == Transition::kNone); ++ if (pincount_ == 0 && (state_ == State::kInactive || state_ == State::kSynced)) { ++ alloc_->ReclaimListRemove(buf_->data_, &list_hook_); ++ } ++ if (state_ == State::kReclaimed) { ++ alloc_->RemoveReclaimedBytes(size()); ++ } ++ if (host_data_ != nullptr) { ++ alloc_->HostMemoryDeallocate(host_data_); ++ } ++} ++ ++inline void LMSTensorBufferImpl::ensure_data() { ++ recursive_mutex_lock l(lock_); ++ if (pincount_ == 0 && state_ != State::kInit) { ++ VLOG(2) << " ACCESS " << (void*)this; ++ ensure_data_internal(); ++ state_ = State::kInit; ++ } ++ DCHECK(buf_->data_ != nullptr); ++ if (transition_ == Transition::kPagingIn) { ++ transition_wait(l); ++ } ++} ++ ++inline void LMSTensorBufferImpl::pin() { ++ recursive_mutex_lock l(lock_); ++ if (++pincount_ == 1) { ++ VLOG(2) << " PIN " << (void*)this; ++ if (state_ != State::kInit) { ++ ensure_data_internal(); ++ } ++ state_ = State::kActive; ++ } ++ DCHECK(buf_->data_ != nullptr); ++ DCHECK(state_ == State::kActive); ++ DCHECK(pincount_ > 0); ++} ++ ++inline void LMSTensorBufferImpl::unpin() { ++ recursive_mutex_lock l(lock_); ++ DCHECK(buf_->data_ != nullptr); ++ DCHECK(state_ == State::kActive); ++ DCHECK(pincount_ > 0); ++ if (--pincount_ == 0) { ++ bool pageout = alloc_->ReclaimListAdd(buf_->data_, &list_hook_); ++ if (pageout && transition_ == Transition::kNone) { ++ // Speculative pageout requested by allocator ++ transition_ = Transition::kPagingOut; ++ buf_->Ref(); ++ host_data_ = alloc_->PageoutAsync(this, [this]() { this->transition_complete(); }); ++ DCHECK(host_data_ != nullptr); ++ } ++ state_ = State::kInactive; ++ VLOG(2) << " UNPIN " << (void*)this; ++ } ++} ++ ++void* LMSTensorBufferImpl::TryPageout() { ++ recursive_mutex_lock l(lock_, std::try_to_lock); ++ if (!l || transition_ != Transition::kNone) { ++ // Inability to acquire the lock means this is likely exiting ++ // the inactive state and thus not a good candidate to reclaim. ++ // ++ // Tensors in transition require waiting on the event, which we shouldn't ++ // to do while holding the allocator lock due to the possibility of deadlock. ++ return nullptr; ++ } ++ ++ DCHECK(buf_->data_ != nullptr); ++ if (state_ == State::kInactive) { ++ host_data_ = alloc_->Pageout(this); ++ } else { ++ CHECK(state_ == State::kSynced); ++ // Nothing to do ++ } ++ DCHECK(host_data_ != nullptr); ++ void* old_device_ptr = buf_->data_; ++ buf_->data_ = nullptr; ++ state_ = State::kReclaimed; ++ ++ return old_device_ptr; ++} ++ ++inline size_t LMSTensorBufferImpl::size() const { ++ return buf_->size(); ++} ++ ++inline void* LMSTensorBufferImpl::GetHostPtr() const { ++ return host_data_; ++} ++ ++inline void* LMSTensorBufferImpl::GetDevicePtr() const { ++ return buf_->data_; ++} ++ ++void LMSTensorBufferImpl::ensure_data_internal() { ++ switch (state_) { ++ case State::kInactive: ++ case State::kSynced: ++ alloc_->ReclaimListRemove(buf_->data_, &list_hook_); ++ break; ++ case State::kReclaimed: ++ DCHECK(buf_->data_ == nullptr); ++ DCHECK(host_data_ != nullptr); ++ transition_ = Transition::kPagingIn; ++ buf_->Ref(); ++ buf_->data_ = alloc_->PageinAsync(this, [this]() { this->transition_complete(); }); ++ DCHECK(buf_->data_ != nullptr); ++ break; ++ case State::kInit: ++ case State::kActive: ++ // Nothing to do ++ break; ++ } ++} ++ ++void LMSTensorBufferImpl::transition_wait(recursive_mutex_lock& l) { ++ DCHECK(transition_ != Transition::kNone); ++ if (VLOG_IS_ON(2)) { ++ LOG(INFO) << "transition_wait: wait " << (void*)this << " (" << size() << ")"; ++ } ++ transition_waiter_++; ++ do { ++ transition_cv_.wait(l); ++ } while (transition_ != Transition::kNone); ++ transition_waiter_--; ++ if (VLOG_IS_ON(2)) { ++ LOG(INFO) << "transition_wait: notified " << (void*)this << " (" << size() << ")"; ++ } ++} ++ ++void LMSTensorBufferImpl::transition_complete() { ++ bool inactive = false; ++ { ++ recursive_mutex_lock l(lock_); ++ DCHECK(transition_ != Transition::kNone); ++ if (state_ == State::kInactive) { ++ state_ = State::kSynced; ++ inactive = true; ++ } ++ if (transition_ == Transition::kPagingIn) { ++ alloc_->RemoveReclaimedBytes(size()); ++ } ++ transition_ = Transition::kNone; ++ if (transition_waiter_) ++ transition_cv_.notify_all(); ++ } ++ bool destroyed = buf_->Unref(); ++ if (inactive && !destroyed) ++ alloc_->ReclaimListNotify(); ++} ++ + } // namespace tensorflow +diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h +index 11910766ba8..89dd281b857 100644 +--- a/tensorflow/core/framework/tensor.h ++++ b/tensorflow/core/framework/tensor.h +@@ -1,4 +1,5 @@ + /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. ++Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -59,15 +60,11 @@ Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index); + /// Interface to access the raw ref-counted data buffer. + class TensorBuffer : public core::RefCounted { + public: +- explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {} ++ explicit TensorBuffer() {} + ~TensorBuffer() override {} + + /// \brief data() points to a memory region of size() bytes. +- /// +- /// NOTE(mrry): The `data()` method is not virtual for performance reasons. +- /// It can be called multiple times when the contents of a `Tensor` are +- /// accessed, and so making it non-virtual allows the body to be inlined. +- void* data() const { return data_; } ++ virtual void* data() const = 0; + + /// \brief Size (in bytes) of the buffer. + virtual size_t size() const = 0; +@@ -90,6 +87,49 @@ class TensorBuffer : public core::RefCounted { + + /// \brief Whether this TensorBuffer owns the underlying memory. + virtual bool OwnsMemory() const { return true; } ++ virtual bool has_data() const { return data() != nullptr; } ++ ++ virtual bool lms_pin() { return false; } ++ virtual void lms_unpin() {} ++ virtual void SetGraphId(int64 id) const {} ++ virtual bool GraphId(int64* id) const { return false; } ++}; ++ ++class LMSTensorBuffer { ++ public: ++ virtual void* TryPageout() = 0; ++ virtual void* GetHostPtr() const = 0; ++ virtual void* GetDevicePtr() const = 0; ++ virtual size_t size() const = 0; ++ void SetId(int64_t id) { ++ id_ = id; ++ } ++ bool Id(int64_t* id) const { ++ if (id_ == 0) return false; ++ *id = id_; ++ return true; ++ } ++ ++ void SetGraphId(int64 id) { ++ graph_id_ = id; ++ } ++ ++ bool GraphId(int64* id) const { ++ if (graph_id_ == 0) return false; ++ *id = graph_id_; ++ return true; ++ } ++ ++ private: ++ int64_t id_ = 0; ++ int64 graph_id_ = 0; ++}; ++ ++class SimpleTensorBufferBase : public TensorBuffer { ++ public: ++ explicit SimpleTensorBufferBase(void* data_ptr) : data_(data_ptr) {} ++ void* data() const override { return data_; } ++ TensorBuffer* root_buffer() override { return this; } + + private: + void* const data_; +@@ -634,6 +674,8 @@ class Tensor { + const TensorShape& shape) { + TF_CHECK_OK(BitcastFrom(other, dtype, shape)); + } ++ void SetGraphId(int64 id) const; ++ bool GraphId(int64* id) const; + + // Returns true if the refcount on buf_ and any possible underlying root + // buffer is one. +@@ -663,6 +705,7 @@ class Tensor { + friend class TensorTestHelper; // For access to set_shape. + friend class CastOpBase; // For access to set_dtype. + friend class ScopedAllocator; // For access to buf_. ++ friend class OpKernelContext; // For access to buf_. + friend Status batch_util::CopyElementToSlice( + Tensor element, Tensor* parent, + int64 index); // For access to base(). +@@ -924,9 +967,9 @@ inline Tensor::Tensor(Tensor&& other) + other.buf_ = nullptr; + } + +-class Tensor::HostScalarTensorBufferBase : public TensorBuffer { ++class Tensor::HostScalarTensorBufferBase : public SimpleTensorBufferBase { + public: +- using TensorBuffer::TensorBuffer; ++ using SimpleTensorBufferBase::SimpleTensorBufferBase; + bool GetAllocatedBytes(size_t* out_bytes) const final; + void FillAllocationDescription(AllocationDescription* proto) const final; + }; +@@ -941,7 +984,6 @@ struct Tensor::ValueAndTensorBuffer { + explicit HostScalarTensorBuffer(void* data) + : HostScalarTensorBufferBase(data) {} + size_t size() const final { return sizeof(T); } +- TensorBuffer* root_buffer() final { return this; } + + // Override `operator delete` so that calling `delete this` in + // `core::Refcounted::Unref()` for an object of this type will free +diff --git a/tensorflow/core/platform/default/mutex.cc b/tensorflow/core/platform/default/mutex.cc +index 9101e8630be..a5ab3332726 100644 +--- a/tensorflow/core/platform/default/mutex.cc ++++ b/tensorflow/core/platform/default/mutex.cc +@@ -14,6 +14,7 @@ limitations under the License. + ==============================================================================*/ + + #include "tensorflow/core/platform/mutex.h" ++#include // NOLINT + + #include + +@@ -85,6 +86,17 @@ static inline nsync::nsync_cv *cv_cast(internal::CVData *cv) { + return reinterpret_cast(cv); + } + ++recursive_mutex::recursive_mutex() { } ++ ++recursive_mutex::recursive_mutex(LinkerInitialized x) {} ++ ++void recursive_mutex::lock() { mu_.lock(); } ++ ++bool recursive_mutex::try_lock() { return mu_.try_lock(); }; ++ ++void recursive_mutex::unlock() { mu_.unlock(); } ++ ++ + condition_variable::condition_variable() { + nsync::nsync_cv_init(cv_cast(&cv_)); + } +@@ -99,6 +111,16 @@ void condition_variable::notify_all() { + nsync::nsync_cv_broadcast(cv_cast(&cv_)); + } + ++recursive_condition_variable::recursive_condition_variable() {} ++ ++void recursive_condition_variable::wait(recursive_mutex_lock &lock) { ++ cv_.wait(lock.mutex()->mu_); ++} ++ ++void recursive_condition_variable::notify_one() { cv_.notify_one(); } ++ ++void recursive_condition_variable::notify_all() { cv_.notify_all(); } ++ + namespace internal { + std::cv_status wait_until_system_clock( + CVData *cv_data, MuData *mu_data, +diff --git a/tensorflow/core/platform/mutex.h b/tensorflow/core/platform/mutex.h +index a668df4f7b3..a329f7aca9d 100644 +--- a/tensorflow/core/platform/mutex.h ++++ b/tensorflow/core/platform/mutex.h +@@ -105,6 +105,24 @@ class TF_LOCKABLE mutex { + internal::MuData mu_; + }; + ++// Wrap std::recusive_mutex so the Clang TF_GUARDED_BY thread safety annotations ++// will work for this mutex. ++class TF_LOCKABLE recursive_mutex { ++ public: ++ recursive_mutex(); ++ // The default implementation of the underlying mutex is safe to use after ++ // the linker initialization to zero. ++ explicit recursive_mutex(LinkerInitialized x); ++ ++ void lock() TF_EXCLUSIVE_LOCK_FUNCTION(); ++ bool try_lock() TF_EXCLUSIVE_TRYLOCK_FUNCTION(true); ++ void unlock() TF_UNLOCK_FUNCTION(); ++ ++ private: ++ friend class recursive_condition_variable; ++ std::recursive_mutex mu_; ++}; ++ + // A Condition represents a predicate on state protected by a mutex. The + // function must have no side-effects on that state. When passed to + // mutex::Await(), the function will be called with the mutex held. It may be +@@ -221,6 +239,44 @@ class TF_SCOPED_LOCKABLE tf_shared_lock { + #define tf_shared_lock(x) \ + static_assert(0, "tf_shared_lock_decl_missing_var_name"); + ++// Mimic a subset of the std::unique_lock functionality. ++class TF_SCOPED_LOCKABLE recursive_mutex_lock { ++ public: ++ typedef ::tensorflow::recursive_mutex mutex_type; ++ ++ explicit recursive_mutex_lock(mutex_type& mu) TF_EXCLUSIVE_LOCK_FUNCTION(mu) : mu_(&mu) { ++ mu_->lock(); ++ } ++ ++ recursive_mutex_lock(mutex_type& mu, std::try_to_lock_t) TF_EXCLUSIVE_LOCK_FUNCTION(mu) ++ : mu_(&mu) { ++ if (!mu.try_lock()) { ++ mu_ = nullptr; ++ } ++ } ++ ++ // Manually nulls out the source to prevent double-free. ++ // (std::move does not null the source pointer by default.) ++ recursive_mutex_lock(recursive_mutex_lock&& ml) noexcept TF_EXCLUSIVE_LOCK_FUNCTION(ml.mu_) ++ : mu_(ml.mu_) { ++ ml.mu_ = nullptr; ++ } ++ ~recursive_mutex_lock() TF_UNLOCK_FUNCTION() { ++ if (mu_ != nullptr) { ++ mu_->unlock(); ++ } ++ } ++ mutex_type* mutex() { return mu_; } ++ ++ explicit operator bool() const { return mu_ != nullptr; } ++ ++ private: ++ mutex_type* mu_; ++}; ++ ++// Catch bug where variable name is omitted, e.g. recursive_mutex_lock (mu); ++#define recursive_mutex_lock(x) static_assert(0, "recursive_mutex_lock_decl_missing_var_name"); ++ + // Mimic std::condition_variable. + class condition_variable { + public: +@@ -239,6 +295,22 @@ class condition_variable { + internal::CVData cv_; + }; + ++// Wrap std::condition_variable_any for recursive_mutex_lock ++class recursive_condition_variable { ++ public: ++ recursive_condition_variable(); ++ ++ void wait(recursive_mutex_lock& lock); ++ template ++ std::cv_status wait_for(recursive_mutex_lock& lock, ++ std::chrono::duration dur); ++ void notify_one(); ++ void notify_all(); ++ ++ private: ++ std::condition_variable_any cv_; ++}; ++ + // Like "cv->wait(*mu)", except that it only waits for up to "ms" milliseconds. + // + // Returns kCond_Timeout if the timeout expired without this +diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto +index 93f350f4c30..ab0fb3ae582 100644 +--- a/tensorflow/core/protobuf/config.proto ++++ b/tensorflow/core/protobuf/config.proto +@@ -185,6 +185,10 @@ message GPUOptions { + // launch an additional kernel will stall until an event + // completes. + int32 kernel_tracker_max_pending = 9; ++ ++ // If true, Large Model support is turned on in eager mode. ++ bool lms_enabled = 10; ++ + } + + // Everything inside experimental is subject to change and is not subject +diff --git a/tensorflow/lite/delegates/flex/buffer_map.cc b/tensorflow/lite/delegates/flex/buffer_map.cc +index c2611290c1b..d950e554d81 100644 +--- a/tensorflow/lite/delegates/flex/buffer_map.cc ++++ b/tensorflow/lite/delegates/flex/buffer_map.cc +@@ -26,10 +26,9 @@ namespace tflite { + namespace flex { + namespace { + // A tensor buffer that is allocated, deallocated and populated by TF Lite. +-class BaseTfLiteTensorBuffer : public tensorflow::TensorBuffer { +- using tensorflow::TensorBuffer::TensorBuffer; ++class BaseTfLiteTensorBuffer : public tensorflow::SimpleTensorBufferBase { ++ using tensorflow::SimpleTensorBufferBase::SimpleTensorBufferBase; + +- TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription( + tensorflow::AllocationDescription* proto) const override { + tensorflow::int64 rb = size(); +diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD +index 1d4db0a5ee9..f5095f9577d 100644 +--- a/tensorflow/python/BUILD ++++ b/tensorflow/python/BUILD +@@ -101,6 +101,7 @@ py_library( + "//third_party/py/tensorflow_core:__subpackages__", + ], + deps = [ ++ ":_pywrap_bfc_allocator_stats", + ":_pywrap_checkpoint_reader", + ":_pywrap_events_writer", + ":_pywrap_kernel_registry", +@@ -872,6 +873,7 @@ cc_library( + ], + ) + ++ + cc_library( + name = "py_func_lib", + srcs = ["lib/core/py_func.cc"], +@@ -1212,6 +1214,22 @@ py_library( + ], + ) + ++tf_python_pybind_extension( ++ name = "_pywrap_bfc_allocator_stats", ++ srcs = ["framework/bfc_allocator_stats_wrapper.cc"], ++ module_name = "_pywrap_bfc_allocator_stats", ++ deps = [ ++ ":pybind11_absl", ++ ":pybind11_lib", ++ "//tensorflow/core:protos_all_cc", ++ "//tensorflow/core:framework_headers_lib", ++ "//tensorflow/core:gpu_bfc_allocator", ++ "//third_party/python_runtime:headers", ++ "//third_party/eigen3", ++ "@pybind11", ++ ], ++) ++ + py_library( + name = "subscribe", + srcs = ["framework/subscribe.py"], +@@ -1235,6 +1253,7 @@ py_library( + ], + srcs_version = "PY2AND3", + deps = [ ++ ":_pywrap_bfc_allocator_stats", + ":_pywrap_checkpoint_reader", + ":_pywrap_debug_events_writer", + ":_pywrap_events_writer", +@@ -1269,6 +1288,7 @@ py_library( + ":tensor_util", + ":type_spec", + ":util", ++ ":bfc_allocator_stats", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/python/eager:context", +@@ -5746,6 +5766,18 @@ tf_cuda_library( + alwayslink = 1, + ) + ++ ++py_library( ++ name = "bfc_allocator_stats", ++ srcs = ["framework/bfc_allocator_stats.py"], ++ srcs_version = "PY2AND3", ++ deps = [ ++ ":_pywrap_bfc_allocator_stats", ++ ":util", ++ ], ++) ++ ++ + py_library( + name = "pywrap_tensorflow", + srcs = [ +@@ -7883,6 +7915,7 @@ py_library( + deps = [":_pywrap_model_analyzer"], + ) + ++ + tf_py_test( + name = "model_analyzer_test", + size = "small", +diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py +index f2e0e5127dd..576fa0e053e 100644 +--- a/tensorflow/python/__init__.py ++++ b/tensorflow/python/__init__.py +@@ -66,6 +66,7 @@ from tensorflow.python.framework.versions import * + from tensorflow.python.framework import config + from tensorflow.python.framework import errors + from tensorflow.python.framework import graph_util ++from tensorflow.python.framework import bfc_allocator_stats + + # Session + from tensorflow.python.client.client_lib import * +diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py +index 073c33383c3..2ad9accaafd 100644 +--- a/tensorflow/python/eager/context.py ++++ b/tensorflow/python/eager/context.py +@@ -1,4 +1,5 @@ + # Copyright 2017 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2019, 2020. IBM All Rights Reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. +@@ -432,6 +433,9 @@ class Context(object): + self._enable_mlir_bridge = None + self._optimizer_experimental_options = {} + ++ # LMS ++ self._lms_enabled = False ++ + _python_eager_context_create_counter.get_cell().increase_by(1) + # pylint: enable=redefined-outer-name + +@@ -982,6 +986,15 @@ class Context(object): + visible_device_list = [] + virtual_devices = [] + gpu_index = -1 ++ ++ # Check Large Model Support configuration ++ lms_enabled = None ++ ++ if self._lms_enabled is not None: ++ lms_enabled = self._lms_enabled ++ else: ++ lms_enabled = False ++ + memory_growths = set() + for dev in self.list_physical_devices("GPU"): + gpu_index += 1 +@@ -1016,7 +1029,8 @@ class Context(object): + allow_growth=allow_growth, + visible_device_list=",".join(visible_device_list), + experimental=config_pb2.GPUOptions.Experimental( +- virtual_devices=virtual_devices)) ++ virtual_devices=virtual_devices, ++ lms_enabled=lms_enabled)) + + @property + def function_call_options(self): +@@ -1366,6 +1380,18 @@ class Context(object): + + self._virtual_device_map[dev] = virtual_devices + ++ ++ @property ++ def lms_enabled(self): ++ return self._lms_enabled ++ ++ @lms_enabled.setter ++ def lms_enabled(self, lms_enabled): ++ self._lms_enabled = lms_enabled ++ ++ def get_lms_enabled(self): ++ return self._lms_enabled ++ + @property + def enable_mlir_bridge(self): + return self._enable_mlir_bridge +diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc +index f8e1fb568ac..cac4044ecbf 100644 +--- a/tensorflow/python/eager/pywrap_tensor.cc ++++ b/tensorflow/python/eager/pywrap_tensor.cc +@@ -1,4 +1,5 @@ + /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. ++Copyright 2019, 2020. IBM All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. +@@ -570,6 +571,23 @@ static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value, + return 0; + } + ++static PyObject* EagerTensor_graph_id(EagerTensor* self, void* unused) { ++ int64_t id; ++ if (self->handle && TFE_TensorHandle_GraphId(self->handle, &id)) { ++ return PyLong_FromLongLong(id); ++ } ++ Py_INCREF(Py_None); ++ return Py_None; ++} ++ ++static int EagerTensor_setgraph_id(EagerTensor* self, PyObject* value, ++ void* unused) { ++ if (self->handle) { ++ TFE_TensorHandle_SetGraphId(self->handle, PyLong_AsLongLong(value)); ++ } ++ return 0; ++} ++ + // Function `_copy_to_device`. + static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args, + PyObject* kwds) { +@@ -658,6 +676,9 @@ static PyGetSetDef EagerTensor_getsetters[] = { + {const_cast("_tensor_shape"), (getter)EagerTensor_tensor_shape, + (setter)EagerTensor_settensor_shape, + const_cast("Shape of the tensor."), nullptr}, ++ {const_cast("graph_id"), (getter)EagerTensor_graph_id, ++ (setter)EagerTensor_setgraph_id, const_cast("graph_id"), ++ nullptr}, + {nullptr} /* Sentinel */ + }; + +diff --git a/tensorflow/python/framework/bfc_allocator_stats.py b/tensorflow/python/framework/bfc_allocator_stats.py +new file mode 100644 +index 00000000000..d52a2675764 +--- /dev/null ++++ b/tensorflow/python/framework/bfc_allocator_stats.py +@@ -0,0 +1,89 @@ ++# Copyright 2019, 2020. IBM All Rights Reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ============================================================================== ++ ++from tensorflow.python import _pywrap_bfc_allocator_stats as bfc_alloc_stats ++from tensorflow.python.util.tf_export import tf_export ++ ++@tf_export("experimental.get_num_allocs") ++def get_num_allocs( gpu_id ): ++ return bfc_alloc_stats.getNumAllocs( gpu_id ) ++ ++@tf_export("experimental.get_bytes_in_use") ++def get_bytes_in_use( gpu_id ): ++ return bfc_alloc_stats.getBytesInUse( gpu_id ) ++ ++@tf_export("experimental.get_peak_bytes_in_use") ++def get_peak_bytes_in_use( gpu_id ): ++ return bfc_alloc_stats.getPeakBytesInUse( gpu_id ) ++ ++@tf_export("experimental.get_largest_alloc_size") ++def get_largest_alloc_size( gpu_id ): ++ return bfc_alloc_stats.getLargestAllocSize( gpu_id ) ++ ++@tf_export("experimental.get_bytes_limit") ++def get_bytes_limit( gpu_id ): ++ return bfc_alloc_stats.getBytesLimit( gpu_id ) ++ ++@tf_export("experimental.get_bytes_reserved") ++def get_bytes_reserved( gpu_id ): ++ return bfc_alloc_stats.getBytesReserved( gpu_id ) ++ ++@tf_export("experimental.get_peak_bytes_reserved") ++def get_peak_bytes_reserved( gpu_id ): ++ return bfc_alloc_stats.getPeakBytesReserved( gpu_id ) ++ ++@tf_export("experimental.get_bytes_reservable_limit") ++def get_bytes_reservable_limit( gpu_id ): ++ return bfc_alloc_stats.getBytesReservableLimit( gpu_id ) ++ ++@tf_export("experimental.get_bytes_inactive") ++def get_bytes_inactive( gpu_id ): ++ return bfc_alloc_stats.getBytesInactive( gpu_id ) ++ ++@tf_export("experimental.get_bytes_active") ++def get_bytes_active( gpu_id ): ++ return bfc_alloc_stats.getBytesActive( gpu_id ) ++ ++@tf_export("experimental.get_peak_bytes_active") ++def get_peak_bytes_active( gpu_id ): ++ return bfc_alloc_stats.getPeakBytesActive( gpu_id ) ++ ++@tf_export("experimental.get_bytes_reclaimed") ++def get_bytes_reclaimed( gpu_id ): ++ return bfc_alloc_stats.getBytesReclaimed( gpu_id ) ++ ++@tf_export("experimental.get_peak_bytes_reclaimed") ++def get_peak_bytes_reclaimed( gpu_id ): ++ return bfc_alloc_stats.getPeakBytesReclaimed( gpu_id ) ++ ++@tf_export("experimental.get_current_bytes_reclaimed") ++def get_current_bytes_reclaimed( gpu_id ): ++ return bfc_alloc_stats.getCurrentBytesReclaimed( gpu_id ) ++ ++@tf_export("experimental.get_num_single_reclaims") ++def get_num_single_reclaims( gpu_id ): ++ return bfc_alloc_stats.getNumSingleReclaims( gpu_id ) ++ ++@tf_export("experimental.get_num_full_reclaims") ++def get_num_full_reclaims( gpu_id ): ++ return bfc_alloc_stats.getNumFullReclaims( gpu_id ) ++ ++@tf_export("experimental.get_gpu_host_bytes_in_use") ++def get_gpu_host_bytes_in_use( numa_node ): ++ return bfc_alloc_stats.getGPUHostBytesInUse( numa_node ) ++ ++@tf_export("experimental.get_gpu_host_peak_bytes_in_use") ++def get_gpu_host_peak_bytes_in_use( numa_node ): ++ return bfc_alloc_stats.getGPUHostPeakBytesInUse( numa_node ) +diff --git a/tensorflow/python/framework/bfc_allocator_stats_wrapper.cc b/tensorflow/python/framework/bfc_allocator_stats_wrapper.cc +new file mode 100644 +index 00000000000..fac774b41d5 +--- /dev/null ++++ b/tensorflow/python/framework/bfc_allocator_stats_wrapper.cc +@@ -0,0 +1,404 @@ ++/* Copyright 2020 IBM All Rights Reserved. ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++==============================================================================*/ ++#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" // for GPUProcessState class ++#include "tensorflow/core/common_runtime/gpu/gpu_id.h" // for TfGpuId type ++#include "tensorflow/core/framework/allocator.h" // for Allocator class ++#include "tensorflow/core/common_runtime/bfc_allocator.h" // for BFCAllocator ++#include "tensorflow/core/platform/logging.h" // for VLOG ++#include "include/pybind11/pybind11.h" ++ ++namespace tensorflow { ++ ++namespace { ++ ++namespace py = pybind11; ++ ++ absl::optional GetBFCAllocatorStats( int gpu_id ) ++ { ++ tensorflow::GPUProcessState * ps = tensorflow::GPUProcessState::singleton(); ++ bool gpu_registered = ps->HasGPUDevice(); ++ ++ if(gpu_registered) ++ { ++ // placeholder variable for input to `GetGPUAllocator` ++ // It will be ignored as we are making sure the gpu device has been created ++ // before we attempt to get the gpu allocator ++ size_t total_bytes = 1; ++ tensorflow::TfGpuId tf_gpu_id(gpu_id); ++ tensorflow::GPUOptions options; ++ std::string bfc = "BFC"; ++ options.set_allocator_type(bfc); ++ tensorflow::Allocator * allocator = ps->GetGPUAllocator(options, ++ tf_gpu_id, ++ total_bytes); ++ std::string name = allocator->Name(); ++ tensorflow::BFCAllocator * bfc_allocator = static_cast(allocator); ++ return bfc_allocator->GetStats(); ++ } ++ else ++ { ++ LOG(ERROR) << "(GetBFCAllocatorStats) No GPU device registered. Skipping getting stats\n"; ++ return absl::nullopt; ++ } ++ } ++ ++ ++ int64 getNumAllocs( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->num_allocs; ++ } ++ else ++ { ++ LOG(ERROR) << "(getNumAllocs) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getBytesInUse( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->bytes_in_use; ++ } ++ else ++ { ++ LOG(ERROR) << "(getBytesInUse) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getPeakBytesInUse( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->peak_bytes_in_use; ++ } ++ else ++ { ++ LOG(ERROR) << "(getPeakBytesInUse) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getLargestAllocSize( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->largest_alloc_size; ++ } ++ else ++ { ++ LOG(ERROR) << "(getLargestAllocSize) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getBytesLimit( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ if( allocator_stats->bytes_limit.has_value() ) ++ { ++ result = allocator_stats->bytes_limit.value(); ++ } ++ else ++ { ++ LOG(INFO) << "(getBytesLimit) - Optional value is empty"; ++ } ++ } ++ else ++ { ++ LOG(ERROR) << "(getBytesLimit) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getBytesReserved( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->bytes_reserved; ++ } ++ else ++ { ++ LOG(ERROR) << "(getBytesReserved) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getPeakBytesReserved( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->peak_bytes_reserved; ++ } ++ else ++ { ++ LOG(ERROR) << "(getPeakBytesReserved) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getBytesReservableLimit( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ if( allocator_stats->bytes_reservable_limit.has_value() ) ++ { ++ result = allocator_stats->bytes_reservable_limit.value(); ++ } ++ else ++ { ++ LOG(INFO) << "(getBytesReservableLimit) - Optional value is empty"; ++ } ++ } ++ else ++ { ++ LOG(ERROR) << "(getBytesReservableLimit) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getBytesInactive( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->bytes_inactive; ++ } ++ else ++ { ++ LOG(ERROR) << "(getBytesInactive) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getBytesActive( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->bytes_active(); ++ } ++ else ++ { ++ LOG(ERROR) << "(getBytesActive) - Could not retrieve BFC Allocator Stats"; ++ } ++ ++ return result; ++ } ++ ++ int64 getPeakBytesActive( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->peak_bytes_active; ++ } ++ else ++ { ++ LOG(ERROR) << "(getPeakBytesActive) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getBytesReclaimed( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->bytes_reclaimed; ++ } ++ else ++ { ++ LOG(ERROR) << "(getBytesReclaimed) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getPeakBytesReclaimed( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->peak_bytes_reclaimed; ++ } ++ else ++ { ++ LOG(ERROR) << "(getPeakBytesReclaimed) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getCurrentBytesReclaimed( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->cur_bytes_reclaimed; ++ } ++ else ++ { ++ LOG(ERROR) << "(getCurrentBytesReclaimed) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getNumSingleReclaims( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->num_single_reclaims; ++ } ++ else ++ { ++ LOG(ERROR) << "(getNumSingleReclaims) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getNumFullReclaims( int gpu_id ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetBFCAllocatorStats( gpu_id ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->num_full_reclaims; ++ } ++ else ++ { ++ LOG(ERROR) << "(getNumFullReclaims) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ // GPU host allocator ++ absl::optional GetGPUHostAllocatorStats ( int numa_node ) ++ { ++ tensorflow::GPUProcessState * ps = tensorflow::GPUProcessState::singleton(); ++ bool gpu_registered = ps->HasGPUDevice(); ++ ++ if(gpu_registered) ++ { ++ tensorflow::Allocator * allocator = ps->GetGpuHostAllocator(numa_node); ++ tensorflow::BFCAllocator * bfc_allocator = static_cast(allocator); ++ return bfc_allocator->GetStats(); ++ } ++ else ++ { ++ LOG(ERROR) << "(GetGPUHostAllocatorStats) No GPU device registered. Skipping getting stats\n"; ++ return absl::nullopt; ++ } ++ } ++ ++ int64 getGPUHostBytesInUse( int numa_node ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetGPUHostAllocatorStats( numa_node ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->bytes_in_use; ++ } ++ else ++ { ++ LOG(ERROR) << "(getGPUHostBytesInUse) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ int64 getGPUHostPeakBytesInUse( int numa_node ) ++ { ++ int64 result = -1; ++ absl::optional allocator_stats = GetGPUHostAllocatorStats( numa_node ); ++ ++ if( allocator_stats != absl::nullopt ) ++ { ++ result = allocator_stats->peak_bytes_in_use; ++ } ++ else ++ { ++ LOG(ERROR) << "(getGPUHostPeakBytesInUse) - Could not retrieve BFC Allocator Stats"; ++ } ++ return result; ++ } ++ ++ ++ ++}// namespace ++ ++PYBIND11_MODULE(_pywrap_bfc_allocator_stats, m) { ++ m.def("getNumAllocs", &getNumAllocs); ++ m.def("getBytesInUse", &getBytesInUse); ++ m.def("getPeakBytesInUse", &getPeakBytesInUse); ++ m.def("getLargestAllocSize", &getLargestAllocSize); ++ m.def("getBytesLimit", &getBytesLimit); ++ m.def("getBytesReserved", &getBytesReserved); ++ m.def("getPeakBytesReserved", &getPeakBytesReserved); ++ m.def("getBytesReservableLimit", &getBytesReservableLimit); ++ m.def("getBytesInactive", &getBytesInactive); ++ m.def("getBytesActive", &getBytesActive); ++ m.def("getPeakBytesActive", &getPeakBytesActive); ++ m.def("getBytesReclaimed", &getBytesReclaimed); ++ m.def("getCurrentBytesReclaimed", &getCurrentBytesReclaimed); ++ m.def("getPeakBytesReclaimed", &getPeakBytesReclaimed); ++ m.def("getNumSingleReclaims", &getNumSingleReclaims); ++ m.def("getNumFullReclaims", &getNumFullReclaims); ++ m.def("getGPUHostBytesInUse", &getGPUHostBytesInUse); ++ m.def("getGPUHostPeakBytesInUse", &getGPUHostPeakBytesInUse); ++} ++} // namespace tensorflow +diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py +index c696675fed8..fa7568d358f 100644 +--- a/tensorflow/python/framework/config.py ++++ b/tensorflow/python/framework/config.py +@@ -1,4 +1,5 @@ + # Copyright 2019 The TensorFlow Authors. All Rights Reserved. ++# Copyright 2019, 2020. IBM All Rights Reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. +@@ -500,6 +501,20 @@ def set_memory_growth(device, enable): + context.context().set_memory_growth(device, enable) + + ++@tf_export('config.experimental.get_lms_enabled') ++def get_lms_enabled(): ++ """Get value denoting whether LMS has been enabled ++ """ ++ return context.context().get_lms_enabled() ++ ++ ++@tf_export('config.experimental.set_lms_enabled') ++def set_lms_enabled(lms_enabled): ++ """Set value denoting whether LMS has been enabled ++ """ ++ context.context().lms_enabled = lms_enabled ++ ++ + @tf_export('config.get_logical_device_configuration', + 'config.experimental.get_virtual_device_configuration') + @deprecation.deprecated_endpoints( +diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py +index 6d628dbfbaf..be288a5529a 100644 +--- a/tensorflow/python/keras/engine/network.py ++++ b/tensorflow/python/keras/engine/network.py +@@ -833,10 +833,16 @@ class Network(base_layer.Layer): + + # Dictionary mapping reference tensors to computed tensors. + tensor_dict = {} ++ ++ def _add_tensor_to_dict(id, t): ++ for x in t: ++ x.graph_id = id ++ tensor_dict[str(id)] = t ++ + for x, y in zip(self.inputs, inputs): + y = self._conform_to_reference_input(y, ref_input=x) + x_id = str(id(x)) +- tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] ++ _add_tensor_to_dict(id(x), [y] * self._tensor_usage_count[x_id]) + + depth_keys = list(self._nodes_by_depth.keys()) + depth_keys.sort(reverse=True) +@@ -891,7 +897,7 @@ class Network(base_layer.Layer): + for x, y in zip( + nest.flatten(node.output_tensors), nest.flatten(output_tensors)): + x_id = str(id(x)) +- tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] ++ _add_tensor_to_dict(id(x), [y] * self._tensor_usage_count[x_id]) + + output_tensors = [] + output_shapes = [] +diff --git a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt +index b8f92b30099..f390ca0b568 100644 +--- a/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt ++++ b/tensorflow/tools/api/golden/v1/tensorflow.config.experimental.pbtxt +@@ -20,6 +20,10 @@ tf_module { + name: "get_device_policy" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } ++ member_method { ++ name: "get_lms_enabled" ++ argspec: "args=[], varargs=None, keywords=None, defaults=None" ++ } + member_method { + name: "get_memory_growth" + argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None" +@@ -48,6 +52,10 @@ tf_module { + name: "set_device_policy" + argspec: "args=[\'device_policy\'], varargs=None, keywords=None, defaults=None" + } ++ member_method { ++ name: "set_lms_enabled" ++ argspec: "args=[\'lms_enabled\'], varargs=None, keywords=None, defaults=None" ++ } + member_method { + name: "set_memory_growth" + argspec: "args=[\'device\', \'enable\'], varargs=None, keywords=None, defaults=None" +diff --git a/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt +index ccd4919f59f..7df4e1b6a4e 100644 +--- a/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt ++++ b/tensorflow/tools/api/golden/v1/tensorflow.experimental.pbtxt +@@ -16,4 +16,77 @@ tf_module { + name: "output_all_intermediates" + argspec: "args=[\'state\'], varargs=None, keywords=None, defaults=None" + } ++ member_method { ++ name: "get_num_allocs" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_in_use" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_peak_bytes_in_use" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_largest_alloc_size" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_limit" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_reserved" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_peak_bytes_reserved" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_reservable_limit"" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_inactive" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_active" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_peak_bytes_active" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_reclaimed" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_peak_bytes_reclaimed" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_current_bytes_reclaimed" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_num_single_reclaims" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_num_full_reclaims" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_gpu_host_bytes_in_use" ++ argspec: "args=[\'numa_node\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_gpu_host_peak_bytes_in_use" ++ argspec: "args=[\'numa_node\'], varargs=None, keywords=None, defaults=None" ++ } ++ + } +diff --git a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt +index b8f92b30099..f390ca0b568 100644 +--- a/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt ++++ b/tensorflow/tools/api/golden/v2/tensorflow.config.experimental.pbtxt +@@ -20,6 +20,10 @@ tf_module { + name: "get_device_policy" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } ++ member_method { ++ name: "get_lms_enabled" ++ argspec: "args=[], varargs=None, keywords=None, defaults=None" ++ } + member_method { + name: "get_memory_growth" + argspec: "args=[\'device\'], varargs=None, keywords=None, defaults=None" +@@ -48,6 +52,10 @@ tf_module { + name: "set_device_policy" + argspec: "args=[\'device_policy\'], varargs=None, keywords=None, defaults=None" + } ++ member_method { ++ name: "set_lms_enabled" ++ argspec: "args=[\'lms_enabled\'], varargs=None, keywords=None, defaults=None" ++ } + member_method { + name: "set_memory_growth" + argspec: "args=[\'device\', \'enable\'], varargs=None, keywords=None, defaults=None" +diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt +index 2e2579e698d..afe12eab516 100644 +--- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt ++++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.pbtxt +@@ -20,4 +20,77 @@ tf_module { + name: "function_executor_type" + argspec: "args=[\'executor_type\'], varargs=None, keywords=None, defaults=None" + } ++ member_method { ++ name: "get_num_allocs" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_in_use" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_peak_bytes_in_use" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_largest_alloc_size" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_limit" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_reserved" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_peak_bytes_reserved" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_reservable_limit" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_inactive" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_active" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_peak_bytes_active" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_bytes_reclaimed" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_peak_bytes_reclaimed" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_current_bytes_reclaimed" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_num_single_reclaims" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_num_full_reclaims" ++ argspec: "args=[\'gpu_id\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_gpu_host_bytes_in_use" ++ argspec: "args=[\'numa_node\'], varargs=None, keywords=None, defaults=None" ++ } ++ member_method { ++ name: "get_gpu_host_peak_bytes_in_use" ++ argspec: "args=[\'numa_node\'], varargs=None, keywords=None, defaults=None" ++ } ++ + } +-- +2.15.1 +