From c2afd9ff3507e098b712833cbed728c48182f9dd Mon Sep 17 00:00:00 2001 From: shaojiewang Date: Thu, 30 Mar 2023 20:33:50 +0800 Subject: [PATCH] add nan inf check --- .../framework/details/nan_inf_utils_detail.cc | 72 ++-- .../framework/details/nan_inf_utils_detail.cu | 328 ++++++++++++++++-- .../framework/details/nan_inf_utils_detail.h | 324 ++++++++++++++++- paddle/fluid/platform/flags.cc | 22 ++ .../fluid/contrib/mixed_precision/__init__.py | 3 + .../mixed_precision/program_checker.py | 99 ++++++ 6 files changed, 787 insertions(+), 61 deletions(-) create mode 100644 python/paddle/fluid/contrib/mixed_precision/program_checker.py diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cc b/paddle/fluid/framework/details/nan_inf_utils_detail.cc index bce7b64e6d735..5e1eea56d0a6b 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cc +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cc @@ -24,9 +24,28 @@ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/kernels/funcs/eigen/extensions.h" +DECLARE_int32(check_nan_inf_level); + namespace paddle { namespace framework { namespace details { +struct DebugTools { + DebugTools() {} + std::string path = ""; +}; +static DebugTools debug_nan_inf; + +void SetNanInfDebugPath(const std::string& nan_inf_path) { + debug_nan_inf.path = nan_inf_path; + VLOG(4) << "Set the log's path of debug tools : " << nan_inf_path; +} + +std::string GetNanPath() { + if (debug_nan_inf.path.empty()) { + return ""; + } + return debug_nan_inf.path + "/"; +} static std::once_flag white_list_init_flag; @@ -90,7 +109,7 @@ static void InitWhiteListFormEnv() { const char* op_role_skip = std::getenv("PADDLE_INF_NAN_SKIP_ROLE"); const char* op_var_skip = std::getenv("PADDLE_INF_NAN_SKIP_VAR"); - if (op_type_skip != NULL) { + if (op_type_skip) { std::stringstream ss(op_type_skip); std::string op_type; while (std::getline(ss, op_type, ',')) { @@ -98,7 +117,7 @@ static void InitWhiteListFormEnv() { } } - if (op_role_skip != NULL) { + if (op_role_skip) { std::stringstream ss(op_role_skip); std::string op_role; while (std::getline(ss, op_role, ',')) { @@ -113,7 +132,7 @@ static void InitWhiteListFormEnv() { } } - if (op_var_skip != NULL) { + if (op_var_skip) { std::stringstream ss(op_var_skip); std::string op_var; while (std::getline(ss, op_var, ',')) { @@ -326,13 +345,13 @@ void TensorCheckerVisitor::apply( // use env strategy control in future, -1=print_all. int print_num = 3; CheckNanInf( - tensor_.data(), tensor_.numel(), print_num, op_type_, var_name_); + tensor.data(), tensor.numel(), print_num, op_type, var_name); } template <> void tensor_check(const std::string& op_type, const std::string& var_name, - const framework::Tensor& tensor, + const phi::DenseTensor& tensor, const platform::Place& place) { TensorCheckerVisitor vistor( op_type, var_name, tensor, place); @@ -348,9 +367,9 @@ void CheckVarHasNanOrInf(const std::string& op_type, platform::errors::NotFound( "Cannot find var: `%s` in op `%s`.", var_name, op_type)); - const Tensor* tensor{nullptr}; - if (var->IsType()) { - tensor = &var->Get(); + const phi::DenseTensor* tensor{nullptr}; + if (var->IsType()) { + tensor = &var->Get(); } else if (var->IsType()) { tensor = &var->Get().value(); } else { @@ -371,7 +390,8 @@ void CheckVarHasNanOrInf(const std::string& op_type, tensor_check(op_type, var_name, *tensor, place); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "Tensor[%s] use gpu place. PaddlePaddle must compile with GPU.", + "phi::DenseTensor[%s] use gpu place. PaddlePaddle must compile " + "with GPU.", var_name)); #endif return; @@ -400,10 +420,13 @@ void CheckVarHasNanOrInf(const std::string& op_type, flag, true, platform::errors::Fatal( - "Operator %s output Tensor %s contains Inf.", op_type, var_name)); + "Operator %s output phi::DenseTensor %s contains Inf.", + op_type, + var_name)); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "Tensor[%s] use xpu place. PaddlePaddle must compile with XPU.", + "phi::DenseTensor[%s] use xpu place. PaddlePaddle must compile " + "with XPU.", var_name)); #endif return; @@ -414,7 +437,7 @@ void CheckVarHasNanOrInf(const std::string& op_type, return; } - framework::LoDTensor cpu_tensor; + phi::DenseTensor cpu_tensor; cpu_tensor.Resize(tensor->dims()); float* cpu_data = static_cast( cpu_tensor.mutable_data(platform::CPUPlace(), tensor->dtype())); @@ -431,10 +454,13 @@ void CheckVarHasNanOrInf(const std::string& op_type, flag, true, platform::errors::Fatal( - "Operator %s output Tensor %s contains Inf.", op_type, var_name)); + "Operator %s output phi::DenseTensor %s contains Inf.", + op_type, + var_name)); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "Tensor[%s] use npu place. PaddlePaddle must compile with NPU.", + "phi::DenseTensor[%s] use npu place. PaddlePaddle must compile " + "with NPU.", var_name)); #endif return; @@ -473,8 +499,8 @@ using NpuOpRunner = paddle::operators::NpuOpRunner; constexpr int FLOAT_STATUS_SIZE = 8; -static framework::Tensor& npu_float_status() { - static framework::Tensor float_status; +static phi::DenseTensor& npu_float_status() { + static phi::DenseTensor float_status; return float_status; } @@ -494,7 +520,7 @@ void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op, flag.mutable_data({FLOAT_STATUS_SIZE}, place); NpuOpRunner("NPUAllocFloatStatus", {}, {flag}).Run(stream); - framework::Tensor tmp; + phi::DenseTensor tmp; tmp.mutable_data({FLOAT_STATUS_SIZE}, place); NpuOpRunner("NPUClearFloatStatus", {tmp}, {flag}).Run(stream); } @@ -503,9 +529,9 @@ void PrintNpuVarInfo(const std::string& op_type, const std::string& var_name, const framework::Variable* var, const platform::Place& place) { - const Tensor* tensor{nullptr}; - if (var->IsType()) { - tensor = &var->Get(); + const phi::DenseTensor* tensor{nullptr}; + if (var->IsType()) { + tensor = &var->Get(); } else if (var->IsType()) { tensor = &var->Get().value(); } else { @@ -528,7 +554,7 @@ void PrintNpuVarInfo(const std::string& op_type, VLOG(10) << "begin check " << op_type << " var_name:" << var_name << ", place:" << tensor->place() << ", numel:" << tensor->numel(); - framework::Tensor cpu_tensor; + phi::DenseTensor cpu_tensor; cpu_tensor.Resize(tensor->dims()); cpu_tensor.mutable_data(platform::CPUPlace(), tensor->dtype()); framework::TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor); @@ -575,13 +601,13 @@ static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op, auto stream = dev_ctx->stream(); auto& flag = npu_float_status(); - Tensor tmp; + phi::DenseTensor tmp; tmp.mutable_data({FLOAT_STATUS_SIZE}, place); // NPUGetFloatStatus updates data on input in-place. // tmp is only placeholder. NpuOpRunner("NPUGetFloatStatus", {flag}, {tmp}).Run(stream); - framework::Tensor cpu_tensor; + phi::DenseTensor cpu_tensor; auto cpu_place = platform::CPUPlace(); float* cpu_data = static_cast( cpu_tensor.mutable_data({FLOAT_STATUS_SIZE}, cpu_place)); diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cu b/paddle/fluid/framework/details/nan_inf_utils_detail.cu index 4aa24f8cb6ab8..514c58b8d7584 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cu +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cu @@ -22,6 +22,11 @@ #include "paddle/fluid/framework/details/nan_inf_utils_detail.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" + +DECLARE_int32(check_nan_inf_level); + namespace paddle { namespace framework { namespace details { @@ -133,27 +138,229 @@ __global__ void CheckNanInfKernel(const T* value, PrintNanInfKernel(value, numel, print_num, debug_info); } -template <> -template -void TensorCheckerVisitor::apply( - typename std::enable_if< - std::is_floating_point::value || - std::is_same>::value || - std::is_same>::value>::type*) - const { - int print_num = 3; +template +__device__ T BlockReduce(T value) { + __shared__ T shared_mem[1024]; - auto* dev_ctx = reinterpret_cast( - platform::DeviceContextPool::Instance().Get(tensor_.place())); - int dev_id = tensor_.place().device; + shared_mem[threadIdx.x] = value; + __syncthreads(); + + for (int stride = blockDim.x >> 1; stride > 0; stride = stride >> 1) { + if (threadIdx.x < stride) { + T value0 = shared_mem[threadIdx.x]; + T value1 = shared_mem[threadIdx.x + stride]; + T reduce_value; + if (ReduceType == 0) { + // max + reduce_value = value0 > value1 ? value0 : value1; + } else if (ReduceType == 1) { + // min + reduce_value = value0 < value1 ? value0 : value1; + } else if (ReduceType == 2) { + // sum + reduce_value = value0 + value1; + } + shared_mem[threadIdx.x] = reduce_value; + } + + if (stride > 16) { + __syncthreads(); + } + } + + __syncthreads(); + return shared_mem[0]; +} + +__device__ void BlockReduceNumNanInfAndWrite(const int64_t num_nan, + const int64_t num_inf, + const int64_t num_zero, + int64_t offset, + int64_t* num_nan_ptr, + int64_t* num_inf_ptr, + int64_t* num_zero_ptr) { + int64_t block_num_nan = BlockReduce(num_nan); + int64_t block_num_inf = BlockReduce(num_inf); + int64_t block_num_zero = BlockReduce(num_zero); + + if (threadIdx.x == 0) { + num_nan_ptr[offset] = block_num_nan; + num_inf_ptr[offset] = block_num_inf; + num_zero_ptr[offset] = block_num_zero; + } +} + +template < + typename T, + std::enable_if_t>::value || + std::is_same>::value, + bool> = true> +__device__ void BlockReduceMaxMinAndWrite(const T max_value, + const T min_value, + const T mean_value, + int64_t offset, + T* max_ptr, + T* min_ptr, + T* mean_ptr) { + // TODO(Xreki): support complex +} + +template < + typename T, + std::enable_if_t>::value && + !std::is_same>::value, + bool> = true> +__device__ void BlockReduceMaxMinAndWrite(const T max_value, + const T min_value, + const T mean_value, + int64_t offset, + T* max_ptr, + T* min_ptr, + T* mean_ptr) { + if (max_ptr && min_ptr && mean_ptr) { + __syncthreads(); + + T block_max_value = phi::funcs::blockReduceMax(max_value, FINAL_MASK); + T block_min_value = phi::funcs::blockReduceMin(min_value, FINAL_MASK); + T block_mean_value = phi::funcs::blockReduceSum(mean_value, FINAL_MASK); + + if (threadIdx.x == 0) { + max_ptr[offset] = block_max_value; + min_ptr[offset] = block_min_value; + mean_ptr[offset] = block_mean_value; + } + } +} + +template +__global__ void FindNanInfAndBlockMaxMin(const T* value_ptr, + const int64_t numel, + int64_t* block_num_nan_ptr, + int64_t* block_num_inf_ptr, + int64_t* block_num_zero_ptr, + MT* tensor_block_max_ptr, + MT* tensor_block_min_ptr, + MT* tensor_block_mean_ptr) { + int64_t i = threadIdx.x + blockIdx.x * blockDim.x; + + int64_t num_nan = 0; + int64_t num_inf = 0; + int64_t num_zero = 0; + + MT max_value = static_cast(i < numel ? value_ptr[i] : value_ptr[0]); + MT min_value = static_cast(i < numel ? value_ptr[i] : value_ptr[0]); + MT mean_value = static_cast(0); + for (; i < numel; i += blockDim.x * gridDim.x) { + MT value = static_cast(value_ptr[i]); + + max_value = value > max_value ? value : max_value; + min_value = value < min_value ? value : min_value; + mean_value += value / static_cast(numel); + + if (isnan(value)) { + num_nan += 1; + } else if (isinf(value)) { + num_inf += 1; + } + if (value == static_cast(0)) { + num_zero += 1; + } + } + + BlockReduceNumNanInfAndWrite(num_nan, + num_inf, + num_zero, + blockIdx.x, + block_num_nan_ptr, + block_num_inf_ptr, + block_num_zero_ptr); + + BlockReduceMaxMinAndWrite(max_value, + min_value, + mean_value, + blockIdx.x, + tensor_block_max_ptr, + tensor_block_min_ptr, + tensor_block_mean_ptr); +} + +template +__global__ void FindGlobalMaxMinAndPrint(const int64_t* block_num_nan_ptr, + const int64_t* block_num_inf_ptr, + const int64_t* block_num_zero_ptr, + const MT* tensor_block_max_ptr, + const MT* tensor_block_min_ptr, + const MT* tensor_block_mean_ptr, + const char* debug_info, + int64_t numel, + int64_t numel_max_min, + int check_nan_inf_level) { + if (blockIdx.x == 0 && threadIdx.x == 0) { + int64_t num_nan = 0; + int64_t num_inf = 0; + int64_t num_zero = 0; + + // numel_max_min <= 128 + for (int64_t i = 0; i < numel_max_min; ++i) { + num_nan += block_num_nan_ptr[i]; + num_inf += block_num_inf_ptr[i]; + num_zero += block_num_zero_ptr[i]; + } + + MT max_value = static_cast(0); + MT min_value = static_cast(0); + MT mean_value = static_cast(0); + if (tensor_block_max_ptr && tensor_block_min_ptr && tensor_block_mean_ptr) { + max_value = tensor_block_max_ptr[0]; + min_value = tensor_block_min_ptr[0]; + mean_value = tensor_block_mean_ptr[0]; + + // numel_max_min <= 128 + for (int64_t i = 1; i < numel_max_min; ++i) { + MT tmp_max_value = tensor_block_max_ptr[i]; + MT tmp_min_value = tensor_block_min_ptr[i]; + MT tmp_mean_value = tensor_block_mean_ptr[i]; + + max_value = tmp_max_value > max_value ? tmp_max_value : max_value; + min_value = tmp_min_value < min_value ? tmp_min_value : min_value; + mean_value += tmp_mean_value; + } + } + + PrintForDifferentLevel(debug_info, + numel, + num_nan, + num_inf, + num_zero, + max_value, + min_value, + mean_value, + check_nan_inf_level); + } +} + +template +inline std::string GetHintString(const std::string& op_type, + const std::string& var_name, + const phi::Place& place, + int dev_id = -1) { + std::string op_var = GetCpuHintString(op_type, var_name, place, dev_id); PADDLE_ENFORCE_EQ( (dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()), true, platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d", multi_op_var2gpu_str_mutex().size())); + return op_var; +} - std::string op_var = "[op=" + op_type_ + "] [tensor=" + var_name_ + "]"; - char* gpu_str_ptr = NULL; +template +static char* GetGpuHintStringPtr(const phi::GPUContext& ctx, + const std::string& op_type, + const std::string& var_name, + int dev_id) { + std::string op_var = + GetHintString(op_type, var_name, ctx.GetPlace(), dev_id); + char* gpu_str_ptr = nullptr; { auto& op_var2gpu_str_mutex = multi_op_var2gpu_str_mutex().at(dev_id); @@ -162,9 +369,9 @@ void TensorCheckerVisitor::apply( std::lock_guard guard(op_var2gpu_str_mutex); if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) { // insert auto gpu_str_tensor = paddle::memory::Alloc( - dev_ctx->GetPlace(), + ctx.GetPlace(), op_var.length() + 1, - phi::Stream(reinterpret_cast(dev_ctx->stream()))); + phi::Stream(reinterpret_cast(ctx.stream()))); gpu_str_ptr = reinterpret_cast(gpu_str_tensor->ptr()); op_var2gpu_str.emplace(op_var, std::move(gpu_str_tensor)); @@ -182,13 +389,13 @@ void TensorCheckerVisitor::apply( iter->first.c_str(), op_var.length() + 1, hipMemcpyHostToDevice, - dev_ctx->stream())); + ctx.stream())); #else PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1, cudaMemcpyHostToDevice, - dev_ctx->stream())); + ctx.stream())); #endif } else { // get auto iter = op_var2gpu_str.find(op_var); @@ -201,6 +408,40 @@ void TensorCheckerVisitor::apply( gpu_str_ptr = reinterpret_cast(iter->second->ptr()); } } + return gpu_str_ptr; +} + +template <> +template +void TensorCheckerVisitor::apply( + typename std::enable_if< + std::is_floating_point::value || + std::is_same>::value || + std::is_same>::value>::type*) + const { + auto* dev_ctx = reinterpret_cast( + platform::DeviceContextPool::Instance().Get(tensor.place())); + int dev_id = tensor.place().device; + // Write log to file + auto file_path = GetNanPath(); + if (file_path.size() > 0) { + phi::DenseTensor cpu_tensor; + platform::CPUPlace cpu_place; + cpu_tensor.Resize(tensor.dims()); + // 1. copy from gpu to cpu + paddle::framework::TensorCopySync(tensor, cpu_place, &cpu_tensor); + auto* dev_ctx = reinterpret_cast( + platform::DeviceContextPool::Instance().Get(tensor.place())); + const std::string debug_info = + GetHintString(op_type, var_name, place, dev_id); + // 2. write log to file + CheckNanInfCpuImpl(cpu_tensor.data(), tensor.numel(), debug_info, "gpu"); + return; + } + + // Write log to window + char* gpu_str_ptr = + GetGpuHintStringPtr(*dev_ctx, op_type, var_name, dev_id); #ifdef __HIPCC__ // HIP will throw GPU memory access fault if threads > 256 @@ -210,27 +451,66 @@ void TensorCheckerVisitor::apply( #endif size_t blocks = std::min(static_cast(128), - static_cast((tensor_.numel() + threads - 1) / threads)); + static_cast((tensor.numel() + threads - 1) / threads)); #ifdef __HIPCC__ + int print_num = 3; + hipLaunchKernelGGL(CheckNanInfKernel, dim3(blocks), dim3(threads), 0, dev_ctx->stream(), - tensor_.data(), - tensor_.numel(), + tensor.data(), + tensor.numel(), print_num, gpu_str_ptr); #else - CheckNanInfKernel<<stream()>>>( - tensor_.data(), tensor_.numel(), print_num, gpu_str_ptr); + using MT = float; //typename details::MPTypeTrait::Type; + + int64_t numel_max_min = blocks; + + phi::DenseTensor block_num_nan_inf_zero; + block_num_nan_inf_zero.Resize({static_cast(3 * numel_max_min)}); + int64_t* block_num_nan_ptr = + dev_ctx->template Alloc(&block_num_nan_inf_zero); + int64_t* block_num_inf_ptr = block_num_nan_ptr + numel_max_min; + int64_t* block_num_zero_ptr = block_num_inf_ptr + numel_max_min; + + phi::DenseTensor tensor_block_max_min; + tensor_block_max_min.Resize({static_cast(3 * numel_max_min)}); + MT* tensor_block_max_ptr = dev_ctx->template Alloc(&tensor_block_max_min); + MT* tensor_block_min_ptr = tensor_block_max_ptr + numel_max_min; + MT* tensor_block_mean_ptr = tensor_block_max_ptr + 2 * numel_max_min; + + FindNanInfAndBlockMaxMin + <<stream()>>>(tensor.data(), + tensor.numel(), + block_num_nan_ptr, + block_num_inf_ptr, + block_num_zero_ptr, + tensor_block_max_ptr, + tensor_block_min_ptr, + tensor_block_mean_ptr); + + int check_nan_inf_level = FLAGS_check_nan_inf_level; + FindGlobalMaxMinAndPrint + <<<1, 1, 0, dev_ctx->stream()>>>(block_num_nan_ptr, + block_num_inf_ptr, + block_num_zero_ptr, + tensor_block_max_ptr, + tensor_block_min_ptr, + tensor_block_mean_ptr, + gpu_str_ptr, + tensor.numel(), + numel_max_min, + check_nan_inf_level); #endif } template <> void tensor_check(const std::string& op_type, const std::string& var_name, - const framework::Tensor& tensor, + const phi::DenseTensor& tensor, const platform::Place& place) { std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap); diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.h b/paddle/fluid/framework/details/nan_inf_utils_detail.h index 99186c43e129e..3c4679f1cf664 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.h +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.h @@ -19,26 +19,322 @@ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/funcs/eigen/extensions.h" +#ifdef _WIN32 +#include +#include +#define MKDIR(path) _mkdir(path) +#else +#include +#define MKDIR(path) mkdir(path, S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH) +#endif + +DECLARE_int32(check_nan_inf_level); namespace paddle { namespace framework { namespace details { +void SetNanInfDebugPath(const std::string& nan_inf_path); + +std::string GetNanPath(); + +template ::value, bool> = true> +HOSTDEVICE bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) { + if (check_nan_inf_level >= 3) { + return true; + } else if (check_nan_inf_level >= 2) { + MT fp16_max = + static_cast(std::numeric_limits::max()); + return max_value > fp16_max || min_value < -fp16_max; + } + return false; +} + +template ::value, bool> = true> +HOSTDEVICE bool NeedPrint(MT max_value, MT min_value, int check_nan_inf_level) { + if (check_nan_inf_level >= 3) { + return true; + } + return false; +} + +template +HOSTDEVICE void PrintForDifferentLevel(const char* debug_info, + int64_t numel, + int64_t num_nan, + int64_t num_inf, + int64_t num_zero, + MT max_value, + MT min_value, + MT mean_value, + int check_nan_inf_level) { + if (num_nan > 0 || num_inf > 0) { + printf( + "[PRECISION] [ERROR] in %s, numel=%lld, num_nan=%lld, " + "num_inf=%lld, num_zero=%lld, max=%e, min=%e, mean=%e\n", + debug_info, + static_cast(numel), // NOLINT + static_cast(num_nan), // NOLINT + static_cast(num_inf), // NOLINT + static_cast(num_zero), // NOLINT + static_cast(max_value), + static_cast(min_value), + static_cast(mean_value)); + if (check_nan_inf_level == 0) { +#if defined(__NVCC__) || defined(__HIPCC__) + PADDLE_ENFORCE(false, + "There are NAN or INF (num_nan=%ld, num_inf=%lld, " + "num_zero=%lld) in %s.", + static_cast(num_nan), // NOLINT + static_cast(num_inf), // NOLINT + static_cast(num_zero), // NOLINT + debug_info); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "There are NAN or INF (num_nan=%lld, num_inf=%lld, num_zero=%lld) in " + "%s.", + static_cast(num_nan), // NOLINT + static_cast(num_inf), // NOLINT + static_cast(num_zero), // NOLINT + debug_info)); +#endif + } + } else if (NeedPrint(max_value, min_value, check_nan_inf_level)) { + printf("[PRECISION] in %s, numel=%lld, max=%e, min=%e, mean=%e\n", + debug_info, + static_cast(numel), // NOLINT + static_cast(max_value), + static_cast(min_value), + static_cast(mean_value)); + } +} + +template +void PrintForDifferentLevelFile(const char* debug_info, + int64_t numel, + int64_t num_nan, + int64_t num_inf, + int64_t num_zero, + MT max_value, + MT min_value, + MT mean_value, + int check_nan_inf_level, + const std::string& log_name) { + int dev_id = 0; +#ifdef PADDLE_WITH_HIP + hipGetDevice(&dev_id); +#elif PADDLE_WITH_CUDA + cudaGetDevice(&dev_id); +#endif + auto file_path = GetNanPath(); + MKDIR(file_path.c_str()); + std::string file_name = "worker_" + log_name + "." + std::to_string(dev_id); + std::string path = file_path + file_name; + std::ofstream outfile(path, std::ios::app); + if (!outfile.is_open()) { + return; + } + + if (num_nan > 0 || num_inf > 0) { + outfile << "[PRECISION] [ERROR] in " << debug_info + << ", numel=" << static_cast(numel) // NOLINT + << ", num_nan=" << static_cast(num_nan) // NOLINT + << ", num_inf=" << static_cast(num_inf) // NOLINT + << ", num_zero=" << static_cast(num_zero) // NOLINT + << ", max=" << static_cast(max_value) + << ", min=" << static_cast(min_value) + << ", mean=" << static_cast(mean_value) << std::endl; + } else if (NeedPrint(max_value, min_value, check_nan_inf_level)) { + outfile << "[PRECISION] in " << debug_info + << ", numel=" << static_cast(numel) // NOLINT + << ", max=" << static_cast(max_value) + << ", min=" << static_cast(min_value) + << ", mean=" << static_cast(mean_value) << std::endl; + } + outfile.close(); +} + +template +inline std::string GetCpuHintString(const std::string& op_type, + const std::string& var_name, + const phi::Place& place, + int device_id = -1) { + std::string dtype_str = DataTypeToString(DataTypeTrait::DataType()); + if (dtype_str == "float") { + dtype_str = "fp32"; + } else if (dtype_str == "double") { + dtype_str = "fp64"; + } else if (dtype_str == "::paddle::platform::float16") { + dtype_str = "fp16"; + } else if (dtype_str == "::paddle::platform::bfloat16") { + dtype_str = "bf16"; + } + + std::stringstream ss; + if (platform::is_gpu_place(place)) { + ss << "[device=gpu:" << device_id << ", "; + } else { + ss << "[device=cpu, "; + } + ss << "op=" << op_type << ", tensor=" << var_name << ", dtype=" << dtype_str + << "]"; + return ss.str(); +} + +template < + typename T, + std::enable_if_t>::value && + !std::is_same>::value, + bool> = true> +static void CheckNanInfCpuImpl(const T* value_ptr, + const int64_t numel, + const std::string& cpu_hint_str, + const std::string log_name = "cpu") { + using MT = float; //typename details::MPTypeTrait::Type; + +#ifdef _OPENMP + // Use maximum 4 threads to collect the nan and inf information. + //int num_threads = std::max(omp_get_num_threads(), 1); + //num_threads = std::min(num_threads, 4); + int num_threads = 1; +#else + int num_threads = 1; +#endif + + std::vector thread_num_nan(num_threads, 0); + std::vector thread_num_inf(num_threads, 0); + std::vector thread_num_zero(num_threads, 0); + std::vector thread_min_value(num_threads, static_cast(value_ptr[0])); + std::vector thread_max_value(num_threads, static_cast(value_ptr[0])); + std::vector thread_mean_value(num_threads, static_cast(0)); + +#ifdef _OPENMP +#pragma omp parallel num_threads(num_threads) +#endif + { +#ifdef _OPENMP + // int64_t tid = omp_get_thread_num(); + // int64_t chunk_size = (numel + num_threads - 1) / num_threads; + // int64_t begin = tid * chunk_size; + // int64_t end = chunk_size + begin > numel ? numel : chunk_size + begin; + int64_t tid = 0; + int64_t begin = 0; + int64_t end = numel; +#else + int64_t tid = 0; + int64_t begin = 0; + int64_t end = numel; +#endif + for (int64_t i = begin; i < end; ++i) { + MT value = static_cast(value_ptr[i]); + + thread_min_value[tid] = std::min(thread_min_value[tid], value); + thread_max_value[tid] = std::max(thread_max_value[tid], value); + thread_mean_value[tid] += value / static_cast(numel); + + if (std::isnan(value)) { + thread_num_nan[tid] += 1; + } else if (std::isinf(value)) { + thread_num_inf[tid] += 1; + } + if (value == 0) { + thread_num_zero[tid] += 1; + } + } + } + + int64_t num_nan = 0; + int64_t num_inf = 0; + int64_t num_zero = 0; + MT min_value = thread_min_value[0]; + MT max_value = thread_max_value[0]; + MT mean_value = static_cast(0); + for (int i = 0; i < num_threads; ++i) { + num_nan += thread_num_nan[i]; + num_inf += thread_num_inf[i]; + num_zero += thread_num_zero[i]; + min_value = std::min(thread_min_value[i], min_value); + max_value = std::max(thread_max_value[i], max_value); + mean_value += thread_mean_value[i]; + } + auto file_path = GetNanPath(); + // Write log to file + if (file_path.size() > 0) { + VLOG(4) << "[FLAGS_check_nan_inf_level=" << FLAGS_check_nan_inf_level + << "]. Write log to " << file_path; + PrintForDifferentLevelFile(cpu_hint_str.c_str(), + numel, + num_nan, + num_inf, + num_zero, + max_value, + min_value, + mean_value, + FLAGS_check_nan_inf_level, + log_name); + return; + } + + PrintForDifferentLevel(cpu_hint_str.c_str(), + numel, + num_nan, + num_inf, + num_zero, + max_value, + min_value, + mean_value, + FLAGS_check_nan_inf_level); +} + +template < + typename T, + std::enable_if_t>::value || + std::is_same>::value, + bool> = true> +void CheckNanInfCpuImpl(const T* value_ptr, + const int64_t numel, + const std::string& cpu_hint_str, + const std::string log_name = "cpu") { + using RealType = typename T::value_type; + + RealType real_sum = 0.0f, imag_sum = 0.0f; + +#ifdef _OPENMP +#pragma omp parallel for reduction(+ : real_sum) reduction(+ : imag_sum) +#endif + for (int64_t i = 0; i < numel; ++i) { + T value = value_ptr[i]; + real_sum += (value.real - value.real); + imag_sum += (value.imag - value.imag); + } + + if (std::isnan(real_sum) || std::isinf(real_sum) || std::isnan(imag_sum) || + std::isinf(imag_sum)) { + // hot fix for compile failed in gcc4.8 + // here also need print detail info of nan or inf later + PADDLE_THROW(platform::errors::PreconditionNotMet( + "There are NAN or INF in %s.", cpu_hint_str)); + } +} + template struct TensorCheckerVisitor { - TensorCheckerVisitor(const std::string& op_type, - const std::string& var_name, - const framework::Tensor& tensor, - const platform::Place& place) - : op_type_(op_type), - var_name_(var_name), - tensor_(tensor), - place_(place) {} + TensorCheckerVisitor(const std::string& o, + const std::string& v, + const phi::DenseTensor& t, + const platform::Place& p) + : op_type(o), var_name(v), tensor(t), place(p) {} template void apply( typename std::enable_if::value>::type* = 0) const { - VLOG(10) << var_name_ << " need not to check, it's type is not float point"; + VLOG(10) << var_name << " need not to check, it's type is not float point"; } template @@ -49,16 +345,16 @@ struct TensorCheckerVisitor { std::is_same>::value>::type* = 0) const; - std::string op_type_; - std::string var_name_; - const framework::Tensor& tensor_; - const platform::Place& place_; + std::string op_type; + std::string var_name; + const phi::DenseTensor& tensor; + const platform::Place& place; }; template void tensor_check(const std::string& op_type, const std::string& var_name, - const framework::Tensor& tensor, + const phi::DenseTensor& tensor, const platform::Place& place); } // namespace details diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 518aabbb09ead..bd4cc10880648 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -68,6 +68,28 @@ PADDLE_DEFINE_EXPORTED_bool( "Checking whether operator produce NAN/INF or not. It will be " "extremely slow so please use this flag wisely."); +/** + * Operator related FLAG + * Name: FLAGS_check_nan_inf_level + * Since Version: 2.5.0 + * Value Range: int32, default=0 + * Example: + * Note: Used to debug. Setting the check and print level when + * FLAGS_check_nan_inf is set. + * - 0, abort the process when any operator produce NAN/INF and only print the + * information of tensor which holds NAN/INF. + * - 1, continue the training or inference process and print the information of + * all tensors which holds NAN/INF. + * - 2, print the information of float tensors when the max or min value + * overflowing float16's limit. + * - 3, print the information of all tensors. + */ +PADDLE_DEFINE_EXPORTED_int32( + check_nan_inf_level, + 0, + "Setting the check and print level when FLAGS_check_nan_inf is set."); + + /** * Operator related FLAG * Name: FLAGS_check_nan_inf diff --git a/python/paddle/fluid/contrib/mixed_precision/__init__.py b/python/paddle/fluid/contrib/mixed_precision/__init__.py index 1dd5015ec80f2..ef89b881226c0 100644 --- a/python/paddle/fluid/contrib/mixed_precision/__init__.py +++ b/python/paddle/fluid/contrib/mixed_precision/__init__.py @@ -21,8 +21,11 @@ from . import fp16_utils from .fp16_utils import * from . import bf16 +from . import program_checker +from .program_checker import * __all__ = [] __all__ += decorator.__all__ __all__ += fp16_lists.__all__ __all__ += fp16_utils.__all__ +__all__ += program_checker.__all__ \ No newline at end of file diff --git a/python/paddle/fluid/contrib/mixed_precision/program_checker.py b/python/paddle/fluid/contrib/mixed_precision/program_checker.py new file mode 100644 index 0000000000000..d30f01dc84c87 --- /dev/null +++ b/python/paddle/fluid/contrib/mixed_precision/program_checker.py @@ -0,0 +1,99 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ... import core +from ... import default_main_program +from ... import default_startup_program +from ... import framework +from ... import layers +from ... import program_guard +from ... import unique_name +from . import fp16_utils +from .fp16_utils import rewrite_program +from .fp16_utils import cast_model_to_fp16 +from .fp16_utils import cast_parameters_to_fp16 +from .fp16_utils import update_role_var_grad +from .fp16_lists import AutoMixedPrecisionLists +from .amp_nn import check_finite_and_unscale +from .amp_nn import update_loss_scaling +import types +import warnings +import paddle + +__all__ = ["collect_operator_stats"] + +class op_checker(object): + def __init__(self) -> None: + self.op_name = None + self.fp32_calls = 0 + self.fp16_calls = 0 + self.bf16_calls = 0 + + def inc_calls(self, out_var): + if out_var.dtype == core.VarDesc.VarType.FP32: + self.fp32_calls = self.fp32_calls + 1 + if out_var.dtype == core.VarDesc.VarType.FP16: + self.fp16_calls = self.fp16_calls + 1 + if out_var.dtype == core.VarDesc.VarType.BF16: + self.bf16_calls = self.bf16_calls + 1 + + +def collect_operator_stats(program=default_main_program()): + block = program.global_block() + + op_checker_list = [] + param_names = [p.name for p in block.all_parameters()] + + global_block = program.global_block() + + for block in program.blocks: + ops = block.ops + for op in ops: + if op.type == 'create_py_reader' or op.type == 'read' or op.type == 'create_double_buffer_reader': + continue + + op_name = op.type + if 'Out' in op.output_names: + out_names = op.output('Out') + elif 'Y' in op.output_names: + out_names = op.output('Y') + elif 'X@GRAD' in op.output_names: + out_names = op.output('X@GRAD') + else: + continue + + out_name = out_names[0] + + if op.type == 'elementwise_mul': + print(f"outvar={global_block.var(out_name)}, op={op}") + + is_in_list = False + for each_checker in op_checker_list: + if op_name == each_checker.op_name: + each_checker.inc_calls(global_block.var(out_name)) + is_in_list = True + break + + if not is_in_list: + static_op_checker = op_checker() + static_op_checker.op_name = op_name + static_op_checker.inc_calls(global_block.var(out_name)) + op_checker_list.append(static_op_checker) + + for each_checker in op_checker_list: + print(f"op={each_checker.op_name}, fp32 calls={each_checker.fp32_calls}, fp16 calls={each_checker.fp16_calls}, bf16 calls={each_checker.bf16_calls}") + + + + #print(param_names) \ No newline at end of file