forked from mlpack/mlpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_impl.hpp
251 lines (216 loc) · 7.09 KB
/
load_impl.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
/**
* @file load_impl.hpp
* @author Ryan Curtin
*
* Implementation of templatized load() function defined in load.hpp.
*/
#ifndef __MLPACK_CORE_DATA_LOAD_IMPL_HPP
#define __MLPACK_CORE_DATA_LOAD_IMPL_HPP
// In case it hasn't already been included.
#include "load.hpp"
#include <algorithm>
#include <mlpack/core/util/timers.hpp>
namespace mlpack {
namespace data {
template<typename eT>
bool inline inplace_transpose(arma::Mat<eT>& X)
{
try
{
X = arma::trans(X);
return false;
}
catch (std::bad_alloc& exception)
{
arma::inplace_trans(X, "lowmem");
return true;
}
}
template<typename eT>
bool Load(const std::string& filename,
arma::Mat<eT>& matrix,
bool fatal,
bool transpose)
{
Timer::Start("loading_data");
// First we will try to discriminate by file extension.
size_t ext = filename.rfind('.');
if (ext == std::string::npos)
{
Timer::Stop("loading_data");
if (fatal)
Log::Fatal << "Cannot determine type of file '" << filename << "'; "
<< "no extension is present." << std::endl;
else
Log::Warn << "Cannot determine type of file '" << filename << "'; "
<< "no extension is present. Load failed." << std::endl;
return false;
}
// Get the extension and force it to lowercase.
std::string extension = filename.substr(ext + 1);
std::transform(extension.begin(), extension.end(), extension.begin(),
::tolower);
// Catch nonexistent files by opening the stream ourselves.
std::fstream stream;
stream.open(filename.c_str(), std::fstream::in);
if (!stream.is_open())
{
Timer::Stop("loading_data");
if (fatal)
Log::Fatal << "Cannot open file '" << filename << "'. " << std::endl;
else
Log::Warn << "Cannot open file '" << filename << "'; load failed."
<< std::endl;
return false;
}
bool unknownType = false;
arma::file_type loadType;
std::string stringType;
if (extension == "csv")
{
loadType = arma::csv_ascii;
stringType = "CSV data";
}
else if (extension == "txt")
{
// This could be raw ASCII or Armadillo ASCII (ASCII with size header).
// We'll let Armadillo do its guessing (although we have to check if it is
// arma_ascii ourselves) and see what we come up with.
// This is taken from load_auto_detect() in diskio_meat.hpp
const std::string ARMA_MAT_TXT = "ARMA_MAT_TXT";
char* rawHeader = new char[ARMA_MAT_TXT.length() + 1];
std::streampos pos = stream.tellg();
stream.read(rawHeader, std::streamsize(ARMA_MAT_TXT.length()));
rawHeader[ARMA_MAT_TXT.length()] = '\0';
stream.clear();
stream.seekg(pos); // Reset stream position after peeking.
if (std::string(rawHeader) == ARMA_MAT_TXT)
{
loadType = arma::arma_ascii;
stringType = "Armadillo ASCII formatted data";
}
else // It's not arma_ascii. Now we let Armadillo guess.
{
loadType = arma::diskio::guess_file_type(stream);
if (loadType == arma::raw_ascii) // Raw ASCII (space-separated).
stringType = "raw ASCII formatted data";
else if (loadType == arma::csv_ascii) // CSV can be .txt too.
stringType = "CSV data";
else // Unknown .txt... we will throw an error.
unknownType = true;
}
delete[] rawHeader;
}
else if (extension == "bin")
{
// This could be raw binary or Armadillo binary (binary with header). We
// will check to see if it is Armadillo binary.
const std::string ARMA_MAT_BIN = "ARMA_MAT_BIN";
char *rawHeader = new char[ARMA_MAT_BIN.length() + 1];
std::streampos pos = stream.tellg();
stream.read(rawHeader, std::streamsize(ARMA_MAT_BIN.length()));
rawHeader[ARMA_MAT_BIN.length()] = '\0';
stream.clear();
stream.seekg(pos); // Reset stream position after peeking.
if (std::string(rawHeader) == ARMA_MAT_BIN)
{
stringType = "Armadillo binary formatted data";
loadType = arma::arma_binary;
}
else // We can only assume it's raw binary.
{
stringType = "raw binary formatted data";
loadType = arma::raw_binary;
}
delete[] rawHeader;
}
else if (extension == "pgm")
{
loadType = arma::pgm_binary;
stringType = "PGM data";
}
else if (extension == "h5" || extension == "hdf5" || extension == "hdf" ||
extension == "he5")
{
#ifdef ARMA_USE_HDF5
loadType = arma::hdf5_binary;
stringType = "HDF5 data";
#if ARMA_VERSION_MAJOR == 4 && \
(ARMA_VERSION_MINOR >= 300 && ARMA_VERSION_MINOR <= 400)
Timer::Stop("loading_data");
if (fatal)
Log::Fatal << "Attempted to load '" << filename << "' as HDF5 data, but "
<< "Armadillo 4.300.0 through Armadillo 4.400.1 are known to have "
<< "bugs and one of these versions is in use. Load failed."
<< std::endl;
else
Log::Warn << "Attempted to load '" << filename << "' as HDF5 data, but "
<< "Armadillo 4.300.0 through Armadillo 4.400.1 are known to have "
<< "bugs and one of these versions is in use. Load failed."
<< std::endl;
return false;
#endif
#else
Timer::Stop("loading_data");
if (fatal)
Log::Fatal << "Attempted to load '" << filename << "' as HDF5 data, but "
<< "Armadillo was compiled without HDF5 support. Load failed."
<< std::endl;
else
Log::Warn << "Attempted to load '" << filename << "' as HDF5 data, but "
<< "Armadillo was compiled without HDF5 support. Load failed."
<< std::endl;
return false;
#endif
}
else // Unknown extension...
{
unknownType = true;
loadType = arma::raw_binary; // Won't be used; prevent a warning.
stringType = "";
}
// Provide error if we don't know the type.
if (unknownType)
{
Timer::Stop("loading_data");
if (fatal)
Log::Fatal << "Unable to detect type of '" << filename << "'; "
<< "incorrect extension?" << std::endl;
else
Log::Warn << "Unable to detect type of '" << filename << "'; load failed."
<< " Incorrect extension?" << std::endl;
return false;
}
// Try to load the file; but if it's raw_binary, it could be a problem.
if (loadType == arma::raw_binary)
Log::Warn << "Loading '" << filename << "' as " << stringType << "; "
<< "but this may not be the actual filetype!" << std::endl;
else
Log::Info << "Loading '" << filename << "' as " << stringType << ". "
<< std::flush;
const bool success = matrix.load(stream, loadType);
if (!success)
{
Log::Info << std::endl;
Timer::Stop("loading_data");
if (fatal)
Log::Fatal << "Loading from '" << filename << "' failed." << std::endl;
else
Log::Warn << "Loading from '" << filename << "' failed." << std::endl;
return false;
}
else
Log::Info << "Size is " << (transpose ? matrix.n_cols : matrix.n_rows)
<< " x " << (transpose ? matrix.n_rows : matrix.n_cols) << ".\n";
// Now transpose the matrix, if necessary.
if (transpose)
{
inplace_transpose(matrix);
}
Timer::Stop("loading_data");
// Finally, return the success indicator.
return success;
}
}; // namespace data
}; // namespace mlpack
#endif