From 0c10392d2a3c96b40df57e6b406333e0a239b9f9 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Wed, 24 Oct 2018 17:14:15 +0900 Subject: [PATCH 1/3] Updated for debugging purpose --- core/src/main/java/hivemall/model/FeatureValue.java | 5 +++++ core/src/main/java/hivemall/optimizer/Optimizer.java | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/hivemall/model/FeatureValue.java b/core/src/main/java/hivemall/model/FeatureValue.java index d7aecd8c2..209f1ed8b 100644 --- a/core/src/main/java/hivemall/model/FeatureValue.java +++ b/core/src/main/java/hivemall/model/FeatureValue.java @@ -177,4 +177,9 @@ public static void parseFeatureAsString(@Nonnull final String s, } } + @Override + public String toString() { + return feature + ":" + value; + } + } diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java index 0cbac42e4..587adf2e0 100644 --- a/core/src/main/java/hivemall/optimizer/Optimizer.java +++ b/core/src/main/java/hivemall/optimizer/Optimizer.java @@ -71,7 +71,9 @@ public void proceedStep() { protected float update(@Nonnull final IWeightValue weight, final float gradient) { float oldWeight = weight.get(); float delta = computeDelta(weight, gradient); - float newWeight = oldWeight - _eta.eta(_numStep) * _reg.regularize(oldWeight, delta); + float eta = _eta.eta(_numStep); + float reg = _reg.regularize(oldWeight, delta); + float newWeight = oldWeight - eta * reg; weight.set(newWeight); return newWeight; } From e0dc4b954650c6751d6e37ee5ecf6c9656872b16 Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Wed, 24 Oct 2018 17:15:03 +0900 Subject: [PATCH 2/3] Introduced gradient clipping by value to avoid exploding gradients --- .../main/java/hivemall/GeneralLearnerBaseUDTF.java | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java index 0198e773d..4aad70a19 100644 --- a/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java +++ b/core/src/main/java/hivemall/GeneralLearnerBaseUDTF.java @@ -76,6 +76,8 @@ public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF { private static final Log logger = LogFactory.getLog(GeneralLearnerBaseUDTF.class); + private static final float MAX_DLOSS = 1e+12f; + private static final float MIN_DLOSS = -1e+12f; private ListObjectInspector featureListOI; private PrimitiveObjectInspector targetOI; @@ -168,6 +170,8 @@ protected Options getOptions() { opts.addOption("loss", "loss_function", true, getLossOptionDescription()); opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]"); + opts.addOption("iters", "iterations", true, + "The maximum number of iterations [default: 10]"); // conversion check opts.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: OFF]"); @@ -451,11 +455,16 @@ protected void update(@Nonnull final FeatureValue[] features, final float target float loss = lossFunction.loss(predicted, target); cvState.incrLoss(loss); // retain cumulative loss to check convergence - final float dloss = lossFunction.dloss(predicted, target); + float dloss = lossFunction.dloss(predicted, target); if (dloss == 0.f) { optimizer.proceedStep(); return; } + if (dloss < MIN_DLOSS) { + dloss = MIN_DLOSS; + } else if (dloss > MAX_DLOSS) { + dloss = MAX_DLOSS; + } if (is_mini_batch) { accumulateUpdate(features, dloss); From 7e932e99cfd990bb47ff7acfed44c19678fadc8f Mon Sep 17 00:00:00 2001 From: Makoto Yui Date: Wed, 24 Oct 2018 17:15:52 +0900 Subject: [PATCH 3/3] Added a unit test for gradient clipping --- .../regression/GeneralRegressorUDTFTest.java | 82 +++++++++++++++++- .../hivemall/regression/clipping_data.tsv.gz | Bin 0 -> 7948 bytes 2 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 core/src/test/resources/hivemall/regression/clipping_data.tsv.gz diff --git a/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java b/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java index 27553404f..a2a8696f3 100644 --- a/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java +++ b/core/src/test/java/hivemall/regression/GeneralRegressorUDTFTest.java @@ -22,13 +22,23 @@ import static hivemall.utils.hadoop.HiveUtils.lazyLong; import static hivemall.utils.hadoop.HiveUtils.lazyString; +import hivemall.TestUtils; +import hivemall.model.FeatureValue; +import hivemall.utils.lang.StringUtils; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; import java.util.List; +import java.util.StringTokenizer; +import java.util.zip.GZIPInputStream; import javax.annotation.Nonnull; -import hivemall.TestUtils; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.Collector; @@ -346,6 +356,76 @@ public void testSerialization() throws HiveException { new Object[][] {{Arrays.asList("1:-2", "2:-1"), 10.f}}); } + @Test + public void testGradientClippingSGD() throws IOException, HiveException { + String filePath = "clipping_data.tsv.gz"; + String options = "-loss squaredloss -opt SGD -reg no -eta0 0.01 -iter 1"; + + GeneralRegressorUDTF udtf = new GeneralRegressorUDTF(); + + ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector); + ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, options); + + udtf.initialize(new ObjectInspector[] {stringListOI, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, params}); + + BufferedReader reader = readFile(filePath); + String line = reader.readLine(); + for (int i = 0; line != null; i++) { + //System.out.println("> " + i); + //System.out.println(line); + + StringTokenizer tokenizer = new StringTokenizer(line, " "); + double y = Double.parseDouble(tokenizer.nextToken()); + List X = new ArrayList(); + while (tokenizer.hasMoreTokens()) { + String f = tokenizer.nextToken(); + X.add(f); + } + FeatureValue[] features = udtf.parseFeatures(X); + if (DEBUG) { + printLine(features, y); + } + + float yhat = udtf.predict(features); + //System.out.println(yhat); + if (Float.isNaN(yhat)) { + Assert.fail("NaN cause in line: " + i); + } + + udtf.process(new Object[] {X, y}); + + line = reader.readLine(); + } + + udtf.finalizeTraining(); + } + + private static void printLine(FeatureValue[] features, final double y) { + Arrays.sort(features, new Comparator() { + @Override + public int compare(FeatureValue o1, FeatureValue o2) { + int f1 = Integer.parseInt(o1.getFeatureAsString()); + int f2 = Integer.parseInt(o2.getFeatureAsString()); + return Integer.compare(f1, f2); + } + }); + System.out.print(y); + System.out.print(' '); + System.out.println(StringUtils.join(features, ' ')); + } + + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + InputStream is = GeneralRegressorUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } + private static void println(String msg) { if (DEBUG) { System.out.println(msg); diff --git a/core/src/test/resources/hivemall/regression/clipping_data.tsv.gz b/core/src/test/resources/hivemall/regression/clipping_data.tsv.gz new file mode 100644 index 0000000000000000000000000000000000000000..c55a9992cc728ba8733697e08ae49ace44596f07 GIT binary patch literal 7948 zcmd^Ec~n|w+MiC^Nz$}+603}AF-#Ihq`|CLuc<;UM z^FF`jc>tAe*buVwE9CFip)%64v(pZz#3vp)a%k@nOy2J_V@QrmJ89QfOF^g3zlpB+ zdehwf^?44%`#-&=8=D9Qsxp^rswACI@v$h5@bgkyyOWJNikm{2g z2VdXz+#ilTJS;{Gy;aUeUiD0#v;&8h7aAa!=($OWBlLJ|r~^I(IFC;$Y4veN7f0Vo zXo}urUwZqeV77zubiwGIl=AcVx>>j?hmocIOjdYcS4=$36+GpJG@(2VOyyZ??CPTahgy(TQa5yOL)dcUebG|bEo=1RO&3tKvA zmOHqSgtN^?jbRAYYcv9K{M=B<*Rut|&SZgnIgE?DtaqWgoMy6jIxXK04i@%6bclA= zBrGa1B-{2^Pg(oSZ2!^vGnPxoo5Du&Pr~oUaAV-1cRa3-i%;5RspS5{VeiZl2{mOu zr!S$2kfNIAc3|j@s%z{jALV=3_%#)?&v1YmRhqL#@?(fEV)JBbE@ua-d{}-{<#M;iitxn2ErQ-A zQEW5hN|Hs2p!BD{v7Hq99~+EWfQ~I>GzdPVc{7|rQGPo0C| zwY^R$ab;9Sw857rt>iToRYnS8iDtQwYC}J~=Ac?aDxgB@ZR_bIZR4~H(1}!}gTuMG zRR}X|NG$%QkBP?bUX9Pk5X##3Ie5MMRZcuz+S80C`ASPRhoE=EQ8iD|iTzQ{eeK>U zUG0W!{XONa2LSv7v@D=(Pt2fgQf0I!zxs2wk-)n(YUOYSnt zmB{hZUCzYw4^Mp3w@5G7)@pjsbHnf$5Pq}kuNnI%E?wLnU!U!Q zGD)(_q&UI}R6V~`>>9-XLLcBB0P=uz@xSNnH5~PuUJ?DXx|1jSc&bL6(E+y4I1I26 z4wUh&UmBV%GP02e*&9A5Dn`ruWhbDIqrI7qlV(H%iNoKsQ^i<_+!`rJ#mB*w(K&4n zNpsd4q~ZL^0*v*Z`KE;Q>q>rzwmqmi|5ZO(e5b9LZlCgjh#fwb^KN2&W~^mjS)m-; zm*Y6mv|PAe%s_Kbor=`Eo@!(N)QV*xew5M8S9Tf)a?h0+CR0#n30qp-&>GJ^8@wRS zu&(sXQ0>zi5^4RWM;jd;AJ~p!`l4I5u;Z|5ob7#484bBS2q^VpjVd9Hm9C{%-g|7xH9WUAdi#Mo$dPouo>-NNBv;a%w6Xs^RIH&asV2mLX zWAGXW>iyi85{cIrj()hZ%pyHx!AdutJM;9|REXy{SGOU!lZ()|22~>c)CtSC(>2_) zt|?}JaTG}7YWL1#sqdG47-3(f7dkc2c_Xw2Q_y`A)hAu;8E!;tJ;P$-3p{1Ag+YRt z$-Vo3gKEP1tZvzM^ts#?ZvGf3rf*A7-?3qSk4GES>k_k-Y($p6edj>5kG&RE}gTND{l22Z%XE&sKc7JU7D`_@QZHikdH{&o(kYCi9T ziu|=8ijG}(^vys0al_NUUlr!7)}agO_6>~wQ0H8d>8UwH5d$%eb*_V3+_E?{9K^$* zpM@Z^eI7&SOW_`IRCP>J9Y|Os-T1l;-wmxzs=e;|W*U6vSkXk*sG0Bylkd>Z>S8aX z%KIVIO;6g=ia!27`lT`DD-ZwwEZ!LZ}fg@8-napz32ZA(?qXd9x>^YGMqhGZ9A!(43| z+qz$e3DJ`!-cXGvV)Q-q$Z&GEEbzMwuxCWiNpJje6HH9S&Q*Y=7tuDRYHr?drni>X z+yw+{Dz1U%eMPLgPe`c0!749fM|?bxE}FeT@-C{vWdzuaAT9NC6%8xgBx4Tj?>LWT z+WO-Lw=7&P#TmM(s1z!Kpay-B5DoRYSpXZoS&-^kZSnhI@H879PN+;G}9n~pO9gycJJ$aLcKg|QC z7c8}Eu68SfB?(mC?Dc5x$XV0G|6F0j%)_+p{g)K}-M6tzTQ$eeV-rV5_clojeA|@p zeUo?Hpm2OwlhMcb;TyRgvZ<1Ztr<^XL@8lF4+8cYs1rLuM_@?yWMAh@eCov&9#}k$ zaA)TIEmYL3h_Sqa+&nxTyKxvcta>G3T6nl2WZYEeh%(mhB04~5{fC(Q{mHeO`galu1vh>R4bW)SqA2?38laP13L8u{ zIYhYr>5=Z@SD3^*l9GoSFUuy(!d;vnHF0n=OH3_y!X?MCLrZH-XIP&{=Wrv$1S#L& zTR_2QWJ}1;Gu~cx2(plVBf4%fZ=6({3|`WG-yn4X8xKhRiPIVS7XonT6gvttD3)h* z#wh+=)b6aJ(Ekyk3Wh(%xDnXrm^(7}mBEu|`k5!yEIQf%w<%&rvKYoF@>Q3V(kqP# z!OWLB&8n(3K9(@GyOM6l;pA3#rRsPN#V^@VI38y81Ni zziS?Tm{dJPCJ8Pc@zvNF>^H`C7>R%e-}nUd$$jJQwRkUpVW|J_aC31db?)!pl033THp|$+Y{!v8C6-0n{&&h`-f7i z<$J@gdfvUD3Zrt06u0M<4MM@u%ez*uvW5;3Bm)>ldS?p<7iS&9*IX_b^m@Os#{bIg!N~s zXwv&dHCtQWp=Dhm#w4SZ`HrNjVc{PE0j9|Bl7-#ClI z;Fw%=af2buat8}r`_S)QWh&)eM|Kr2)Cm>iu6iqLsto)kn*wn8nW#{dH|HH`cklgxjYO=2+aE=>@h{cNr{i}uU`?U@*2cY-oe*8V~v9sBl-HF1obXYY*+F3b;>g}p0}(5;@`Spze>XRdq0 zMjD)E`krC|gn?VW#`F|m7g%=xWu)4&Z(L8Z9!vCmAbHnAo`Ujc8UQfPz3^`ktex9P zs>#yU->fBYk$zyTUAr+gBKCH1TloXjY_(}GP6FCt9xU^&g0^6Pt|IIcYFiP_Az%adg;%yzm`G%H&-TVyLETwWK8F|?Wj;5+f%hy&zofn#=g^uVw%bA1_SWyqm6aj(uTwAU09ND z^DP93%L^Arsk-`ZdZURsJXT*NR$Po*U`rp#gHI&R@$MT!X&}b9JE=AR@7$otV8v6! zU>8PQXl;D#`P0h>aRv4sj<#@UwTok9!Jx+N@H~3Iq65=b$2LM3`T~hp#JGKvu%o`q zmX`x?$bS0)zmbb>Hd}4-bHg5HliT9I(Ve4^PI?MOVDky6NzI=42cSd?JhDy~zj&2y za!FXL?o%PsUSh*TmhK;2H|#H>r=m--87DTll^I?t7X}m~`x9AvZkW zXkI(7kssTPNV5H|lt36J^JtX08GraNvr4I`2yjj&RK9wxyl}o3u zB~8O(inEVCz3X?Y%wc%Cq;sh6x`U_5(TcrrAuM{8wR&K3 z^XNZMt}pG%=tel8)Wf-DQ7ZgPNT9P-5k=9=YtMY$P=2Al*LPxT2OE-HI^9wb{Iwu$ z_QJ2mE6c-{r6UKypq=3y?=DRZyxy@h-JLa)Dm{-m_fdUpso%DpMGE