Skip to content

Commit

Permalink
SWIG interface further updated
Browse files Browse the repository at this point in the history
  • Loading branch information
WladimirSidorenko committed Jan 12, 2016
1 parent 6f1b0ac commit 02ca8fc
Show file tree
Hide file tree
Showing 11 changed files with 342 additions and 206 deletions.
2 changes: 1 addition & 1 deletion configure.ac
Expand Up @@ -157,9 +157,9 @@ AC_SUBST(libdir)
dnl ------------------------------------------------------------------
dnl Output the configure results.
dnl ------------------------------------------------------------------
AC_CONFIG_FILES([swig/python/setup.py], [chmod +x swig/python/setup.py])
AC_CONFIG_FILES(Makefile genbinary.sh include/Makefile lib/cqdb/Makefile dnl
lib/crf/Makefile frontend/Makefile swig/Makefile dnl
tests/Makefile)
AC_CONFIG_FILES([swig/python/setup.py], [chmod +x swig/python/setup.py])

AC_OUTPUT
112 changes: 90 additions & 22 deletions include/crfsuite.hpp
Expand Up @@ -32,6 +32,8 @@
#define __CRFSUITE_HPP__

#include <cmath>
#include <cstdlib>
#include <cstring>
#include <vector>
#include <string>
#include <stdexcept>
Expand All @@ -54,12 +56,12 @@ namespace CRFSuite

Trainer::~Trainer()
{
if (data != NULL) {
if (data) {
clear();
delete data;
data = NULL;
}
if (tr != NULL) {
if (tr) {
tr->release(tr);
tr = NULL;
}
Expand All @@ -82,6 +84,14 @@ namespace CRFSuite
throw std::runtime_error("Failed to create a dictionary instance for labels.");
}
}

// Create an instance of node label dictionary.
if (tr && tr->ftype == FTYPE_CRF1TREE && data->node_labels == NULL) {
int ret = crfsuite_create_instance("dictionary", (void**)&data->node_labels);
if (!ret) {
throw std::runtime_error("Failed to create a dictionary instance for node labels.");
}
}
}

void Trainer::clear()
Expand All @@ -97,61 +107,119 @@ namespace CRFSuite
data->attrs = NULL;
}

if (data->node_labels != NULL) {
data->node_labels->release(data->node_labels);
data->node_labels = NULL;
}
// crfsuite_data_init() is automatically called from `finish()`
crfsuite_data_finish(data);
crfsuite_data_init(data);
}
}

void Trainer::append(const ItemSequence& xseq, const StringList& yseq, int group)
{
// Create dictionary objects if necessary.
if (data->attrs == NULL || data->labels == NULL) {
if (data->attrs == NULL || data->labels == NULL || \
(tr && tr->ftype == FTYPE_CRF1TREE && data->node_labels == NULL))
init();
}

// Make sure |y| == |x|.
if (xseq.size() != yseq.size()) {
std::stringstream ss;
ss << "The numbers of items and labels differ: |x| = " << xseq.size() << ", |y| = " << yseq.size();
ss << "The numbers of items and labels differ: |x| = " << \
xseq.size() << ", |y| = " << yseq.size();
throw std::invalid_argument(ss.str());
}

// Convert instance_type to crfsuite_instance_t.
int i, n_items;
crfsuite_instance_t _inst;
crfsuite_instance_init_n(&_inst, xseq.size());
for (size_t t = 0;t < xseq.size();++t) {
for (size_t t = 0; t < xseq.size(); ++t) {
const Item& item = xseq[t];
crfsuite_item_t* _item = &_inst.items[t];

// Set the attributes in the item.
crfsuite_item_init_n(_item, item.size());
for (size_t i = 0;i < item.size();++i) {
i = 0;
n_items = item.size();

if (tr && tr->ftype == FTYPE_CRF1TREE) {
if (n_items < 1)
throw std::runtime_error("Invalid tree format: node label should be given as attribute.");
else if (n_items < 2)
throw std::runtime_error("Invalid tree format: parent label should be given as attribute.");

// Allocate memory for attributes
i = 2;
crfsuite_item_init_n(_item, n_items - i);

// remember string label of this node
_item->id = data->node_labels->get(data->node_labels, item[0].attr.c_str());
_item->node_label = (char *) calloc((item[0].attr.length() + 1), sizeof(char));

if (_item->node_label)
strcpy(_item->node_label, item[0].attr.c_str());
else
throw std::runtime_error("ERROR: Could not allocate memory for storing node label.\n");

// remember parent of this node
if (item[1].attr == "_")
_item->prnt = -1;
else
_item->prnt = data->node_labels->get(data->node_labels, item[1].attr.c_str());
} else {
crfsuite_item_init_n(_item, n_items);
}

// add attributes
for (; i < n_items; ++i) {
if (item[i].attr.empty())
continue;

_item->contents[i].aid = data->attrs->get(data->attrs, item[i].attr.c_str());
_item->contents[i].value = (floatval_t)item[i].value;
}

// Set the label of the item.
_inst.labels[t] = data->labels->get(data->labels, yseq[t].c_str());
}

// initialize instance tree
if (tr && tr->ftype == FTYPE_CRF1TREE && crfsuite_tree_init(&_inst) != 0)
throw std::runtime_error("ERROR: Could not create tree for training instance.\n");

// assign instance to a group
_inst.group = group;

// Append the instance to the training set.
crfsuite_data_append(data, &_inst);

// Finish the instance.
crfsuite_instance_finish(&_inst);

/* clear dictionary of node labels so that new instances will
have dense representation of node ids again */
if (data->node_labels)
data->node_labels->reset(data->node_labels);
}

bool Trainer::select(const std::string& algorithm, const std::string& type)
{
int ret;
int ret = 0;

// Release the trainer if it is already initialized.
if (tr != NULL) {
tr->release(tr);
tr = NULL;
}

if (algorithm != "lbfgs" && type == "semim") {
std::stringstream ss;
ss << "ERROR: Training algorithm '" << algorithm << \
"' is not supported for this type of graphical model. Try `lbfgs' instead";
throw std::invalid_argument(ss.str());
}

// Build the trainer string ID.
std::string tid = "train/";
tid += type;
Expand All @@ -160,30 +228,29 @@ namespace CRFSuite

// Create an instance of a trainer.
ret = crfsuite_create_instance(tid.c_str(), (void**)&tr);
if (!ret) {
if (!ret)
return false;
}

// Set the callback function for receiving messages.
tr->set_message_callback(tr, this, __logging_callback);

return true;
}

int Trainer::train(const std::string& model, int holdout)
{
// Run the training algorithm.
int ret = tr->train(tr, data, model.c_str(), holdout);

return ret;
return tr->train(tr, data, model.empty()? NULL: model.c_str(), holdout);
}

StringList Trainer::params()
{
StringList pars;
if (!tr)
return pars;

crfsuite_params_t* params = tr->params(tr);
int n = params->num(params);
for (int i = 0;i < n;++i) {
for (int i = 0; i < n; ++i) {
char *name = NULL;
params->name(params, i, &name);
pars.push_back(name);
Expand Down Expand Up @@ -247,7 +314,8 @@ namespace CRFSuite


// all members are deault initialized in class
Tagger::Tagger()
Tagger::Tagger():
m_ftype(FTYPE_NONE)
{}

Tagger::~Tagger()
Expand Down Expand Up @@ -318,13 +386,13 @@ namespace CRFSuite
}
}

StringList Tagger::tag(const ItemSequence& xseq)
StringList Tagger::tag(ItemSequence& xseq)
{
set(xseq);
return viterbi();
}

void Tagger::set(const ItemSequence& xseq)
void Tagger::set(ItemSequence& xseq)
{
int ret;
StringList yseq;
Expand All @@ -338,7 +406,7 @@ namespace CRFSuite

// Build an instance.
crfsuite_instance_init_n(&_inst, xseq.size());
for (size_t t = 0;t < xseq.size();++t) {
for (size_t t = 0; t < xseq.size(); ++t) {
const Item& item = xseq[t];
crfsuite_item_t* _item = &_inst.items[t];

Expand Down Expand Up @@ -430,7 +498,7 @@ namespace CRFSuite
return yseq;
}

double Tagger::probability(const StringList& yseq)
double Tagger::probability(StringList& yseq)
{
int ret;
size_t T;
Expand Down

0 comments on commit 02ca8fc

Please sign in to comment.