Skip to content
Permalink
Browse files
Merge branch 'mini-batch-sgd' of https://github.com/Lewuathe/hivemall
…into Lewuathe-mini-batch-sgd
  • Loading branch information
myui committed Jan 5, 2016
2 parents 2d4b636 + f5f6e76 commit 1f5bbf3920f88f84969ed205043740d3e07a7b90
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 14 deletions.
@@ -64,6 +64,8 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions {
protected String preloadedModelFile;
protected boolean dense_model;
protected int model_dims;
protected boolean is_mini_batch;
protected float mini_batch_ratio;
protected boolean disable_halffloat;
protected String mixConnectInfo;
protected String mixSessionName;
@@ -85,6 +87,8 @@ protected Options getOptions() {
opts.addOption("loadmodel", true, "Model file name in the distributed cache");
opts.addOption("dense", "densemodel", false, "Use dense model or not");
opts.addOption("dims", "feature_dimensions", true, "The dimension of model [default: 16777216 (2^24)]");
opts.addOption("mini_batch", false, "Use mini batch algorithm or not");
opts.addOption("mini_batch_ratio", true, "The mini batch sampling ratio against all dataset");
opts.addOption("disable_halffloat", false, "Toggle this option to disable the use of SpaceEfficientDenseModel");
opts.addOption("mix", "mix_servers", true, "Comma separated list of MIX servers");
opts.addOption("mix_session", "mix_session_name", true, "Mix session name [default: ${mapred.job.id}]");
@@ -101,6 +105,8 @@ protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
String modelfile = null;
boolean denseModel = false;
int modelDims = -1;
boolean isMinibatch = false;
float miniBatchRatio = 1.f;
boolean disableHalfFloat = false;
String mixConnectInfo = null;
String mixSessionName = null;
@@ -119,6 +125,10 @@ protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
if(denseModel) {
modelDims = Primitives.parseInt(cl.getOptionValue("dims"), 16777216);
}
isMinibatch = cl.hasOption("mini_batch");
if (isMinibatch) {
miniBatchRatio = Primitives.parseFloat(cl.getOptionValue("mini_batch_ratio"), 1.f);
}

disableHalfFloat = cl.hasOption("disable_halffloat");

@@ -136,6 +146,8 @@ protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
this.preloadedModelFile = modelfile;
this.dense_model = denseModel;
this.model_dims = modelDims;
this.is_mini_batch = isMinibatch;
this.mini_batch_ratio = miniBatchRatio;
this.disable_halffloat = disableHalfFloat;
this.mixConnectInfo = mixConnectInfo;
this.mixSessionName = mixSessionName;
@@ -44,6 +44,10 @@ public float getValue() {
return value;
}

public void setValue(float value) {
this.value = value;
}

@Nullable
public static FeatureValue parse(final Object o) throws IllegalArgumentException {
if(o == null) {
@@ -33,7 +33,7 @@
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

public class AROWRegressionUDTF extends OnlineRegressionUDTF {
public class AROWRegressionUDTF extends RegressionBaseUDTF {

/** Regularization parameter r */
protected float r;
@@ -36,7 +36,7 @@
/**
* ADADELTA: AN ADAPTIVE LEARNING RATE METHOD.
*/
public final class AdaDeltaUDTF extends OnlineRegressionUDTF {
public final class AdaDeltaUDTF extends RegressionBaseUDTF {

private float decay;
private float eps;
@@ -88,11 +88,11 @@ protected final void checkTargetValue(final float target) throws UDFArgumentExce
@Override
protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) {
float gradient = LossFunctions.logisticLoss(target, predicted);
update(features, gradient);
onlineUpdate(features, gradient);
}

@Override
protected void update(@Nonnull final FeatureValue[] features, float gradient) {
protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) {
final float g_g = gradient * (gradient / scaling);

for(FeatureValue f : features) {// w[i] += y * x[i]
@@ -36,7 +36,7 @@
/**
* ADAGRAD algorithm with element-wise adaptive learning rates.
*/
public final class AdaGradUDTF extends OnlineRegressionUDTF {
public final class AdaGradUDTF extends RegressionBaseUDTF {

private float eta;
private float eps;
@@ -88,11 +88,11 @@ protected final void checkTargetValue(final float target) throws UDFArgumentExce
@Override
protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) {
float gradient = LossFunctions.logisticLoss(target, predicted);
update(features, gradient);
onlineUpdate(features, gradient);
}

@Override
protected void update(@Nonnull final FeatureValue[] features, float gradient) {
protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) {
final float g_g = gradient * (gradient / scaling);

for(FeatureValue f : features) {// w[i] += y * x[i]
@@ -21,13 +21,15 @@
import hivemall.common.EtaEstimator;
import hivemall.common.LossFunctions;

import hivemall.io.IWeightValue;
import hivemall.io.WeightValue;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

public final class LogressUDTF extends OnlineRegressionUDTF {
public final class LogressUDTF extends RegressionBaseUDTF {

private EtaEstimator etaEstimator;

@@ -72,4 +74,13 @@ protected float computeUpdate(final float target, final float predicted) {
return eta * gradient;
}

@Override
protected IWeightValue getNewWeight(IWeightValue old_w, float delta) {
float oldWeight = 0.f;
if (old_w != null) {
oldWeight = old_w.get();
}
return new WeightValue(oldWeight + (delta / sampled));
}

}
@@ -31,7 +31,7 @@
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

public class PassiveAggressiveRegressionUDTF extends OnlineRegressionUDTF {
public class PassiveAggressiveRegressionUDTF extends RegressionBaseUDTF {

/** Aggressiveness parameter */
protected float c;
@@ -102,7 +102,7 @@ protected void train(@Nonnull final FeatureValue[] features, float target) {
float eta = eta(loss, margin); // min(C, loss / |x|^2)
float coeff = sign * eta;
if(!Float.isInfinite(coeff)) {
update(features, coeff);
onlineUpdate(features, coeff);
}
}
}
@@ -33,6 +33,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -52,8 +53,12 @@
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;

public abstract class OnlineRegressionUDTF extends LearnerBaseUDTF {
private static final Log logger = LogFactory.getLog(OnlineRegressionUDTF.class);
/**
* The base class for regression algorithms. RegressionBaseUDTF provides
* general implementation for online training and batch training.
*/
public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
private static final Log logger = LogFactory.getLog(RegressionBaseUDTF.class);

private ListObjectInspector featureListOI;
private PrimitiveObjectInspector featureInputOI;
@@ -62,6 +67,11 @@ public abstract class OnlineRegressionUDTF extends LearnerBaseUDTF {

protected PredictionModel model;
protected int count;
// The accumulated delta of each weight values.
protected FeatureValue[] accDelta;
// The number of samples which picked up through mini batch training.
protected int sampled;
protected Random rnd;

@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
@@ -82,6 +92,9 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
}

this.count = 0;
this.accDelta = null;
this.rnd = new Random(42);
this.sampled = 0;
return getReturnOI(featureOutputOI);
}

@@ -123,10 +136,14 @@ public void process(Object[] args) throws HiveException {
if(featureVector == null) {
return;
}
if (accDelta == null) {
accDelta = new FeatureValue[featureVector.length];
}
float target = PrimitiveObjectInspectorUtils.getFloat(args[1], targetOI);
checkTargetValue(target);

count++;

train(featureVector, target);
}

@@ -228,14 +245,61 @@ protected PredictionResult calcScoreAndVariance(@Nonnull final FeatureValue[] fe

protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) {
float d = computeUpdate(target, predicted);
update(features, d);

if (this.is_mini_batch) {
batchUpdate(features, d);
} else {
onlineUpdate(features, d);
}
}

protected float computeUpdate(float target, float predicted) {
throw new IllegalStateException();
}

protected void update(@Nonnull final FeatureValue[] features, float coeff) {
protected IWeightValue getNewWeight(IWeightValue old_w, float delta) {
throw new IllegalStateException();
}
/**
* Accumulates the delta calculated from each samples
* @param features
* @param coeff
*/
protected void accumulateDelta(@Nonnull final FeatureValue[] features, float coeff) {
for (int i = 0; i < features.length; i++) {
if (features[i] == null) {
continue;
}
final Object x = features[i].getFeature();
final float xi = features[i].getValue();
float delta = xi * coeff;
if (accDelta[i] == null) {
accDelta[i] = new FeatureValue(x, delta);
} else {
accDelta[i].setValue(accDelta[i].getValue() + delta);
}
}
}

/**
* Calculate the update value for batch training.
* @param features
* @param coeff
*/
protected void batchUpdate(@Nonnull final FeatureValue[] features, float coeff) {
if (rnd.nextFloat() <= this.mini_batch_ratio) {
assert features.length == accDelta.length;
accumulateDelta(features, coeff);
sampled += 1;
}
}

/**
* Calculate the update value for online training.
* @param features
* @param coeff
*/
protected void onlineUpdate(@Nonnull final FeatureValue[] features, float coeff) {
for(FeatureValue f : features) {// w[i] += y * x[i]
if(f == null) {
continue;
@@ -253,6 +317,17 @@ protected void update(@Nonnull final FeatureValue[] features, float coeff) {
public final void close() throws HiveException {
super.close();
if(model != null) {
// Update model with accumulated delta. This is done
// at the end of iteration only in case of batch training.
if (this.is_mini_batch) {
for (int i = 0; i < accDelta.length; i++) {
final Object x = accDelta[i].getFeature();
final float delta = accDelta[i].getValue();
IWeightValue old_w = model.get(x);
IWeightValue new_w = getNewWeight(old_w, delta);
model.set(x, new_w);
}
}
int numForwarded = 0;
if(useCovariance()) {
final WeightValueWithCovar probe = new WeightValueWithCovar();

0 comments on commit 1f5bbf3

Please sign in to comment.