Skip to content

Commit d12c5f9

Browse files
committed
re-org code into a proper GLM framework
1 parent 91d6680 commit d12c5f9

File tree

8 files changed

+133
-74
lines changed

8 files changed

+133
-74
lines changed

Diff for: Count/src/com/mzlabs/count/ctab/CTab.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
import com.mzlabs.count.op.impl.SimpleSum;
1515
import com.mzlabs.count.op.impl.ThreadedSum;
1616
import com.mzlabs.count.op.iter.OrderStepperTot;
17-
import com.mzlabs.count.util.Fitter;
18-
import com.mzlabs.count.util.LogLinearFitter;
1917
import com.mzlabs.count.zeroone.ZeroOneCounter;
18+
import com.mzlabs.fit.Fitter;
19+
import com.mzlabs.fit.GLMFitter;
20+
import com.mzlabs.fit.SquareLossOfExp;
2021

2122

2223
public final class CTab {
@@ -215,7 +216,7 @@ public static void main(final String[] args) {
215216
System.out.println("n" + "\t" + "total" + "\t" + "target" + "\t" + "count" + "\t" + "date" + "\t" + "cacheSizes" + "\t" + "tableFinishTimeEst");
216217
for(int n=1;n<=9;++n) {
217218
final CTab ctab = new CTab(n,true);
218-
final Fitter lf = new LogLinearFitter();
219+
final Fitter lf = new GLMFitter(new SquareLossOfExp());
219220
final int tLast = (n*n-3*n+2)/2;
220221
for(int total=0;total<=tLast;++total) {
221222
final Date startTime = new Date();

Diff for: Count/src/com/mzlabs/count/util/Fitter.java renamed to Count/src/com/mzlabs/fit/Fitter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.mzlabs.count.util;
1+
package com.mzlabs.fit;
22

33
public interface Fitter {
44

Diff for: Count/src/com/mzlabs/count/util/LogLinearFitter.java renamed to Count/src/com/mzlabs/fit/GLMFitter.java

+9-66
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.mzlabs.count.util;
1+
package com.mzlabs.fit;
22

33
import java.util.ArrayList;
44
import java.util.Arrays;
@@ -7,32 +7,13 @@
77
import com.winvector.linalg.LinalgFactory;
88
import com.winvector.linalg.colt.ColtMatrix;
99

10-
public final class LogLinearFitter implements Fitter {
11-
public static final class Obs {
12-
public final double[] x;
13-
public final double y;
14-
public final double wt;
15-
16-
public Obs(final double[] x, final double y, final double wt) {
17-
this.x = Arrays.copyOf(x,x.length);
18-
this.y = y;
19-
this.wt = wt;
20-
}
21-
22-
@Override
23-
public String toString() {
24-
final StringBuilder b = new StringBuilder();
25-
b.append("" + wt + ":[");
26-
for(final double xi:x) {
27-
b.append(" " + xi);
28-
}
29-
b.append(" ]-> " + y);
30-
return b.toString();
31-
}
32-
}
33-
34-
private final ArrayList<Obs> obs = new ArrayList<Obs>();
10+
public final class GLMFitter implements Fitter {
11+
public final Link link;
12+
public final ArrayList<Obs> obs = new ArrayList<Obs>();
3513

14+
public GLMFitter(final Link link) {
15+
this.link = link;
16+
}
3617

3718
@Override
3819
public void addObservation(final double[] x, final double y, final double wt) {
@@ -52,44 +33,6 @@ public void addObservation(final double[] x, final double y, final double wt) {
5233
*
5334
*/
5435

55-
private static double dot(final double[] soln, final double[] x) {
56-
final int n = x.length;
57-
double sum = 0.0;
58-
for(int i=0;i<=n;++i) {
59-
final double xi = i<n?x[i]:1.0;
60-
sum += xi*soln[i];
61-
}
62-
return sum;
63-
}
64-
65-
private double errAndGradAndHessian(final double[] beta, final double[] grad, final ColtMatrix hessian) {
66-
final int dim = beta.length;
67-
Arrays.fill(grad,0.0);
68-
for(int i=0;i<dim;++i) {
69-
for(int j=0;j<dim;++j) {
70-
hessian.set(i,j,0.0);
71-
}
72-
}
73-
double err = 0.0;
74-
for(final Obs obsi: obs) {
75-
final double ebx = Math.exp(dot(beta,obsi.x));
76-
final double diff = obsi.y-ebx;
77-
err += diff*diff;
78-
final double gradCoef = -2*diff*ebx*obsi.wt;
79-
final double hessCoef = -2*(obsi.y-2*ebx)*ebx*obsi.wt;
80-
for(int i=0;i<dim;++i) {
81-
final double xi = i<dim-1?obsi.x[i]:1.0;
82-
grad[i] += gradCoef*xi;
83-
for(int j=0;j<dim;++j) {
84-
final double xj = j<dim-1?obsi.x[j]:1.0;
85-
final double hij = hessian.get(i,j);
86-
hessian.set(i,j,hij+xi*xj*hessCoef);
87-
}
88-
}
89-
}
90-
return err;
91-
}
92-
9336
@Override
9437
public double[] solve() {
9538
final LinalgFactory<ColtMatrix> factory = ColtMatrix.factory;
@@ -107,7 +50,7 @@ public double[] solve() {
10750
int nFails = 0;
10851
out:
10952
while(true) {
110-
final double err = errAndGradAndHessian(beta,grad,hessian);
53+
final double err = link.lossAndGradAndHessian(obs,beta,grad,hessian);
11154
if((null==bestBeta)||(err<bestErr)) {
11255
bestErr = err;
11356
bestBeta = Arrays.copyOf(beta,beta.length);
@@ -173,6 +116,6 @@ public double predict(final double[] soln, final double[] x) {
173116
if((n!=x.length)||(n+1!=soln.length)) {
174117
throw new IllegalArgumentException();
175118
}
176-
return Math.exp(dot(soln,x));
119+
return Math.exp(Obs.dot(soln,x));
177120
}
178121
}

Diff for: Count/src/com/mzlabs/count/util/LinearFitter.java renamed to Count/src/com/mzlabs/fit/LinearFitter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.mzlabs.count.util;
1+
package com.mzlabs.fit;
22

33
import com.winvector.linalg.LinalgFactory;
44
import com.winvector.linalg.colt.ColtMatrix;

Diff for: Count/src/com/mzlabs/fit/Link.java

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.mzlabs.fit;
2+
3+
import com.winvector.linalg.colt.ColtMatrix;
4+
5+
public interface Link {
6+
public double lossAndGradAndHessian(Iterable<Obs> obs, double[] beta, double[] grad, ColtMatrix hessian);
7+
public double inverseLink(double y);
8+
}

Diff for: Count/src/com/mzlabs/fit/Obs.java

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package com.mzlabs.fit;
2+
3+
import java.util.Arrays;
4+
5+
public final class Obs {
6+
public final double[] x;
7+
public final double y;
8+
public final double wt;
9+
10+
public Obs(final double[] x, final double y, final double wt) {
11+
this.x = Arrays.copyOf(x,x.length);
12+
this.y = y;
13+
this.wt = wt;
14+
}
15+
16+
/**
17+
*
18+
* @param soln soln.length==x.length+1
19+
* @param x
20+
* @return
21+
*/
22+
public static double dot(final double soln[], final double[] x) {
23+
final int n = x.length;
24+
if(soln.length!=n+1) {
25+
throw new IllegalArgumentException();
26+
}
27+
double sum = 0.0;
28+
for(int i=0;i<=n;++i) {
29+
final double xi = i<n?x[i]:1.0;
30+
sum += xi*soln[i];
31+
}
32+
return sum;
33+
}
34+
35+
/**
36+
*
37+
* @param soln.length==x.length+1
38+
* @return
39+
*/
40+
public double dot(final double[] soln) {
41+
return dot(soln,x);
42+
}
43+
44+
@Override
45+
public String toString() {
46+
final StringBuilder b = new StringBuilder();
47+
b.append("" + wt + ":[");
48+
for(final double xi:x) {
49+
b.append(" " + xi);
50+
}
51+
b.append(" ]-> " + y);
52+
return b.toString();
53+
}
54+
}

Diff for: Count/src/com/mzlabs/fit/SquareLossOfExp.java

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package com.mzlabs.fit;
2+
3+
import java.util.Arrays;
4+
5+
import com.winvector.linalg.colt.ColtMatrix;
6+
7+
/**
8+
* sum(y-e^{x.dot(beta)})^2
9+
* @author johnmount
10+
*
11+
*/
12+
public class SquareLossOfExp implements Link {
13+
14+
@Override
15+
public double lossAndGradAndHessian(final Iterable<Obs> obs, final double[] beta,
16+
final double[] grad, final ColtMatrix hessian) {
17+
final int dim = beta.length;
18+
Arrays.fill(grad,0.0);
19+
for(int i=0;i<dim;++i) {
20+
for(int j=0;j<dim;++j) {
21+
hessian.set(i,j,0.0);
22+
}
23+
}
24+
double err = 0.0;
25+
for(final Obs obsi: obs) {
26+
final double ebx = Math.exp(obsi.dot(beta));
27+
final double diff = obsi.y-ebx;
28+
err += diff*diff;
29+
final double gradCoef = -2*diff*ebx*obsi.wt;
30+
final double hessCoef = -2*(obsi.y-2*ebx)*ebx*obsi.wt;
31+
for(int i=0;i<dim;++i) {
32+
final double xi = i<dim-1?obsi.x[i]:1.0;
33+
grad[i] += gradCoef*xi;
34+
for(int j=0;j<dim;++j) {
35+
final double xj = j<dim-1?obsi.x[j]:1.0;
36+
final double hij = hessian.get(i,j);
37+
hessian.set(i,j,hij+xi*xj*hessCoef);
38+
}
39+
}
40+
}
41+
return err;
42+
}
43+
44+
@Override
45+
public double inverseLink(final double y) {
46+
return Math.exp(y);
47+
}
48+
49+
}

Diff for: Count/tests/com/mzlabs/count/util/TestLogLinFitter.java

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
package com.mzlabs.count.util;
22

3-
import static org.junit.Assert.*;
3+
import static org.junit.Assert.assertTrue;
44

55
import java.util.ArrayList;
66
import java.util.Random;
77

88
import org.junit.Test;
99

10-
import com.mzlabs.count.util.LogLinearFitter.Obs;
10+
import com.mzlabs.fit.Fitter;
11+
import com.mzlabs.fit.GLMFitter;
12+
import com.mzlabs.fit.LinearFitter;
13+
import com.mzlabs.fit.Obs;
14+
import com.mzlabs.fit.SquareLossOfExp;
1115

1216
public class TestLogLinFitter {
1317
@Test
1418
public void testLFit() {
1519
final Fitter lf = new LinearFitter(2);
16-
final LogLinearFitter llf = new LogLinearFitter();
20+
final Fitter llf = new GLMFitter(new SquareLossOfExp());
1721
final Random rand = new Random(343406L);
1822
final ArrayList<Obs> obs = new ArrayList<Obs>();
1923
for(int i=1;i<7;++i) {

0 commit comments

Comments
 (0)