# Direct TSQR
**Input**: Matrix $A$ of size $m\times n$ (with $m\gg n$)
1) **First step**: The matrix $A$ is subdivided into $p$ sub-matricies of size $m_j\times n$ and *map* procedure, perform a local QR decomposition on each sub-matrix, resulting on $\{(Q^{(1)}_1, R_1),(Q^{(1)}_2, R_2),...,(Q^{(1)}_p, R_p)\}$, where $Q_j^{(1)}\in\mathbb{R}^{m_j\times n}$ and $R_j\in\mathbb{R}^{n\times n}$.
2) **Second step**: a *reduce* task collect each $R_j$ into a $np\times n$ matrix. A second QR decomposition returns a $\tilde{Q}=[Q^{(2)}_1,...,Q^{(2)}_p]^T \in \mathbb{R}^{np\times n}$ (each $Q_j^{(2)}$ is $n\times n$) and $\tilde{R}\in\mathbb{R}^{n\times n}$
3) **Third step**: a *map* task build the final $Q$ matrix using $Q_j=Q_j^{(1)}Q_j^{(2)}$


In [1]:
from dask.distributed import Client, wait
from dask import delayed, compute
import dask.array as da

from numpy.linalg import qr
import numpy as np

import time

N_WORKERS = 4
THREADS_PER_WORKER = 1
MEMORY_PER_WORKER = "1GB"
client = Client(n_workers=N_WORKERS, 
                threads_per_worker=THREADS_PER_WORKER, 
                memory_limit=MEMORY_PER_WORKER)

print(f"DASK Client with {N_WORKERS} workers ({THREADS_PER_WORKER} threads, {MEMORY_PER_WORKER} memory)")
print("DASK Dashboard link:", client.dashboard_link)

DASK Client with 4 workers (1 threads, 1GB memory)
DASK Dashboard link: http://127.0.0.1:8787/status


In [2]:
# Step 1: (map) local QR on each worker
def local_qr(A_block : np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    Q, R = qr(A_block)
    return Q, R

# Step 2: (reduce) stack Rs and global QR
def global_qr(Rs : list[np.ndarray]) -> tuple[list[np.ndarray], np.ndarray]: 
    R_stacked = np.vstack(Rs)
    Q2, R_final = qr(R_stacked)
    p, n = len(Rs), R_final.shape[1]
    Q2_blocks = [Q2[i*n:(i+1)*n, :] for i in range(p)]
    return Q2_blocks, R_final

# Step 3: (map) building the final Q
def block_matmul(Q1_block : np.ndarray, Q2_block : np.ndarray) -> np.ndarray:
    return np.matmul(Q1_block, Q2_block)


def direct_tsqr(A : da.Array, client : Client, mode : str = None) -> tuple[np.ndarray, np.ndarray]:
    A = A.persist()
    A_blocks = A.to_delayed().ravel().tolist()
    A_blocks_futures = client.compute(A_blocks)
    wait(A)
    
    # Step 1: (map) perform QR decomposition in parallel on each block
    QR1 = client.map(local_qr, A_blocks_futures)
    wait(QR1)
    del A_blocks_futures, A_blocks, A
    
    Q1s_future = client.map(lambda qr: qr[0], QR1)
    R1s_future = client.map(lambda qr: qr[1], QR1)
    wait(R1s_future)

    # Step 2: (reduce) perform global QR decomposition
    R1s = client.gather(R1s_future)
    Q2s, R2s = global_qr(R1s)
    del R1s_future, R1s

    # Early return for R only mode
    if mode == "r":
        return None, R2s

    # Step 3: (map) building the final Q by multiplying Qs blocks
    Q2s_futures = client.scatter(Q2s)   # send the matrix Q2 to workers
    wait(Q2s_futures)
    del Q2s

    Qs_future = client.map(block_matmul, Q1s_future, Q2s_futures)
    wait(Qs_future)
    del Q1s_future, Q2s_futures, QR1

    Qs = client.gather(Qs_future)
    Q = client.compute(np.vstack(Qs))
    del Qs_future, Qs
    return Q, R2s

In [3]:
data = np.random.rand(int(4e6), 4)

#from sklearn.datasets import fetch_california_housing
#data = fetch_california_housing(as_frame=True).data.values

m, n = data.shape

X_da = da.from_array(data, chunks=(m // N_WORKERS, n))
X_da

Unnamed: 0,Array,Chunk
Bytes,122.07 MiB,30.52 MiB
Shape,"(4000000, 4)","(1000000, 4)"
Dask graph,4 chunks in 1 graph layer,4 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 122.07 MiB 30.52 MiB Shape (4000000, 4) (1000000, 4) Dask graph 4 chunks in 1 graph layer Data type float64 numpy.ndarray",4  4000000,

Unnamed: 0,Array,Chunk
Bytes,122.07 MiB,30.52 MiB
Shape,"(4000000, 4)","(1000000, 4)"
Dask graph,4 chunks in 1 graph layer,4 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [7]:
%%time

# Parallel computation
Q, R = direct_tsqr(X_da, client)

CPU times: user 58.8 ms, sys: 92.2 ms, total: 151 ms
Wall time: 284 ms


In [6]:
%%time

# Sequential computation 
Q_seq, R_seq = qr(data)

CPU times: user 195 ms, sys: 110 ms, total: 305 ms
Wall time: 292 ms


In [8]:
# Check the results
print("Reconstruction error check:", np.isclose(Q @ R, data, atol=1e-6).all())
print("Orthogonality check:", np.linalg.norm(Q.T @ Q - np.eye(n), ord='fro') < 1e-6)

Reconstruction error check: True
Orthogonality check: True


In [55]:
client.close()