This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
/
iter_image_recordio_2.cc
956 lines (887 loc) · 33.3 KB
/
iter_image_recordio_2.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file iter_image_recordio_2.cc
* \brief new version of recordio data iterator
*/
#include <mxnet/io.h>
#include <dmlc/parameter.h>
#include <dmlc/threadediter.h>
#include <dmlc/input_split_shuffle.h>
#include <dmlc/recordio.h>
#include <dmlc/base.h>
#include <dmlc/io.h>
#include <dmlc/omp.h>
#include <dmlc/common.h>
#include <dmlc/timer.h>
#include <memory>
#include <type_traits>
#if MXNET_USE_LIBJPEG_TURBO
#include <turbojpeg.h>
#endif
#include "./image_recordio.h"
#include "./image_augmenter.h"
#include "./image_iter_common.h"
#include "./inst_vector.h"
#include "../common/utils.h"
#include "../profiler/profiler.h"
namespace mxnet {
namespace io {
// parser to parse image recordio
template <typename DType>
class ImageRecordIOParser2 {
public:
// initialize the parser
inline void Init(const std::vector<std::pair<std::string, std::string>>& kwargs);
// set record to the head
inline void BeforeFirst() {
if (batch_param_.round_batch == 0 || !overflow) {
n_parsed_ = 0;
return source_->BeforeFirst();
} else {
overflow = false;
}
}
// parse next set of records, return an array of
// instance vector to the user
inline bool ParseNext(DataBatch* out);
private:
#if MXNET_USE_OPENCV
template <int n_channels>
void ProcessImage(const cv::Mat& res,
mshadow::Tensor<cpu, 3, DType>* data_ptr,
const bool is_mirrored,
const float contrast_scaled,
const float illumination_scaled);
#if MXNET_USE_LIBJPEG_TURBO
cv::Mat TJimdecode(cv::Mat buf, int color);
#endif
#endif
inline size_t ParseChunk(DType* data_dptr,
real_t* label_dptr,
const size_t current_size,
dmlc::InputSplit::Blob* chunk);
inline void CreateMeanImg();
// magic number to seed prng
static const int kRandMagic = 111;
static const int kRandMagicNormalize = 0;
/*! \brief parameters */
ImageRecParserParam param_;
ImageRecordParam record_param_;
BatchParam batch_param_;
ImageNormalizeParam normalize_param_;
#if MXNET_USE_OPENCV
/*! \brief augmenters */
std::vector<std::vector<std::unique_ptr<ImageAugmenter>>> augmenters_;
#endif
/*! \brief random samplers */
std::vector<std::unique_ptr<common::RANDOM_ENGINE>> prnds_;
common::RANDOM_ENGINE rnd_;
/*! \brief data source */
std::unique_ptr<dmlc::InputSplit> source_;
/*! \brief label information, if any */
std::unique_ptr<ImageLabelMap> label_map_;
/*! \brief temporary results */
std::vector<InstVector<DType>> temp_;
/*! \brief temp space */
mshadow::TensorContainer<cpu, 3> img_;
/*! \brief internal instance order */
std::vector<std::pair<size_t, size_t>> inst_order_;
size_t inst_index_;
/*! \brief internal counter tracking number of already parsed entries */
size_t n_parsed_;
/*! \brief overflow marker */
bool overflow;
/*! \brief unit size */
std::vector<size_t> unit_size_;
/*! \brief mean image, if needed */
mshadow::TensorContainer<cpu, 3> meanimg_;
// whether to use legacy shuffle
// (without IndexedRecordIO support)
bool legacy_shuffle_;
// whether mean image is ready.
bool meanfile_ready_;
/*! \brief OMPException obj to store and rethrow exceptions from omp blocks*/
dmlc::OMPException omp_exc_;
};
template <typename DType>
inline void ImageRecordIOParser2<DType>::Init(
const std::vector<std::pair<std::string, std::string>>& kwargs) {
#if MXNET_USE_OPENCV
// initialize parameter
// init image rec param
param_.InitAllowUnknown(kwargs);
record_param_.InitAllowUnknown(kwargs);
batch_param_.InitAllowUnknown(kwargs);
normalize_param_.InitAllowUnknown(kwargs);
PrefetcherParam prefetch_param;
prefetch_param.InitAllowUnknown(kwargs);
n_parsed_ = 0;
overflow = false;
rnd_.seed(kRandMagic + record_param_.seed);
int maxthread, threadget;
if (prefetch_param.ctx == PrefetcherParam::CtxType::kCPU) {
threadget = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
} else {
#pragma omp parallel
{
// be conservative, set number of real cores
maxthread = std::max(omp_get_num_procs() / 2, 1);
}
param_.preprocess_threads = std::min(maxthread, param_.preprocess_threads);
#pragma omp parallel num_threads(param_.preprocess_threads)
{ threadget = omp_get_num_threads(); }
}
param_.preprocess_threads = threadget;
std::vector<std::string> aug_names = dmlc::Split(param_.aug_seq, ',');
augmenters_.clear();
augmenters_.resize(threadget);
// setup decoders
for (int i = 0; i < threadget; ++i) {
for (const auto& aug_name : aug_names) {
augmenters_[i].emplace_back(ImageAugmenter::Create(aug_name));
augmenters_[i].back()->Init(kwargs);
}
prnds_.emplace_back(new common::RANDOM_ENGINE((i + 1) * kRandMagic));
}
if (param_.path_imglist.length() != 0) {
label_map_ = std::make_unique<ImageLabelMap>(
param_.path_imglist.c_str(), param_.label_width, !param_.verbose);
}
CHECK(param_.path_imgrec.length() != 0) << "ImageRecordIter2: must specify image_rec";
if (param_.verbose) {
LOG(INFO) << "ImageRecordIOParser2: " << param_.path_imgrec << ", use " << threadget
<< " threads for decoding..";
}
legacy_shuffle_ = false;
if (param_.path_imgidx.length() != 0) {
source_.reset(dmlc::InputSplit::Create(param_.path_imgrec.c_str(),
param_.path_imgidx.c_str(),
param_.part_index,
param_.num_parts,
"indexed_recordio",
record_param_.shuffle,
record_param_.seed,
batch_param_.batch_size));
} else {
source_.reset(dmlc::InputSplit::Create(
param_.path_imgrec.c_str(), param_.part_index, param_.num_parts, "recordio"));
if (record_param_.shuffle)
legacy_shuffle_ = true;
if (param_.shuffle_chunk_size > 0) {
if (param_.shuffle_chunk_size > 4096) {
LOG(INFO) << "Chunk size: " << param_.shuffle_chunk_size
<< " MB which is larger than 4096 MB, please set "
"smaller chunk size";
}
if (param_.shuffle_chunk_size < 4) {
LOG(INFO) << "Chunk size: " << param_.shuffle_chunk_size
<< " MB which is less than 4 MB, please set "
"larger chunk size";
}
// 1.1 ratio is for a bit more shuffle parts to avoid boundary issue
size_t num_shuffle_parts = std::ceil(
source_->GetTotalSize() * 1.1 / (param_.num_parts * (param_.shuffle_chunk_size << 20UL)));
if (num_shuffle_parts > 1) {
source_.reset(dmlc::InputSplitShuffle::Create(param_.path_imgrec.c_str(),
param_.part_index,
param_.num_parts,
"recordio",
num_shuffle_parts,
param_.shuffle_chunk_seed));
}
source_->HintChunkSize(param_.shuffle_chunk_size << 17UL);
} else {
// use 64 MB chunk when possible
source_->HintChunkSize(64 << 20UL);
}
}
// Normalize init
if (!std::is_same<DType, uint8_t>::value) {
meanimg_.set_pad(false);
meanfile_ready_ = false;
if (normalize_param_.mean_img.length() != 0) {
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(normalize_param_.mean_img.c_str(), "r", true));
if (fi.get() == nullptr) {
this->CreateMeanImg();
} else {
fi.reset(nullptr);
if (param_.verbose) {
LOG(INFO) << "Load mean image from " << normalize_param_.mean_img;
}
// use python compatible ndarray store format
std::vector<NDArray> data;
std::vector<std::string> keys;
{
std::unique_ptr<dmlc::Stream> fi(
dmlc::Stream::Create(normalize_param_.mean_img.c_str(), "r"));
NDArray::Load(fi.get(), &data, &keys);
}
CHECK_EQ(data.size(), 1) << "Invalid mean image file format";
data[0].WaitToRead();
mshadow::Tensor<cpu, 3> src = data[0].data().get<cpu, 3, real_t>();
meanimg_.Resize(src.shape_);
mshadow::Copy(meanimg_, src);
meanfile_ready_ = true;
if (param_.verbose) {
LOG(INFO) << "Load mean image from " << normalize_param_.mean_img << " completed";
}
}
}
}
#else
LOG(FATAL) << "ImageRec need opencv to process";
#endif
}
template <typename DType>
inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch* out) {
if (overflow) {
return false;
}
CHECK(source_ != nullptr);
dmlc::InputSplit::Blob chunk;
size_t current_size = 0;
out->index.resize(batch_param_.batch_size);
// InitBatch
if (out->data.size() == 0) {
// This assumes that DataInst given by
// InstVector contains only 2 elements in
// data vector (operator[] implementation)
out->data.resize(2);
unit_size_.resize(2);
std::vector<index_t> shape_vec;
shape_vec.push_back(batch_param_.batch_size);
for (index_t dim = 0; dim < param_.data_shape.ndim(); ++dim) {
shape_vec.push_back(param_.data_shape[dim]);
}
mxnet::TShape data_shape(shape_vec.begin(), shape_vec.end());
shape_vec.clear();
shape_vec.push_back(batch_param_.batch_size);
shape_vec.push_back(param_.label_width);
mxnet::TShape label_shape(shape_vec.begin(), shape_vec.end());
auto ctx = Context::CPU(0);
auto dev_id = param_.device_id;
if (dev_id != -1) {
ctx = Context::CPUPinned(dev_id);
}
const std::string profiler_scope =
profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "image_io:";
out->data.at(0) = NDArray(data_shape, ctx, false, mshadow::DataType<DType>::kFlag);
out->data.at(0).AssignStorageInfo(profiler_scope, "data");
out->data.at(1) = NDArray(label_shape, ctx, false, mshadow::DataType<real_t>::kFlag);
out->data.at(1).AssignStorageInfo(profiler_scope, "label");
unit_size_[0] = param_.data_shape.Size();
unit_size_[1] = param_.label_width;
}
while (current_size < batch_param_.batch_size) {
// int n_to_copy;
size_t n_to_out = 0;
if (n_parsed_ == 0) {
if (source_->NextBatch(&chunk, batch_param_.batch_size)) {
inst_order_.clear();
inst_index_ = 0;
DType* data_dptr = static_cast<DType*>(out->data[0].data().dptr_);
real_t* label_dptr = static_cast<real_t*>(out->data[1].data().dptr_);
if (!legacy_shuffle_) {
n_to_out = ParseChunk(data_dptr, label_dptr, current_size, &chunk);
} else {
n_to_out = ParseChunk(nullptr, nullptr, batch_param_.batch_size, &chunk);
}
// Count number of parsed images that do not fit into current out
n_parsed_ = inst_order_.size();
// shuffle instance order if needed
if (legacy_shuffle_) {
std::shuffle(inst_order_.begin(), inst_order_.end(), rnd_);
}
} else {
if (current_size == 0) {
return false;
}
CHECK(!overflow) << "number of input images must be bigger than the batch size";
if (batch_param_.round_batch != 0) {
overflow = true;
source_->BeforeFirst();
} else {
current_size = batch_param_.batch_size;
}
out->num_batch_padd = batch_param_.batch_size - current_size;
n_to_out = 0;
}
} else {
size_t n_to_copy =
std::min(n_parsed_, static_cast<size_t>(batch_param_.batch_size) - current_size);
n_parsed_ -= n_to_copy;
// Copy
#pragma omp parallel for num_threads(param_.preprocess_threads)
for (int i = 0; i < static_cast<int>(n_to_copy); ++i) {
omp_exc_.Run([&] {
std::pair<size_t, size_t> place = inst_order_[inst_index_ + i];
const DataInst& batch = temp_[place.first][place.second];
for (size_t j = 0; j < batch.data.size(); ++j) {
CHECK_EQ(unit_size_[j], batch.data[j].Size());
MSHADOW_TYPE_SWITCH(out->data[j].data().type_flag_, dtype, {
mshadow::Copy(
out->data[j].data().FlatTo1D<cpu, dtype>().Slice(
(current_size + i) * unit_size_[j], (current_size + i + 1) * unit_size_[j]),
batch.data[j].get_with_shape<cpu, 1, dtype>(mshadow::Shape1(unit_size_[j])));
});
}
});
}
omp_exc_.Rethrow();
n_to_out = n_to_copy;
inst_index_ += n_to_copy;
}
current_size += n_to_out;
}
return true;
}
#if MXNET_USE_OPENCV
template <typename DType>
template <int n_channels>
void ImageRecordIOParser2<DType>::ProcessImage(const cv::Mat& res,
mshadow::Tensor<cpu, 3, DType>* data_ptr,
const bool is_mirrored,
const float contrast_scaled,
const float illumination_scaled) {
float RGBA_MULT[4] = {0};
float RGBA_BIAS[4] = {0};
float RGBA_MEAN[4] = {0};
int16_t RGBA_MEAN_INT[4] = {0};
mshadow::Tensor<cpu, 3, DType>& data = (*data_ptr);
if (!std::is_same<DType, uint8_t>::value) {
RGBA_MULT[0] = contrast_scaled / normalize_param_.std_r;
RGBA_MULT[1] = contrast_scaled / normalize_param_.std_g;
RGBA_MULT[2] = contrast_scaled / normalize_param_.std_b;
RGBA_MULT[3] = contrast_scaled / normalize_param_.std_a;
RGBA_BIAS[0] = illumination_scaled / normalize_param_.std_r;
RGBA_BIAS[1] = illumination_scaled / normalize_param_.std_g;
RGBA_BIAS[2] = illumination_scaled / normalize_param_.std_b;
RGBA_BIAS[3] = illumination_scaled / normalize_param_.std_a;
if (!meanfile_ready_) {
RGBA_MEAN[0] = normalize_param_.mean_r;
RGBA_MEAN[1] = normalize_param_.mean_g;
RGBA_MEAN[2] = normalize_param_.mean_b;
RGBA_MEAN[3] = normalize_param_.mean_a;
RGBA_MEAN_INT[0] = std::round(normalize_param_.mean_r);
RGBA_MEAN_INT[1] = std::round(normalize_param_.mean_g);
RGBA_MEAN_INT[2] = std::round(normalize_param_.mean_b);
RGBA_MEAN_INT[3] = std::round(normalize_param_.mean_a);
}
}
int swap_indices[n_channels]; // NOLINT(*)
if (n_channels == 1) {
swap_indices[0] = 0;
} else if (n_channels == 3) {
swap_indices[0] = 2;
swap_indices[1] = 1;
swap_indices[2] = 0;
} else if (n_channels == 4) {
swap_indices[0] = 2;
swap_indices[1] = 1;
swap_indices[2] = 0;
swap_indices[3] = 3;
}
DType RGBA[n_channels] = {};
for (int i = 0; i < res.rows; ++i) {
const uchar* im_data = res.ptr<uchar>(i);
for (int j = 0; j < res.cols; ++j) {
if (std::is_same<DType, int8_t>::value) {
if (meanfile_ready_) {
for (int k = 0; k < n_channels; ++k) {
RGBA[k] = cv::saturate_cast<int8_t>(
im_data[swap_indices[k]] - static_cast<int16_t>(std::round(meanimg_[k][i][j])));
}
} else {
for (int k = 0; k < n_channels; ++k) {
RGBA[k] = cv::saturate_cast<int8_t>(im_data[swap_indices[k]] - RGBA_MEAN_INT[k]);
}
}
} else {
for (int k = 0; k < n_channels; ++k) {
RGBA[k] = im_data[swap_indices[k]];
}
if (!std::is_same<DType, uint8_t>::value) {
// normalize/mirror here to avoid memory copies
// logic from iter_normalize.h, function SetOutImg
for (int k = 0; k < n_channels; ++k) {
if (meanfile_ready_) {
RGBA[k] = (RGBA[k] - meanimg_[k][i][j]) * RGBA_MULT[k] + RGBA_BIAS[k];
} else {
RGBA[k] = (RGBA[k] - RGBA_MEAN[k]) * RGBA_MULT[k] + RGBA_BIAS[k];
}
}
}
}
for (int k = 0; k < n_channels; ++k) {
// mirror here to avoid memory copies
// logic from iter_normalize.h, function SetOutImg
if (is_mirrored) {
data[k][i][res.cols - j - 1] = RGBA[k];
} else {
data[k][i][j] = RGBA[k];
}
}
im_data += n_channels;
}
}
}
#if MXNET_USE_LIBJPEG_TURBO
bool is_jpeg(unsigned char* file) {
if ((file[0] == 255) && (file[1] == 216)) {
return true;
} else {
return false;
}
}
template <typename DType>
cv::Mat ImageRecordIOParser2<DType>::TJimdecode(cv::Mat image, int color) {
unsigned char* jpeg = image.ptr();
size_t jpeg_size = image.rows * image.cols;
if (!is_jpeg(jpeg)) {
// If it is not JPEG then fall back to OpenCV
return cv::imdecode(image, color);
}
tjhandle handle = tjInitDecompress();
int h, w, subsamp;
int err = tjDecompressHeader2(handle, jpeg, jpeg_size, &w, &h, &subsamp);
if (err != 0) {
// If it is a malformed JPEG then fall back to OpenCV
return cv::imdecode(image, color);
}
cv::Mat ret = cv::Mat(h, w, color ? CV_8UC3 : CV_8UC1);
err = tjDecompress2(handle, jpeg, jpeg_size, ret.ptr(), w, 0, h, color ? TJPF_BGR : TJPF_GRAY, 0);
if (err != 0) {
// If it is a malformed JPEG then fall back to OpenCV
return cv::imdecode(image, color);
}
tjDestroy(handle);
return ret;
}
#endif
#endif
// Returns the number of images that are put into output
template <typename DType>
inline size_t ImageRecordIOParser2<DType>::ParseChunk(DType* data_dptr,
real_t* label_dptr,
const size_t current_size,
dmlc::InputSplit::Blob* chunk) {
temp_.resize(param_.preprocess_threads);
#if MXNET_USE_OPENCV
// save opencv out
dmlc::RecordIOChunkReader reader(*chunk, 0, 1);
size_t gl_idx = current_size;
#pragma omp parallel num_threads(param_.preprocess_threads)
{
omp_exc_.Run([&] {
CHECK(omp_get_num_threads() == param_.preprocess_threads);
int tid = omp_get_thread_num();
// dmlc::RecordIOChunkReader reader(*chunk, tid, param_.preprocess_threads);
ImageRecordIO rec;
dmlc::InputSplit::Blob blob;
// image data
InstVector<DType>& out_tmp = temp_[tid];
out_tmp.Clear();
while (true) {
bool reader_has_data;
size_t idx;
#pragma omp critical
{
reader_has_data = reader.NextRecord(&blob);
if (reader_has_data) {
idx = gl_idx++;
if (idx >= batch_param_.batch_size) {
inst_order_.push_back(std::make_pair(tid, out_tmp.Size()));
}
}
}
if (!reader_has_data)
break;
// Opencv decode and augments
cv::Mat res;
rec.Load(blob.dptr, blob.size);
cv::Mat buf(1, rec.content_size, CV_8U, rec.content);
// If augmentation seed is supplied
// Re-seed RNG to guarantee reproducible results
if (param_.seed_aug.has_value()) {
prnds_[tid]->seed(idx + param_.seed_aug.value() + kRandMagic);
}
switch (param_.data_shape[0]) {
case 1:
#if MXNET_USE_LIBJPEG_TURBO
res = TJimdecode(buf, 0);
#else
res = cv::imdecode(buf, 0);
#endif
break;
case 3:
#if MXNET_USE_LIBJPEG_TURBO
res = TJimdecode(buf, 1);
#else
res = cv::imdecode(buf, 1);
#endif
break;
case 4:
// -1 to keep the number of channel of the encoded image, and not force gray or color.
res = cv::imdecode(buf, -1);
CHECK_EQ(res.channels(), 4) << "Invalid image with index " << rec.image_index()
<< ". Expected 4 channels, got " << res.channels();
break;
default:
LOG(FATAL) << "Invalid output shape " << param_.data_shape;
}
const int n_channels = res.channels();
// load label before augmentations
std::vector<float> label_buf;
if (label_map_ != nullptr) {
label_buf = label_map_->FindCopy(rec.image_index());
} else if (rec.label != nullptr) {
CHECK_EQ(param_.label_width, rec.num_label) << "rec file provide " << rec.num_label
<< "-dimensional label "
"but label_width is set to "
<< param_.label_width;
label_buf.assign(rec.label, rec.label + rec.num_label);
} else {
CHECK_EQ(param_.label_width, 1)
<< "label_width must be 1 unless an imglist is provided "
"or the rec file is packed with multi dimensional label";
label_buf.assign(&rec.header.label, &rec.header.label + 1);
}
for (auto& aug : augmenters_[tid]) {
res = aug->Process(res, &label_buf, prnds_[tid].get());
}
mshadow::Tensor<cpu, 3, DType> data;
if (idx < batch_param_.batch_size) {
data = mshadow::Tensor<cpu, 3, DType>(data_dptr + idx * unit_size_[0],
mshadow::Shape3(n_channels, res.rows, res.cols));
} else {
out_tmp.Push(static_cast<size_t>(rec.image_index()),
mshadow::Shape3(n_channels, res.rows, res.cols),
mshadow::Shape1(param_.label_width));
data = out_tmp.data().Back();
}
std::uniform_real_distribution<float> rand_uniform(0, 1);
std::bernoulli_distribution coin_flip(0.5);
bool is_mirrored =
(normalize_param_.rand_mirror && coin_flip(*(prnds_[tid]))) || normalize_param_.mirror;
float contrast_scaled = 1;
float illumination_scaled = 0;
if (!std::is_same<DType, uint8_t>::value) {
contrast_scaled =
(rand_uniform(*(prnds_[tid])) * normalize_param_.max_random_contrast * 2 -
normalize_param_.max_random_contrast + 1) *
normalize_param_.scale;
illumination_scaled =
(rand_uniform(*(prnds_[tid])) * normalize_param_.max_random_illumination * 2 -
normalize_param_.max_random_illumination) *
normalize_param_.scale;
}
// For RGB or RGBA data, swap the B and R channel:
// OpenCV store as BGR (or BGRA) and we want RGB (or RGBA)
if (n_channels == 1) {
ProcessImage<1>(res, &data, is_mirrored, contrast_scaled, illumination_scaled);
} else if (n_channels == 3) {
ProcessImage<3>(res, &data, is_mirrored, contrast_scaled, illumination_scaled);
} else if (n_channels == 4) {
ProcessImage<4>(res, &data, is_mirrored, contrast_scaled, illumination_scaled);
}
mshadow::Tensor<cpu, 1, real_t> label;
if (idx < batch_param_.batch_size) {
label = mshadow::Tensor<cpu, 1, real_t>(label_dptr + idx * unit_size_[1],
mshadow::Shape1(param_.label_width));
} else {
label = out_tmp.label().Back();
}
mshadow::Copy(
label,
mshadow::Tensor<cpu, 1>(dmlc::BeginPtr(label_buf), mshadow::Shape1(label_buf.size())));
res.release();
}
});
}
omp_exc_.Rethrow();
return (std::min(static_cast<size_t>(batch_param_.batch_size), gl_idx) - current_size);
#else
LOG(FATAL) << "Opencv is needed for image decoding and augmenting.";
return 0;
#endif
}
// create mean image.
template <typename DType>
inline void ImageRecordIOParser2<DType>::CreateMeanImg() {
if (param_.verbose) {
LOG(INFO) << "Cannot find " << normalize_param_.mean_img
<< ": create mean image, this will take some time...";
}
double start = dmlc::GetTime();
dmlc::InputSplit::Blob chunk;
size_t imcnt = 0; // NOLINT(*)
while (source_->NextChunk(&chunk)) {
inst_order_.clear();
// Parse chunk w/o putting anything in out
ParseChunk(nullptr, nullptr, batch_param_.batch_size, &chunk);
for (auto place : inst_order_) {
mshadow::Tensor<cpu, 3> outimg =
temp_[place.first][place.second].data[0].template get<cpu, 3, real_t>();
if (imcnt == 0) {
meanimg_.Resize(outimg.shape_);
mshadow::Copy(meanimg_, outimg);
} else {
meanimg_ += outimg;
}
imcnt += 1;
double elapsed = dmlc::GetTime() - start;
if (imcnt % 10000L == 0 && param_.verbose) {
LOG(INFO) << imcnt << " images processed, " << elapsed << " sec elapsed";
}
}
}
meanimg_ *= (1.0f / imcnt);
// save as mxnet python compatible format.
TBlob tmp = meanimg_;
{
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(normalize_param_.mean_img.c_str(), "w"));
NDArray::Save(fo.get(), {NDArray(tmp, 0)}, {"mean_img"});
}
if (param_.verbose) {
LOG(INFO) << "Save mean image to " << normalize_param_.mean_img << "..";
}
meanfile_ready_ = true;
this->BeforeFirst();
}
template <typename DType = real_t>
class ImageRecordIter2 : public IIterator<DataBatch> {
public:
ImageRecordIter2() = default;
~ImageRecordIter2() override {
iter_.Destroy();
}
void Init(const std::vector<std::pair<std::string, std::string>>& kwargs) override {
prefetch_param_.InitAllowUnknown(kwargs);
parser_.Init(kwargs);
// maximum prefetch threaded iter internal size
const int kMaxPrefetchBuffer = 16;
// init thread iter
iter_.set_max_capacity(kMaxPrefetchBuffer);
// init thread iter
iter_.Init(
[this](DataBatch** dptr) {
if (*dptr == nullptr) {
*dptr = new DataBatch();
}
return parser_.ParseNext(*dptr);
},
[this]() { parser_.BeforeFirst(); });
}
void BeforeFirst() override {
iter_.BeforeFirst();
}
// From iter_prefetcher.h
bool Next() override {
if (out_ != nullptr) {
recycle_queue_.push(out_);
out_ = nullptr;
}
// do recycle
if (recycle_queue_.size() == prefetch_param_.prefetch_buffer) {
DataBatch* old_batch = recycle_queue_.front();
// can be more efficient on engine
for (NDArray& arr : old_batch->data) {
arr.WaitToWrite();
}
recycle_queue_.pop();
iter_.Recycle(&old_batch);
}
return iter_.Next(&out_);
}
const DataBatch& Value() const override {
return *out_;
}
private:
/*! \brief Backend thread */
dmlc::ThreadedIter<DataBatch> iter_;
/*! \brief Parameters */
PrefetcherParam prefetch_param_;
/*! \brief output data */
DataBatch* out_{nullptr};
/*! \brief queue to be recycled */
std::queue<DataBatch*> recycle_queue_;
/* \brief parser */
ImageRecordIOParser2<DType> parser_;
};
template <typename DType = real_t>
class ImageRecordIter2CPU : public IIterator<DataBatch> {
public:
ImageRecordIter2CPU() {
out_ = new DataBatch();
var_ = Engine::Get()->NewVariable();
}
~ImageRecordIter2CPU() override {
Engine::Get()->DeleteVariable([](mxnet::RunContext ctx) {}, Context::CPU(), var_);
delete out_;
}
void Init(const std::vector<std::pair<std::string, std::string>>& kwargs) override {
parser_.Init(kwargs);
}
void BeforeFirst() override {
parser_.BeforeFirst();
}
// From iter_prefetcher.h
bool Next() override {
bool result = false;
const auto engine = Engine::Get();
engine->PushSync([this, &result](RunContext ctx) { result = this->parser_.ParseNext(out_); },
Context::CPU(),
{},
{var_},
FnProperty::kNormal,
0,
"DataLoader");
engine->WaitForVar(var_);
return result;
}
const DataBatch& Value() const override {
return *out_;
}
private:
/*! \brief Backend thread */
dmlc::ThreadedIter<DataBatch> iter_;
/*! \brief output data */
DataBatch* out_;
Engine::VarHandle var_;
/*! \brief queue to be recycled */
std::queue<DataBatch*> recycle_queue_;
/* \brief parser */
ImageRecordIOParser2<DType> parser_;
};
class ImageRecordIter2Wrapper : public IIterator<DataBatch> {
public:
~ImageRecordIter2Wrapper() override {
if (record_iter_)
delete record_iter_;
}
void Init(const std::vector<std::pair<std::string, std::string>>& kwargs) override {
PrefetcherParam prefetch_param;
prefetch_param.InitAllowUnknown(kwargs);
int dtype = mshadow::kFloat32;
if (prefetch_param.dtype.has_value()) {
dtype = prefetch_param.dtype.value();
}
if (prefetch_param.ctx == PrefetcherParam::CtxType::kCPU) {
LOG(INFO) << "Create ImageRecordIter2 optimized for CPU backend."
<< "Use omp threads instead of preprocess_threads.";
switch (dtype) {
case mshadow::kFloat32:
record_iter_ = new ImageRecordIter2CPU<float>();
break;
case mshadow::kUint8:
record_iter_ = new ImageRecordIter2CPU<uint8_t>();
break;
case mshadow::kInt8:
record_iter_ = new ImageRecordIter2CPU<int8_t>();
break;
default:
LOG(FATAL) << "unknown dtype for ImageRecordIter2.";
}
} else {
// For gpu
switch (dtype) {
case mshadow::kFloat32:
record_iter_ = new ImageRecordIter2<float>();
break;
case mshadow::kUint8:
record_iter_ = new ImageRecordIter2<uint8_t>();
break;
case mshadow::kInt8:
record_iter_ = new ImageRecordIter2<int8_t>();
break;
default:
LOG(FATAL) << "unknown dtype for ImageRecordIter2.";
}
}
record_iter_->Init(kwargs);
}
void BeforeFirst() override {
record_iter_->BeforeFirst();
}
// From iter_prefetcher.h
bool Next() override {
return record_iter_->Next();
}
const DataBatch& Value() const override {
return record_iter_->Value();
}
private:
IIterator<DataBatch>* record_iter_ = nullptr;
};
MXNET_REGISTER_IO_ITER(ImageRecordIter)
.describe(R"code(Iterates on image RecordIO files
Reads batches of images from .rec RecordIO files. One can use ``im2rec.py`` tool
(in tools/) to pack raw image files into RecordIO files. This iterator is less
flexible to customization but is fast and has lot of language bindings. To
iterate over raw images directly use ``ImageIter`` instead (in Python).
Example::
data_iter = mx.io.ImageRecordIter(
path_imgrec="./sample.rec", # The target record file.
data_shape=(3, 227, 227), # Output data shape; 227x227 region will be cropped from the original image.
batch_size=4, # Number of items per batch.
resize=256 # Resize the shorter edge to 256 before cropping.
# You can specify more augmentation options. Use help(mx.io.ImageRecordIter) to see all the options.
)
# You can now use the data_iter to access batches of images.
batch = data_iter.next() # first batch.
images = batch.data[0] # This will contain 4 (=batch_size) images each of 3x227x227.
# process the images
...
data_iter.reset() # To restart the iterator from the beginning.
)code" ADD_FILELINE)
.add_arguments(ImageRecParserParam::__FIELDS__())
.add_arguments(ImageRecordParam::__FIELDS__())
.add_arguments(BatchParam::__FIELDS__())
.add_arguments(PrefetcherParam::__FIELDS__())
.add_arguments(ListDefaultAugParams())
.add_arguments(ImageNormalizeParam::__FIELDS__())
.set_body([]() { return new ImageRecordIter2Wrapper(); });
MXNET_REGISTER_IO_ITER(ImageRecordUInt8Iter)
.describe(R"code(Iterating on image RecordIO files
.. note:: ImageRecordUInt8Iter is deprecated. Use ImageRecordIter(dtype='uint8') instead.
This iterator is identical to ``ImageRecordIter`` except for using ``uint8`` as
the data type instead of ``float``.
)code" ADD_FILELINE)
.add_arguments(ImageRecParserParam::__FIELDS__())
.add_arguments(ImageRecordParam::__FIELDS__())
.add_arguments(BatchParam::__FIELDS__())
.add_arguments(PrefetcherParam::__FIELDS__())
.add_arguments(ListDefaultAugParams())
.set_body([]() { return new ImageRecordIter2<uint8_t>(); });
MXNET_REGISTER_IO_ITER(ImageRecordInt8Iter)
.describe(R"code(Iterating on image RecordIO files
.. note:: ``ImageRecordInt8Iter`` is deprecated. Use ImageRecordIter(dtype='int8') instead.
This iterator is identical to ``ImageRecordIter`` except for using ``int8`` as
the data type instead of ``float``.
)code" ADD_FILELINE)
.add_arguments(ImageRecParserParam::__FIELDS__())
.add_arguments(ImageRecordParam::__FIELDS__())
.add_arguments(BatchParam::__FIELDS__())
.add_arguments(PrefetcherParam::__FIELDS__())
.add_arguments(ListDefaultAugParams())
.set_body([]() { return new ImageRecordIter2<int8_t>(); });
} // namespace io
} // namespace mxnet