# Direct TSQR
**Input**: Matrix $A$ of size $m\times n$ (with $m\gg n$)
1) **First step**: The matrix $A$ is subdivided into $p$ sub-matricies of size $m_j\times n$ and *map* procedure, perform a local QR decomposition on each sub-matrix, resulting on $\{(Q^{(1)}_1, R_1),(Q^{(1)}_2, R_2),...,(Q^{(1)}_p, R_p)\}$, where $Q_j^{(1)}\in\mathbb{R}^{m_j\times n}$ and $R_j\in\mathbb{R}^{n\times n}$.
2) **Second step**: a *reduce* task collect each $R_j$ into a $np\times n$ matrix. A second QR decomposition returns a $\tilde{Q}=[Q^{(2)}_1,...,Q^{(2)}_p]^T \in \mathbb{R}^{np\times n}$ (each $Q_j^{(2)}$ is $n\times n$) and $\tilde{R}\in\mathbb{R}^{n\times n}$
3) **Third step**: a *map* task build the final $Q$ matrix using $Q_j=Q_j^{(1)}Q_j^{(2)}$


In [189]:
from dask.distributed import Client, wait
from dask import delayed, compute
import dask.array as da

from numpy.linalg import qr
import numpy as np

import time

N_WORKERS = 8
THREADS_PER_WORKER = 1
MEMORY_PER_WORKER = "1.5GB"
client = Client(n_workers=N_WORKERS, 
                threads_per_worker=THREADS_PER_WORKER, 
                memory_limit=MEMORY_PER_WORKER)

print(f"DASK Client with {N_WORKERS} workers ({THREADS_PER_WORKER} threads, {MEMORY_PER_WORKER} memory)")
print("DASK Dashboard link:", client.dashboard_link)

DASK Client with 8 workers (1 threads, 1.5GB memory)
DASK Dashboard link: http://127.0.0.1:8787/status


In [178]:
# Step 1: (map) local QR on each worker
@delayed
def local_qr(A_block : da.Array) -> tuple[da.Array, da.Array]:
    Q, R = qr(A_block)
    return Q, R

# Step 2: (reduce) stack Rs and global QR
@delayed
def global_qr(Rs : list[da.Array]) -> tuple[list[da.Array], da.Array]: 
    R_stacked = np.vstack(Rs)
    Q2, R_final = qr(R_stacked)
    p, n = len(Rs), R_final.shape[1]
    Q2_blocks = [Q2[i*n:(i+1)*n, :] for i in range(p)]
    return Q2_blocks, R_final

# Step 3: (map) building the final Q
@delayed
def block_matmul(Q1_block : da.Array, Q2_block : da.Array) -> da.Array:
    return da.matmul(Q1_block, Q2_block)

def direct_tsqr(A : da.Array) -> tuple[delayed, delayed]:
    A_blocks = A.to_delayed().ravel().tolist()
    
    # Step 1: (map) perform QR decomposition in parallel on each block
    QR1 = [local_qr(A_block) for A_block in A_blocks]
    Q1s = [delayed(lambda qr: qr[0])(qr) for qr in QR1]
    R1s = [delayed(lambda qr: qr[1])(qr) for qr in QR1]

    # Step 2: (reduce) perform global QR decomposition
    QR2 = global_qr(R1s)
    Q2s = delayed(lambda qr: qr[0])(QR2)
    R2s = delayed(lambda qr: qr[1])(QR2)

    # Step 3: (map) building the final Q by multiplying Qs blocks
    Qs = [block_matmul(Q1s[i], Q2s[i]) for i in range(len(A_blocks))]
    Q = delayed(da.vstack)(Qs)
    return Q, R2s

In [190]:
from sklearn.datasets import fetch_california_housing

#data = np.random.rand(int(1e7), 4)
data = fetch_california_housing(as_frame=True).data
m, n = data.shape

X_da = da.from_array(data.values, chunks=(m // N_WORKERS, n))
X_da

Unnamed: 0,Array,Chunk
Bytes,1.26 MiB,161.25 kiB
Shape,"(20640, 8)","(2580, 8)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 1.26 MiB 161.25 kiB Shape (20640, 8) (2580, 8) Dask graph 8 chunks in 1 graph layer Data type float64 numpy.ndarray",8  20640,

Unnamed: 0,Array,Chunk
Bytes,1.26 MiB,161.25 kiB
Shape,"(20640, 8)","(2580, 8)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [186]:
%%time

# Parallel computation 
Q_delayed, R_delayed = direct_tsqr(X_da)
Q, R = compute(Q_delayed, R_delayed)

CPU times: user 29.9 ms, sys: 4.56 ms, total: 34.5 ms
Wall time: 56.2 ms


In [197]:
%%time

# Sequential computation 
Q_seq, R_seq = qr(data.values)

CPU times: user 53.7 ms, sys: 1.95 ms, total: 55.6 ms
Wall time: 7.8 ms


In [198]:
# Check the results
Q = np.array(Q)
R = np.array(R)
print("Reconstruction error check:", np.isclose(Q @ R, data, atol=1e-6).all())
print("Orthogonality check:", np.linalg.norm(Q.T @ Q - np.eye(n), ord='fro') < 1e-6)

Reconstruction error check: True
Orthogonality check: True


In [188]:
client.close()