11package com .mzlabs .fit ;
22
3+ import java .util .ArrayList ;
34import java .util .Arrays ;
5+ import java .util .Random ;
46
7+ import com .winvector .linalg .LinalgFactory ;
58import com .winvector .linalg .colt .ColtMatrix ;
69
710public final class ExpectationAndSqLoss implements VectorFnWithJacobian {
@@ -24,7 +27,7 @@ public static double dotNE(double[] beta, double[] x) {
2427 }
2528
2629 @ Override
27- public double evalEst (double [] beta , double [] x ) {
30+ public double evalEst (final double [] beta , final double [] x ) {
2831 return link .invLink (dotNE (beta ,x ));
2932 }
3033
@@ -86,4 +89,104 @@ public void balanceAndJacobian(final Iterable<Obs> obs, final double[] beta,
8689 public int dim (final Obs obs ) {
8790 return obs .x .length + 1 ;
8891 }
92+
93+ @ Override
94+ public String toString () {
95+ return "ExpectationAndSquareLoss(" + link + ")" ;
96+ }
97+
98+ public static void main (final String [] args ) {
99+ // build some example data
100+ final Random rand = new Random (343406L );
101+ final ArrayList <Obs > obs = new ArrayList <Obs >();
102+ for (int i =0 ;i <200 ;++i ) {
103+ final double [] x = new double [] {1 , rand .nextGaussian (), rand .nextGaussian ()};
104+ final double y = Math .max (0.1 ,Math .exp (x [0 ] + 2.0 +x [1 ] + 3.0 +x [2 ]) + 10 *rand .nextDouble ());
105+ obs .add (new Obs (x ,y ,1.0 ));
106+ }
107+ // provision fitters
108+ final ExpectationAndSqLoss fn = new ExpectationAndSqLoss (LinkBasedGradHess .logLink );
109+ final int elIndex = 2 ;
110+ final NewtonFitter [] fitters = {
111+ new NewtonFitter (new SquareLossOfExp ()),
112+ new NewtonFitter (DirectPoissonJacobian .poissonLink ),
113+ new NewtonFitter (fn ) // elIndex == 2
114+ };
115+ final int dim = obs .get (0 ).x .length ;
116+ final int nTrain = obs .size ()/2 ;
117+ final LinearFitter lf = new LinearFitter (dim );
118+ // scan data
119+ for (int i =0 ;i <nTrain ;++i ) {
120+ final Obs obsi = obs .get (i );
121+ lf .addObservation (obsi .x ,Math .log (obsi .y ),1.0 );
122+ for (final NewtonFitter fitter : fitters ) {
123+ fitter .addObservation (obsi .x ,obsi .y ,1.0 );
124+ }
125+ }
126+ // solve
127+ final double [] lfSoln = lf .solve ();
128+ final double [][] fSoln = new double [fitters .length ][];
129+ for (int j =0 ;j <fitters .length ;++j ) {
130+ fSoln [j ] = fitters [j ].solve ();
131+ }
132+ // check balance condition
133+ final int pdim = fSoln [elIndex ].length ;
134+ final double [] balance = new double [pdim ];
135+ final LinalgFactory <ColtMatrix > factory = ColtMatrix .factory ;
136+ final ColtMatrix jacobian = factory .newMatrix (pdim ,pdim ,false );
137+ fn .balanceAndJacobian (obs .subList (0 ,nTrain ), fSoln [elIndex ], balance , jacobian );
138+ double balanceCheck = 0.0 ;
139+ for (int i =0 ;i <nTrain ;++i ) {
140+ final Obs obsi = obs .get (i );
141+ final double llfit = fn .evalEst (fSoln [elIndex ],obsi .x );
142+ balanceCheck += obsi .wt *(obsi .y -llfit );
143+ }
144+ for (final double bi : balance ) {
145+ if (Math .abs (bi )>1.0e-3 ) {
146+ throw new IllegalStateException ("didn't balance" );
147+ }
148+ }
149+ if (Math .abs (balanceCheck )>1.0e-4 ) {
150+ throw new IllegalStateException ("didn't balance" );
151+ }
152+ // print data and estimates
153+ for (int j =1 ;j <dim ;++j ) {
154+ System .out .print ("x" +j + "\t " );
155+ }
156+ System .out .print ("y" + "\t " + "isTrain" + "\t " + "logYest" );
157+ for (int j =0 ;j <fitters .length ;++j ) {
158+ System .out .print ("\t " + fitters [j ].fn );
159+ }
160+ System .out .println ();
161+ for (int i =0 ;i <obs .size ();++i ) {
162+ final Obs obsi = obs .get (i );
163+ for (int j =1 ;j <dim ;++j ) {
164+ System .out .print (obsi .x [j ] + "\t " );
165+ }
166+ System .out .print (obsi .y + "\t " + (i <nTrain ?1 :0 ) + "\t " + Math .exp (lf .evalEst (lfSoln ,obsi .x )));
167+ for (int j =0 ;j <fitters .length ;++j ) {
168+ System .out .print ("\t " + fitters [j ].evalEst (fSoln [j ],obsi .x ));
169+ }
170+ System .out .println ();
171+ }
172+ /**
173+ R: steps
174+
175+ library(ggplot2)
176+ library(reshape2)
177+ d <- read.table('expFit.tsv',sep='\t',stringsAsFactors=FALSE,header=TRUE)
178+ ests <- c('logYest','SquareLossOfExp','GLM.PoissonRegression.log.link..','ExpectationAndSquareLoss.log.link.')
179+ dTrain <- subset(d,isTrain==1)
180+ dTest <- subset(d,isTrain==0)
181+ for(v in ests) {
182+ print(paste(v,sum(dTrain$y-dTrain[,v]),sum((dTrain$y-dTrain[,v])^2)))
183+ }
184+ for(v in ests) {
185+ print(paste(v,sum(dTest$y-dTest[,v]),sum((dTest$y-dTest[,v])^2)))
186+ }
187+ dplot <- melt(subset(d,isTrain==1),id.vars=c('x1','x2','isTrain','y'),variable.name='estimate')
188+ ggplot(data=dplot,aes(x=value,y=y,color=estimate,shape=estimate)) + geom_point()
189+ ggplot(data=dplot,aes(x=value,y=y,color=estimate,shape=estimate)) + geom_point() + scale_x_log10() + scale_y_log10()
190+ */
191+ }
89192}
0 commit comments