Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rgf feature importances #161

Merged
merged 20 commits into from Mar 6, 2018
83 changes: 0 additions & 83 deletions include/rgf/src/com/AzDmat.cpp
Expand Up @@ -326,47 +326,6 @@ double AzDmat::get(int row, int col) const
return column[col]->get(row);
}

#if 0
/*-------------------------------------------------------------*/
void AzDmat::dump(const AzOut &out, const char *header,
const AzStrArray *sp_row,
const AzStrArray *sp_col,
int cut_num) const
{
if (out.isNull()) return;

AzPrint o(out);

const char *my_header = "";
if (header != NULL) my_header = header;
o.writeln(my_header);

/* (row,col)=(r,c)\n */
o.printBegin("", "");
o.print("(row,col)=");
o.pair_inParen(row_num, col_num, ",");
o.printEnd();

int cx;
for (cx = 0; cx < col_num; ++cx) {
if (column[cx] == NULL) {
continue;
}

/* column=cx (col_header) */
o.printBegin("", " ", "=");
o.print("column", cx);
if (sp_col != NULL) {
o.inParen(sp_col->c_str(cx));
}
o.printEnd();

column[cx]->dump(out, "", sp_row, cut_num);
}
o.flush();
}
#endif

/*-------------------------------------------------------------*/
void AzDmat::dump(const AzOut &out, const char *header,
int max_col,
Expand Down Expand Up @@ -971,25 +930,6 @@ void AzDvect::max_abs(const AzDvect *v)
}
}

#if 0
/*-------------------------------------------------------------*/
void AzDvect::max_abs(const AzReadOnlyVector *v)
{
const char *eyec = "AzDvect::max_abs";
if (num != v->rowNum()) {
throw new AzException(eyec, "shape mismatch");
}
AzCursor cur;
for ( ; ; ) {
double val;
int ex = v->next(cur, val);
if (ex < 0) break;
val = (val > 0) ? val : -val;
elm[ex] = MAX(elm[ex], val);
}
}
#endif

/*-------------------------------------------------------------*/
void AzDvect::add_abs(const AzDvect *v)
{
Expand All @@ -1007,29 +947,6 @@ void AzDvect::add_abs(const AzDvect *v)
}
}

#if 0
/*-------------------------------------------------------------*/
void AzDvect::add_abs(const AzReadOnlyVector *v)
{
const char *eyec = "AzDvect::add_abs";
if (num != v->rowNum()) {
throw new AzException(eyec, "shape mismatch");
}
AzCursor cur;
for ( ; ; ) {
double val;
int ex = v->next(cur, val);
if (ex < 0) break;
if (val > 0) {
elm[ex] += val;
}
else {
elm[ex] -= val;
}
}
}
#endif

/*-------------------------------------------------------------*/
/* WARNING: This could be very slow. */
void AzDvect::add(const AzReadOnlyVector *vect1, double coefficient)
Expand Down
12 changes: 4 additions & 8 deletions include/rgf/src/com/AzDmat.hpp
Expand Up @@ -144,13 +144,11 @@ class AzDvect : /* implements */ public virtual AzReadOnlyVector {
int nonZeroRowNum() const;
void all(AzIFarr *ifa) const {
ifa->prepare(num);
int row;
for (row = 0; row < num; ++row) ifa->put(row, elm[row]);
for (int row = 0; row < num; ++row) ifa->put(row, elm[row]);
}
void zeroRowNo(AzIntArr *ia) const {
ia->reset();
int row;
for (row = 0; row < num; ++row) if (elm[row] == 0) ia->put(row);
for (int row = 0; row < num; ++row) if (elm[row] == 0) ia->put(row);
}

inline void set(int row, double val) {
Expand Down Expand Up @@ -183,8 +181,7 @@ class AzDvect : /* implements */ public virtual AzReadOnlyVector {
throw new AzException("AzDvect::set(inp_tmpl, num)", "Invalid input");
}
if (inp_num != num) _reform_noset(inp_num);
int ex;
for (ex = 0; ex < num; ++ex) elm[ex] = inp[ex];
for (int ex = 0; ex < num; ++ex) elm[ex] = inp[ex];
}

void set(double val);
Expand Down Expand Up @@ -272,8 +269,7 @@ class AzDvect : /* implements */ public virtual AzReadOnlyVector {
double normalize1();

inline void zeroOut() {
int ex;
for (ex = 0; ex < num; ++ex) elm[ex] = 0;
for (int ex = 0; ex < num; ++ex) elm[ex] = 0;
}

int next(AzCursor &cursor, double &out_val) const;
Expand Down
37 changes: 18 additions & 19 deletions include/rgf/src/tet/AzFindSplit.cpp
Expand Up @@ -88,8 +88,7 @@ void AzFindSplit::_findBestSplit(int nx,
double AzFindSplit::evalSplit(const Az_forFindSplit i[2],
double bestP[2]) const
{
double gain = 0;
gain += getBestGain(i[0].w_sum, i[0].wy_sum, &bestP[0]);
double gain = getBestGain(i[0].w_sum, i[0].wy_sum, &bestP[0]);
gain += getBestGain(i[1].w_sum, i[1].wy_sum, &bestP[1]);
return gain;
}
Expand All @@ -108,8 +107,8 @@ void AzFindSplit::loop(AzTrTsplit *best_split,
Az_forFindSplit i[2];
Az_forFindSplit *src = &i[1];
Az_forFindSplit *dest = &i[0];
double bestP[2] = {0,0};
int le_idx, gt_idx;

int le_idx, gt_idx;
if (sorted->isForward()) {
le_idx = 0;
gt_idx = 1;
Expand Down Expand Up @@ -156,10 +155,11 @@ void AzFindSplit::loop(AzTrTsplit *best_split,
src->wy_sum = total->wy_sum - dest->wy_sum;
src->w_sum = total->w_sum - dest->w_sum;

double bestP[2] = {0, 0};
const double gain = evalSplit(i, bestP);
if (gain > best_split->gain) {
best_split->reset_values(fx, value, gain,
bestP[le_idx], bestP[gt_idx]);
best_split->reset_values(fx, value, gain,
bestP[le_idx], bestP[gt_idx]);
}
}
}
Expand All @@ -168,25 +168,24 @@ void AzFindSplit::loop(AzTrTsplit *best_split,
void AzFindSplit::_pickFeats(int pick_num, int f_num)
{
if (pick_num < 1 || pick_num > f_num) {
throw new AzException("AzFindSplit::pickFeats", "out of range");
throw new AzException("AzFindSplit::pickFeats", "out of range");
}
ia_feats.reset();
ia_feats.reset();
if (pick_num == f_num) {
ia_fx = NULL;
return;
ia_fx = NULL;
return;
}

AzIntArr ia_onOff;
ia_onOff.reset(f_num, 0);
int *onOff = ia_onOff.point_u();
AzIntArr ia_onOff;
ia_onOff.reset(f_num, 0);
int *onOff = ia_onOff.point_u();
for ( ; ; ) {
if (ia_feats.size() >= pick_num) break;
int fx = rand() % f_num;
if (ia_feats.size() >= pick_num) break;
int fx = rand() % f_num;
if (onOff[fx] == 0) {
onOff[fx] = 1;
ia_feats.put(fx);
onOff[fx] = 1;
ia_feats.put(fx);
}
}
ia_fx = &ia_feats;
ia_fx = &ia_feats;
}

3 changes: 1 addition & 2 deletions include/rgf/src/tet/AzFindSplit.hpp
Expand Up @@ -82,8 +82,7 @@ class AzFindSplit
double *out_best_p) /* must not be null */
const = 0;
virtual double evalSplit(const Az_forFindSplit i[2],
double bestP[2]) /* output */
const;
double bestP[2]) /* output */ const;
/*----------------------------------------------------------------*/

void _findBestSplit(int nx,
Expand Down
10 changes: 4 additions & 6 deletions include/rgf/src/tet/AzOptOnTree.cpp
Expand Up @@ -75,10 +75,9 @@ void AzOptOnTree::reset(AzLossType l_type,
}

/*--------------------------------------------------------*/
void
AzOptOnTree::_warmup(const AzTrTreeEnsemble_ReadOnly *inp_ens,
const AzTrTreeFeat *inp_tree_feat,
const AzDvect *inp_v_p)
void AzOptOnTree::_warmup(const AzTrTreeEnsemble_ReadOnly *inp_ens,
const AzTrTreeFeat *inp_tree_feat,
const AzDvect *inp_v_p)
{
v_w.reform(inp_tree_feat->featNum());
var_const = inp_ens->constant() - fixed_const;
Expand Down Expand Up @@ -385,8 +384,7 @@ double AzOptOnTree::getDelta(const int *dxs,
double nsig,
double py_avg,
/*--- inout ---*/
AzRgf_forDelta *for_del) /* updated */
const
AzRgf_forDelta *for_del) /* updated */ const
{
const char *eyec = "AzOptOnTree::getDelta";
if (dxs == NULL) return 0;
Expand Down
7 changes: 3 additions & 4 deletions include/rgf/src/tet/AzRgf_FindSplit_TreeReg.cpp
Expand Up @@ -48,11 +48,10 @@ double AzRgf_FindSplit_TreeReg::evalSplit(

double penalty_diff = reg->penalty_diff(d); /* new - old */

double gain = 2*d[0]*i[0].wy_sum - d[0]*d[0]*i[0].w_sum
+ 2*d[1]*i[1].wy_sum - d[1]*d[1]*i[1].w_sum;
double gain = 2*d[0]*i[0].wy_sum - d[0]*d[0]*i[0].w_sum - nlam * penalty_diff;
gain += 2*d[1]*i[1].wy_sum - d[1]*d[1]*i[1].w_sum - nlam * penalty_diff;

gain -= 2 * nlam * penalty_diff;
/* "2*" b/c penalty is sum v^2/2 */

return gain;
return gain;
}
2 changes: 1 addition & 1 deletion include/rgf/src/tet/AzRgf_FindSplit_TreeReg.hpp
Expand Up @@ -59,6 +59,6 @@ class AzRgf_FindSplit_TreeReg : /* extends */ public virtual AzRgf_FindSplit_Df

//! override AzFindSplit::evalSplit
virtual double evalSplit(const Az_forFindSplit i[2],
double bestP[2]) const;
double bestP[2]) const;
};
#endif