Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
55c8588 replay gib
Browse files Browse the repository at this point in the history
  • Loading branch information
DrRacket committed Dec 5, 2017
1 parent 15556b9 commit 7ee335d
Show file tree
Hide file tree
Showing 20 changed files with 904 additions and 376 deletions.
521 changes: 439 additions & 82 deletions core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java

Large diffs are not rendered by default.

148 changes: 9 additions & 139 deletions core/src/main/java/hivemall/LearnerBaseUDTF.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package hivemall;

import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableFloatObjectInspector;
import hivemall.mix.MixMessage.MixEventName;
import hivemall.mix.client.MixClient;
import hivemall.model.DenseModel;
Expand All @@ -29,22 +28,14 @@
import hivemall.model.SpaceEfficientDenseModel;
import hivemall.model.SparseModel;
import hivemall.model.SynchronizedModelWrapper;
import hivemall.model.WeightValue;
import hivemall.model.WeightValue.WeightValueWithCovar;
import hivemall.optimizer.DenseOptimizerFactory;
import hivemall.optimizer.Optimizer;
import hivemall.optimizer.SparseOptimizerFactory;
import hivemall.utils.datetime.StopWatch;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.Primitives;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;

import javax.annotation.CheckForNull;
Expand All @@ -57,21 +48,15 @@
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableFloatObjectInspector;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

public abstract class LearnerBaseUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(LearnerBaseUDTF.class);

protected final boolean enableNewModel;
protected String preloadedModelFile;
protected boolean dense_model;
protected int model_dims;
protected boolean disable_halffloat;
Expand All @@ -97,7 +82,6 @@ protected boolean useCovariance() {
@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("loadmodel", true, "Model file name in the distributed cache");
opts.addOption("dense", "densemodel", false, "Use dense model or not");
opts.addOption("dims", "feature_dimensions", true,
"The dimension of model [default: 16777216 (2^24)]");
Expand All @@ -119,7 +103,6 @@ protected Options getOptions() {
@Override
protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
throws UDFArgumentException {
String modelfile = null;
boolean denseModel = false;
int modelDims = -1;
boolean disableHalfFloat = false;
Expand All @@ -135,8 +118,6 @@ protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
String rawArgs = HiveUtils.getConstString(argOIs[2]);
cl = parseOptions(rawArgs);

modelfile = cl.getOptionValue("loadmodel");

denseModel = cl.hasOption("dense");
if (denseModel) {
modelDims = Primitives.parseInt(cl.getOptionValue("dims"), 16777216);
Expand All @@ -160,7 +141,6 @@ protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
ssl = cl.hasOption("ssl");
}

this.preloadedModelFile = modelfile;
this.dense_model = denseModel;
this.model_dims = modelDims;
this.disable_halffloat = disableHalfFloat;
Expand Down Expand Up @@ -272,124 +252,14 @@ protected int getInitialModelSize() {
return 16384;
}

protected void loadPredictionModel(PredictionModel model, String filename,
PrimitiveObjectInspector keyOI) {
final StopWatch elapsed = new StopWatch();
final long lines;
try {
if (useCovariance()) {
lines = loadPredictionModel(model, new File(filename), keyOI,
writableFloatObjectInspector, writableFloatObjectInspector);
} else {
lines = loadPredictionModel(model, new File(filename), keyOI,
writableFloatObjectInspector);
}
} catch (IOException e) {
throw new RuntimeException("Failed to load a model: " + filename, e);
} catch (SerDeException e) {
throw new RuntimeException("Failed to load a model: " + filename, e);
}
if (model.size() > 0) {
logger.info("Loaded " + model.size() + " features from distributed cache '" + filename
+ "' (" + lines + " lines) in " + elapsed);
}
}

private static long loadPredictionModel(PredictionModel model, File file,
PrimitiveObjectInspector keyOI, WritableFloatObjectInspector valueOI)
throws IOException, SerDeException {
long count = 0L;
if (!file.exists()) {
return count;
}
if (!file.getName().endsWith(".crc")) {
if (file.isDirectory()) {
for (File f : file.listFiles()) {
count += loadPredictionModel(model, f, keyOI, valueOI);
}
} else {
LazySimpleSerDe serde = HiveUtils.getKeyValueLineSerde(keyOI, valueOI);
StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector();
StructField keyRef = lineOI.getStructFieldRef("key");
StructField valueRef = lineOI.getStructFieldRef("value");
PrimitiveObjectInspector keyRefOI = (PrimitiveObjectInspector) keyRef.getFieldObjectInspector();
FloatObjectInspector varRefOI = (FloatObjectInspector) valueRef.getFieldObjectInspector();

BufferedReader reader = null;
try {
reader = HadoopUtils.getBufferedReader(file);
String line;
while ((line = reader.readLine()) != null) {
count++;
Text lineText = new Text(line);
Object lineObj = serde.deserialize(lineText);
List<Object> fields = lineOI.getStructFieldsDataAsList(lineObj);
Object f0 = fields.get(0);
Object f1 = fields.get(1);
if (f0 == null || f1 == null) {
continue; // avoid the case that key or value is null
}
Object k = keyRefOI.getPrimitiveWritableObject(keyRefOI.copyObject(f0));
float v = varRefOI.get(f1);
model.set(k, new WeightValue(v, false));
}
} finally {
IOUtils.closeQuietly(reader);
}
}
}
return count;
}

private static long loadPredictionModel(PredictionModel model, File file,
PrimitiveObjectInspector featureOI, WritableFloatObjectInspector weightOI,
WritableFloatObjectInspector covarOI) throws IOException, SerDeException {
long count = 0L;
if (!file.exists()) {
return count;
}
if (!file.getName().endsWith(".crc")) {
if (file.isDirectory()) {
for (File f : file.listFiles()) {
count += loadPredictionModel(model, f, featureOI, weightOI, covarOI);
}
} else {
LazySimpleSerDe serde = HiveUtils.getLineSerde(featureOI, weightOI, covarOI);
StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector();
StructField c1ref = lineOI.getStructFieldRef("c1");
StructField c2ref = lineOI.getStructFieldRef("c2");
StructField c3ref = lineOI.getStructFieldRef("c3");
PrimitiveObjectInspector c1oi = (PrimitiveObjectInspector) c1ref.getFieldObjectInspector();
FloatObjectInspector c2oi = (FloatObjectInspector) c2ref.getFieldObjectInspector();
FloatObjectInspector c3oi = (FloatObjectInspector) c3ref.getFieldObjectInspector();

BufferedReader reader = null;
try {
reader = HadoopUtils.getBufferedReader(file);
String line;
while ((line = reader.readLine()) != null) {
count++;
Text lineText = new Text(line);
Object lineObj = serde.deserialize(lineText);
List<Object> fields = lineOI.getStructFieldsDataAsList(lineObj);
Object f0 = fields.get(0);
Object f1 = fields.get(1);
Object f2 = fields.get(2);
if (f0 == null || f1 == null) {
continue; // avoid unexpected case
}
Object k = c1oi.getPrimitiveWritableObject(c1oi.copyObject(f0));
float v = c2oi.get(f1);
float cov = (f2 == null) ? WeightValueWithCovar.DEFAULT_COVAR
: c3oi.get(f2);
model.set(k, new WeightValueWithCovar(v, cov, false));
}
} finally {
IOUtils.closeQuietly(reader);
}
}
@Nonnull
protected ObjectInspector getFeatureOutputOI(@Nonnull PrimitiveObjectInspector featureInputOI)
throws UDFArgumentException {
if (dense_model) {
// TODO validation
return PrimitiveObjectInspectorFactory.javaIntObjectInspector; // see DenseModel
}
return count;
return ObjectInspectorUtils.getStandardObjectInspector(featureInputOI);
}

@Override
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/hivemall/UDTFWithOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ protected final Reporter getReporter() {
return mapredContext.getReporter();
}

protected static void reportProgress(@Nonnull Reporter reporter) {
protected static void reportProgress(@Nullable Reporter reporter) {
if (reporter != null) {
synchronized (reporter) {
reporter.progress();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,15 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu

processOptions(argOIs);

PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector
: featureInputOI;
this.model = createModel();
if (preloadedModelFile != null) {
loadPredictionModel(model, preloadedModelFile, featureOutputOI);
}

this.count = 0;
this.sampled = 0;
return getReturnOI(featureOutputOI);

return getReturnOI(getFeatureOutputOI(featureInputOI));
}

protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg)
@Nonnull
protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg)
throws UDFArgumentException {
this.featureListOI = (ListObjectInspector) arg;
ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector();
Expand All @@ -106,13 +102,13 @@ protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg)
return HiveUtils.asPrimitiveObjectInspector(featureRawOI);
}

protected StructObjectInspector getReturnOI(ObjectInspector featureRawOI) {
@Nonnull
protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) {
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

fieldNames.add("feature");
ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureRawOI);
fieldOIs.add(featureOI);
fieldOIs.add(featureOutputOI);
fieldNames.add("weight");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
if (useCovariance()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ public final class GeneralClassifierUDTF extends GeneralLearnerBaseUDTF {
@Override
protected String getLossOptionDescription() {
return "Loss function [HingeLoss (default), LogLoss, SquaredHingeLoss, ModifiedHuberLoss, or\n"
+ "a regression loss: SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, HuberLoss]";
+ "a regression loss: SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, "
+ "SquaredEpsilonInsensitiveLoss, HuberLoss]";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,19 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu

processOptions(argOIs);

PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector
: featureInputOI;
this.label2model = new HashMap<Object, PredictionModel>(64);
if (preloadedModelFile != null) {
loadPredictionModel(label2model, preloadedModelFile, labelInputOI, featureOutputOI);
}

this.count = 0;
return getReturnOI(labelInputOI, featureOutputOI);

return getReturnOI(labelInputOI, getFeatureOutputOI(featureInputOI));
}

@Override
protected int getInitialModelSize() {
return 8192;
}

protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg)
@Nonnull
protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg)
throws UDFArgumentException {
this.featureListOI = (ListObjectInspector) arg;
ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector();
Expand All @@ -133,17 +129,17 @@ protected PrimitiveObjectInspector processFeaturesOI(ObjectInspector arg)
return HiveUtils.asPrimitiveObjectInspector(featureRawOI);
}

protected StructObjectInspector getReturnOI(ObjectInspector labelRawOI,
ObjectInspector featureRawOI) {
@Nonnull
protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector labelRawOI,
@Nonnull ObjectInspector featureOutputOI) {
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

fieldNames.add("label");
ObjectInspector labelOI = ObjectInspectorUtils.getStandardObjectInspector(labelRawOI);
fieldOIs.add(labelOI);
fieldNames.add("feature");
ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureRawOI);
fieldOIs.add(featureOI);
fieldOIs.add(featureOutputOI);
fieldNames.add("weight");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
if (useCovariance()) {
Expand Down
Loading

0 comments on commit 7ee335d

Please sign in to comment.