Skip to content

Commit 9bd524b

Browse files
committed
get rid of implicit dc term in fitters
1 parent 69e9e0b commit 9bd524b

File tree

9 files changed

+36
-40
lines changed

9 files changed

+36
-40
lines changed

Count/src/com/mzlabs/count/ctab/CTab.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,6 @@ public static void main(final String[] args) {
215215
System.out.println("n" + "\t" + "total" + "\t" + "target" + "\t" + "count" + "\t" + "date" + "\t" + "cacheSizes" + "\t" + "tableFinishTimeEst");
216216
for(int n=1;n<=9;++n) {
217217
final CTab ctab = new CTab(n,true);
218-
//final NewtonFitter lf = new NewtonFitter(new SquareLossOfExp());
219218
final NewtonFitter lf = new NewtonFitter(DirectPoissonJacobian.poissonLink);
220219
final int tLast = (n*n-3*n+2)/2;
221220
for(int total=0;total<=tLast;++total) {
@@ -226,14 +225,14 @@ public static void main(final String[] args) {
226225
long remainingTimeEstMS = 10000;
227226
if(total>2) {
228227
// simplistic model: time ~ exp(a + b*size)
229-
final double[] x = { total };
228+
final double[] x = { 1, total };
230229
final double y = 10000.0+curTime.getTime() - startTime.getTime();
231230
lf.addObservation(x,y,1.0);
232231
if(total>6) {
233232
final double[] beta = lf.solve();
234233
double timeEstMS = 0.0;
235234
for(int j=total+1;j<=tLast;++j) {
236-
final double predict = lf.link.evalEst(beta,new double[] {j});
235+
final double predict = lf.link.evalEst(beta,new double[] {1,j});
237236
timeEstMS += predict;
238237
}
239238
remainingTimeEstMS = (long)Math.ceil(timeEstMS);

Count/src/com/mzlabs/fit/BalanceJacobianCalc.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ public interface BalanceJacobianCalc {
44
/**
55
*
66
* @param obs
7-
* @param beta beta.length==obs.x.length+1
7+
* @param beta beta.length==obs.x.length
88
* @return balance eqns and Jacobian coefficients of underlying loss function for parametes beta with respect to datum obs
99
*/
1010
BalanceJacobianCoef calc(Obs obs, double[] beta);
1111

1212
/**
1313
*
1414
* @param beta
15-
* @param x x.length==beta.length-1 (last beta is the coefficient matchng an implicit constant term by convention)
15+
* @param x x.length==beta.length (last beta is the coefficient matchng an implicit constant term by convention)
1616
* @return the estimate of y given parameters beta and datum x
1717
*/
1818
double evalEst(double[] beta, double[] x);

Count/src/com/mzlabs/fit/GLMModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ public void balanceAndJacobian(final Iterable<Obs> obs, final double[] beta,
2424
for(final Obs obsi: obs) {
2525
final BalanceJacobianCoef ghc = balanceJacobianCalc.calc(obsi,beta);
2626
for(int i=0;i<dim;++i) {
27-
final double xi = i<dim-1?obsi.x[i]:1.0;
27+
final double xi = obsi.x[i];
2828
balance[i] += ghc.balanceCoef*xi;
2929
for(int j=0;j<dim;++j) {
30-
final double xj = j<dim-1?obsi.x[j]:1.0;
30+
final double xj = obsi.x[j];
3131
final double hij = jacobian.get(i,j);
3232
jacobian.set(i,j,hij+xi*xj*ghc.jacobianCoef);
3333
}

Count/src/com/mzlabs/fit/LinearFitter.java

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,24 @@ public final class LinearFitter implements Fitter {
1919
*/
2020
public LinearFitter(final int n) {
2121
final LinalgFactory<ColtMatrix> factory = ColtMatrix.factory;
22-
xTx = factory.newMatrix(n+1,n+1,false);
23-
xTy = new double[n+1];
22+
xTx = factory.newMatrix(n,n,false);
23+
xTy = new double[n];
2424
}
2525

2626
/* (non-Javadoc)
2727
* @see com.mzlabs.count.util.Fitter#addObservation(double[], double, double)
2828
*/
2929
@Override
3030
public void addObservation(final double[] x, final double y, final double wt) {
31-
final int n = xTx.rows()-1;
31+
final int n = xTx.rows();
3232
if(n!=x.length) {
3333
throw new IllegalArgumentException();
3434
}
35-
for(int i=0;i<=n;++i) {
36-
final double xi = i<n?x[i]:1.0;
35+
for(int i=0;i<n;++i) {
36+
final double xi = x[i];
3737
xTy[i] += wt*xi*y;
38-
for(int j=0;j<=n;++j) {
39-
final double xj = j<n?x[j]:1.0;
38+
for(int j=0;j<n;++j) {
39+
final double xj = x[j];
4040
xTx.set(i,j,xTx.get(i, j)+wt*xi*xj);
4141
}
4242
}
@@ -47,15 +47,15 @@ public void addObservation(final double[] x, final double y, final double wt) {
4747
*/
4848
@Override
4949
public double[] solve() {
50-
final int n = xTx.rows()-1;
50+
final int n = xTx.rows();
5151
final double epsilon = 1.0e-5;
5252
final double[] xTxii = new double[n+1];
53-
for(int i=0;i<=n;++i) {
53+
for(int i=0;i<n;++i) {
5454
xTxii[i] = xTx.get(i,i);
5555
xTx.set(i,i,xTxii[i]+epsilon); // Ridge term
5656
}
5757
final double[] soln = xTx.solve(xTy);
58-
for(int i=0;i<=n;++i) {
58+
for(int i=0;i<n;++i) {
5959
xTx.set(i,i,xTxii[i]);
6060
}
6161
return soln;
@@ -65,14 +65,13 @@ public double[] solve() {
6565
* @see com.mzlabs.count.util.Fitter#predict(double[], double[])
6666
*/
6767
public double predict(final double[] soln, final double[] x) {
68-
final int n = soln.length-1;
69-
if((n!=x.length)||(n+1!=soln.length)) {
68+
final int n = soln.length;
69+
if((n!=x.length)||(n!=soln.length)) {
7070
throw new IllegalArgumentException();
7171
}
7272
double sum = 0.0;
73-
for(int i=0;i<=n;++i) {
74-
final double xi = i<n?x[i]:1.0;
75-
sum += xi*soln[i];
73+
for(int i=0;i<n;++i) {
74+
sum += x[i]*soln[i];
7675
}
7776
return sum;
7877
}

Count/src/com/mzlabs/fit/NewtonFitter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ public void addObservation(final double[] x, final double y, final double wt) {
3434
@Override
3535
public double[] solve() {
3636
final LinalgFactory<ColtMatrix> factory = ColtMatrix.factory;
37-
final int dim = obs.get(0).x.length+1;
37+
final int dim = obs.get(0).x.length;
3838
// roughly: often solving y ~ f(b.x), so start at f^-1(y) ~ b.x
39-
final Fitter sf = new LinearFitter(dim-1);
39+
final Fitter sf = new LinearFitter(dim);
4040
for(final Obs obsi: obs) {
4141
sf.addObservation(obsi.x, link.heuristicLink(obsi.y), obsi.wt);
4242
}

Count/src/com/mzlabs/fit/Obs.java

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,25 @@ public Obs(final double[] x, final double y, final double wt) {
1515

1616
/**
1717
*
18-
* @param soln soln.length==x.length+1
18+
* @param soln soln.length==x.length
1919
* @param x
2020
* @return
2121
*/
2222
public static double dot(final double soln[], final double[] x) {
2323
final int n = x.length;
24-
if(soln.length!=n+1) {
24+
if(soln.length!=n) {
2525
throw new IllegalArgumentException();
2626
}
2727
double sum = 0.0;
28-
for(int i=0;i<=n;++i) {
29-
final double xi = i<n?x[i]:1.0;
30-
sum += xi*soln[i];
28+
for(int i=0;i<n;++i) {
29+
sum += x[i]*soln[i];
3130
}
32-
return sum;
31+
return sum;
3332
}
3433

3534
/**
3635
*
37-
* @param soln.length==x.length+1
36+
* @param soln.length==x.length
3837
* @return
3938
*/
4039
public double dot(final double[] soln) {

Count/src/com/mzlabs/fit/SquareLossOfExp.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ public void balanceAndJacobian(final Iterable<Obs> obs, final double[] beta,
3737
final double gradCoef = -2*diff*ebx*obsi.wt;
3838
final double hessCoef = -2*(obsi.y-2*ebx)*ebx*obsi.wt;
3939
for(int i=0;i<dim;++i) {
40-
final double xi = i<dim-1?obsi.x[i]:1.0;
40+
final double xi = obsi.x[i];
4141
balance[i] += gradCoef*xi;
4242
for(int j=0;j<dim;++j) {
43-
final double xj = j<dim-1?obsi.x[j]:1.0;
43+
final double xj = obsi.x[j];
4444
final double hij = jacobian.get(i,j);
4545
jacobian.set(i,j,hij+xi*xj*hessCoef);
4646
}

Count/src/com/mzlabs/fit/VectorFnWithJacobian.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public interface VectorFnWithJacobian {
1111
/**
1212
*
1313
* @param beta
14-
* @param x x.length==beta.length-1 (last beta is the coefficient matchng an implicit constant term by convention)
14+
* @param x x.length==beta.length (last beta is the coefficient matchng an implicit constant term by convention)
1515
* @return the estimate of y given parameters beta and datum x
1616
*/
1717
double evalEst(double[] beta, double[] x);

Count/tests/com/mzlabs/count/util/TestLogLinFitter.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
public class TestLogLinFitter {
2020
@Test
2121
public void testLFit() {
22-
final LinearFitter lf = new LinearFitter(1);
22+
final LinearFitter lf = new LinearFitter(2);
2323
final NewtonFitter llf = new NewtonFitter(new SquareLossOfExp());
2424
final Random rand = new Random(343406L);
2525
final ArrayList<Obs> obs = new ArrayList<Obs>();
2626
for(int i=1;i<7;++i) {
2727
final double y = Math.exp(2.0*i);
2828
for(int j=0;j<10;++j) {
29-
final double[] x = new double[] {i};
29+
final double[] x = new double[] {1,i};
3030
final double yObserved = y*(1+0.3*rand.nextGaussian());
3131
llf.addObservation(x,yObserved,1.0);
3232
lf.addObservation(x,Math.log(Math.max(1.0,yObserved)),1.0);
@@ -55,7 +55,7 @@ public void testLFit() {
5555
@Test
5656
public void testPLinks() {
5757
final double y = 1.55528;
58-
final double[] x = { 5.0 };
58+
final double[] x = { 1.0, 5.0 };
5959
final Obs obs = new Obs(x,y,1.0);
6060
final double[] beta = { 0.2 , -0.1};
6161
final BalanceJacobianCoef lpgh = LinkBasedGradHess.poissonGradHess.calc(obs, beta);
@@ -72,7 +72,7 @@ public void testPFit() {
7272
final Random rand = new Random(343406L);
7373
for(int i=1;i<=5;++i) {
7474
final double y = Math.exp(0.4*i) + rand.nextGaussian();
75-
final double[] x = new double[] {i};
75+
final double[] x = new double[] {1,i};
7676
llf.addObservation(x,y,1.0);
7777
obs.add(new Obs(x,y,1.0));
7878
}
@@ -89,8 +89,7 @@ public void testPFit() {
8989
final double[] x = obsi.x;
9090
final double llfit = llf.link.evalEst(llsoln,x);
9191
for(int i=0;i<dim;++i) {
92-
final double xi = i<dim-1?obsi.x[i]:1.0;
93-
sums[i] += obsi.wt*xi*(y-llfit);
92+
sums[i] += obsi.wt*x[i]*(y-llfit);
9493
}
9594
}
9695
for(final double si: sums) {

0 commit comments

Comments
 (0)