Skip to content

Commit aa42e1d

Browse files
committed
had been calling linear fitter wrong (added checks for that)
added new log-linear fitter (minimizes sq-error, but doesn't match expectations)
1 parent b4cad30 commit aa42e1d

File tree

5 files changed

+274
-13
lines changed

5 files changed

+274
-13
lines changed

Count/src/com/mzlabs/count/ctab/CTab.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import com.mzlabs.count.op.impl.SimpleSum;
1515
import com.mzlabs.count.op.impl.ThreadedSum;
1616
import com.mzlabs.count.op.iter.OrderStepperTot;
17-
import com.mzlabs.count.util.LinearFitter;
17+
import com.mzlabs.count.util.Fitter;
18+
import com.mzlabs.count.util.LogLinearFitter;
1819
import com.mzlabs.count.zeroone.ZeroOneCounter;
1920

2021

@@ -214,7 +215,7 @@ public static void main(final String[] args) {
214215
System.out.println("n" + "\t" + "total" + "\t" + "target" + "\t" + "count" + "\t" + "date" + "\t" + "cacheSizes" + "\t" + "tableFinishTimeEst");
215216
for(int n=8;n<=10;++n) {
216217
final CTab ctab = new CTab(n,true);
217-
final LinearFitter lf = new LinearFitter(1);
218+
final Fitter lf = new LogLinearFitter();
218219
final int tLast = (n*n-3*n+2)/2;
219220
for(int total=0;total<=tLast;++total) {
220221
final Date startTime = new Date();
@@ -223,15 +224,15 @@ public static void main(final String[] args) {
223224
final Date curTime = new Date();
224225
long remainingTimeEstMS = 10000;
225226
if(total>0) {
226-
// simplistic model: log(time) ~ a + b*size
227+
// simplistic model: time ~ exp(a + b*size + c*size*size)
227228
final double[] x = { total, total*total };
228229
final double y = 10000.0+curTime.getTime() - startTime.getTime();
229-
lf.addObservation(x, Math.log(y),1.0);
230+
lf.addObservation(x, y,1.0);
230231
final double[] beta = lf.solve();
231232
double timeEstMS = 0.0;
232233
for(int j=total+1;j<=tLast;++j) {
233-
final double predict = LinearFitter.predict(beta,new double[] {j, j*j});
234-
timeEstMS += Math.exp(predict);
234+
final double predict = lf.predict(beta,new double[] {j, j*j});
235+
timeEstMS += predict;
235236
}
236237
remainingTimeEstMS = (long)Math.ceil(timeEstMS);
237238
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package com.mzlabs.count.util;
2+
3+
public interface Fitter {
4+
5+
/**
6+
* add a y ~ f(x) observation
7+
* @param x
8+
* @param y
9+
* @param wt weight of observation (set to 1.0 in many cases)
10+
*/
11+
public abstract void addObservation(final double[] x, final double y,
12+
final double wt);
13+
14+
public abstract double[] solve();
15+
16+
/**
17+
*
18+
* @param soln
19+
* @param x length(soln)==length(x)+1
20+
* @return
21+
*/
22+
public abstract double predict(final double[] soln, final double[] x);
23+
24+
}

Count/src/com/mzlabs/count/util/LinearFitter.java

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
* @author johnmount
1010
*
1111
*/
12-
public final class LinearFitter {
12+
public final class LinearFitter implements Fitter {
1313
private final ColtMatrix xTx;
1414
private final double[] xTy;
1515

@@ -23,14 +23,15 @@ public LinearFitter(final int n) {
2323
xTy = new double[n+1];
2424
}
2525

26-
/**
27-
* add a y ~ f(x) observation
28-
* @param x
29-
* @param y
30-
* @param wt weight of observation (set to 1.0 in many cases)
26+
/* (non-Javadoc)
27+
* @see com.mzlabs.count.util.Fitter#addObservation(double[], double, double)
3128
*/
29+
@Override
3230
public void addObservation(final double[] x, final double y, final double wt) {
3331
final int n = xTx.rows()-1;
32+
if(n!=x.length) {
33+
throw new IllegalArgumentException();
34+
}
3435
for(int i=0;i<=n;++i) {
3536
final double xi = i<n?x[i]:1.0;
3637
xTy[i] += wt*xi*y;
@@ -41,6 +42,10 @@ public void addObservation(final double[] x, final double y, final double wt) {
4142
}
4243
}
4344

45+
/* (non-Javadoc)
46+
* @see com.mzlabs.count.util.Fitter#solve()
47+
*/
48+
@Override
4449
public double[] solve() {
4550
final int n = xTx.rows()-1;
4651
final double epsilon = 1.0e-5;
@@ -56,8 +61,15 @@ public double[] solve() {
5661
return soln;
5762
}
5863

59-
public static double predict(final double[] soln, final double[] x) {
64+
/* (non-Javadoc)
65+
* @see com.mzlabs.count.util.Fitter#predict(double[], double[])
66+
*/
67+
@Override
68+
public double predict(final double[] soln, final double[] x) {
6069
final int n = soln.length-1;
70+
if((n!=x.length)||(n+1!=soln.length)) {
71+
throw new IllegalArgumentException();
72+
}
6173
double sum = 0.0;
6274
for(int i=0;i<=n;++i) {
6375
final double xi = i<n?x[i]:1.0;
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
package com.mzlabs.count.util;
2+
3+
import java.util.ArrayList;
4+
import java.util.Arrays;
5+
6+
7+
import com.winvector.linalg.LinalgFactory;
8+
import com.winvector.linalg.colt.ColtMatrix;
9+
10+
public final class LogLinearFitter implements Fitter {
11+
public static final class Obs {
12+
public final double[] x;
13+
public final double y;
14+
public final double wt;
15+
16+
public Obs(final double[] x, final double y, final double wt) {
17+
this.x = Arrays.copyOf(x,x.length);
18+
this.y = y;
19+
this.wt = wt;
20+
}
21+
22+
@Override
23+
public String toString() {
24+
final StringBuilder b = new StringBuilder();
25+
b.append("" + wt + ":[");
26+
for(final double xi:x) {
27+
b.append(" " + xi);
28+
}
29+
b.append(" ]-> " + y);
30+
return b.toString();
31+
}
32+
}
33+
34+
private final ArrayList<Obs> obs = new ArrayList<Obs>();
35+
36+
37+
@Override
38+
public void addObservation(final double[] x, final double y, final double wt) {
39+
if(!obs.isEmpty()) {
40+
final int n = obs.get(0).x.length;
41+
if(n!=x.length) {
42+
throw new IllegalArgumentException();
43+
}
44+
}
45+
final Obs obsi = new Obs(x,y,wt);
46+
obs.add(obsi);
47+
}
48+
49+
/**
50+
* minimize sum_i wt[i] (e^{beta.x[i]} - y[i])^2
51+
* via Newton's method over gradient (should equal zero) and Hessian (Jacobian of vector eqn)
52+
*
53+
*/
54+
55+
private static double dot(final double[] soln, final double[] x) {
56+
final int n = x.length;
57+
double sum = 0.0;
58+
for(int i=0;i<=n;++i) {
59+
final double xi = i<n?x[i]:1.0;
60+
sum += xi*soln[i];
61+
}
62+
return sum;
63+
}
64+
65+
private double errAndGradAndHessian(final double[] beta, final double[] grad, final ColtMatrix hessian) {
66+
final int dim = beta.length;
67+
Arrays.fill(grad,0.0);
68+
for(int i=0;i<dim;++i) {
69+
for(int j=0;j<dim;++j) {
70+
hessian.set(i,j,0.0);
71+
}
72+
}
73+
double err = 0.0;
74+
for(final Obs obsi: obs) {
75+
final double ebx = Math.exp(dot(beta,obsi.x));
76+
final double diff = obsi.y-ebx;
77+
err += diff*diff;
78+
final double gradCoef = -2*diff*ebx*obsi.wt;
79+
final double hessCoef = -2*(obsi.y-2*ebx)*ebx*obsi.wt;
80+
for(int i=0;i<dim;++i) {
81+
final double xi = i<dim-1?obsi.x[i]:1.0;
82+
grad[i] += gradCoef*xi;
83+
for(int j=0;j<dim;++j) {
84+
final double xj = j<dim-1?obsi.x[j]:1.0;
85+
final double hij = hessian.get(i,j);
86+
hessian.set(i,j,hij+xi*xj*hessCoef);
87+
}
88+
}
89+
}
90+
return err;
91+
}
92+
93+
@Override
94+
public double[] solve() {
95+
final LinalgFactory<ColtMatrix> factory = ColtMatrix.factory;
96+
final int dim = obs.get(0).x.length+1;
97+
// start at solution to log(y) ~ b.x
98+
final Fitter sf = new LinearFitter(dim-1);
99+
for(final Obs obsi: obs) {
100+
sf.addObservation(obsi.x, Math.log(Math.max(1.0,obsi.y)), obsi.wt);
101+
}
102+
final double[] beta = sf.solve();
103+
double bestErr = Double.POSITIVE_INFINITY;
104+
double[] bestBeta = Arrays.copyOf(beta,beta.length);
105+
final double[] grad = new double[dim];
106+
final ColtMatrix hessian = factory.newMatrix(dim, dim, false);
107+
int nFails = 0;
108+
out:
109+
while(true) {
110+
final double err = errAndGradAndHessian(beta,grad,hessian);
111+
if((null==bestBeta)||(err<bestErr)) {
112+
bestErr = err;
113+
bestBeta = Arrays.copyOf(beta,beta.length);
114+
nFails = 0;
115+
} else {
116+
++nFails;
117+
if(nFails>=5) {
118+
break out;
119+
}
120+
}
121+
double absGrad = 0.0;
122+
for(final double gi: grad) {
123+
absGrad += Math.abs(gi);
124+
}
125+
if(Double.isInfinite(absGrad)||Double.isNaN(absGrad)||(absGrad<=1.0e-8)) {
126+
break out;
127+
}
128+
try {
129+
// // neaten up system a touch before solving
130+
// double totAbs = 0.0;
131+
// for(int i=0;i<dim;++i) {
132+
// for(int j=0;j<dim;++j) {
133+
// totAbs += Math.abs(hessian.get(i,j));
134+
// }
135+
// }
136+
// if(Double.isInfinite(totAbs)||Double.isNaN(totAbs)||(totAbs<=1.0e-8)) {
137+
// break out;
138+
// }
139+
// final double scale = (dim*dim)/totAbs;
140+
// for(int i=0;i<dim;++i) {
141+
// grad[i] *= scale;
142+
// for(int j=0;j<dim;++j) {
143+
// hessian.set(i,j,hessian.get(i,j)*scale);
144+
// }
145+
// }
146+
// for(int i=0;i<dim;++i) {
147+
// hessian.set(i,i,hessian.get(i,i)+1.e-5); // Ridge term
148+
// }
149+
final double[] delta = hessian.solve(grad);
150+
for(final double di: delta) {
151+
if(Double.isNaN(di)||Double.isNaN(di)) {
152+
break out;
153+
}
154+
}
155+
double deltaAbs = 0.0;
156+
for(int i=0;i<dim;++i) {
157+
beta[i] -= delta[i];
158+
deltaAbs += Math.abs(delta[i]);
159+
}
160+
if(deltaAbs<=1.0e-7) {
161+
break out;
162+
}
163+
} catch (Exception ex) {
164+
break out;
165+
}
166+
}
167+
return bestBeta;
168+
}
169+
170+
@Override
171+
public double predict(final double[] soln, final double[] x) {
172+
final int n = obs.get(0).x.length;
173+
if((n!=x.length)||(n+1!=soln.length)) {
174+
throw new IllegalArgumentException();
175+
}
176+
return Math.exp(dot(soln,x));
177+
}
178+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package com.mzlabs.count.util;
2+
3+
import static org.junit.Assert.*;
4+
5+
import java.util.ArrayList;
6+
import java.util.Random;
7+
8+
import org.junit.Test;
9+
10+
import com.mzlabs.count.util.LogLinearFitter.Obs;
11+
12+
public class TestLogLinFitter {
13+
@Test
14+
public void testLFit() {
15+
final Fitter lf = new LinearFitter(2);
16+
final LogLinearFitter llf = new LogLinearFitter();
17+
final Random rand = new Random(343406L);
18+
final ArrayList<Obs> obs = new ArrayList<Obs>();
19+
for(int i=1;i<7;++i) {
20+
final double y = Math.exp(2.0*i + 3.0*i*i);
21+
for(int j=0;j<10;++j) {
22+
final double[] x = new double[] {i,i*i};
23+
final double yObserved = y*(1+0.3*rand.nextGaussian());
24+
llf.addObservation(x,yObserved,1.0);
25+
lf.addObservation(x,Math.log(Math.max(1.0,yObserved)),1.0);
26+
obs.add(new Obs(x,y,1.0));
27+
}
28+
}
29+
final double[] lsoln = lf.solve();
30+
final double[] llsoln = llf.solve();
31+
//System.out.println("" + "y" + "\t" + "fit" + "\t" + "llfit");
32+
double sqLError = 0.0;
33+
double sqLLError = 0.0;
34+
for(final Obs obsi: obs) {
35+
final double y = obsi.y;
36+
final double[] x = obsi.x;
37+
final double lfit = Math.exp(lf.predict(lsoln, x));
38+
final double llfit = llf.predict(llsoln, x);
39+
//System.out.println("" + y + "\t" + lfit + "\t" + llfit );
40+
sqLError += Math.pow(lfit-y,2);
41+
sqLLError += Math.pow(llfit-y,2);
42+
}
43+
//System.out.println("errors\t" + sqLError + "\t" + sqLLError);
44+
assertTrue(sqLLError<sqLError);
45+
}
46+
}

0 commit comments

Comments
 (0)