forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_save_op_util.cc
128 lines (122 loc) · 4.14 KB
/
load_save_op_util.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include "caffe2/operators/load_save_op_util.h"
namespace caffe2 {
namespace load_save_op_util {
std::string buildBlobNameFromDbKey(
const std::string& dbKey,
const std::string& strip_prefix,
const std::string& add_prefix) {
std::string key = dbKey.substr(0, dbKey.find(kChunkIdSeparator));
if (!strip_prefix.empty()) {
auto match_pos = key.find(strip_prefix);
if (match_pos != std::string::npos) {
key = key.substr(match_pos + strip_prefix.size());
}
}
key = add_prefix + key;
return key;
}
void ProcessBlob(
Blob* blob,
const BlobProto& proto,
std::unordered_map<std::string, BlobState>* blob_states_ptr,
const std::string& key,
int* loaded_blobs) {
auto& blob_states = *blob_states_ptr;
if (blob_states.count(key) == 0) {
// We reset the blob so that any existing content is destroyed. This
// is to guaranee correct device placement: if we are deserializing
// into a TensorCUDA, without explicit Reset we might be loading data
// into an existing TensorCUDA that has pre-allocated memory on a
// different GPU.
blob->Reset();
}
DeserializeBlob(proto, blob);
if (proto.has_content_num_chunks()) {
if (!blob_states.count(key)) {
blob_states[key] = BlobState(proto.content_num_chunks());
}
CAFFE_ENFORCE(
blob_states[key]
.seen_chunks_ids.insert(proto.content_chunk_id())
.second,
"Chunk with the same id has occurred twice for: ",
key);
CAFFE_ENFORCE(
proto.content_chunk_id() >= 0 &&
proto.content_chunk_id() < blob_states[key].total_size,
"Chunk id has to be not less than 0 and "
"less than content_num_chunks for key: ",
key);
blob_states[key].current_size++;
CAFFE_ENFORCE(
!blob_states[key].is_tensor,
"Proto with content_chunks can not store tensor: ",
key);
CAFFE_ENFORCE(
blob_states[key].current_size <= blob_states[key].total_size,
"Found an extra part for an already filled blob: ",
key);
if (blob_states[key].current_size == blob_states[key].total_size) {
(*loaded_blobs)++;
}
return;
}
if (!proto.has_tensor()) {
// If blob is divided into chunks the field content_chunks has to be set,
// otherwise only tensors can be seen multiple times as chunks.
CAFFE_ENFORCE(blob_states.count(key) == 0, "Blob duplicated: ", key);
blob_states[key] = BlobState();
(*loaded_blobs)++;
return;
}
CAFFE_ENFORCE(proto.has_tensor());
if (blob_states.count(key)) {
CAFFE_ENFORCE(blob_states[key].is_tensor, "Must be tensor ", key);
CAFFE_ENFORCE(
blob_states[key].current_size < blob_states[key].total_size,
"Found an extra part for an already filled tensor: ",
key);
CAFFE_ENFORCE(
proto.tensor().has_segment(),
"Partial tensor must have a segment: ",
key);
blob_states[key].current_size +=
proto.tensor().segment().end() - proto.tensor().segment().begin();
CAFFE_ENFORCE(
blob_states[key].current_size <= blob_states[key].total_size,
"Tensor parts are bigger than target size for tensor: ",
key);
} else {
const auto& dims = proto.tensor().dims();
int64_t total_size = 1;
for (const auto& dim : dims) {
total_size *= dim;
}
auto current_size = total_size;
if (proto.tensor().has_segment()) {
current_size =
proto.tensor().segment().end() - proto.tensor().segment().begin();
}
blob_states[key] =
BlobState(total_size, current_size, true /* is_tensor */);
}
if (blob_states[key].current_size == blob_states[key].total_size) {
(*loaded_blobs)++;
}
}
void validateBlobStates(
const std::unordered_map<std::string, BlobState>& blob_states) {
for (const auto& iter : blob_states) {
const BlobState& blob_state = iter.second;
CAFFE_ENFORCE(
blob_state.current_size == blob_state.total_size,
"Data size mismatch for blob ",
iter.first,
". Expected: ",
blob_state.total_size,
" Read: ",
blob_state.current_size);
}
}
} // namespace load_save_op_util
} // namespace caffe2