Skip to content
Permalink
Browse files
Merged in hotfixes (pull request #5)
Fixed a bug in PA1a and PA2a
  • Loading branch information
myui committed Oct 5, 2013
2 parents 63b5d27 + 11a401f commit b551cfae2066ebf113e70a0b9fcb76a3afacd449
Showing 3 changed files with 27 additions and 16 deletions.
@@ -4,11 +4,11 @@
* Copyright (C) 2013
* National Institute of Advanced Industrial Science and Technology (AIST)
* Registration Number: H25PRO-1520
*
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation.
*
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
@@ -4,11 +4,11 @@
* Copyright (C) 2013
* National Institute of Advanced Industrial Science and Technology (AIST)
* Registration Number: H25PRO-1520
*
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation.
*
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
@@ -108,7 +108,7 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
protected Options getOptions() {
Options opts = new Options();
opts.addOption("fh", "fhash", false, "Enable feature hashing (only used when feature is TEXT type) [default: off]");
opts.addOption("b", "bias", true, "Bias clause [default 1.0, 0.0 to disable]");
opts.addOption("b", "bias", true, "Bias clause [default 0.0 (disable)]");
return opts;
}

@@ -4,11 +4,11 @@
* Copyright (C) 2013
* National Institute of Advanced Industrial Science and Technology (AIST)
* Registration Number: H25PRO-1520
*
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation.
*
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
@@ -94,6 +94,8 @@ protected float aggressiveness() {

@Override
protected void train(Map<Object, FloatWritable> weights, Collection<?> features, float target) {
preTrain(target);

PredictionResult margin = calcScore(features);
float predicted = margin.getScore();
float loss = loss(target, predicted);
@@ -108,11 +110,12 @@ protected void train(Map<Object, FloatWritable> weights, Collection<?> features,
}
}

protected void preTrain(float target) {}

/**
* |w^t - y| - epsilon
*/
protected float loss(float target, float predicted) {
//return Math.abs(target - predicted) - epsilon;
return EpsilonInsensitiveLoss.loss(predicted, target, epsilon);
}

@@ -127,19 +130,23 @@ protected float eta(float loss, PredictionResult margin) {

public static class PA1a extends PassiveAggressiveRegressionUDTF {

private OnlineVariance target_stddev;
private OnlineVariance targetStdDev;

@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs)
throws UDFArgumentException {
this.target_stddev = new OnlineVariance();
this.targetStdDev = new OnlineVariance();
return super.initialize(argOIs);
}

@Override
protected void preTrain(float target) {
targetStdDev.handle(target);
}

@Override
protected float loss(float target, float predicted) {
float stddev = (float) target_stddev.stddev();
//return Math.abs(target - predicted) - (epsilon * stddev);
float stddev = (float) targetStdDev.stddev();
float e = epsilon * stddev;
return EpsilonInsensitiveLoss.loss(predicted, target, e);
}
@@ -164,19 +171,23 @@ protected float eta(float loss, PredictionResult margin) {

public static class PA2a extends PA2 {

private OnlineVariance target_stddev;
private OnlineVariance targetStdDev;

@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs)
throws UDFArgumentException {
this.target_stddev = new OnlineVariance();
this.targetStdDev = new OnlineVariance();
return super.initialize(argOIs);
}

@Override
protected void preTrain(float target) {
targetStdDev.handle(target);
}

@Override
protected float loss(float target, float predicted) {
float stddev = (float) target_stddev.stddev();
//return Math.abs(target - predicted) - (epsilon * stddev);
float stddev = (float) targetStdDev.stddev();
float e = epsilon * stddev;
return EpsilonInsensitiveLoss.loss(predicted, target, e);
}

0 comments on commit b551cfa

Please sign in to comment.