@@ -143,7 +143,6 @@ ITensor* residualDenseBlock(INetworkDefinition *network, std::map<std::string, W
143
143
return ew1->getOutput (0 );
144
144
}
145
145
146
-
147
146
ITensor* RRDB (INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor* x, std::string lname)
148
147
{
149
148
ITensor* out = residualDenseBlock (network, weightMap, x, lname + " .rdb1" );
@@ -253,7 +252,24 @@ void createEngine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig*
253
252
254
253
// Build engine
255
254
builder->setMaxBatchSize (maxBatchSize);
256
- config->setMaxWorkspaceSize (1 << 20 );
255
+ // config->setMaxWorkspaceSize(1 << 22);
256
+ config->setMaxWorkspaceSize (28 * (1 << 23 )); // 28MB
257
+
258
+ if (precision_mode == 16 ) {
259
+ std::cout << " ==== precision f16 ====" << std::endl << std::endl;
260
+ config->setFlag (BuilderFlag::kFP16 );
261
+ }
262
+ else if (precision_mode == 8 ) {
263
+ // std::cout << "==== precision int8 ====" << std::endl << std::endl;
264
+ // std::cout << "Your platform support int8: " << builder->platformHasFastInt8() << std::endl;
265
+ // assert(builder->platformHasFastInt8());
266
+ // config->setFlag(BuilderFlag::kINT8);
267
+ // Int8EntropyCalibrator2 *calibrator = new Int8EntropyCalibrator2(maxBatchSize, INPUT_W, INPUT_H, 0, "../data_calib/", "../Int8_calib_table/detr_int8_calib.table", INPUT_BLOB_NAME);
268
+ // config->setInt8Calibrator(calibrator);
269
+ }
270
+ else {
271
+ std::cout << " ==== precision f32 ====" << std::endl << std::endl;
272
+ }
257
273
258
274
std::cout << " Building engine, please wait for a while..." << std::endl;
259
275
IHostMemory* engine = builder->buildSerializedNetwork (*network, *config);
@@ -285,7 +301,7 @@ int main()
285
301
char engineFileName[] = " real-esrgan" ;
286
302
287
303
char engine_file_path[256 ];
288
- sprintf (engine_file_path, " ../Engine/%s .engine" , engineFileName);
304
+ sprintf (engine_file_path, " ../Engine/%s_%d .engine" , engineFileName, precision_mode );
289
305
290
306
// 1) engine file 만들기
291
307
// 강제 만들기 true면 무조건 다시 만들기
@@ -359,7 +375,7 @@ int main()
359
375
std::cout << " ===== input load done =====" << std::endl << std::endl;
360
376
361
377
uint64_t dur_time = 0 ;
362
- uint64_t iter_count = 1 ;
378
+ uint64_t iter_count = 10 ;
363
379
364
380
// CUDA 스트림 생성
365
381
cudaStream_t stream;
0 commit comments