Skip to content

Commit 5a75d66

Browse files
committed
rebase fitter on balance and Jacobian instead of gradient and Hessian
1 parent 0b5ca93 commit 5a75d66

File tree

12 files changed

+270
-148
lines changed

12 files changed

+270
-148
lines changed

Count/src/com/mzlabs/count/ctab/CTab.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import com.mzlabs.count.op.impl.ThreadedSum;
1616
import com.mzlabs.count.op.iter.OrderStepperTot;
1717
import com.mzlabs.count.zeroone.ZeroOneCounter;
18-
import com.mzlabs.fit.GLMModel;
18+
import com.mzlabs.fit.DirectPoissonJacobian;
1919
import com.mzlabs.fit.NewtonFitter;
2020

2121

@@ -216,7 +216,7 @@ public static void main(final String[] args) {
216216
for(int n=1;n<=9;++n) {
217217
final CTab ctab = new CTab(n,true);
218218
//final NewtonFitter lf = new NewtonFitter(new SquareLossOfExp());
219-
final NewtonFitter lf = new NewtonFitter(GLMModel.PoissonLink);
219+
final NewtonFitter lf = new NewtonFitter(DirectPoissonJacobian.poissonLink);
220220
final int tLast = (n*n-3*n+2)/2;
221221
for(int total=0;total<=tLast;++total) {
222222
final Date startTime = new Date();
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.mzlabs.fit;
2+
3+
4+
/**
5+
* This returns a set of equations that should be zero (what we have been calling a gradient) and the Jacobian of these equations (what we
6+
* have been calling the hessian). We don't care if there is an underlying function that we are the gradient of. And if we are using this
7+
* to build a GLM link we are assuming the probability model is such that the gradient of log(P(y|f(obs.dot(beta)))) is our given balance equations.
8+
* Note we only use the first two positions of the link.
9+
* @author johnmount
10+
*
11+
*/
12+
public final class BalanceBasedJacobian implements BalanceJacobianCalc {
13+
public final Link link;
14+
15+
public BalanceBasedJacobian(final Link link) {
16+
this.link = link;
17+
}
18+
19+
@Override
20+
public BalanceJacobianCoef calc(final Obs obs, final double[] beta) {
21+
final double[] z = new double[3];
22+
final double bx = obs.dot(beta);
23+
link.invLink(bx,z);
24+
final double balanceCoef = obs.wt*(obs.y-z[0]); // hard coded Poisson gradient component
25+
final double jacobianCoef = obs.wt*(-z[1]);
26+
return new BalanceJacobianCoef(balanceCoef,jacobianCoef);
27+
}
28+
29+
@Override
30+
public double evalEst(final double[] beta, final double[] x) {
31+
return link.invLink(Obs.dot(beta,x));
32+
}
33+
34+
public static final BalanceBasedJacobian poissonJacobian = new BalanceBasedJacobian(LinkBasedGradHess.LogLink);
35+
public static GLMModel poissonLink = new GLMModel(poissonJacobian);
36+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package com.mzlabs.fit;
2+
3+
public interface BalanceJacobianCalc {
4+
/**
5+
*
6+
* @param obs
7+
* @param beta beta.length==obs.x.length+1
8+
* @return balance eqns and Jacobian coefficients of underlying loss function for parametes beta with respect to datum obs
9+
*/
10+
BalanceJacobianCoef calc(Obs obs, double[] beta);
11+
12+
/**
13+
*
14+
* @param beta
15+
* @param x x.length==beta.length-1 (last beta is the coefficient matchng an implicit constant term by convention)
16+
* @return the estimate of y given parameters beta and datum x
17+
*/
18+
double evalEst(double[] beta, double[] x);
19+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.mzlabs.fit;
2+
3+
public final class BalanceJacobianCoef {
4+
public final double balanceCoef;
5+
public final double jacobianCoef;
6+
7+
public BalanceJacobianCoef(final double balanceCoef, final double jacobianCoef) {
8+
this.balanceCoef = balanceCoef;
9+
this.jacobianCoef = jacobianCoef;
10+
}
11+
12+
public double absDiff(final BalanceJacobianCoef o) {
13+
return Math.abs(balanceCoef-o.balanceCoef) + Math.abs(jacobianCoef-jacobianCoef);
14+
}
15+
16+
@Override
17+
public String toString() {
18+
return "g:" + balanceCoef + " h:" + jacobianCoef;
19+
}
20+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.mzlabs.fit;
2+
3+
4+
/**
5+
* Direct implementation of gradient and Hessian coefficient calculation
6+
* @author johnmount
7+
*
8+
*/
9+
public final class DirectPoissonJacobian implements BalanceJacobianCalc {
10+
private DirectPoissonJacobian() {
11+
}
12+
13+
@Override
14+
public BalanceJacobianCoef calc(Obs obs, double[] beta) {
15+
final double bx = obs.dot(beta);
16+
final double balCoef = obs.wt*(obs.y-Math.exp(bx)); // hard coded Poisson gradient component
17+
final double jacCoef = obs.wt*(-Math.exp(bx)); // hard coded Poisson hessian component
18+
return new BalanceJacobianCoef(balCoef,jacCoef);
19+
}
20+
21+
@Override
22+
public double evalEst(double[] beta, double[] x) {
23+
return Math.exp(Obs.dot(beta,x));
24+
}
25+
26+
public static final DirectPoissonJacobian poissonGradHess = new DirectPoissonJacobian();
27+
public static GLMModel poissonLink = new GLMModel(poissonGradHess);
28+
}

Count/src/com/mzlabs/fit/GLMModel.java

Lines changed: 16 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -4,124 +4,39 @@
44

55
import com.winvector.linalg.colt.ColtMatrix;
66

7-
public class GLMModel implements VectorFnWithGradAndHessian {
8-
public final boolean debug;
9-
public final ProbYGivenZ prob;
10-
public final Link link;
11-
12-
13-
public GLMModel(final ProbYGivenZ prob, final Link link, final boolean debug) {
14-
this.prob = prob;
15-
this.link = link;
16-
this.debug = debug;
17-
}
18-
19-
@Override
20-
public double evalEst(final double[] beta, final double[] x) {
21-
return link.invLink(Obs.dot(beta,x));
22-
}
7+
public class GLMModel implements VectorFnWithJacobian {
8+
public final BalanceJacobianCalc balanceJacobianCalc;
239

24-
public Double checkGradCoef(final Obs obsi, final double[] beta) {
25-
return null;
26-
}
27-
28-
public Double checkHessianCoef(final Obs obsi, final double[] beta) {
29-
return null;
10+
public GLMModel(final BalanceJacobianCalc balanceJacobianCalc) {
11+
this.balanceJacobianCalc = balanceJacobianCalc;
3012
}
3113

32-
public static final ProbYGivenZ PoissonProbability = new ProbYGivenZ() {
33-
@Override
34-
public void eval(final double y, final double z, final double[] res) {
35-
if(y>0) {
36-
final double fyz = Math.exp(y*Math.log(z)-z); // ignoring a gamma(y+1) term here
37-
res[0] = fyz;
38-
res[1] = fyz*(y/z-1);
39-
res[2] = fyz*((y/z-1)*(y/z-1)-y/(z*z));
40-
} else {
41-
final double fyz = Math.exp(-z); // ignoring a gamma(y+1) term here
42-
res[0] = fyz;
43-
res[1] = -fyz;
44-
res[2] = fyz;
45-
}
46-
}
47-
};
48-
public static final Link LogLink = new Link() {
49-
@Override
50-
public void invLink(final double z, final double[] res) {
51-
final double ez = Math.exp(z);
52-
res[0] = ez;
53-
res[1] = ez;
54-
res[2] = ez;
55-
}
56-
@Override
57-
public double invLink(final double z) {
58-
return Math.exp(z);
59-
}
60-
};
61-
62-
public static GLMModel PoissonLink = new GLMModel(PoissonProbability,LogLink,false);
63-
public static GLMModel PoissonLinkDebug = new GLMModel(PoissonProbability,LogLink,true) {
64-
@Override
65-
public Double checkGradCoef(final Obs obsi, final double[] beta) {
66-
final double bx = obsi.dot(beta);
67-
return obsi.wt*(obsi.y-Math.exp(bx)); // hard coded Poisson gradient component
68-
}
69-
@Override
70-
public Double checkHessianCoef(final Obs obsi, final double[] beta) {
71-
final double bx = obsi.dot(beta);
72-
return obsi.wt*(-Math.exp(bx)); // hard coded Poisson hessian component
73-
}
74-
};
75-
76-
77-
private static double sq(final double z) {
78-
return z*z;
79-
}
80-
81-
8214
@Override
83-
public double lossAndGradAndHessian(final Iterable<Obs> obs, final double[] beta,
84-
final double[] grad, final ColtMatrix hessian) {
15+
public void balanceAndJacobian(final Iterable<Obs> obs, final double[] beta,
16+
final double[] balance, final ColtMatrix jacobian) {
8517
final int dim = beta.length;
86-
Arrays.fill(grad,0.0);
18+
Arrays.fill(balance,0.0);
8719
for(int i=0;i<dim;++i) {
8820
for(int j=0;j<dim;++j) {
89-
hessian.set(i,j,0.0);
21+
jacobian.set(i,j,0.0);
9022
}
9123
}
92-
final double[] p = new double[3];
93-
final double[] z = new double[3];
94-
double sum = 0.0;
9524
for(final Obs obsi: obs) {
96-
final double bx = obsi.dot(beta);
97-
link.invLink(bx,z);
98-
prob.eval(obsi.y,z[0],p);
99-
sum += obsi.wt*Math.log(z[0]);
100-
final double gradScale = obsi.wt*z[1]*p[1]/p[0];
101-
final double hessCoef = obsi.wt*(-sq(z[1])*sq(p[1])/sq(p[0]) + sq(z[1])*p[2]/p[0] + z[2]*p[1]/p[0]);
102-
if(debug) {
103-
final Double gradP = checkGradCoef(obsi,beta);
104-
final Double hessP = checkHessianCoef(obsi,beta);
105-
if((null!=gradP)&&(Math.abs(gradScale-gradP)>=1.0e-6)) {
106-
throw new IllegalStateException("gradient checks didn't match");
107-
}
108-
if((null!=hessP)&&(Math.abs(hessCoef-hessP)>1.0e-6)) {
109-
throw new IllegalStateException("Hessian checks didn't match");
110-
}
111-
}
25+
final BalanceJacobianCoef ghc = balanceJacobianCalc.calc(obsi,beta);
11226
for(int i=0;i<dim;++i) {
11327
final double xi = i<dim-1?obsi.x[i]:1.0;
114-
grad[i] += gradScale*xi;
28+
balance[i] += ghc.balanceCoef*xi;
11529
for(int j=0;j<dim;++j) {
11630
final double xj = j<dim-1?obsi.x[j]:1.0;
117-
final double hij = hessian.get(i,j);
118-
hessian.set(i,j,hij+xi*xj*hessCoef);
31+
final double hij = jacobian.get(i,j);
32+
jacobian.set(i,j,hij+xi*xj*ghc.jacobianCoef);
11933
}
12034
}
12135
}
122-
return sum;
12336
}
12437

125-
126-
38+
@Override
39+
public double evalEst(double[] beta, double[] x) {
40+
return balanceJacobianCalc.evalEst(beta,x);
41+
}
12742
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package com.mzlabs.fit;
2+
3+
/**
4+
* calculate the scalar portion of gradient and hessian (balance and jacobian in larger scope) directly using
5+
* link and probability model
6+
* @author johnmount
7+
*
8+
*/
9+
public final class LinkBasedGradHess implements BalanceJacobianCalc {
10+
public final ProbYGivenZ prob;
11+
public final Link link;
12+
13+
public LinkBasedGradHess(final ProbYGivenZ prob, final Link link) {
14+
this.prob = prob;
15+
this.link = link;
16+
}
17+
18+
private static double sq(final double z) {
19+
return z*z;
20+
}
21+
22+
@Override
23+
public BalanceJacobianCoef calc(Obs obs, double[] beta) {
24+
final double[] p = new double[3];
25+
final double[] z = new double[3];
26+
final double bx = obs.dot(beta);
27+
link.invLink(bx,z);
28+
prob.eval(obs.y,z[0],p);
29+
final double gradCoef = obs.wt*z[1]*p[1]/p[0];
30+
final double hessCoef = obs.wt*(-sq(z[1])*sq(p[1])/sq(p[0]) + sq(z[1])*p[2]/p[0] + z[2]*p[1]/p[0]);
31+
return new BalanceJacobianCoef(gradCoef,hessCoef);
32+
}
33+
34+
@Override
35+
public double evalEst(double[] beta, double[] x) {
36+
return link.invLink(Obs.dot(beta,x));
37+
}
38+
39+
40+
41+
42+
public static final ProbYGivenZ PoissonProbability = new ProbYGivenZ() {
43+
@Override
44+
public void eval(final double y, final double z, final double[] res) {
45+
if(y>0) {
46+
final double fyz = Math.exp(y*Math.log(z)-z); // ignoring a gamma(y+1) term here
47+
res[0] = fyz;
48+
res[1] = fyz*(y/z-1);
49+
res[2] = fyz*((y/z-1)*(y/z-1)-y/(z*z));
50+
} else {
51+
final double fyz = Math.exp(-z); // ignoring a gamma(y+1) term here
52+
res[0] = fyz;
53+
res[1] = -fyz;
54+
res[2] = fyz;
55+
}
56+
}
57+
};
58+
59+
public static final Link LogLink = new Link() {
60+
@Override
61+
public void invLink(final double z, final double[] res) {
62+
final double ez = Math.exp(z);
63+
res[0] = ez;
64+
res[1] = ez;
65+
res[2] = ez;
66+
}
67+
@Override
68+
public double invLink(final double z) {
69+
return Math.exp(z);
70+
}
71+
};
72+
73+
74+
public static final LinkBasedGradHess poissonGradHess = new LinkBasedGradHess(PoissonProbability,LogLink);
75+
public static GLMModel poissonLink = new GLMModel(poissonGradHess);
76+
}

0 commit comments

Comments
 (0)