Skip to content

Commit

Permalink
Merge pull request #4 from TatsuyaOGth/hotfix/use_smart_pointer
Browse files Browse the repository at this point in the history
change pointer type to smart pointer, remove mTrainData
  • Loading branch information
TatsuyaOGth committed Sep 23, 2016
2 parents 024a68d + 747b5e2 commit 3453a48
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 39 deletions.
52 changes: 16 additions & 36 deletions src/ofxSvm.cpp
Expand Up @@ -213,7 +213,7 @@ bool ofxSvm::Data::scale(double lower, double upper, double y_lower, double y_up



ofxSvm::ofxSvm() : mModel(NULL), mTrainData(NULL)
ofxSvm::ofxSvm()
{
defaultParams();
svm_set_print_string_function(ofxSvm::printStdOut);
Expand All @@ -227,10 +227,9 @@ ofxSvm::~ofxSvm()
void ofxSvm::clear()
{
svm_destroy_param(&mParam);
if (mModel != NULL)
if (mModel)
{
svm_free_model_content(mModel);
svm_free_and_destroy_model(&mModel);
svm_free_model_content(mModel.get());
}
}

Expand Down Expand Up @@ -263,12 +262,12 @@ void ofxSvm::train(const Data& data)
{
svm_problem prob;

const multimap<double, vector<double> >& v = data.mData;
const auto& v = data.mData;

prob.l = v.size();
prob.y = new double[prob.l];
{
multimap<double, vector<double> >::const_iterator it = v.begin();
auto it = v.begin();
int i = 0;
while (it != v.end())
{
Expand All @@ -287,7 +286,7 @@ void ofxSvm::train(const Data& data)
svm_node* node = new svm_node[prob.l * nodeLength];
prob.x = new svm_node*[prob.l];
{
multimap<double, vector<double> >::const_iterator it = v.begin();
auto it = v.begin();
int i = 0;
while (it != v.end())
{
Expand Down Expand Up @@ -318,54 +317,35 @@ void ofxSvm::train(const Data& data)

ofLogVerbose(LOG_MODULE, "Start train...");

mModel = svm_train(&prob, &mParam);
mModel = shared_ptr<svm_model>(svm_train(&prob, &mParam));

ofLogVerbose(LOG_MODULE, "Finish");

delete[] node;
delete[] prob.x;
delete[] prob.y;

mTrainData = &data;
}

vector<double> ofxSvm::predict(const Data& data)
{
vector<double> res;

if (mTrainData == NULL)
{
ofLogError(LOG_MODULE, "did not trained yet");
return res;
}
if (mModel == NULL)
if (!mModel)
{
ofLogError(LOG_MODULE, "null model, befor do train or load model file");
ofLogError(LOG_MODULE, "model is null, do train or load model file");
return res;
}
if (svm_check_probability_model(mModel))
if (svm_check_probability_model(mModel.get()))
{
ofLogError(LOG_MODULE, "provavility model is not available");
return res;
}
if (data.mScaleParameter.isEnable != mTrainData->mScaleParameter.isEnable)
{
ofLogWarning(LOG_MODULE, "different scale");
}

multimap<double, vector<double> >::const_iterator it = data.mData.begin();
auto it = data.mData.begin();
while (it != data.mData.end())
{
const vector<double>& testVec = it->second;

if (testVec.size() != mTrainData->mDimension)
{
ofLogError(LOG_MODULE, "diffetent dimension");
res.push_back(-DBL_MAX);
++it;
continue;
}

vector<double> testVector(testVec);

svm_node* node = new svm_node[data.mDimension + 1];
Expand All @@ -376,7 +356,7 @@ vector<double> ofxSvm::predict(const Data& data)
}
node[data.mDimension].index = -1;

res.push_back( svm_predict(mModel, node) );
res.push_back( svm_predict(mModel.get(), node) );
delete[] node;

++it;
Expand All @@ -392,12 +372,12 @@ void ofxSvm::saveModel(const string &filename)
ofLogError(LOG_MODULE, "null model, befor do train or load model file");
return;
}
svm_save_model(ofToDataPath(filename).c_str(), mModel);
svm_save_model(ofToDataPath(filename).c_str(), mModel.get());
}

void ofxSvm::loadModel(const string &filename)
{
mModel = svm_load_model(ofToDataPath(filename).c_str());
mModel = shared_ptr<svm_model>(svm_load_model(ofToDataPath(filename).c_str()));
}

vector<int> ofxSvm::getSupportVectorIndex()
Expand All @@ -410,9 +390,9 @@ vector<int> ofxSvm::getSupportVectorIndex()
return dst;
}

const int num = svm_get_nr_sv(mModel);
const int num = svm_get_nr_sv(mModel.get());
vector<int> indices(num);
svm_get_sv_indices(mModel, indices.data());
svm_get_sv_indices(mModel.get(), indices.data());
for(int i = 0; i < num; ++i )
{
dst.push_back(indices[i] - 1);
Expand Down
5 changes: 2 additions & 3 deletions src/ofxSvm.h
Expand Up @@ -74,9 +74,8 @@ class ofxSvm
void defaultParams();

protected:
svm_parameter mParam;
svm_model *mModel;
Data const *mTrainData;
svm_parameter mParam;
shared_ptr<svm_model> mModel;

static void printStdOut(const char *s);
};

0 comments on commit 3453a48

Please sign in to comment.