Skip to content

Commit

Permalink
ENH: Return n_iter_ from liblinear and print convergence warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder authored and ogrisel committed Aug 1, 2014
1 parent 3edfa3c commit c8c72fd
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 142 deletions.
4 changes: 4 additions & 0 deletions sklearn/linear_model/logistic.py
Expand Up @@ -639,6 +639,10 @@ class LogisticRegression(BaseLibLinear, LinearClassifierMixin,
Intercept (a.k.a. bias) added to the decision function.
If `fit_intercept` is set to False, the intercept is set to zero.
`n_iter_` : int | array, shape (n_classes,)
Number of iterations run per class. Valid only for the liblinear
solver.
See also
--------
SGDClassifier : incrementally trained logistic regression (when given
Expand Down
18 changes: 11 additions & 7 deletions sklearn/svm/base.py
Expand Up @@ -716,17 +716,21 @@ def fit(self, X, y):

# LibLinear wants targets as doubles, even for classification
y_ind = np.asarray(y_ind, dtype=np.float64).ravel()
raw_coef_ = liblinear.train_wrap(X, y_ind,
sp.isspmatrix(X),
self._get_solver_type(),
self.tol, self._get_bias(),
self.C, self.class_weight_,
self.max_iter,
rnd.randint(np.iinfo('i').max))
raw_coef_, self.n_iter_ = liblinear.train_wrap(
X, y_ind, sp.isspmatrix(X), self._get_solver_type(),
self.tol, self._get_bias(), self.C, self.class_weight_,
self.max_iter, rnd.randint(np.iinfo('i').max)
)
# Regarding rnd.randint(..) in the above signature:
# seed for srand in range [0..INT_MAX); due to limitations in Numpy
# on 32-bit platforms, we can't get to the UINT_MAX limit that
# srand supports
for n_iter in self.n_iter_:
if n_iter >= self.max_iter:
warnings.warn("Liblinear failed to converge, increase "
"the number of iterations.", ConvergenceWarning)
if len(self.classes_) == 2:
self.n_iter_ = self.n_iter_[0]

if self.fit_intercept:
self.coef_ = raw_coef_[:, :-1]
Expand Down
324 changes: 220 additions & 104 deletions sklearn/svm/liblinear.c

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions sklearn/svm/liblinear.pxd
Expand Up @@ -13,6 +13,7 @@ cdef extern from "src/liblinear/linear.h":
model *train(problem_const_ptr prob, parameter_const_ptr param) nogil
int get_nr_feature (model *model)
int get_nr_class (model *model)
void get_n_iter (model *model, int *n_iter)
void free_and_destroy_model (model **)
void destroy_param (parameter *)

Expand Down
9 changes: 8 additions & 1 deletion sklearn/svm/liblinear.pyx
Expand Up @@ -55,6 +55,13 @@ def train_wrap(X, np.ndarray[np.float64_t, ndim=1, mode='c'] Y,
# coef matrix holder created as fortran since that's what's used in liblinear
cdef np.ndarray[np.float64_t, ndim=2, mode='fortran'] w
cdef int nr_class = get_nr_class(model)

cdef int labels_ = nr_class
if nr_class == 2:
labels_ = 1
cdef np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter = np.zeros(labels_, dtype=np.int32)
get_n_iter(model, <int *>n_iter.data)

cdef int nr_feature = get_nr_feature(model)
if bias > 0: nr_feature = nr_feature + 1
if nr_class == 2 and solver_type != 4: # solver is not Crammer-Singer
Expand All @@ -71,7 +78,7 @@ def train_wrap(X, np.ndarray[np.float64_t, ndim=1, mode='c'] Y,
free_parameter(param)
# destroy_param(param) don't call this or it will destroy class_weight_label and class_weight

return w
return w, n_iter


def set_verbosity_wrap(int verbosity):
Expand Down
69 changes: 46 additions & 23 deletions sklearn/svm/src/liblinear/linear.cpp
Expand Up @@ -480,7 +480,7 @@ class Solver_MCSVM_CS
public:
Solver_MCSVM_CS(const problem *prob, int nr_class, double *C, double eps=0.1, int max_iter=100000);
~Solver_MCSVM_CS();
void Solve(double *w);
int Solve(double *w);
private:
void solve_sub_problem(double A_i, int yi, double C_yi, int active_i, double *alpha_new);
bool be_shrunk(int i, int m, int yi, double alpha_i, double minG);
Expand Down Expand Up @@ -555,7 +555,7 @@ bool Solver_MCSVM_CS::be_shrunk(int i, int m, int yi, double alpha_i, double min
return false;
}

void Solver_MCSVM_CS::Solve(double *w)
int Solver_MCSVM_CS::Solve(double *w)
{
int i, m, s;
int iter = 0;
Expand Down Expand Up @@ -765,6 +765,7 @@ void Solver_MCSVM_CS::Solve(double *w)
delete [] alpha_index;
delete [] y_index;
delete [] active_size_i;
return iter;
}

// A coordinate descent algorithm for
Expand Down Expand Up @@ -797,7 +798,7 @@ void Solver_MCSVM_CS::Solve(double *w)
#define GETI(i) (y[i]+1)
// To support weights for instances, use GETI(i) (i)

static void solve_l2r_l1l2_svc(
static int solve_l2r_l1l2_svc(
const problem *prob, double *w, double eps,
double Cp, double Cn, int solver_type, int max_iter)
{
Expand Down Expand Up @@ -983,6 +984,7 @@ static void solve_l2r_l1l2_svc(
delete [] alpha;
delete [] y;
delete [] index;
return iter;
}


Expand Down Expand Up @@ -1014,7 +1016,7 @@ static void solve_l2r_l1l2_svc(
#define GETI(i) (0)
// To support weights for instances, use GETI(i) (i)

static void solve_l2r_l1l2_svr(
static int solve_l2r_l1l2_svr(
const problem *prob, double *w, const parameter *param,
int solver_type, int max_iter)
{
Expand Down Expand Up @@ -1215,6 +1217,7 @@ static void solve_l2r_l1l2_svr(
delete [] beta;
delete [] QD;
delete [] index;
return iter;
}


Expand All @@ -1240,7 +1243,7 @@ static void solve_l2r_l1l2_svr(
#define GETI(i) (y[i]+1)
// To support weights for instances, use GETI(i) (i)

void solve_l2r_lr_dual(const problem *prob, double *w, double eps, double Cp, double Cn,
int solve_l2r_lr_dual(const problem *prob, double *w, double eps, double Cp, double Cn,
int max_iter)
{
int l = prob->l;
Expand Down Expand Up @@ -1395,6 +1398,7 @@ void solve_l2r_lr_dual(const problem *prob, double *w, double eps, double Cp, do
delete [] alpha;
delete [] y;
delete [] index;
return iter;
}

// A coordinate descent algorithm for
Expand All @@ -1414,7 +1418,7 @@ void solve_l2r_lr_dual(const problem *prob, double *w, double eps, double Cp, do
#define GETI(i) (y[i]+1)
// To support weights for instances, use GETI(i) (i)

static void solve_l1r_l2_svc(
static int solve_l1r_l2_svc(
problem *prob_col, double *w, double eps,
double Cp, double Cn, int max_iter)
{
Expand Down Expand Up @@ -1681,6 +1685,7 @@ static void solve_l1r_l2_svc(
delete [] y;
delete [] b;
delete [] xj_sq;
return iter;
}

// A coordinate descent algorithm for
Expand All @@ -1700,7 +1705,7 @@ static void solve_l1r_l2_svc(
#define GETI(i) (y[i]+1)
// To support weights for instances, use GETI(i) (i)

static void solve_l1r_lr(
static int solve_l1r_lr(
const problem *prob_col, double *w, double eps,
double Cp, double Cn, int max_newton_iter)
{
Expand Down Expand Up @@ -2061,6 +2066,7 @@ static void solve_l1r_lr(
delete [] exp_wTx_new;
delete [] tau;
delete [] D;
return newton_iter;
}

// transpose matrix X from row format to column format
Expand Down Expand Up @@ -2211,12 +2217,13 @@ static void group_classes(const problem *prob, int *nr_class_ret, int **label_re
free(data_label);
}

static void train_one(const problem *prob, const parameter *param, double *w, double Cp, double Cn)
static int train_one(const problem *prob, const parameter *param, double *w, double Cp, double Cn)
{
double eps=param->eps;
int max_iter=param->max_iter;
int pos = 0;
int neg = 0;
int n_iter;
for(int i=0;i<prob->l;i++)
if(prob->y[i] > 0)
pos++;
Expand All @@ -2240,7 +2247,7 @@ static void train_one(const problem *prob, const parameter *param, double *w, do
fun_obj=new l2r_lr_fun(prob, C);
TRON tron_obj(fun_obj, primal_solver_tol, max_iter);
tron_obj.set_print_string(liblinear_print_string);
tron_obj.tron(w);
n_iter=tron_obj.tron(w);
delete fun_obj;
delete [] C;
break;
Expand All @@ -2258,23 +2265,23 @@ static void train_one(const problem *prob, const parameter *param, double *w, do
fun_obj=new l2r_l2_svc_fun(prob, C);
TRON tron_obj(fun_obj, primal_solver_tol, max_iter);
tron_obj.set_print_string(liblinear_print_string);
tron_obj.tron(w);
n_iter=tron_obj.tron(w);
delete fun_obj;
delete [] C;
break;
}
case L2R_L2LOSS_SVC_DUAL:
solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, L2R_L2LOSS_SVC_DUAL, max_iter);
n_iter=solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, L2R_L2LOSS_SVC_DUAL, max_iter);
break;
case L2R_L1LOSS_SVC_DUAL:
solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, L2R_L1LOSS_SVC_DUAL, max_iter);
n_iter=solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, L2R_L1LOSS_SVC_DUAL, max_iter);
break;
case L1R_L2LOSS_SVC:
{
problem prob_col;
feature_node *x_space = NULL;
transpose(prob, &x_space ,&prob_col);
solve_l1r_l2_svc(&prob_col, w, primal_solver_tol, Cp, Cn, max_iter);
n_iter=solve_l1r_l2_svc(&prob_col, w, primal_solver_tol, Cp, Cn, max_iter);
delete [] prob_col.y;
delete [] prob_col.x;
delete [] x_space;
Expand All @@ -2285,14 +2292,14 @@ static void train_one(const problem *prob, const parameter *param, double *w, do
problem prob_col;
feature_node *x_space = NULL;
transpose(prob, &x_space ,&prob_col);
solve_l1r_lr(&prob_col, w, primal_solver_tol, Cp, Cn, max_iter);
n_iter=solve_l1r_lr(&prob_col, w, primal_solver_tol, Cp, Cn, max_iter);
delete [] prob_col.y;
delete [] prob_col.x;
delete [] x_space;
break;
}
case L2R_LR_DUAL:
solve_l2r_lr_dual(prob, w, eps, Cp, Cn, max_iter);
n_iter=solve_l2r_lr_dual(prob, w, eps, Cp, Cn, max_iter);
break;
case L2R_L2LOSS_SVR:
{
Expand All @@ -2303,22 +2310,23 @@ static void train_one(const problem *prob, const parameter *param, double *w, do
fun_obj=new l2r_l2_svr_fun(prob, C, param->p);
TRON tron_obj(fun_obj, param->eps, max_iter);
tron_obj.set_print_string(liblinear_print_string);
tron_obj.tron(w);
n_iter=tron_obj.tron(w);
delete fun_obj;
delete [] C;
break;

}
case L2R_L1LOSS_SVR_DUAL:
solve_l2r_l1l2_svr(prob, w, param, L2R_L1LOSS_SVR_DUAL, max_iter);
n_iter=solve_l2r_l1l2_svr(prob, w, param, L2R_L1LOSS_SVR_DUAL, max_iter);
break;
case L2R_L2LOSS_SVR_DUAL:
solve_l2r_l1l2_svr(prob, w, param, L2R_L2LOSS_SVR_DUAL, max_iter);
n_iter=solve_l2r_l1l2_svr(prob, w, param, L2R_L2LOSS_SVR_DUAL, max_iter);
break;
default:
fprintf(stderr, "ERROR: unknown solver_type\n");
break;
}
return n_iter;
}

//
Expand All @@ -2330,6 +2338,7 @@ model* train(const problem *prob, const parameter *param)
int l = prob->l;
int n = prob->n;
int w_size = prob->n;
int n_iter;
model *model_ = Malloc(model,1);

if(prob->bias>=0)
Expand All @@ -2344,9 +2353,10 @@ model* train(const problem *prob, const parameter *param)
param->solver_type == L2R_L2LOSS_SVR_DUAL)
{
model_->w = Malloc(double, w_size);
model_->n_iter = Malloc(int, 1);
model_->nr_class = 2;
model_->label = NULL;
train_one(prob, param, &model_->w[0], 0, 0);
model_->n_iter[0] =train_one(prob, param, &model_->w[0], 0, 0);
}
else
{
Expand Down Expand Up @@ -2398,31 +2408,33 @@ model* train(const problem *prob, const parameter *param)
if(param->solver_type == MCSVM_CS)
{
model_->w=Malloc(double, n*nr_class);
model_->n_iter=Malloc(int, 1);
for(i=0;i<nr_class;i++)
for(j=start[i];j<start[i]+count[i];j++)
sub_prob.y[j] = i;
Solver_MCSVM_CS Solver(&sub_prob, nr_class, weighted_C, param->eps);
Solver.Solve(model_->w);
model_->n_iter[0]=Solver.Solve(model_->w);
}
else
{
if(nr_class == 2)
{
model_->w=Malloc(double, w_size);

model_->n_iter=Malloc(int, 1);
int e0 = start[0]+count[0];
k=0;
for(; k<e0; k++)
sub_prob.y[k] = -1;
for(; k<sub_prob.l; k++)
sub_prob.y[k] = +1;

train_one(&sub_prob, param, &model_->w[0], weighted_C[1], weighted_C[0]);
model_->n_iter[0]=train_one(&sub_prob, param, &model_->w[0], weighted_C[1], weighted_C[0]);
}
else
{
model_->w=Malloc(double, w_size*nr_class);
double *w=Malloc(double, w_size);
model_->n_iter=Malloc(int, nr_class);
for(i=0;i<nr_class;i++)
{
int si = start[i];
Expand All @@ -2436,7 +2448,7 @@ model* train(const problem *prob, const parameter *param)
for(; k<sub_prob.l; k++)
sub_prob.y[k] = -1;

train_one(&sub_prob, param, w, weighted_C[i], param->C);
model_->n_iter[i]=train_one(&sub_prob, param, w, weighted_C[i], param->C);

for(int j=0;j<w_size;j++)
model_->w[j*nr_class+i] = w[j];
Expand Down Expand Up @@ -2795,6 +2807,17 @@ void get_labels(const model *model_, int* label)
label[i] = model_->label[i];
}

void get_n_iter(const model *model_, int* n_iter)
{
int labels;
labels = model_->nr_class;
if (labels == 2)
labels = 1;
if (model_->n_iter != NULL)
for(int i=0;i<labels;i++)
n_iter[i] = model_->n_iter[i];
}

void free_model_content(struct model *model_ptr)
{
if(model_ptr->w != NULL)
Expand Down
2 changes: 2 additions & 0 deletions sklearn/svm/src/liblinear/linear.h
Expand Up @@ -43,6 +43,7 @@ struct model
double *w;
int *label; /* label of each class */
double bias;
int *n_iter; /* no. of iterations of each class */
};

struct model* train(const struct problem *prob, const struct parameter *param);
Expand All @@ -58,6 +59,7 @@ struct model *load_model(const char *model_file_name);
int get_nr_feature(const struct model *model_);
int get_nr_class(const struct model *model_);
void get_labels(const struct model *model_, int* label);
void get_n_iter(const struct model *model_, int* n_iter);

void free_model_content(struct model *model_ptr);
void free_and_destroy_model(struct model **model_ptr_ptr);
Expand Down
3 changes: 2 additions & 1 deletion sklearn/svm/src/liblinear/tron.cpp
Expand Up @@ -44,7 +44,7 @@ TRON::~TRON()
{
}

void TRON::tron(double *w)
int TRON::tron(double *w)
{
// Parameters for updating the iterates.
double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
Expand Down Expand Up @@ -146,6 +146,7 @@ void TRON::tron(double *w)
delete[] r;
delete[] w_new;
delete[] s;
return --iter;
}

int TRON::trcg(double delta, double *g, double *s, double *r)
Expand Down

0 comments on commit c8c72fd

Please sign in to comment.