In [18]:
import numpy as np
from scipy.linalg import hadamard

def hadamard_sketch(r, m):
    # Constructs a Hadamard Sketch Matrix
    # Takes:
    # - r (dimension of )
    # - m (order of Hadamard Matrix)
    # Gives:
    # - Hadamard Sketch

    H = hadamard(2 ** m)
    D = np.diag(np.random.choice([-1, 1], int(2**m)))
    indices = np.random.choice(2 ** m, r)
    S = []
    for index in indices:
        new = np.zeros(2 ** m)
        new[index] = 1
        S.append(new)
    S = np.vstack(S)
    return S @ H @ D

def leverage_sample(A, r):
    m = A.shape[0]
    U,_,_ = np.linalg.svd(A)
    p = np.square(np.linalg.norm(U, axis=1))
    S_choices = np.random.choice(m, r, replace=True)
    S = np.zeros((r, m))
    for i in np.arange(r):
        S[i, S_choices[i]] = 1/(r * p[i])
    return S

def eps_JLT(r_1, r_2):
    return np.random.choice([-np.sqrt(3/r_1), np.sqrt(3/r_1), 0], size=(r_1, r_2), p=[1/6, 1/6, 2/3])

def approximate_leverage_sample(A, r_1, r_2, r):
    m = int(np.ceil(np.log2(A.shape[0])))
    A_tilde = np.zeros((2 ** m, 2 ** m))
    A_tilde[:A.shape[0], :A.shape[1]] = A
    S_1 = hadamard_sketch(r_1, m)
    S_2 = eps_JLT(r_1, r_2)
    approx = (A_tilde @ (np.linalg.pinv(S_1 @ A_tilde) @ S_2))[:A.shape[0]]
    p = np.square(np.linalg.norm(approx, axis=1))
    S_choices = np.random.choice(m, r, replace=True)
    S = np.zeros((r, A.shape[0]))
    for i in np.arange(r):
        S[i, S_choices[i]] = 1/(r * p[i])
    return S

def approx_solve(A, b, r_1, r_2, r, approx_leverage=False):
    if approx_leverage:
        S = approximate_leverage_sample(A, r_1, r_2, r)
        print(S.shape)
        return np.linalg.pinv(S @ A) @ (S @ b)
    else:
        S = leverage_sample(A, r)
        return np.linalg.pinv(S @ A) @ (S @ b)

In [19]:
import numpy as np
import scipy
import time

m, n = 1000, 10
A, x = np.random.randn(m, n), np.random.randn(n, 1)
b = A @ x + 0.1 * np.random.randn(m, 1)

#start1 = time.time()
#x_lstsq = np.linalg.pinv(A) @ b
#end1 = time.time()
start2 = time.time()
x_approx = approx_solve(A, b, 100, 10, 10, True)
end2 = time.time()
#print("Least-squares x:", x_lstsq)
#print("Time:", end1 - start1)
print("Approximate Least-squares x:", x_approx)
print("Time:", end2 - start2)
print("x:", x)

(10, 1000)
Approximate Least-squares x: [[ 0.12881877]
 [-0.47309552]
 [ 0.61222563]
 [ 0.60676244]
 [ 0.19314851]
 [-0.0966958 ]
 [-0.10197319]
 [-0.24522309]
 [ 0.22947198]
 [-0.01952384]]
Time: 0.09389901161193848
x: [[ 0.38891406]
 [-0.54732568]
 [ 0.34035253]
 [ 0.65510577]
 [-0.07885547]
 [ 0.07603404]
 [ 0.83528704]
 [-0.75888453]
 [ 0.80987631]
 [ 0.59725388]]
