In [24]:
import torch
from scipy.stats import ortho_group

In [25]:
def rand_nys_appx(K, n, r, device):
    # Calculate sketch
    Phi = torch.randn((n, r), device=device) / (n ** 0.5)
    Phi = torch.linalg.qr(Phi, mode='reduced')[0]

    Y = K @ Phi

    # Calculate shift
    shift = torch.finfo(Y.dtype).eps
    Y_shifted = Y + shift * Phi

    # Calculate Phi^T * K * Phi (w/ shift) for Cholesky
    choleskytarget = torch.mm(Phi.t(), Y_shifted)

    # Perform Cholesky decomposition
    C = torch.linalg.cholesky(choleskytarget)

    B = torch.linalg.solve_triangular(C.t(), Y_shifted, upper=True, left=False)
    U, S, _ = torch.linalg.svd(B, full_matrices=False)
    S = torch.max(torch.square(S) - shift, torch.tensor(0.0))

    return U, S

In [26]:
# Generate an orthogonal matrix of size 1000 by 1000
n = 1000
V = ortho_group.rvs(n)

V = torch.tensor(V, dtype=torch.float32).to('cuda')

Sigma = torch.diag(torch.arange(1, n+1, dtype=torch.float32) ** -3).to('cuda')

print(Sigma)

K = V @ Sigma @ V.t()

U, S = rand_nys_appx(K, n, 1000, 'cuda')

print(S)

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 1.2500e-01, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 3.7037e-02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.0060e-09, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 1.0030e-09,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         1.0000e-09]], device='cuda:0')
tensor([9.9941e-01, 1.2496e-01, 3.7036e-02, 1.5620e-02, 7.9991e-03, 4.6295e-03,
        2.9147e-03, 1.9530e-03, 1.3719e-03, 1.0001e-03, 7.5122e-04, 5.7878e-04,
        4.5523e-04, 3.6449e-04, 2.9633e-04, 2.4416e-04, 2.0354e-04, 1.7149e-04,
        1.4582e-04, 1.2502e-04, 1.0801e-04, 9.3933e-05, 8.2207e-05, 7.2357e-05,
        6.4011e-05, 5.6901e-05, 5.0814e-05, 4.5562e-05, 4.1007e-05,