forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
caffe2_pb.h
142 lines (132 loc) · 4.36 KB
/
caffe2_pb.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#pragma once
#include <c10/core/Device.h>
#include <c10/util/Exception.h>
#include <caffe2/proto/caffe2.pb.h>
namespace caffe2 {
using DeviceType = at::DeviceType;
constexpr DeviceType CPU = DeviceType::CPU;
constexpr DeviceType CUDA = DeviceType::CUDA;
constexpr DeviceType OPENGL = DeviceType::OPENGL;
constexpr DeviceType OPENCL = DeviceType::OPENCL;
constexpr DeviceType MKLDNN = DeviceType::MKLDNN;
constexpr DeviceType IDEEP = DeviceType::IDEEP;
constexpr DeviceType HIP = DeviceType::HIP;
constexpr DeviceType COMPILE_TIME_MAX_DEVICE_TYPES =
DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
constexpr DeviceType ONLY_FOR_TEST = DeviceType::ONLY_FOR_TEST;
inline CAFFE2_API DeviceType ProtoToType(const caffe2::DeviceTypeProto p) {
switch (p) {
case caffe2::PROTO_CPU:
return DeviceType::CPU;
case caffe2::PROTO_CUDA:
return DeviceType::CUDA;
case caffe2::PROTO_OPENGL:
return DeviceType::OPENGL;
case caffe2::PROTO_OPENCL:
return DeviceType::OPENCL;
case caffe2::PROTO_MKLDNN:
return DeviceType::MKLDNN;
case caffe2::PROTO_IDEEP:
return DeviceType::IDEEP;
case caffe2::PROTO_HIP:
return DeviceType::HIP;
case caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES:
return DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
case caffe2::PROTO_ONLY_FOR_TEST:
return DeviceType::ONLY_FOR_TEST;
default:
AT_ERROR(
"Unknown device:",
static_cast<int32_t>(p),
". If you have recently updated the caffe2.proto file to add a new "
"device type, did you forget to update the ProtoToType() and TypeToProto"
"function to reflect such recent changes?");
}
}
inline CAFFE2_API DeviceType ProtoToType(int p) {
return ProtoToType(static_cast<caffe2::DeviceTypeProto>(p));
}
inline CAFFE2_API DeviceTypeProto TypeToProto(const DeviceType& t) {
switch (t) {
case DeviceType::CPU:
return caffe2::PROTO_CPU;
case DeviceType::CUDA:
return caffe2::PROTO_CUDA;
case DeviceType::OPENGL:
return caffe2::PROTO_OPENGL;
case DeviceType::OPENCL:
return caffe2::PROTO_OPENCL;
case DeviceType::MKLDNN:
return caffe2::PROTO_MKLDNN;
case DeviceType::IDEEP:
return caffe2::PROTO_IDEEP;
case DeviceType::HIP:
return caffe2::PROTO_HIP;
case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES:
return caffe2::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES;
case DeviceType::ONLY_FOR_TEST:
return caffe2::PROTO_ONLY_FOR_TEST;
default:
AT_ERROR(
"Unknown device:",
static_cast<int32_t>(t),
". If you have recently updated the caffe2.proto file to add a new "
"device type, did you forget to update the ProtoToType() and TypeToProto"
"function to reflect such recent changes?");
}
}
inline CAFFE2_API caffe2::DeviceOption DeviceToOption(
const at::Device& device) {
caffe2::DeviceOption option;
auto type = device.type();
option.set_device_type(TypeToProto(type));
switch (type) {
case DeviceType::CPU:
if (device.index() != -1) {
option.set_numa_node_id(device.index());
}
break;
case DeviceType::CUDA:
case DeviceType::HIP:
option.set_device_id(device.index());
break;
case DeviceType::OPENGL:
case DeviceType::OPENCL:
case DeviceType::MKLDNN:
case DeviceType::IDEEP:
case DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES:
case DeviceType::ONLY_FOR_TEST:
break;
default:
AT_ERROR(
"Unknown device:",
static_cast<int32_t>(type),
". If you have recently updated the caffe2.proto file to add a new "
"device type, did you forget to update the ProtoToType() and TypeToProto"
"function to reflect such recent changes?");
}
return option;
}
inline CAFFE2_API at::Device OptionToDevice(const caffe2::DeviceOption option) {
auto type = option.device_type();
int32_t id = -1;
switch (type) {
case caffe2::PROTO_CPU:
if (option.has_numa_node_id()) {
id = option.numa_node_id();
}
break;
case caffe2::PROTO_CUDA:
case caffe2::PROTO_HIP:
id = option.device_id();
break;
}
return at::Device(ProtoToType(type), id);
}
inline void ExtractDeviceOption(
DeviceOption* device_option,
const at::Device& device) {
AT_ASSERT(device_option);
device_option->CopyFrom(DeviceToOption(device));
}
} // namespace caffe2