Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jun 4, 2018
1 parent c446026 commit 63aac3f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 36 deletions.
2 changes: 1 addition & 1 deletion amalgamation/amalgamation.py
Expand Up @@ -23,7 +23,7 @@
import platform

blacklist = [
'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
'Windows.h', 'intrin.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h',
'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h',
'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
Expand Down
64 changes: 29 additions & 35 deletions src/storage/pooled_storage_manager.h
Expand Up @@ -27,7 +27,14 @@

#if MXNET_USE_CUDA
#include <cuda_runtime.h>
#if defined(_MSC_VER)
#include <Windows.h>
#include <intrin.h>
#pragma intrinsic(_BitScanForward)
#pragma intrinsic(_BitScanForward64)
#endif // defined(_MSC_VER)
#endif // MXNET_USE_CUDA

#include <mxnet/base.h>
#include <mxnet/storage.h>
#include <unordered_map>
Expand Down Expand Up @@ -212,46 +219,46 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
#define clz(x) __builtin_clzl(x)
#define ctz(x) __builtin_ctzl(x)

#elif defined(__WINDOWS__)
#elif defined(_MSC_VER)
#define clz(x) __lzcnt64(x)
uint64_t __inline ctz(uint64_t value) {
QWORD trailing_zero = 0;
_BitScanForward64(&trailing_zero, value)
return trailing_zero;
}
uint64_t __inline clz(uint64_t value) {
QWORD leading_zero = 0;
_BitScanReverse64(&leading_zero, value)
return 63 - leading_zero;
int __inline ctz(size_t value) {
DWORD trailing_zero = 0;
_BitScanForward64(&trailing_zero, value);
return static_cast<int>(trailing_zero);
}

#else
#define clz(x) flsl(x)
#define ctz(x) ffsl(x)

#endif // defined(__clang__) || defined(__GNUC__)


#elif __SIZEOF_SIZE_T__ == __SIZEOF_INT__

#if defined(__clang__) || defined(__GNUC__) || defined(__WINDOWS__)
#if defined(__clang__) || defined(__GNUC__)
#define clz(x) __builtin_clz(x)
#define ctz(x) __builtin_ctz(x)

#elif defined(__WINDOWS__)
uint32_t __inline clz(uint32_t value) {
DWORD leading_zero = 0;
_BitScanReverse(&leading_zero, value)
return 31 - leading_zero;
}
uint32_t __inline ctz(uint32_t value) {
#elif defined(_MSC_VER)
#define clz(x) __lzcnt(x)
int __inline ctz(size_t value) {
DWORD trailing_zero = 0;
_BitScanForward(&trailing_zero, value)
return trailing_zero;
_BitScanForward(&trailing_zero, value);
return static_cast<int>(trailing_zero);
}

#else
#define clz(x) fls(x)
#define ctz(x) ffs(x)

#endif // defined(__clang__) || defined(__GNUC__)

#endif // __SIZEOF_SIZE_T__

#if defined(__clang__) || defined(__GNUC__) || defined(__WINDOWS__)
inline int log2_round_up(size_t s) {
int result = addr_width - 1 - clz(s);
return result + ((ctz(s) < result)?1:0);
return result + ((ctz(s) < result) ? 1 : 0);
}
inline int div_pow2_round_up(size_t s, int divisor_log2) {
// (1025, 10) -> 2
Expand All @@ -260,26 +267,13 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
int ffs = ctz(s); // find first set
return (s >> divisor_log2) + (ffs < divisor_log2 ? 1 : 0);
}
#else
inline int log2_round_up(size_t s) {
return static_cast<int>(std::ceil(std::log2(s)));
}
inline int div_pow2_round_up(size_t s, int divisor_log2) {
// (1025, 10) -> 2
// (2048, 10) -> 2
// (2049, 10) -> 3
int divisor = std::pow(2, divisor_log2);
return s / divisor + (s % divisor ? 1 : 0);
}
#endif // defined(__clang__) || defined(__GNUC__) || defined(__WINDOWS__)
inline int get_bucket(size_t s) {
int log_size = log2_round_up(s);
if (log_size > static_cast<int>(cut_off_))
return div_pow2_round_up(s, cut_off_) - 1 + cut_off_;
else
return std::max(log_size, static_cast<int>(page_size_));
}

inline size_t get_size(int bucket) {
if (bucket <= static_cast<int>(cut_off_))
return 1ul << bucket;
Expand Down

0 comments on commit 63aac3f

Please sign in to comment.