|
| 1 | +package com.mzlabs.count.util; |
| 2 | + |
| 3 | +import java.util.ArrayList; |
| 4 | +import java.util.Arrays; |
| 5 | + |
| 6 | + |
| 7 | +import com.winvector.linalg.LinalgFactory; |
| 8 | +import com.winvector.linalg.colt.ColtMatrix; |
| 9 | + |
| 10 | +public final class LogLinearFitter implements Fitter { |
| 11 | + public static final class Obs { |
| 12 | + public final double[] x; |
| 13 | + public final double y; |
| 14 | + public final double wt; |
| 15 | + |
| 16 | + public Obs(final double[] x, final double y, final double wt) { |
| 17 | + this.x = Arrays.copyOf(x,x.length); |
| 18 | + this.y = y; |
| 19 | + this.wt = wt; |
| 20 | + } |
| 21 | + |
| 22 | + @Override |
| 23 | + public String toString() { |
| 24 | + final StringBuilder b = new StringBuilder(); |
| 25 | + b.append("" + wt + ":["); |
| 26 | + for(final double xi:x) { |
| 27 | + b.append(" " + xi); |
| 28 | + } |
| 29 | + b.append(" ]-> " + y); |
| 30 | + return b.toString(); |
| 31 | + } |
| 32 | + } |
| 33 | + |
| 34 | + private final ArrayList<Obs> obs = new ArrayList<Obs>(); |
| 35 | + |
| 36 | + |
| 37 | + @Override |
| 38 | + public void addObservation(final double[] x, final double y, final double wt) { |
| 39 | + if(!obs.isEmpty()) { |
| 40 | + final int n = obs.get(0).x.length; |
| 41 | + if(n!=x.length) { |
| 42 | + throw new IllegalArgumentException(); |
| 43 | + } |
| 44 | + } |
| 45 | + final Obs obsi = new Obs(x,y,wt); |
| 46 | + obs.add(obsi); |
| 47 | + } |
| 48 | + |
| 49 | + /** |
| 50 | + * minimize sum_i wt[i] (e^{beta.x[i]} - y[i])^2 |
| 51 | + * via Newton's method over gradient (should equal zero) and Hessian (Jacobian of vector eqn) |
| 52 | + * |
| 53 | + */ |
| 54 | + |
| 55 | + private static double dot(final double[] soln, final double[] x) { |
| 56 | + final int n = x.length; |
| 57 | + double sum = 0.0; |
| 58 | + for(int i=0;i<=n;++i) { |
| 59 | + final double xi = i<n?x[i]:1.0; |
| 60 | + sum += xi*soln[i]; |
| 61 | + } |
| 62 | + return sum; |
| 63 | + } |
| 64 | + |
| 65 | + private double errAndGradAndHessian(final double[] beta, final double[] grad, final ColtMatrix hessian) { |
| 66 | + final int dim = beta.length; |
| 67 | + Arrays.fill(grad,0.0); |
| 68 | + for(int i=0;i<dim;++i) { |
| 69 | + for(int j=0;j<dim;++j) { |
| 70 | + hessian.set(i,j,0.0); |
| 71 | + } |
| 72 | + } |
| 73 | + double err = 0.0; |
| 74 | + for(final Obs obsi: obs) { |
| 75 | + final double ebx = Math.exp(dot(beta,obsi.x)); |
| 76 | + final double diff = obsi.y-ebx; |
| 77 | + err += diff*diff; |
| 78 | + final double gradCoef = -2*diff*ebx*obsi.wt; |
| 79 | + final double hessCoef = -2*(obsi.y-2*ebx)*ebx*obsi.wt; |
| 80 | + for(int i=0;i<dim;++i) { |
| 81 | + final double xi = i<dim-1?obsi.x[i]:1.0; |
| 82 | + grad[i] += gradCoef*xi; |
| 83 | + for(int j=0;j<dim;++j) { |
| 84 | + final double xj = j<dim-1?obsi.x[j]:1.0; |
| 85 | + final double hij = hessian.get(i,j); |
| 86 | + hessian.set(i,j,hij+xi*xj*hessCoef); |
| 87 | + } |
| 88 | + } |
| 89 | + } |
| 90 | + return err; |
| 91 | + } |
| 92 | + |
| 93 | + @Override |
| 94 | + public double[] solve() { |
| 95 | + final LinalgFactory<ColtMatrix> factory = ColtMatrix.factory; |
| 96 | + final int dim = obs.get(0).x.length+1; |
| 97 | + // start at solution to log(y) ~ b.x |
| 98 | + final Fitter sf = new LinearFitter(dim-1); |
| 99 | + for(final Obs obsi: obs) { |
| 100 | + sf.addObservation(obsi.x, Math.log(Math.max(1.0,obsi.y)), obsi.wt); |
| 101 | + } |
| 102 | + final double[] beta = sf.solve(); |
| 103 | + double bestErr = Double.POSITIVE_INFINITY; |
| 104 | + double[] bestBeta = Arrays.copyOf(beta,beta.length); |
| 105 | + final double[] grad = new double[dim]; |
| 106 | + final ColtMatrix hessian = factory.newMatrix(dim, dim, false); |
| 107 | + int nFails = 0; |
| 108 | + out: |
| 109 | + while(true) { |
| 110 | + final double err = errAndGradAndHessian(beta,grad,hessian); |
| 111 | + if((null==bestBeta)||(err<bestErr)) { |
| 112 | + bestErr = err; |
| 113 | + bestBeta = Arrays.copyOf(beta,beta.length); |
| 114 | + nFails = 0; |
| 115 | + } else { |
| 116 | + ++nFails; |
| 117 | + if(nFails>=5) { |
| 118 | + break out; |
| 119 | + } |
| 120 | + } |
| 121 | + double absGrad = 0.0; |
| 122 | + for(final double gi: grad) { |
| 123 | + absGrad += Math.abs(gi); |
| 124 | + } |
| 125 | + if(Double.isInfinite(absGrad)||Double.isNaN(absGrad)||(absGrad<=1.0e-8)) { |
| 126 | + break out; |
| 127 | + } |
| 128 | + try { |
| 129 | +// // neaten up system a touch before solving |
| 130 | +// double totAbs = 0.0; |
| 131 | +// for(int i=0;i<dim;++i) { |
| 132 | +// for(int j=0;j<dim;++j) { |
| 133 | +// totAbs += Math.abs(hessian.get(i,j)); |
| 134 | +// } |
| 135 | +// } |
| 136 | +// if(Double.isInfinite(totAbs)||Double.isNaN(totAbs)||(totAbs<=1.0e-8)) { |
| 137 | +// break out; |
| 138 | +// } |
| 139 | +// final double scale = (dim*dim)/totAbs; |
| 140 | +// for(int i=0;i<dim;++i) { |
| 141 | +// grad[i] *= scale; |
| 142 | +// for(int j=0;j<dim;++j) { |
| 143 | +// hessian.set(i,j,hessian.get(i,j)*scale); |
| 144 | +// } |
| 145 | +// } |
| 146 | +// for(int i=0;i<dim;++i) { |
| 147 | +// hessian.set(i,i,hessian.get(i,i)+1.e-5); // Ridge term |
| 148 | +// } |
| 149 | + final double[] delta = hessian.solve(grad); |
| 150 | + for(final double di: delta) { |
| 151 | + if(Double.isNaN(di)||Double.isNaN(di)) { |
| 152 | + break out; |
| 153 | + } |
| 154 | + } |
| 155 | + double deltaAbs = 0.0; |
| 156 | + for(int i=0;i<dim;++i) { |
| 157 | + beta[i] -= delta[i]; |
| 158 | + deltaAbs += Math.abs(delta[i]); |
| 159 | + } |
| 160 | + if(deltaAbs<=1.0e-7) { |
| 161 | + break out; |
| 162 | + } |
| 163 | + } catch (Exception ex) { |
| 164 | + break out; |
| 165 | + } |
| 166 | + } |
| 167 | + return bestBeta; |
| 168 | + } |
| 169 | + |
| 170 | + @Override |
| 171 | + public double predict(final double[] soln, final double[] x) { |
| 172 | + final int n = obs.get(0).x.length; |
| 173 | + if((n!=x.length)||(n+1!=soln.length)) { |
| 174 | + throw new IllegalArgumentException(); |
| 175 | + } |
| 176 | + return Math.exp(dot(soln,x)); |
| 177 | + } |
| 178 | +} |
0 commit comments