From 95fa87d135b6d37262671e779398f573dc63e95a Mon Sep 17 00:00:00 2001 From: Andreas Noack Date: Tue, 21 Jul 2015 21:58:44 -0400 Subject: [PATCH] Generify lsqr such that we can do distributed least squares --- src/lsqr.jl | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/lsqr.jl b/src/lsqr.jl index 392f54cb..73fcb247 100644 --- a/src/lsqr.jl +++ b/src/lsqr.jl @@ -59,13 +59,13 @@ function lsqr!(x, ch::ConvergenceHistory, A, b; damp=0, atol=sqrt(eps(Adivtype(A Anorm = Acond = ddnorm = res2 = xnorm = xxnorm = z = sn2 = zero(T) cs2 = -one(T) dampsq = damp*damp - tmpm = Array(T, m) - tmpn = Array(T, n) + tmpm = similar(b, T, m) + tmpn = similar(x, T, n) # Set up the first vectors u and v for the bidiagonalization. # These satisfy beta*u = b-A*x, alpha*v = A'u. - u = b-A*x - v = zeros(T, n) + u = b - A*x + v = copy(x) beta = norm(u) alpha = zero(T) if beta > 0 @@ -77,6 +77,7 @@ function lsqr!(x, ch::ConvergenceHistory, A, b; damp=0, atol=sqrt(eps(Adivtype(A scale!(v, one(T)/alpha) end w = copy(v) + wrho = similar(w) ch.mvps += 2 Arnorm = alpha*beta @@ -97,21 +98,24 @@ function lsqr!(x, ch::ConvergenceHistory, A, b; damp=0, atol=sqrt(eps(Adivtype(A # next beta, u, alpha, v. These satisfy the relations # beta*u = A*v - alpha*u, # alpha*v = A'*u - beta*v. + + # Note that the following three lines are a band aid for a GEMM: X: C := αAB + βC. + # This is already supported in A_mul_B! for sparse and distributed matrices, but not yet dense A_mul_B!(tmpm, A, v) - for i = 1:m - u[i] = tmpm[i] - alpha*u[i] - end + scale!(u, -alpha) + LinAlg.axpy!(one(eltype(tmpm)), tmpm, u) beta = norm(u) if beta > 0 scale!(u, one(T)/beta) Anorm = sqrt(Anorm*Anorm + alpha*alpha + beta*beta + dampsq) + # Note that the following three lines are a band aid for a GEMM: X: C := αA'B + βC. + # This is already supported in Ac_mul_B! for sparse and distributed matrices, but not yet dense Ac_mul_B!(tmpn, A, u) - for i = 1:n - v[i] = tmpn[i] - beta*v[i] - end + scale!(v, -beta) + LinAlg.axpy!(one(eltype(tmpn)), tmpn, v) alpha = norm(v) if alpha > 0 - for i = 1:n v[i] /= alpha; end + scale!(v, inv(alpha)) end end ch.mvps += 2 @@ -138,13 +142,13 @@ function lsqr!(x, ch::ConvergenceHistory, A, b; damp=0, atol=sqrt(eps(Adivtype(A # Update x and w t1 = phi /rho t2 = - theta/rho - for i = 1:n - wi = w[i] - x[i] += t1*wi - w[i] = v[i] + t2*wi - wirho = wi/rho - ddnorm += wirho*wirho - end + + LinAlg.axpy!(t1, w, x) + scale!(w, t2) + LinAlg.axpy!(one(t2), v, w) + copy!(wrho, w) + scale!(wrho, inv(rho)) + ddnorm += norm(wrho) # Use a plane rotation on the right to eliminate the # super-diagonal element (theta) of the upper-bidiagonal matrix.