@@ -19,24 +19,24 @@ public final class LinearFitter implements Fitter {
1919 */
2020 public LinearFitter (final int n ) {
2121 final LinalgFactory <ColtMatrix > factory = ColtMatrix .factory ;
22- xTx = factory .newMatrix (n + 1 , n + 1 ,false );
23- xTy = new double [n + 1 ];
22+ xTx = factory .newMatrix (n , n ,false );
23+ xTy = new double [n ];
2424 }
2525
2626 /* (non-Javadoc)
2727 * @see com.mzlabs.count.util.Fitter#addObservation(double[], double, double)
2828 */
2929 @ Override
3030 public void addObservation (final double [] x , final double y , final double wt ) {
31- final int n = xTx .rows ()- 1 ;
31+ final int n = xTx .rows ();
3232 if (n !=x .length ) {
3333 throw new IllegalArgumentException ();
3434 }
35- for (int i =0 ;i <= n ;++i ) {
36- final double xi = i < n ? x [i ]: 1.0 ;
35+ for (int i =0 ;i <n ;++i ) {
36+ final double xi = x [i ];
3737 xTy [i ] += wt *xi *y ;
38- for (int j =0 ;j <= n ;++j ) {
39- final double xj = j < n ? x [j ]: 1.0 ;
38+ for (int j =0 ;j <n ;++j ) {
39+ final double xj = x [j ];
4040 xTx .set (i ,j ,xTx .get (i , j )+wt *xi *xj );
4141 }
4242 }
@@ -47,15 +47,15 @@ public void addObservation(final double[] x, final double y, final double wt) {
4747 */
4848 @ Override
4949 public double [] solve () {
50- final int n = xTx .rows ()- 1 ;
50+ final int n = xTx .rows ();
5151 final double epsilon = 1.0e-5 ;
5252 final double [] xTxii = new double [n +1 ];
53- for (int i =0 ;i <= n ;++i ) {
53+ for (int i =0 ;i <n ;++i ) {
5454 xTxii [i ] = xTx .get (i ,i );
5555 xTx .set (i ,i ,xTxii [i ]+epsilon ); // Ridge term
5656 }
5757 final double [] soln = xTx .solve (xTy );
58- for (int i =0 ;i <= n ;++i ) {
58+ for (int i =0 ;i <n ;++i ) {
5959 xTx .set (i ,i ,xTxii [i ]);
6060 }
6161 return soln ;
@@ -65,14 +65,13 @@ public double[] solve() {
6565 * @see com.mzlabs.count.util.Fitter#predict(double[], double[])
6666 */
6767 public double predict (final double [] soln , final double [] x ) {
68- final int n = soln .length - 1 ;
69- if ((n !=x .length )||(n + 1 !=soln .length )) {
68+ final int n = soln .length ;
69+ if ((n !=x .length )||(n !=soln .length )) {
7070 throw new IllegalArgumentException ();
7171 }
7272 double sum = 0.0 ;
73- for (int i =0 ;i <=n ;++i ) {
74- final double xi = i <n ?x [i ]:1.0 ;
75- sum += xi *soln [i ];
73+ for (int i =0 ;i <n ;++i ) {
74+ sum += x [i ]*soln [i ];
7675 }
7776 return sum ;
7877 }
0 commit comments