|
4 | 4 |
|
5 | 5 | import com.winvector.linalg.colt.ColtMatrix; |
6 | 6 |
|
7 | | -public class GLMModel implements VectorFnWithGradAndHessian { |
8 | | - public final boolean debug; |
9 | | - public final ProbYGivenZ prob; |
10 | | - public final Link link; |
11 | | - |
12 | | - |
13 | | - public GLMModel(final ProbYGivenZ prob, final Link link, final boolean debug) { |
14 | | - this.prob = prob; |
15 | | - this.link = link; |
16 | | - this.debug = debug; |
17 | | - } |
18 | | - |
19 | | - @Override |
20 | | - public double evalEst(final double[] beta, final double[] x) { |
21 | | - return link.invLink(Obs.dot(beta,x)); |
22 | | - } |
| 7 | +public class GLMModel implements VectorFnWithJacobian { |
| 8 | + public final BalanceJacobianCalc balanceJacobianCalc; |
23 | 9 |
|
24 | | - public Double checkGradCoef(final Obs obsi, final double[] beta) { |
25 | | - return null; |
26 | | - } |
27 | | - |
28 | | - public Double checkHessianCoef(final Obs obsi, final double[] beta) { |
29 | | - return null; |
| 10 | + public GLMModel(final BalanceJacobianCalc balanceJacobianCalc) { |
| 11 | + this.balanceJacobianCalc = balanceJacobianCalc; |
30 | 12 | } |
31 | 13 |
|
32 | | - public static final ProbYGivenZ PoissonProbability = new ProbYGivenZ() { |
33 | | - @Override |
34 | | - public void eval(final double y, final double z, final double[] res) { |
35 | | - if(y>0) { |
36 | | - final double fyz = Math.exp(y*Math.log(z)-z); // ignoring a gamma(y+1) term here |
37 | | - res[0] = fyz; |
38 | | - res[1] = fyz*(y/z-1); |
39 | | - res[2] = fyz*((y/z-1)*(y/z-1)-y/(z*z)); |
40 | | - } else { |
41 | | - final double fyz = Math.exp(-z); // ignoring a gamma(y+1) term here |
42 | | - res[0] = fyz; |
43 | | - res[1] = -fyz; |
44 | | - res[2] = fyz; |
45 | | - } |
46 | | - } |
47 | | - }; |
48 | | - public static final Link LogLink = new Link() { |
49 | | - @Override |
50 | | - public void invLink(final double z, final double[] res) { |
51 | | - final double ez = Math.exp(z); |
52 | | - res[0] = ez; |
53 | | - res[1] = ez; |
54 | | - res[2] = ez; |
55 | | - } |
56 | | - @Override |
57 | | - public double invLink(final double z) { |
58 | | - return Math.exp(z); |
59 | | - } |
60 | | - }; |
61 | | - |
62 | | - public static GLMModel PoissonLink = new GLMModel(PoissonProbability,LogLink,false); |
63 | | - public static GLMModel PoissonLinkDebug = new GLMModel(PoissonProbability,LogLink,true) { |
64 | | - @Override |
65 | | - public Double checkGradCoef(final Obs obsi, final double[] beta) { |
66 | | - final double bx = obsi.dot(beta); |
67 | | - return obsi.wt*(obsi.y-Math.exp(bx)); // hard coded Poisson gradient component |
68 | | - } |
69 | | - @Override |
70 | | - public Double checkHessianCoef(final Obs obsi, final double[] beta) { |
71 | | - final double bx = obsi.dot(beta); |
72 | | - return obsi.wt*(-Math.exp(bx)); // hard coded Poisson hessian component |
73 | | - } |
74 | | - }; |
75 | | - |
76 | | - |
77 | | - private static double sq(final double z) { |
78 | | - return z*z; |
79 | | - } |
80 | | - |
81 | | - |
82 | 14 | @Override |
83 | | - public double lossAndGradAndHessian(final Iterable<Obs> obs, final double[] beta, |
84 | | - final double[] grad, final ColtMatrix hessian) { |
| 15 | + public void balanceAndJacobian(final Iterable<Obs> obs, final double[] beta, |
| 16 | + final double[] balance, final ColtMatrix jacobian) { |
85 | 17 | final int dim = beta.length; |
86 | | - Arrays.fill(grad,0.0); |
| 18 | + Arrays.fill(balance,0.0); |
87 | 19 | for(int i=0;i<dim;++i) { |
88 | 20 | for(int j=0;j<dim;++j) { |
89 | | - hessian.set(i,j,0.0); |
| 21 | + jacobian.set(i,j,0.0); |
90 | 22 | } |
91 | 23 | } |
92 | | - final double[] p = new double[3]; |
93 | | - final double[] z = new double[3]; |
94 | | - double sum = 0.0; |
95 | 24 | for(final Obs obsi: obs) { |
96 | | - final double bx = obsi.dot(beta); |
97 | | - link.invLink(bx,z); |
98 | | - prob.eval(obsi.y,z[0],p); |
99 | | - sum += obsi.wt*Math.log(z[0]); |
100 | | - final double gradScale = obsi.wt*z[1]*p[1]/p[0]; |
101 | | - final double hessCoef = obsi.wt*(-sq(z[1])*sq(p[1])/sq(p[0]) + sq(z[1])*p[2]/p[0] + z[2]*p[1]/p[0]); |
102 | | - if(debug) { |
103 | | - final Double gradP = checkGradCoef(obsi,beta); |
104 | | - final Double hessP = checkHessianCoef(obsi,beta); |
105 | | - if((null!=gradP)&&(Math.abs(gradScale-gradP)>=1.0e-6)) { |
106 | | - throw new IllegalStateException("gradient checks didn't match"); |
107 | | - } |
108 | | - if((null!=hessP)&&(Math.abs(hessCoef-hessP)>1.0e-6)) { |
109 | | - throw new IllegalStateException("Hessian checks didn't match"); |
110 | | - } |
111 | | - } |
| 25 | + final BalanceJacobianCoef ghc = balanceJacobianCalc.calc(obsi,beta); |
112 | 26 | for(int i=0;i<dim;++i) { |
113 | 27 | final double xi = i<dim-1?obsi.x[i]:1.0; |
114 | | - grad[i] += gradScale*xi; |
| 28 | + balance[i] += ghc.balanceCoef*xi; |
115 | 29 | for(int j=0;j<dim;++j) { |
116 | 30 | final double xj = j<dim-1?obsi.x[j]:1.0; |
117 | | - final double hij = hessian.get(i,j); |
118 | | - hessian.set(i,j,hij+xi*xj*hessCoef); |
| 31 | + final double hij = jacobian.get(i,j); |
| 32 | + jacobian.set(i,j,hij+xi*xj*ghc.jacobianCoef); |
119 | 33 | } |
120 | 34 | } |
121 | 35 | } |
122 | | - return sum; |
123 | 36 | } |
124 | 37 |
|
125 | | - |
126 | | - |
| 38 | + @Override |
| 39 | + public double evalEst(double[] beta, double[] x) { |
| 40 | + return balanceJacobianCalc.evalEst(beta,x); |
| 41 | + } |
127 | 42 | } |
0 commit comments