forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
types.h
83 lines (68 loc) · 2.2 KB
/
types.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#ifndef CAFFE2_CORE_TYPES_H_
#define CAFFE2_CORE_TYPES_H_
#include <cstdint>
#include <string>
#include <type_traits>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include <c10/util/typeid.h>
#include "caffe2/proto/caffe2_pb.h"
#include <c10/util/Half.h>
namespace caffe2 {
// Storage orders that are often used in the image applications.
enum StorageOrder {
UNKNOWN = 0,
NHWC = 1,
NCHW = 2,
};
inline StorageOrder StringToStorageOrder(const string& str) {
if (str == "NHWC" || str == "nhwc") {
return StorageOrder::NHWC;
} else if (str == "NCHW" || str == "nchw") {
return StorageOrder::NCHW;
} else {
LOG(ERROR) << "Unknown storage order string: " << str;
return StorageOrder::UNKNOWN;
}
}
inline int32_t GetDimFromOrderString(const std::string& str) {
auto order = StringToStorageOrder(str);
switch (order) {
case StorageOrder::NHWC:
return 3;
case StorageOrder::NCHW:
return 1;
default:
CAFFE_THROW("Unsupported storage order: ", str);
return -1;
}
}
inline constexpr char NameScopeSeparator() { return '/'; }
// From TypeMeta to caffe2::DataType protobuffer enum.
CAFFE2_API TensorProto::DataType TypeMetaToDataType(const TypeMeta& meta);
// From caffe2::DataType protobuffer enum to TypeMeta
CAFFE2_API const TypeMeta& DataTypeToTypeMeta(const TensorProto::DataType& dt);
} // namespace caffe2
///////////////////////////////////////////////////////////////////////////////
// at::Half is defined in c10/util/Half.h. Currently half float operators are
// mainly on CUDA gpus.
// The reason we do not directly use the cuda __half data type is because that
// requires compilation with nvcc. The float16 data type should be compatible
// with the cuda __half data type, but will allow us to refer to the data type
// without the need of cuda.
static_assert(sizeof(unsigned short) == 2,
"Short on this platform is not 16 bit.");
namespace caffe2 {
// Helpers to avoid using typeinfo with -rtti
template <typename T>
inline bool fp16_type();
template <>
inline bool fp16_type<at::Half>() {
return true;
}
template <typename T>
inline bool fp16_type() {
return false;
}
} // namespace caffe2
#endif // CAFFE2_CORE_TYPES_H_