Skip to content
Permalink
Browse files
Close #123: [HIVEMALL-154] Refactor Field-aware Factorization Machine…
…s to support Instance-wise L2 normalization
  • Loading branch information
myui committed Oct 24, 2017
1 parent f469cce commit ad15923a1f88ee10b1d66efa92f5f79bbc6ea804
Showing 60 changed files with 2,435 additions and 4,477 deletions.
15 NOTICE
@@ -33,6 +33,17 @@ o hivemall/core/src/main/java/hivemall/utils/collections/OpenHashMap.java
https://github.com/slipperyseal/atomicobjects/
Licensed under the Apache License, Version 2.0

o hivemall/core/src/main/java/hivemall/utils/math/FastMath.java

Copyright 2012-2015 Jeff Hain

https://github.com/jeffhain/jafama/
Licensed under the Apache License, Version 2.0

Copyright (C) 1993 by Sun Microsystems, Inc.

Permission to use, copy, modify, and distribute this software is freely granted, provided that this notice is preserved.

------------------------------------------------------------------------------------------------------
Copyright notifications which have been relocated from ASF projects

@@ -50,7 +61,7 @@ o hivemall/core/src/main/java/hivemall/utils/buffer/DynamicByteArray.java
https://orc.apache.org/
Licensed under the Apache License, Version 2.0

hivemall/spark/spark-2.0/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
o hivemall/spark/spark-2.0/extra-src/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/QueryTest.scala
hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
hivemall/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
@@ -66,3 +77,5 @@ o hivemall/core/src/main/java/hivemall/utils/buffer/DynamicByteArray.java

http://spark.apache.org/
Licensed under the Apache License, Version 2.0


@@ -139,7 +139,13 @@
<dependency>
<groupId>org.roaringbitmap</groupId>
<artifactId>RoaringBitmap</artifactId>
<version>[0.6,)</version>
<version>[0.6,0.7)</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>it.unimi.dsi</groupId>
<artifactId>fastutil</artifactId>
<version>[8.1.0,8.2)</version>
<scope>compile</scope>
</dependency>

@@ -190,7 +196,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>2.3</version>
<version>3.1.0</version>
<executions>
<execution>
<id>jar-with-dependencies</id>
@@ -212,6 +218,7 @@
<include>org.tukaani:xz</include>
<include>org.apache.commons:commons-math3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
<include>it.unimi.dsi:fastutil</include>
</includes>
</artifactSet>
<transformers>
@@ -55,6 +55,8 @@

public abstract class LearnerBaseUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(LearnerBaseUDTF.class);
private static final int DEFAULT_SPARSE_DIMS = 16384;
private static final int DEFAULT_DENSE_DIMS = 16777216;

protected final boolean enableNewModel;
protected boolean dense_model;
@@ -120,7 +122,7 @@ protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)

denseModel = cl.hasOption("dense");
if (denseModel) {
modelDims = Primitives.parseInt(cl.getOptionValue("dims"), 16777216);
modelDims = Primitives.parseInt(cl.getOptionValue("dims"), DEFAULT_DENSE_DIMS);
}
disableHalfFloat = cl.hasOption("disable_halffloat");

@@ -168,7 +170,7 @@ private final PredictionModel createOldModel(@Nullable String label) {
PredictionModel model;
final boolean useCovar = useCovariance();
if (dense_model) {
if (disable_halffloat == false && model_dims > 16777216) {
if (disable_halffloat == false && model_dims > DEFAULT_DENSE_DIMS) {
logger.info("Build a space efficient dense model with " + model_dims
+ " initial dimensions" + (useCovar ? " w/ covariances" : ""));
model = new SpaceEfficientDenseModel(model_dims, useCovar);
@@ -199,7 +201,7 @@ private final PredictionModel createNewModel(@Nullable String label) {
PredictionModel model;
final boolean useCovar = useCovariance();
if (dense_model) {
if (disable_halffloat == false && model_dims > 16777216) {
if (disable_halffloat == false && model_dims > DEFAULT_DENSE_DIMS) {
logger.info("Build a space efficient dense model with " + model_dims
+ " initial dimensions" + (useCovar ? " w/ covariances" : ""));
model = new NewSpaceEfficientDenseModel(model_dims, useCovar);
@@ -229,9 +231,11 @@ private final PredictionModel createNewModel(@Nullable String label) {
protected final Optimizer createOptimizer(@CheckForNull Map<String, String> options) {
Preconditions.checkNotNull(options);
if (dense_model) {
return DenseOptimizerFactory.create(model_dims, options);
return DenseOptimizerFactory.create(model_dims < 0 ? DEFAULT_DENSE_DIMS : model_dims,
options);
} else {
return SparseOptimizerFactory.create(model_dims, options);
return SparseOptimizerFactory.create(model_dims < 0 ? DEFAULT_SPARSE_DIMS : model_dims,
options);
}
}

@@ -23,11 +23,12 @@
import hivemall.model.FeatureValue;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
import hivemall.utils.collections.maps.Int2FloatOpenHashTable.IMapIterator;
import hivemall.optimizer.LossFunctions;
import hivemall.utils.hashing.HashFunction;
import hivemall.utils.lang.Preconditions;
import it.unimi.dsi.fastutil.ints.Int2FloatMap;
import it.unimi.dsi.fastutil.ints.Int2FloatMaps;
import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap;

import java.util.ArrayList;
import java.util.List;
@@ -72,9 +73,9 @@ public final class KernelExpansionPassiveAggressiveUDTF extends BinaryOnlineClas
// Model parameters

private float _w0;
private Int2FloatOpenHashTable _w1;
private Int2FloatOpenHashTable _w2;
private Int2FloatOpenHashTable _w3;
private Int2FloatMap _w1;
private Int2FloatMap _w2;
private Int2FloatMap _w3;

// ------------------------------------

@@ -182,11 +183,11 @@ public float eta(float loss, PredictionResult margin) {
@Override
protected PredictionModel createModel() {
this._w0 = 0.f;
this._w1 = new Int2FloatOpenHashTable(16384);
this._w1 = new Int2FloatOpenHashMap(16384);
_w1.defaultReturnValue(0.f);
this._w2 = new Int2FloatOpenHashTable(16384);
this._w2 = new Int2FloatOpenHashMap(16384);
_w2.defaultReturnValue(0.f);
this._w3 = new Int2FloatOpenHashTable(16384);
this._w3 = new Int2FloatOpenHashMap(16384);
_w3.defaultReturnValue(0.f);

return null;
@@ -351,13 +352,12 @@ public void close() throws HiveException {

row[2] = w1;
row[3] = w2;
final Int2FloatOpenHashTable w2map = _w2;
final IMapIterator w1itor = _w1.entries();
while (w1itor.next() != -1) {
int k = w1itor.getKey();
final Int2FloatMap w2map = _w2;
for (Int2FloatMap.Entry e : Int2FloatMaps.fastIterable(_w1)) {
int k = e.getIntKey();
Preconditions.checkArgument(k > 0, HiveException.class);
h.set(k);
w1.set(w1itor.getValue());
w1.set(e.getFloatValue());
w2.set(w2map.get(k));
forward(row); // h(f), w1, w2
}
@@ -369,12 +369,12 @@ public void close() throws HiveException {
row[3] = null;
row[4] = hk;
row[5] = w3;
final IMapIterator w3itor = _w3.entries();
while (w3itor.next() != -1) {
int k = w3itor.getKey();

for (Int2FloatMap.Entry e : Int2FloatMaps.fastIterable(_w3)) {
int k = e.getIntKey();
Preconditions.checkArgument(k > 0, HiveException.class);
hk.set(k);
w3.set(w3itor.getValue());
w3.set(e.getFloatValue());
forward(row); // hk(f), w3
}
this._w3 = null;
@@ -18,6 +18,9 @@
*/
package hivemall.common;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

@@ -61,6 +64,13 @@ public double getCumulativeLoss() {
return currLosses;
}

public double getAverageLoss(@Nonnegative final long numInstances) {
if (numInstances == 0) {
return 0.d;
}
return currLosses / numInstances;
}

public double getPreviousLoss() {
return prevLosses;
}
@@ -88,43 +98,42 @@ public boolean isConverged(final long observedTrainingExamples) {

if (currLosses > prevLosses) {
if (logger.isInfoEnabled()) {
logger.info("Iteration #" + curIter + " currLoss `" + currLosses
+ "` > prevLosses `" + prevLosses + '`');
logger.info("Iteration #" + curIter + " current cumulative loss `" + currLosses
+ "` > previous cumulative loss `" + prevLosses + '`');
}
this.readyToFinishIterations = false;
return false;
}

final double changeRate = (prevLosses - currLosses) / prevLosses;
final double changeRate = getChangeRate();
if (changeRate < convergenceRate) {
if (readyToFinishIterations) {
// NOTE: never be true at the first iteration where prevLosses == Double.POSITIVE_INFINITY
if (logger.isInfoEnabled()) {
logger.info("Training converged at " + curIter + "-th iteration. [curLosses="
+ currLosses + ", prevLosses=" + prevLosses + ", changeRate="
+ changeRate + ']');
logger.info("Training converged at " + curIter + "-th iteration!\n"
+ getInfo(observedTrainingExamples));
}
return true;
} else {
if (logger.isInfoEnabled()) {
logger.info("Iteration #" + curIter + " [curLosses=" + currLosses
+ ", prevLosses=" + prevLosses + ", changeRate=" + changeRate
+ ", #trainingExamples=" + observedTrainingExamples + ']');
logger.info(getInfo(observedTrainingExamples));
}
this.readyToFinishIterations = true;
}
} else {
if (logger.isInfoEnabled()) {
logger.info("Iteration #" + curIter + " [curLosses=" + currLosses + ", prevLosses="
+ prevLosses + ", changeRate=" + changeRate + ", #trainingExamples="
+ observedTrainingExamples + ']');
logger.info(getInfo(observedTrainingExamples));
}
this.readyToFinishIterations = false;
}

return false;
}

double getChangeRate() {
return (prevLosses - currLosses) / prevLosses;
}

public void next() {
this.prevLosses = currLosses;
this.currLosses = 0.d;
@@ -135,4 +144,16 @@ public int getCurrentIteration() {
return curIter;
}

@Nonnull
public String getInfo(@Nonnegative final long observedTrainingExamples) {
final StringBuilder buf = new StringBuilder();
buf.append("Iteration #").append(curIter).append(" | ");
buf.append("average loss=").append(getAverageLoss(observedTrainingExamples)).append(", ");
buf.append("current cumulative loss=").append(currLosses).append(", ");
buf.append("previous cumulative loss=").append(prevLosses).append(", ");
buf.append("change rate=").append(getChangeRate()).append(", ");
buf.append("#trainingExamples=").append(observedTrainingExamples);
return buf.toString();
}

}
@@ -23,9 +23,9 @@
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.collections.lists.LongArrayList;
import hivemall.utils.collections.maps.Int2LongOpenHashTable;
import hivemall.utils.collections.maps.Int2LongOpenHashTable.MapIterator;
import hivemall.utils.lang.NumberUtils;
import it.unimi.dsi.fastutil.ints.Int2LongMap;
import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap;

import java.text.NumberFormat;
import java.util.Locale;
@@ -42,9 +42,9 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
// LEARNING PARAMS
private float _w0;
@Nonnull
private final Int2LongOpenHashTable _map;
final Int2LongMap _map;
@Nonnull
private final HeapBuffer _buf;
final HeapBuffer _buf;

@Nonnull
private final LongArrayList _freelistW;
@@ -69,7 +69,8 @@ public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachi
public FFMStringFeatureMapModel(@Nonnull FFMHyperParameters params) {
super(params);
this._w0 = 0.f;
this._map = new Int2LongOpenHashTable(DEFAULT_MAPSIZE);
this._map = new Int2LongOpenHashMap(DEFAULT_MAPSIZE);
_map.defaultReturnValue(-1L);
this._buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE);
this._freelistW = new LongArrayList();
this._freelistV = new LongArrayList();
@@ -326,54 +327,4 @@ public String toString() {
return getStatistics();
}

@Nonnull
EntryIterator entries() {
return new EntryIterator(this);
}

static final class EntryIterator {

@Nonnull
private final MapIterator dictItor;
@Nonnull
private final Entry entryProbeW;
@Nonnull
private final Entry entryProbeV;

EntryIterator(@Nonnull FFMStringFeatureMapModel model) {
this.dictItor = model._map.entries();
this.entryProbeW = new Entry(model._buf, 1);
this.entryProbeV = new Entry(model._buf, model._factor);
}

@Nonnull
Entry getEntryProbeW() {
return entryProbeW;
}

@Nonnull
Entry getEntryProbeV() {
return entryProbeV;
}

boolean hasNext() {
return dictItor.hasNext();
}

boolean next() {
return dictItor.next() != -1;
}

int getEntryIndex() {
return dictItor.getKey();
}

@Nonnull
void getEntry(@Nonnull final Entry probe) {
long offset = dictItor.getValue();
probe.setOffset(offset);
}

}

}

0 comments on commit ad15923

Please sign in to comment.