Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
3804789 replay starts
Browse files Browse the repository at this point in the history
  • Loading branch information
DrRacket committed Dec 7, 2017
1 parent b565661 commit a6de3e0
Show file tree
Hide file tree
Showing 64 changed files with 2,750 additions and 1,934 deletions.
21 changes: 14 additions & 7 deletions core/src/main/java/hivemall/common/ConversionState.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,25 @@ public boolean isConverged(final long observedTrainingExamples) {
if (changeRate < convergenceRate) {
if (readyToFinishIterations) {
// NOTE: never be true at the first iteration where prevLosses == Double.POSITIVE_INFINITY
logger.info("Training converged at " + curIter + "-th iteration. [curLosses="
+ currLosses + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate
+ ']');
if (logger.isInfoEnabled()) {
logger.info("Training converged at " + curIter + "-th iteration. [curLosses="
+ currLosses + ", prevLosses=" + prevLosses + ", changeRate="
+ changeRate + ']');
}
return true;
} else {
if (logger.isInfoEnabled()) {
logger.info("Iteration #" + curIter + " [curLosses=" + currLosses
+ ", prevLosses=" + prevLosses + ", changeRate=" + changeRate
+ ", #trainingExamples=" + observedTrainingExamples + ']');
}
this.readyToFinishIterations = true;
}
} else {
if (logger.isDebugEnabled()) {
logger.debug("Iteration #" + curIter + " [curLosses=" + currLosses
+ ", prevLosses=" + prevLosses + ", changeRate=" + changeRate
+ ", #trainingExamples=" + observedTrainingExamples + ']');
if (logger.isInfoEnabled()) {
logger.info("Iteration #" + curIter + " [curLosses=" + currLosses + ", prevLosses="
+ prevLosses + ", changeRate=" + changeRate + ", #trainingExamples="
+ observedTrainingExamples + ']');
}
this.readyToFinishIterations = false;
}
Expand Down
242 changes: 173 additions & 69 deletions core/src/main/java/hivemall/fm/Entry.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,27 @@

import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.SizeOf;
import hivemall.utils.math.MathUtils;

import java.util.Arrays;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;

class Entry {

@Nonnull
protected final HeapBuffer _buf;
@Nonnegative
protected final int _size;
@Nonnegative
protected final int _factors;

// temporary variables used only in training phase
protected int _key;
@Nonnegative
protected long _offset;

Entry(@Nonnull HeapBuffer buf, int factors) {
Expand All @@ -39,128 +49,210 @@ class Entry {
this._factors = factors;
}

Entry(@Nonnull HeapBuffer buf, int factors, long offset) {
this(buf, factors, Entry.sizeOf(factors), offset);
Entry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) {
this(buf, 1, key, offset);
}

Entry(@Nonnull HeapBuffer buf, int factors, int key, @Nonnegative long offset) {
this(buf, factors, Entry.sizeOf(factors), key, offset);
}

private Entry(@Nonnull HeapBuffer buf, int factors, int size, long offset) {
private Entry(@Nonnull HeapBuffer buf, int factors, int size, int key, @Nonnegative long offset) {
this._buf = buf;
this._size = size;
this._factors = factors;
setOffset(offset);
this._key = key;
this._offset = offset;
}

int getSize() {
final int getSize() {
return _size;
}

long getOffset() {
final int getKey() {
return _key;
}

final long getOffset() {
return _offset;
}

void setOffset(long offset) {
final void setOffset(final long offset) {
this._offset = offset;
}

float getW() {
final float getW() {
return _buf.getFloat(_offset);
}

void setW(final float value) {
final void setW(final float value) {
_buf.putFloat(_offset, value);
}

void getV(@Nonnull final float[] Vf) {
final long offset = _offset + SizeOf.FLOAT;
final void getV(@Nonnull final float[] Vf) {
final long offset = _offset;
final int len = Vf.length;
for (int i = 0; i < len; i++) {
Vf[i] = _buf.getFloat(offset + SizeOf.FLOAT * i);
for (int f = 0; f < len; f++) {
long index = offset + SizeOf.FLOAT * f;
Vf[f] = _buf.getFloat(index);
}
}

void setV(@Nonnull final float[] Vf) {
final long offset = _offset + SizeOf.FLOAT;
final void setV(@Nonnull final float[] Vf) {
final long offset = _offset;
final int len = Vf.length;
for (int i = 0; i < len; i++) {
_buf.putFloat(offset + SizeOf.FLOAT * i, Vf[i]);
for (int f = 0; f < len; f++) {
long index = offset + SizeOf.FLOAT * f;
_buf.putFloat(index, Vf[f]);
}
}

float getV(final int f) {
return _buf.getFloat(_offset + SizeOf.FLOAT + SizeOf.FLOAT * f);
final float getV(final int f) {
long index = _offset + SizeOf.FLOAT * f;
return _buf.getFloat(index);
}

void setV(final int f, final float value) {
long index = _offset + SizeOf.FLOAT + SizeOf.FLOAT * f;
final void setV(final int f, final float value) {
long index = _offset + SizeOf.FLOAT * f;
_buf.putFloat(index, value);
}

double getSumOfSquaredGradientsV() {
double getSumOfSquaredGradients(@Nonnegative int f) {
throw new UnsupportedOperationException();
}

void addGradientV(float grad) {
void addGradient(@Nonnegative int f, float grad) {
throw new UnsupportedOperationException();
}

float updateZ(float gradW, float alpha) {
final float updateZ(final float gradW, final float alpha) {
float w = getW();
return updateZ(0, w, gradW, alpha);
}

float updateZ(@Nonnegative int f, float W, float gradW, float alpha) {
throw new UnsupportedOperationException();
}

double updateN(float gradW) {
final double updateN(final float gradW) {
return updateN(0, gradW);
}

double updateN(@Nonnegative int f, float gradW) {
throw new UnsupportedOperationException();
}

static int sizeOf(int factors) {
return SizeOf.FLOAT + SizeOf.FLOAT * factors;
boolean removable() {
if (!isEntryW(_key)) {// entry for V
final long offset = _offset;
for (int f = 0; f < _factors; f++) {
final float Vf = _buf.getFloat(offset + SizeOf.FLOAT * f);
if (!MathUtils.closeToZero(Vf, 1E-9f)) {
return false;
}
}
}
return true;
}

void clear() {};

static int sizeOf(@Nonnegative final int factors) {
Preconditions.checkArgument(factors >= 1, "Factors must be greather than 0: " + factors);
return SizeOf.FLOAT * factors;
}

static boolean isEntryW(final int i) {
return i < 0;
}

@Override
public String toString() {
if (Entry.isEntryW(_key)) {
return "W=" + getW();
} else {
float[] Vf = new float[_factors];
getV(Vf);
return "V=" + Arrays.toString(Vf);
}
}

static class AdaGradEntry extends Entry {
static final class AdaGradEntry extends Entry {

final long _gg_offset;

AdaGradEntry(@Nonnull HeapBuffer buf, int factors, long offset) {
super(buf, factors, AdaGradEntry.sizeOf(factors), offset);
this._gg_offset = _offset + SizeOf.FLOAT + SizeOf.FLOAT * _factors;
AdaGradEntry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) {
this(buf, 1, key, offset);
}

private AdaGradEntry(@Nonnull HeapBuffer buf, int factors, int size, long offset) {
super(buf, factors, size, offset);
this._gg_offset = _offset + SizeOf.FLOAT + SizeOf.FLOAT * _factors;
AdaGradEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key,
@Nonnegative long offset) {
super(buf, factors, AdaGradEntry.sizeOf(factors), key, offset);
this._gg_offset = _offset + Entry.sizeOf(factors);
}

@Override
double getSumOfSquaredGradientsV() {
return _buf.getDouble(_gg_offset);
double getSumOfSquaredGradients(@Nonnegative final int f) {
Preconditions.checkArgument(f >= 0);

long offset = _gg_offset + SizeOf.DOUBLE * f;
return _buf.getDouble(offset);
}

@Override
void addGradientV(float grad) {
double v = _buf.getDouble(_gg_offset);
void addGradient(@Nonnegative final int f, final float grad) {
Preconditions.checkArgument(f >= 0);

long offset = _gg_offset + SizeOf.DOUBLE * f;
double v = _buf.getDouble(offset);
v += grad * grad;
_buf.putDouble(_gg_offset, v);
_buf.putDouble(offset, v);
}

static int sizeOf(int factors) {
return Entry.sizeOf(factors) + SizeOf.DOUBLE;
@Override
void clear() {
for (int f = 0; f < _factors; f++) {
long offset = _gg_offset + SizeOf.DOUBLE * f;
_buf.putDouble(offset, 0.d);
}
}

static int sizeOf(@Nonnegative final int factors) {
return Entry.sizeOf(factors) + SizeOf.DOUBLE * factors;
}

@Override
public String toString() {
final double[] gg = new double[_factors];
for (int f = 0; f < _factors; f++) {
gg[f] = getSumOfSquaredGradients(f);
}
return super.toString() + ", gg=" + Arrays.toString(gg);
}

}

static final class FTRLEntry extends AdaGradEntry {
static final class FTRLEntry extends Entry {

final long _z_offset;

FTRLEntry(@Nonnull HeapBuffer buf, int factors, long offset) {
super(buf, factors, FTRLEntry.sizeOf(factors), offset);
this._z_offset = _gg_offset + SizeOf.DOUBLE;
FTRLEntry(@Nonnull HeapBuffer buf, int key, long offset) {
this(buf, 1, key, offset);
}

FTRLEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key, long offset) {
super(buf, factors, FTRLEntry.sizeOf(factors), key, offset);
this._z_offset = _offset + Entry.sizeOf(factors);
}

@Override
float updateZ(float gradW, float alpha) {
final float W = getW();
final float z = getZ();
final double n = getN();
float updateZ(final int f, final float W, final float gradW, final float alpha) {
Preconditions.checkArgument(f >= 0);

final long zOffset = offsetZ(f);

final float z = _buf.getFloat(zOffset);
final double n = _buf.getFloat(offsetN(f)); // implicit cast to float

double gg = gradW * gradW;
float sigma = (float) ((Math.sqrt(n + gg) - Math.sqrt(n)) / alpha);
Expand All @@ -171,44 +263,56 @@ float updateZ(float gradW, float alpha) {
+ gradW + ", sigma=" + sigma + ", W=" + W + ", n=" + n + ", gg=" + gg
+ ", alpha=" + alpha);
}
setZ(newZ);
_buf.putFloat(zOffset, newZ);
return newZ;
}

private float getZ() {
return _buf.getFloat(_z_offset);
}

private void setZ(final float value) {
_buf.putFloat(_z_offset, value);
}

@Override
double updateN(final float gradW) {
final double n = getN();
double updateN(final int f, final float gradW) {
Preconditions.checkArgument(f >= 0);

final long nOffset = offsetN(f);

final double n = _buf.getFloat(nOffset);
final double newN = n + gradW * gradW;
if (!NumberUtils.isFinite(newN)) {
throw new IllegalStateException("Got newN " + newN + " where n=" + n + ", gradW="
+ gradW);
}
setN(newN);
_buf.putFloat(nOffset, NumberUtils.castToFloat(newN)); // cast may throw ArithmeticException
return newN;
}

private double getN() {
long index = _z_offset + SizeOf.FLOAT;
return _buf.getDouble(index);
private long offsetZ(@Nonnegative final int f) {
return _z_offset + SizeOf.FLOAT * f;
}

private void setN(final double value) {
long index = _z_offset + SizeOf.FLOAT;
_buf.putDouble(index, value);
private long offsetN(@Nonnegative final int f) {
return _z_offset + SizeOf.FLOAT * (_factors + f);
}

static int sizeOf(int factors) {
return AdaGradEntry.sizeOf(factors) + SizeOf.FLOAT + SizeOf.DOUBLE;
@Override
void clear() {
for (int f = 0; f < _factors; f++) {
_buf.putFloat(offsetZ(f), 0.f);
_buf.putFloat(offsetN(f), 0.f);
}
}

static int sizeOf(@Nonnegative final int factors) {
return Entry.sizeOf(factors) + (SizeOf.FLOAT + SizeOf.FLOAT) * factors;
}

@Override
public String toString() {
final float[] Z = new float[_factors];
final float[] N = new float[_factors];
for (int f = 0; f < _factors; f++) {
Z[f] = _buf.getFloat(offsetZ(f));
N[f] = _buf.getFloat(offsetN(f));
}
return super.toString() + ", Z=" + Arrays.toString(Z) + ", N=" + Arrays.toString(N);
}
}

}
Loading

0 comments on commit a6de3e0

Please sign in to comment.