Skip to content

Commit 65bcf30

Browse files
committed
work family of fitting examples
1 parent 5bbc118 commit 65bcf30

File tree

11 files changed

+346
-8
lines changed

11 files changed

+346
-8
lines changed

Count/expFit.tsv

Lines changed: 201 additions & 0 deletions
Large diffs are not rendered by default.

Count/src/com/mzlabs/fit/BalanceBasedJacobian.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ public final class BalanceBasedJacobian implements BalanceJacobianCalc {
1515
public BalanceBasedJacobian(final Link link) {
1616
this.link = link;
1717
}
18+
19+
@Override
20+
public String toString() {
21+
return link.toString();
22+
}
1823

1924
@Override
2025
public BalanceJacobianCoef calc(final Obs obs, final double[] beta) {

Count/src/com/mzlabs/fit/DirectPoissonJacobian.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ public BalanceJacobianCoef calc(Obs obs, double[] beta) {
1818
return new BalanceJacobianCoef(balCoef,jacCoef);
1919
}
2020

21+
@Override
22+
public String toString() {
23+
return "PoissonRegression(log-link)";
24+
}
25+
2126
@Override
2227
public double evalEst(double[] beta, double[] x) {
2328
return Math.exp(Obs.dot(beta,x));

Count/src/com/mzlabs/fit/ExpectationAndSqLoss.java

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package com.mzlabs.fit;
22

3+
import java.util.ArrayList;
34
import java.util.Arrays;
5+
import java.util.Random;
46

7+
import com.winvector.linalg.LinalgFactory;
58
import com.winvector.linalg.colt.ColtMatrix;
69

710
public final class ExpectationAndSqLoss implements VectorFnWithJacobian {
@@ -24,7 +27,7 @@ public static double dotNE(double[] beta, double[] x) {
2427
}
2528

2629
@Override
27-
public double evalEst(double[] beta, double[] x) {
30+
public double evalEst(final double[] beta, final double[] x) {
2831
return link.invLink(dotNE(beta,x));
2932
}
3033

@@ -86,4 +89,104 @@ public void balanceAndJacobian(final Iterable<Obs> obs, final double[] beta,
8689
public int dim(final Obs obs) {
8790
return obs.x.length + 1;
8891
}
92+
93+
@Override
94+
public String toString() {
95+
return "ExpectationAndSquareLoss(" + link + ")";
96+
}
97+
98+
public static void main(final String[] args) {
99+
// build some example data
100+
final Random rand = new Random(343406L);
101+
final ArrayList<Obs> obs = new ArrayList<Obs>();
102+
for(int i=0;i<200;++i) {
103+
final double[] x = new double[] {1, rand.nextGaussian(), rand.nextGaussian()};
104+
final double y = Math.max(0.1,Math.exp(x[0] + 2.0+x[1] + 3.0+x[2]) + 10*rand.nextDouble());
105+
obs.add(new Obs(x,y,1.0));
106+
}
107+
// provision fitters
108+
final ExpectationAndSqLoss fn = new ExpectationAndSqLoss(LinkBasedGradHess.logLink);
109+
final int elIndex = 2;
110+
final NewtonFitter[] fitters = {
111+
new NewtonFitter(new SquareLossOfExp()),
112+
new NewtonFitter(DirectPoissonJacobian.poissonLink),
113+
new NewtonFitter(fn) // elIndex == 2
114+
};
115+
final int dim = obs.get(0).x.length;
116+
final int nTrain = obs.size()/2;
117+
final LinearFitter lf = new LinearFitter(dim);
118+
// scan data
119+
for(int i=0;i<nTrain;++i) {
120+
final Obs obsi = obs.get(i);
121+
lf.addObservation(obsi.x,Math.log(obsi.y),1.0);
122+
for(final NewtonFitter fitter: fitters) {
123+
fitter.addObservation(obsi.x,obsi.y,1.0);
124+
}
125+
}
126+
// solve
127+
final double[] lfSoln = lf.solve();
128+
final double[][] fSoln = new double[fitters.length][];
129+
for(int j=0;j<fitters.length;++j) {
130+
fSoln[j] = fitters[j].solve();
131+
}
132+
// check balance condition
133+
final int pdim = fSoln[elIndex].length;
134+
final double[] balance = new double[pdim];
135+
final LinalgFactory<ColtMatrix> factory = ColtMatrix.factory;
136+
final ColtMatrix jacobian = factory.newMatrix(pdim,pdim,false);
137+
fn.balanceAndJacobian(obs.subList(0,nTrain), fSoln[elIndex], balance, jacobian);
138+
double balanceCheck = 0.0;
139+
for(int i=0;i<nTrain;++i) {
140+
final Obs obsi = obs.get(i);
141+
final double llfit = fn.evalEst(fSoln[elIndex],obsi.x);
142+
balanceCheck += obsi.wt*(obsi.y-llfit);
143+
}
144+
for(final double bi: balance) {
145+
if(Math.abs(bi)>1.0e-3) {
146+
throw new IllegalStateException("didn't balance");
147+
}
148+
}
149+
if(Math.abs(balanceCheck)>1.0e-4) {
150+
throw new IllegalStateException("didn't balance");
151+
}
152+
// print data and estimates
153+
for(int j=1;j<dim;++j) {
154+
System.out.print("x"+j + "\t");
155+
}
156+
System.out.print("y" + "\t" + "isTrain" + "\t" + "logYest");
157+
for(int j=0;j<fitters.length;++j) {
158+
System.out.print("\t" + fitters[j].fn);
159+
}
160+
System.out.println();
161+
for(int i=0;i<obs.size();++i) {
162+
final Obs obsi = obs.get(i);
163+
for(int j=1;j<dim;++j) {
164+
System.out.print(obsi.x[j] + "\t");
165+
}
166+
System.out.print(obsi.y + "\t" + (i<nTrain?1:0) + "\t" + Math.exp(lf.evalEst(lfSoln,obsi.x)));
167+
for(int j=0;j<fitters.length;++j) {
168+
System.out.print("\t" + fitters[j].evalEst(fSoln[j],obsi.x));
169+
}
170+
System.out.println();
171+
}
172+
/**
173+
R: steps
174+
175+
library(ggplot2)
176+
library(reshape2)
177+
d <- read.table('expFit.tsv',sep='\t',stringsAsFactors=FALSE,header=TRUE)
178+
ests <- c('logYest','SquareLossOfExp','GLM.PoissonRegression.log.link..','ExpectationAndSquareLoss.log.link.')
179+
dTrain <- subset(d,isTrain==1)
180+
dTest <- subset(d,isTrain==0)
181+
for(v in ests) {
182+
print(paste(v,sum(dTrain$y-dTrain[,v]),sum((dTrain$y-dTrain[,v])^2)))
183+
}
184+
for(v in ests) {
185+
print(paste(v,sum(dTest$y-dTest[,v]),sum((dTest$y-dTest[,v])^2)))
186+
}
187+
dplot <- melt(subset(d,isTrain==1),id.vars=c('x1','x2','isTrain','y'),variable.name='estimate')
188+
ggplot(data=dplot,aes(x=value,y=y,color=estimate,shape=estimate)) + geom_point()
189+
ggplot(data=dplot,aes(x=value,y=y,color=estimate,shape=estimate)) + geom_point() + scale_x_log10() + scale_y_log10()
190+
*/
191+
}
89192
}

Count/src/com/mzlabs/fit/Fitter.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ public interface Fitter {
88
* @param y
99
* @param wt weight of observation (set to 1.0 in many cases)
1010
*/
11-
public abstract void addObservation(final double[] x, final double y,
11+
void addObservation(final double[] x, final double y,
1212
final double wt);
1313

14-
public abstract double[] solve();
14+
double[] solve();
15+
16+
double evalEst(double[] beta, double[] x);
1517
}

Count/src/com/mzlabs/fit/GLMModel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ public GLMModel(final BalanceJacobianCalc balanceJacobianCalc) {
1111
this.balanceJacobianCalc = balanceJacobianCalc;
1212
}
1313

14+
@Override
15+
public String toString() {
16+
return "GLM(" + balanceJacobianCalc.toString() + ")";
17+
}
18+
1419
@Override
1520
public void balanceAndJacobian(final Iterable<Obs> obs, final double[] beta,
1621
final double[] balance, final ColtMatrix jacobian) {

Count/src/com/mzlabs/fit/LinearFitter.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,8 @@ public double[] solve() {
6161
return soln;
6262
}
6363

64-
/* (non-Javadoc)
65-
* @see com.mzlabs.count.util.Fitter#predict(double[], double[])
66-
*/
67-
public double predict(final double[] soln, final double[] x) {
64+
@Override
65+
public double evalEst(final double[] soln, final double[] x) {
6866
final int n = soln.length;
6967
if((n!=x.length)||(n!=soln.length)) {
7068
throw new IllegalArgumentException();

Count/src/com/mzlabs/fit/LinkBasedGradHess.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ public LinkBasedGradHess(final ProbYGivenZ prob, final Link link) {
1515
this.link = link;
1616
}
1717

18+
@Override
19+
public String toString() {
20+
return link.toString();
21+
}
22+
1823
private static double sq(final double z) {
1924
return z*z;
2025
}
@@ -77,6 +82,10 @@ public double invLink(final double z) {
7782
public double heuristicLink(double y) {
7883
return Math.log(Math.abs(y)+1.0); // near log(y), but well behaved
7984
}
85+
@Override
86+
public String toString() {
87+
return "log-link";
88+
}
8089
};
8190

8291

Count/src/com/mzlabs/fit/NewtonFitter.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,9 @@ public double[] solve() {
117117
}
118118
return beta;
119119
}
120+
121+
@Override
122+
public double evalEst(final double[] beta, final double[] x) {
123+
return fn.evalEst(beta, x);
124+
}
120125
}

Count/src/com/mzlabs/fit/SquareLossOfExp.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ public double evalEst(final double[] beta, final double[] x) {
1919
public double heuristicLink(final double y) {
2020
return Math.log(Math.abs(y)+1.0); // near log(y), but well behaved
2121
}
22+
23+
@Override
24+
public String toString() {
25+
return "SquareLossOfExp";
26+
}
2227

2328

2429
@Override

0 commit comments

Comments
 (0)