# 03: QR Decomposition

**Module 1.2: Linear Systems & Least Squares**

## Learning Objectives

By the end of this notebook, you will:
1. Understand what QR decomposition is
2. See why QR is more stable than normal equations
3. Solve least squares using QR
4. Use PyTorch's built-in QR functions

## Resources
- Solomon, *Numerical Algorithms*, Chapter 5
- Cohen, *Practical Linear Algebra*, Chapter 8

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
plt.rcParams['figure.figsize'] = (10, 6)

---
## 1. What is QR Decomposition?

Any matrix $X \in \mathbb{R}^{m \times n}$ (with $m \geq n$) can be factored as:

$$X = QR$$

Where:
- $Q \in \mathbb{R}^{m \times n}$ has **orthonormal columns**: $Q^\top Q = I$
- $R \in \mathbb{R}^{n \times n}$ is **upper triangular**

### Naming Convention
- **Q**: "Orthogonal" (columns are orthonormal)
- **R**: "Right triangular" (upper triangular)

In [None]:
# Basic QR decomposition
X = torch.tensor([[1., 2.],
                  [3., 4.],
                  [5., 6.]], dtype=torch.float64)

Q, R = torch.linalg.qr(X)

print("Original X:")
print(X)
print(f"\nShape: {X.shape}")

print("\nQ (orthonormal columns):")
print(Q.round(decimals=4))
print(f"Shape: {Q.shape}")

print("\nR (upper triangular):")
print(R.round(decimals=4))
print(f"Shape: {R.shape}")

# Verify Q'Q = I
print("\nQ'Q (should be identity):")
print((Q.T @ Q).round(decimals=10))

# Verify QR = X
print("\nQR (should equal X):")
print((Q @ R).round(decimals=10))

---
## 2. Why QR for Least Squares?

### Normal Equations (Unstable)

$$\hat{\beta} = (X^\top X)^{-1} X^\top y$$

Problem: Forms $X^\top X$ which squares the condition number.

### QR Approach (Stable)

Starting from $X\beta = y$, substitute $X = QR$:

$$QR\beta = y$$
$$R\beta = Q^\top y$$

Solve by back-substitution (R is triangular).

**Key**: We never form $X^\top X$, so we keep $\kappa(X)$ instead of $\kappa(X)^2$!

In [None]:
def solve_least_squares_normal(X, y):
    """Solve via normal equations (unstable)."""
    XtX = X.T @ X
    Xty = X.T @ y
    return torch.linalg.solve(XtX, Xty)

def solve_least_squares_qr(X, y):
    """Solve via QR decomposition (stable)."""
    Q, R = torch.linalg.qr(X)
    Qty = Q.T @ y
    return torch.linalg.solve_triangular(R, Qty, upper=True)

# Well-conditioned problem: both methods work
X = torch.tensor([[1., 0.],
                  [1., 1.],
                  [1., 2.],
                  [1., 3.]], dtype=torch.float64)
y = torch.tensor([1., 2., 2., 4.], dtype=torch.float64)

beta_normal = solve_least_squares_normal(X, y)
beta_qr = solve_least_squares_qr(X, y)

print("Well-conditioned problem:")
print(f"  Normal equations: {beta_normal.numpy().round(6)}")
print(f"  QR decomposition: {beta_qr.numpy().round(6)}")
print(f"  Difference: {(beta_normal - beta_qr).abs().max():.2e}")

In [None]:
# Ill-conditioned problem: normal equations fail
def create_ill_conditioned_problem(kappa):
    """Create a least squares problem with specified condition number."""
    m, n = 100, 5
    
    # Create X with known condition number
    U, _ = torch.linalg.qr(torch.randn(m, n, dtype=torch.float64))
    V, _ = torch.linalg.qr(torch.randn(n, n, dtype=torch.float64))
    S = torch.logspace(0, -np.log10(kappa), n, dtype=torch.float64)
    X = U @ torch.diag(S) @ V.T
    
    # True solution and response
    beta_true = torch.randn(n, dtype=torch.float64)
    y = X @ beta_true + 0.01 * torch.randn(m, dtype=torch.float64)
    
    return X, y, beta_true

print("Comparison on ill-conditioned problems:")
print(f"{'κ(X)':<12} {'κ(X\'X)':<15} {'Error (Normal)':<18} {'Error (QR)':<15}")
print("-" * 60)

for kappa in [1e2, 1e4, 1e6, 1e8]:
    X, y, beta_true = create_ill_conditioned_problem(kappa)
    
    beta_normal = solve_least_squares_normal(X, y)
    beta_qr = solve_least_squares_qr(X, y)
    
    error_normal = (beta_normal - beta_true).norm() / beta_true.norm()
    error_qr = (beta_qr - beta_true).norm() / beta_true.norm()
    
    print(f"{kappa:<12.0e} {kappa**2:<15.0e} {error_normal.item():<18.2e} {error_qr.item():<15.2e}")

print("\n→ QR remains stable even when normal equations fail!")

---
## 3. Deriving the QR Solution

Starting from normal equations:

$$X^\top X \beta = X^\top y$$

Substitute $X = QR$:

$$(QR)^\top (QR) \beta = (QR)^\top y$$

$$R^\top Q^\top Q R \beta = R^\top Q^\top y$$

Since $Q^\top Q = I$:

$$R^\top R \beta = R^\top Q^\top y$$

If $R$ is invertible, multiply both sides by $(R^\top)^{-1}$:

$$R \beta = Q^\top y$$

This is a triangular system - solve by back-substitution!

In [None]:
# Step-by-step QR solution
X = torch.tensor([[1., 0.],
                  [1., 1.],
                  [1., 2.],
                  [1., 3.]], dtype=torch.float64)
y = torch.tensor([1., 2., 2., 4.], dtype=torch.float64)

print("Step-by-step QR solution:")
print("="*50)

# Step 1: QR decomposition
Q, R = torch.linalg.qr(X)
print("\nStep 1: X = QR")
print(f"Q =\n{Q.numpy().round(4)}")
print(f"\nR =\n{R.numpy().round(4)}")

# Step 2: Compute Q'y
Qty = Q.T @ y
print(f"\nStep 2: Q'y = {Qty.numpy().round(4)}")

# Step 3: Solve R*beta = Q'y by back-substitution
print(f"\nStep 3: Solve R·β = Q'y")
print(f"  {R[0,0]:.4f}·β₀ + {R[0,1]:.4f}·β₁ = {Qty[0]:.4f}")
print(f"  {R[1,0]:.4f}·β₀ + {R[1,1]:.4f}·β₁ = {Qty[1]:.4f}")

# Back-substitution
beta = torch.zeros(2, dtype=torch.float64)
beta[1] = Qty[1] / R[1, 1]  # From row 2
beta[0] = (Qty[0] - R[0, 1] * beta[1]) / R[0, 0]  # From row 1

print(f"\nSolution: β = {beta.numpy().round(4)}")

# Verify
print(f"\nVerification: X·β = {(X @ beta).numpy().round(4)}")
print(f"Actual y = {y.numpy()}")

---
## 4. PyTorch's lstsq Function

In practice, use `torch.linalg.lstsq()` which uses QR internally.

In [None]:
# Using PyTorch's built-in least squares
X = torch.tensor([[1., 0.],
                  [1., 1.],
                  [1., 2.],
                  [1., 3.]], dtype=torch.float64)
y = torch.tensor([1., 2., 2., 4.], dtype=torch.float64)

# lstsq returns: solution, residuals, rank, singular values
result = torch.linalg.lstsq(X, y.unsqueeze(1))

print("torch.linalg.lstsq output:")
print(f"  Solution: {result.solution.squeeze().numpy().round(4)}")
print(f"  Rank: {result.rank}")

# Compare with our QR solution
beta_qr = solve_least_squares_qr(X, y)
print(f"  Our QR solution: {beta_qr.numpy().round(4)}")
print("  ✓ Match!")

---
## 5. Computational Cost Comparison

| Method | Condition Number | Operations | Memory |
|--------|-----------------|------------|--------|
| Normal equations | $\kappa(X)^2$ | $O(mn^2 + n^3)$ | $O(n^2)$ for $X^\top X$ |
| QR decomposition | $\kappa(X)$ | $O(mn^2)$ | $O(mn)$ for Q |

QR is slightly more expensive but much more stable.

In [None]:
import time

def benchmark(m, n, n_trials=10):
    """Benchmark normal equations vs QR."""
    X = torch.randn(m, n, dtype=torch.float64)
    y = torch.randn(m, dtype=torch.float64)
    
    # Normal equations
    start = time.time()
    for _ in range(n_trials):
        _ = solve_least_squares_normal(X, y)
    time_normal = (time.time() - start) / n_trials
    
    # QR
    start = time.time()
    for _ in range(n_trials):
        _ = solve_least_squares_qr(X, y)
    time_qr = (time.time() - start) / n_trials
    
    return time_normal, time_qr

print("Timing comparison (seconds):")
print(f"{'Size':<20} {'Normal Eq':<15} {'QR':<15} {'Ratio':<10}")
print("-" * 60)

for m, n in [(100, 10), (1000, 50), (5000, 100)]:
    t_normal, t_qr = benchmark(m, n)
    print(f"{f'{m}×{n}':<20} {t_normal:<15.6f} {t_qr:<15.6f} {t_qr/t_normal:<10.2f}")

print("\n→ QR is slightly slower but MUCH more stable")

---
## 6. Genomics Application: Stable Regression

When fitting gene expression models, use QR for stability.

In [None]:
# Simulate gene expression regression
torch.manual_seed(42)

n_samples = 100
n_genes = 5

# Design matrix with some correlation (realistic)
# Intercept + Treatment + Batch + Age + BMI
intercept = torch.ones(n_samples)
treatment = torch.cat([torch.zeros(50), torch.ones(50)])
batch = torch.cat([torch.zeros(25), torch.ones(25), torch.zeros(25), torch.ones(25)])
age = torch.randn(n_samples) * 10 + 50
bmi = age * 0.3 + torch.randn(n_samples) * 5  # Correlated with age!

X = torch.stack([intercept, treatment, batch, age, bmi], dim=1).double()

# Check condition number
U, S, Vh = torch.linalg.svd(X)
kappa = S[0] / S[-1]
print(f"Design matrix condition number: {kappa:.1f}")

# Simulate gene expression for 5 genes
true_betas = torch.randn(5, n_genes, dtype=torch.float64)
Y = X @ true_betas + 0.5 * torch.randn(n_samples, n_genes, dtype=torch.float64)

# Solve for each gene using QR
print(f"\nFitting {n_genes} genes using QR decomposition...")

Q, R = torch.linalg.qr(X)
beta_estimates = torch.linalg.solve_triangular(R, Q.T @ Y, upper=True)

print(f"\nTrue vs Estimated coefficients for Gene 1:")
print(f"  True:      {true_betas[:, 0].numpy().round(3)}")
print(f"  Estimated: {beta_estimates[:, 0].numpy().round(3)}")

# Overall error
error = (beta_estimates - true_betas).norm() / true_betas.norm()
print(f"\nRelative error across all genes: {error:.4f}")

---
## Exercises

### Exercise 1: Manual QR
Use Gram-Schmidt to compute QR of a 3×2 matrix. Compare with PyTorch.

### Exercise 2: Stability Test
Create increasingly ill-conditioned problems and compare normal equations vs QR error.

### Exercise 3: DESeq2 Simulation
Simulate the per-gene regression that DESeq2 performs using QR decomposition.

In [None]:
# Your solutions here


---
## Summary

| Concept | Key Point |
|---------|----------|
| QR decomposition | $X = QR$, Q orthonormal, R triangular |
| Stability | Uses $\kappa(X)$ not $\kappa(X)^2$ |
| Solution | $R\beta = Q^\top y$ (back-substitution) |
| Practice | Use `torch.linalg.lstsq()` |

## Next: 04_regularization.ipynb