From 747b5e20bd31113d1090ff98c0f0241607c9b41e Mon Sep 17 00:00:00 2001 From: TatsuyaOgusu Date: Thu, 22 Sep 2016 18:13:33 +0900 Subject: [PATCH] change pointer type to smart pointer, remove mTrainData --- src/ofxSvm.cpp | 52 ++++++++++++++++---------------------------------- src/ofxSvm.h | 5 ++--- 2 files changed, 18 insertions(+), 39 deletions(-) diff --git a/src/ofxSvm.cpp b/src/ofxSvm.cpp index 38fe22c..eb110a0 100644 --- a/src/ofxSvm.cpp +++ b/src/ofxSvm.cpp @@ -213,7 +213,7 @@ bool ofxSvm::Data::scale(double lower, double upper, double y_lower, double y_up -ofxSvm::ofxSvm() : mModel(NULL), mTrainData(NULL) +ofxSvm::ofxSvm() { defaultParams(); svm_set_print_string_function(ofxSvm::printStdOut); @@ -227,10 +227,9 @@ ofxSvm::~ofxSvm() void ofxSvm::clear() { svm_destroy_param(&mParam); - if (mModel != NULL) + if (mModel) { - svm_free_model_content(mModel); - svm_free_and_destroy_model(&mModel); + svm_free_model_content(mModel.get()); } } @@ -263,12 +262,12 @@ void ofxSvm::train(const Data& data) { svm_problem prob; - const multimap >& v = data.mData; + const auto& v = data.mData; prob.l = v.size(); prob.y = new double[prob.l]; { - multimap >::const_iterator it = v.begin(); + auto it = v.begin(); int i = 0; while (it != v.end()) { @@ -287,7 +286,7 @@ void ofxSvm::train(const Data& data) svm_node* node = new svm_node[prob.l * nodeLength]; prob.x = new svm_node*[prob.l]; { - multimap >::const_iterator it = v.begin(); + auto it = v.begin(); int i = 0; while (it != v.end()) { @@ -318,54 +317,35 @@ void ofxSvm::train(const Data& data) ofLogVerbose(LOG_MODULE, "Start train..."); - mModel = svm_train(&prob, &mParam); + mModel = shared_ptr(svm_train(&prob, &mParam)); ofLogVerbose(LOG_MODULE, "Finish"); delete[] node; delete[] prob.x; delete[] prob.y; - - mTrainData = &data; } vector ofxSvm::predict(const Data& data) { vector res; - if (mTrainData == NULL) - { - ofLogError(LOG_MODULE, "did not trained yet"); - return res; - } - if (mModel == NULL) + if (!mModel) { - ofLogError(LOG_MODULE, "null model, befor do train or load model file"); + ofLogError(LOG_MODULE, "model is null, do train or load model file"); return res; } - if (svm_check_probability_model(mModel)) + if (svm_check_probability_model(mModel.get())) { ofLogError(LOG_MODULE, "provavility model is not available"); return res; } - if (data.mScaleParameter.isEnable != mTrainData->mScaleParameter.isEnable) - { - ofLogWarning(LOG_MODULE, "different scale"); - } - multimap >::const_iterator it = data.mData.begin(); + auto it = data.mData.begin(); while (it != data.mData.end()) { const vector& testVec = it->second; - if (testVec.size() != mTrainData->mDimension) - { - ofLogError(LOG_MODULE, "diffetent dimension"); - res.push_back(-DBL_MAX); - ++it; - continue; - } - vector testVector(testVec); svm_node* node = new svm_node[data.mDimension + 1]; @@ -376,7 +356,7 @@ vector ofxSvm::predict(const Data& data) } node[data.mDimension].index = -1; - res.push_back( svm_predict(mModel, node) ); + res.push_back( svm_predict(mModel.get(), node) ); delete[] node; ++it; @@ -392,12 +372,12 @@ void ofxSvm::saveModel(const string &filename) ofLogError(LOG_MODULE, "null model, befor do train or load model file"); return; } - svm_save_model(ofToDataPath(filename).c_str(), mModel); + svm_save_model(ofToDataPath(filename).c_str(), mModel.get()); } void ofxSvm::loadModel(const string &filename) { - mModel = svm_load_model(ofToDataPath(filename).c_str()); + mModel = shared_ptr(svm_load_model(ofToDataPath(filename).c_str())); } vector ofxSvm::getSupportVectorIndex() @@ -410,9 +390,9 @@ vector ofxSvm::getSupportVectorIndex() return dst; } - const int num = svm_get_nr_sv(mModel); + const int num = svm_get_nr_sv(mModel.get()); vector indices(num); - svm_get_sv_indices(mModel, indices.data()); + svm_get_sv_indices(mModel.get(), indices.data()); for(int i = 0; i < num; ++i ) { dst.push_back(indices[i] - 1); diff --git a/src/ofxSvm.h b/src/ofxSvm.h index d2afb38..182ba20 100644 --- a/src/ofxSvm.h +++ b/src/ofxSvm.h @@ -74,9 +74,8 @@ class ofxSvm void defaultParams(); protected: - svm_parameter mParam; - svm_model *mModel; - Data const *mTrainData; + svm_parameter mParam; + shared_ptr mModel; static void printStdOut(const char *s); };