In [1]:
from src.utils import set_random_seed, load_data
from src.opts import apply_nys_precond
from src.kernels import get_kernel
import time
import numpy as np
from pykeops.torch import LazyTensor
import torch
import wandb

In [2]:
def rand_nys_appx(K_sm, K_mm, lambd, n, r, device):
    # Calculate sketch
    Phi = torch.randn((n, r), device=device) / (n ** 0.5)
    Phi = torch.linalg.qr(Phi, mode='reduced')[0]

    Y = K_sm.T @ (K_sm @ Phi) 
    # + lambd * (K_mm @ Phi)

    # Calculate shift
    # TODO: Modify the shift to improve stability
    shift = torch.finfo(Y.dtype).eps
    Y_shifted = Y + n * shift * Phi

    # Calculate Phi^T * K * Phi (w/ shift) for Cholesky
    choleskytarget = torch.mm(Phi.t(), Y_shifted)

    try:
        # Perform Cholesky decomposition
        C = torch.linalg.cholesky(choleskytarget)
    except torch.linalg.LinAlgError:
        # eigendecomposition, eigenvalues and eigenvector matrix
        eigs, eigvectors = torch.linalg.eigh(choleskytarget)
        shift = shift + torch.abs(torch.min(eigs))
        # add shift to eigenvalues
        eigs = eigs + torch.abs(torch.min(eigs))
        # put back the matrix for Cholesky by eigenvector * eigenvalues after shift * eigenvector^T
        C = torch.linalg.cholesky(
            torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T)))

    B = torch.linalg.solve_triangular(C.t(), Y_shifted, upper=True, left=False)
    U, S, _ = torch.linalg.svd(B, full_matrices=False)
    S = torch.max(torch.square(S) - shift, torch.tensor(0.0))

    return U, S

In [3]:
def get_L(K_sm, K_mm, lambd, U, S, rho):
    n = U.shape[0]
    v = torch.randn(n, device=U.device)
    v = v / torch.linalg.norm(v)

    max_eig = None

    for _ in range(10):  # TODO: Make this a parameter or check tolerance instead
        v_old = v.clone()

        UTv = U.t() @ v
        v = U @ (UTv / ((S + rho) ** (0.5))) + 1/(rho ** 0.5) * (v - U @ UTv)

        v = K_sm.T @ (K_sm @ v) + lambd * (K_mm @ v)

        UTv = U.t() @ v
        v = U @ (UTv / ((S + rho) ** (0.5))) + 1/(rho ** 0.5) * (v - U @ UTv)

        max_eig = torch.dot(v_old, v)

        v = v / torch.linalg.norm(v)

    return max_eig

In [4]:
def compute_metrics_dict(K_nm, K_mm, K_tst, a, b, b_tst, lambd, b_norm, task):
    K_nmTb = K_nm.T @ b
    residual = K_nm.T @ (K_nm @ a) + lambd * (K_mm @ a) - K_nmTb
    rel_residual = torch.norm(residual) / torch.norm(K_nmTb)
    loss = 1/2 * (torch.dot(a, residual - K_nmTb) + b_norm ** 2)
    metrics_dict = {'rel_residual': rel_residual, 'train_loss': loss}

    pred = K_tst @ a

    test_metric_name = 'test_acc' if task == 'classification' else 'test_mse'
    if task == 'classification':
        test_metric = torch.sum(torch.sign(pred) == b_tst) / b_tst.shape[0]
        metrics_dict[test_metric_name] = test_metric
    else:
        test_metric = 1/2 * torch.norm(pred - b_tst) ** 2 / b_tst.shape[0]
        smape = torch.sum((pred - b_tst).abs() /
                          ((pred.abs() + b_tst.abs()) / 2)) / b_tst.shape[0]
        metrics_dict[test_metric_name] = test_metric
        metrics_dict['smape'] = smape

    return metrics_dict

In [5]:
def compute_and_log_metrics(K_nm, K_mm, K_tst, y, b, b_tst, lambd, b_norm, iter_time,
                            task, i, log_freq):
    iter_time_dict = {'iter_time': iter_time}
    if (i + 1) % log_freq == 0:
        wandb.log(iter_time_dict |
                  compute_metrics_dict(K_nm, K_mm, K_tst, y, b, b_tst, lambd, b_norm, task))
    else:
        wandb.log(iter_time_dict)

In [6]:
def sketchysgd(x, b, x_tst, b_tst, kernel_params, m, lambd, task, a0, bg, bH, r, rho, max_iter, log_freq, device):
    n = x.shape[0]
    b_norm = torch.linalg.norm(b)
   
    start_time = time.time()

    inducing_pts = torch.from_numpy(np.random.choice(n, m, replace=False))

    # Get inducing points kernel
    x_inducing_i = LazyTensor(x[inducing_pts][:, None, :])
    x_inducing_j = LazyTensor(x[inducing_pts][None, :, :])
    K_mm = get_kernel(x_inducing_i, x_inducing_j, kernel_params)

    # Get kernel between full training set and inducing points
    x_i = LazyTensor(x[:, None, :])
    K_nm = get_kernel(x_i, x_inducing_j, kernel_params)

    # Get kernel for test set
    x_tst_i = LazyTensor(x_tst[:, None, :])
    K_tst = get_kernel(x_tst_i, x_inducing_j, kernel_params)
    
    # Compute the preconditioner
    hess_pts = torch.from_numpy(np.random.choice(n, bH, replace=False))
    x_hess_i = LazyTensor(x[hess_pts][:, None, :])
    K_sm = get_kernel(x_hess_i, x_inducing_j, kernel_params)

    adj_factor = (n / bH) ** 0.5

    U, S = rand_nys_appx(adj_factor * K_sm, K_mm, lambd, m, r, device)

    # Automatically compute the learning rate
    # Do so as in PROMISE -- matvecs with inverse preconditioner and subsampled Hessian in factorized form
    hess_pts_lr = torch.from_numpy(np.random.choice(n, bH, replace=False))
    x_hess_lr_i = LazyTensor(x[hess_pts_lr][:, None, :])
    K_sm_lr = get_kernel(x_hess_lr_i, x_inducing_j, kernel_params)
    eta = 0.5 / (get_L(adj_factor * K_sm_lr, K_mm, lambd, U, S, rho))

    a = a0.clone()
    iter_time = time.time() - start_time

    # Compute and log metrics before any optimization is performed
    compute_and_log_metrics(K_nm, K_mm, K_tst, a, b, b_tst, lambd, b_norm, iter_time,
                        task, -1, log_freq)

    for i in range(max_iter):
        start_time = time.time()

        # Get a stochastic gradient
        # TODO: Use a shuffling approach instead of random sampling to match PROMISE
        idx = torch.from_numpy(np.random.choice(n, bg, replace=False))
        x_idx_i = LazyTensor(x[idx][:, None, :])
        K_nm_idx = get_kernel(x_idx_i, x_inducing_j, kernel_params)
        g = n/bg * (K_nm_idx.T @ (K_nm_idx @ a - b[idx])) + lambd * (K_mm @ a)
    
        # Apply the preconditioner
        dir = apply_nys_precond(U, S, rho, g)
    
        # Update params w/ auto learning rate and preconditioned stochastic gradient
        a -= eta * dir
    
        # Call function to compute and log metrics (as necessary)
        iter_time = time.time() - start_time
        compute_and_log_metrics(K_nm, K_mm, K_tst, a, b, b_tst, lambd, b_norm, iter_time,
                    task, i, log_freq)

In [7]:
def sketchysvrg(x, b, x_tst, b_tst, kernel_params, m, lambd, task, a0, bg, bH, r, rho, update_freq, max_iter, log_freq, device):
    n = x.shape[0]
    b_norm = torch.linalg.norm(b)

    start_time = time.time()

    inducing_pts = torch.from_numpy(np.random.choice(n, m, replace=False))

    # Get inducing points kernel
    x_inducing_i = LazyTensor(x[inducing_pts][:, None, :])
    x_inducing_j = LazyTensor(x[inducing_pts][None, :, :])
    K_mm = get_kernel(x_inducing_i, x_inducing_j, kernel_params)

    # Get kernel between full training set and inducing points
    x_i = LazyTensor(x[:, None, :])
    K_nm = get_kernel(x_i, x_inducing_j, kernel_params)

    # Get kernel for test set
    x_tst_i = LazyTensor(x_tst[:, None, :])
    K_tst = get_kernel(x_tst_i, x_inducing_j, kernel_params)

    # Compute the preconditioner
    hess_pts = torch.from_numpy(np.random.choice(n, bH, replace=False))
    x_hess_i = LazyTensor(x[hess_pts][:, None, :])
    K_sm = get_kernel(x_hess_i, x_inducing_j, kernel_params)

    adj_factor = (n / bH) ** 0.5

    U, S = rand_nys_appx(adj_factor * K_sm, K_mm, lambd, m, r, device)

    # Automatically compute the learning rate
    # Do so as in PROMISE -- matvecs with inverse preconditioner and subsampled Hessian in factorized form
    hess_pts_lr = torch.from_numpy(np.random.choice(n, bH, replace=False))
    x_hess_lr_i = LazyTensor(x[hess_pts_lr][:, None, :])
    K_sm_lr = get_kernel(x_hess_lr_i, x_inducing_j, kernel_params)
    eta = (1/10) / get_L(adj_factor * K_sm_lr, K_mm, lambd, U, S, rho)

    a = a0.clone()
    a_tilde = None
    g_bar = None
    iter_time = time.time() - start_time

    # Compute and log metrics before any optimization is performed
    compute_and_log_metrics(K_nm, K_mm, K_tst, a, b, b_tst, lambd, b_norm, iter_time,
                            task, -1, log_freq)

    for i in range(max_iter):
        start_time = time.time()

        # Update snapshot and full gradient at snapshot
        if i % update_freq == 0:
            a_tilde = a.clone()
            g_bar = K_nm.T @ (K_nm @ a_tilde - b) + lambd * (K_mm @ a_tilde)

        # Get a stochastic gradient
        # TODO: Use a shuffling approach instead of random sampling to match PROMISE
        idx = torch.from_numpy(np.random.choice(n, bg, replace=False))
        x_idx_i = LazyTensor(x[idx][:, None, :])
        K_nm_idx = get_kernel(x_idx_i, x_inducing_j, kernel_params)
        a_diff = a - a_tilde
        g_diff = n/bg * (K_nm_idx.T @ (K_nm_idx @ a_diff)) + lambd * (K_mm @ a_diff)

        # Apply the preconditioner
        dir = apply_nys_precond(U, S, rho, g_diff + g_bar)

        # Update params w/ auto learning rate and preconditioned stochastic gradient
        a -= eta * dir

        # Call function to compute and log metrics (as necessary)
        iter_time = time.time() - start_time
        compute_and_log_metrics(K_nm, K_mm, K_tst, a, b, b_tst, lambd, b_norm, iter_time,
                                task, i, log_freq)

In [8]:
def nystrom_pcg(x, b, x_tst, b_tst, kernel_params, m, lambd, task, a0, r, rho, max_iter, log_freq, device):
    n = x.shape[0]
    b_norm = torch.linalg.norm(b)

    start_time = time.time()

    inducing_pts = torch.from_numpy(np.random.choice(n, m, replace=False))

    # Get inducing points kernel
    x_inducing_i = LazyTensor(x[inducing_pts][:, None, :])
    x_inducing_j = LazyTensor(x[inducing_pts][None, :, :])
    K_mm = get_kernel(x_inducing_i, x_inducing_j, kernel_params)

    # Get kernel between full training set and inducing points
    x_i = LazyTensor(x[:, None, :])
    K_nm = get_kernel(x_i, x_inducing_j, kernel_params)

    # Get kernel for test set
    x_tst_i = LazyTensor(x_tst[:, None, :])
    K_tst = get_kernel(x_tst_i, x_inducing_j, kernel_params)

    b_restricted = K_nm.T @ b

    # Compute the preconditioner
    U, S = rand_nys_appx(K_nm, K_mm, lambd, m, r, device)

    # Initialize PCG
    a = a0.clone()

    resid = b_restricted - (K_nm.T @ (K_nm @ a0) + lambd * (K_mm @ a0))
    z = apply_nys_precond(U, S, rho, resid)
    p = z.clone()

    iter_time = time.time() - start_time

    # Compute and log metrics before any optimization is performed
    compute_and_log_metrics(K_nm, K_mm, K_tst, a, b, b_tst, lambd, b_norm, iter_time,
                            task, -1, log_freq)

    for i in range(max_iter):
        start_time = time.time()

        # Perform PCG iteration
        v = K_nm.T @ (K_nm @ p) + lambd * (K_mm @ p)
        alpha = torch.dot(z, resid) / torch.dot(p, v)
        a += alpha * p

        rTz = torch.dot(resid, z)
        resid -= alpha * v
        z = apply_nys_precond(U, S, rho, resid)
        beta = torch.dot(resid, z) / rTz

        p = z + beta * p

        # Call function to compute and log metrics (as necessary)
        iter_time = time.time() - start_time
        compute_and_log_metrics(K_nm, K_mm, K_tst, a, b, b_tst, lambd, b_norm, iter_time,
                                task, i, log_freq)

In [9]:
data = 'homo'
seed = 0
device = 'cuda:1'

In [10]:
set_random_seed(seed)

In [11]:
Xtr, Xtst, ytr, ytst = load_data(data, seed, device)

In [12]:
m = 10000  # Number of inducing points
kernel_params = {'type': 'l1_laplace', 'sigma': 5120}
lambd = 1e-3
task = 'regression'
bg = 256
r = 30
rho = 1e-1 # 1e-3
update_freq = int(Xtr.shape[0] / bg)
max_iter = 10000
log_freq = 100

# opt = 'sketchysgd'
opt = 'sketchysvrg'
# opt = 'nystrom_pcg'

wandb_project = "sksgd_krr_testing"

In [13]:
# experiment_args = {
#     'dataset': data,
#     'task': task,
#     'kernel_params': kernel_params,
#     'lambd': lambd,
#     'm': m,
#     'opt': opt,
#     'bg': bg,
#     'r': r,
#     'rho': rho,
#     'max_iter': max_iter,
#     'log_freq': log_freq,
#     'seed': seed,
#     'device': device
# }

experiment_args = {
    'dataset': data,
    'task': task,
    'kernel_params': kernel_params,
    'lambd': lambd,
    'm': m,
    'opt': opt,
    'r': r,
    'rho': rho,
    'max_iter': max_iter,
    'log_freq': log_freq,
    'seed': seed,
    'device': device
}

if opt in ['sketchysgd', 'sketchysvrg']:
    experiment_args['bg'] = bg
if opt == 'sketchysvrg':
    experiment_args['update_freq'] = update_freq


In [14]:
with wandb.init(project=wandb_project, config=experiment_args):
    # Access the experiment configuration
    config = wandb.config

    # Load the dataset
    Xtr, Xtst, ytr, ytst = load_data(config.dataset, config.seed, config.device)

    bH = int(Xtr.shape[0] ** 0.5)

    # Initialize at 0
    a0 = torch.zeros(config.m, device=config.device)

    # Run the optimizer
    with torch.no_grad():
        if config.opt == 'sketchysgd':
            sketchysgd(Xtr, ytr, Xtst, ytst, config.kernel_params, config.m, config.lambd,
                    config.task, a0, config.bg, bH, config.r, config.rho, 
                    config.max_iter, config.log_freq, config.device)
        elif config.opt == 'sketchysvrg':
            sketchysvrg(Xtr, ytr, Xtst, ytst, config.kernel_params, config.m, config.lambd,
                        config.task, a0, config.bg, bH, config.r, config.rho, config.update_freq,
                        config.max_iter, config.log_freq, config.device)
        elif config.opt == 'nystrom_pcg':
            nystrom_pcg(Xtr, ytr, Xtst, ytst, config.kernel_params, config.m, config.lambd,
                        config.task, a0, config.r, config.rho, config.max_iter, 
                        config.log_freq, config.device)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpratikrathore8[0m ([33msketchy-opts[0m). Use [1m`wandb login --relogin`[0m to force relogin




Traceback (most recent call last):
  File "/tmp/ipykernel_3988738/3064188348.py", line 20, in <module>
    sketchysvrg(Xtr, ytr, Xtst, ytst, config.kernel_params, config.m, config.lambd,
  File "/tmp/ipykernel_3988738/3145367713.py", line 61, in sketchysvrg
    g_diff = n/bg * (K_nm_idx.T @ (K_nm_idx @ a_diff)) + lambd * (K_mm @ a_diff)
  File "/home/pratikr/fast_krr/fast_krr_env/lib/python3.10/site-packages/pykeops/common/lazy_tensor.py", line 2524, in __matmul__
    Kv = Kv.sum(Kv.dim() - 2, **kwargs)  # Matrix-vector or Matrix-matrix product
  File "/home/pratikr/fast_krr/fast_krr_env/lib/python3.10/site-packages/pykeops/common/lazy_tensor.py", line 2096, in sum
    return self.reduction("Sum", axis=axis, **kwargs)
  File "/home/pratikr/fast_krr/fast_krr_env/lib/python3.10/site-packages/pykeops/common/lazy_tensor.py", line 775, in reduction
    return res()
  File "/home/pratikr/fast_krr/fast_krr_env/lib/python3.10/site-packages/pykeops/common/lazy_tensor.py", line 957, in __call__


VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



0,1
iter_time,▄▄▃▂█▂▃▄▃▃▂▄▂▆▃▂▃▂▂▃▂▁▂▃▂▃▃▁▁▃▂▂▂▂▂▂▂▂▆▁
rel_residual,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
smape,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
iter_time,0.0885
rel_residual,0.00015
smape,0.05496
test_mse,0.11426
train_loss,11356.0


KeyboardInterrupt: 