In [119]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [120]:
# %matplotlib inline
# import matplotlib.pyplot as plt
import numpy as np
import torch
import time
from falkon import Falkon, kernels

## Set data dimensions

In [137]:
num_parallel = 15
N = 19000
M = 2000
D = 1024
cuda = True

## Create synthetic data

In [138]:
X = torch.randn(num_parallel, N, D)
W = torch.randn(num_parallel, M, D)
alpha = torch.randn(num_parallel, M, 1)
if cuda:
    X, W, alpha = X.cuda(), W.cuda(), alpha.cuda()

## Normal kernels (falkon)

In [139]:
kernel = kernels.GaussianKernel(1.0)

In [124]:
%%timeit
preds = torch.empty(num_parallel, N, 1, device=X.device)
for i in range(num_parallel):
    preds[i] = kernel.mmv(X[i], W[i], alpha[i])

26.6 ms ± 1.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [125]:
0.0015, 1/1000, 22.5/15

(0.0015, 0.001, 1.5)

## Normal kernels (in-core)

In [140]:
from falkon.la_helpers.cuda_la_helpers import square_norm

In [153]:
norm_times = []
mm_times = []
mulexp_times = []
# @torch.jit.script
def squared_euclidean_distance(x1, x2):
    t_0 = time.time()
    x1_norm = torch.norm(x1, p=2, dim=-1, keepdim=True).pow_(2)  # N x 1
    x2_norm = torch.norm(x2, p=2, dim=-1, keepdim=True).pow_(2)  # M x 1
#     x1_norm = square_norm(x1, dim=-1, keepdim=True)
#     x2_norm = square_norm(x2, dim=-1, keepdim=True)
    torch.cuda.synchronize()
    t_1 = time.time()
    res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
    torch.cuda.synchronize()
    t_2 = time.time()
    res = res.clamp_min_(1e-30)
    norm_times.append(t_1-t_0)
    mm_times.append(t_2-t_1)
    torch.cuda.synchronize()
    return res
# @torch.jit.script
def full_rbf_kernel(X1, X2, sigma):
    pairwise_dists = squared_euclidean_distance(X1 / sigma, X2 / sigma)
    t_3 = time.time()
    pairwise_dists.mul_(-0.5).exp_()
    torch.cuda.synchronize()
    t_4 = time.time()
    mulexp_times.append(t_4-t_3)
    return pairwise_dists

In [160]:
%%timeit
preds = torch.empty(num_parallel, N, 1, device=X.device)
for i in range(num_parallel):
    preds[i] = full_rbf_kernel(X[i], W[i], torch.tensor(1.0)) @ alpha[i]
torch.cuda.synchronize()

219 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [152]:
print("Norm: %.2fms - MM: %.2fms - MulExp: %.2fms" % (
    np.mean(norm_times) * 1000, np.mean(mm_times) * 1000, np.mean(mulexp_times) * 1000))

Norm: 0.79ms - MM: 6.99ms - MulExp: 1.36ms


In [161]:
print("Norm: %.2fms - MM: %.2fms - MulExp: %.2fms" % (
    np.mean(norm_times) * 1000, np.mean(mm_times) * 1000, np.mean(mulexp_times) * 1000))

Norm: 0.97ms - MM: 8.02ms - MulExp: 1.50ms


## batch kernel (in-core)

In [96]:
@torch.jit.script
def batch_sqeuc(x1, x2):
    x1_norm = torch.norm(x1, p=2, dim=-1, keepdim=True).pow_(2)  # B x N x 1
    x2_norm = torch.norm(x2, p=2, dim=-1, keepdim=True).pow_(2)  # B x M x 1
    # B x 1 x M + (B x N x 1  @  B x 1 x M)
    res = torch.baddbmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
    res = res.clamp_min_(1e-30)
    return res
@torch.jit.script
def batch_rbf_kernel(X1, X2, sigma):
    pairwise_dists = batch_sqeuc(X1 / sigma, X2 / sigma)
    return pairwise_dists.mul_(-0.5).exp_()

In [97]:
preds = batch_rbf_kernel(X, W, torch.tensor(1.0)) @ alpha
torch.cuda.synchronize()

In [98]:
%%timeit
preds = batch_rbf_kernel(X, W, torch.tensor(1.0)) @ alpha
torch.cuda.synchronize()

8.1 ms ± 9.09 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## batch kernel (keops)

In [49]:
from pykeops.torch import Genred
from pykeops.torch import LazyTensor

In [50]:
formula = f'Exp(IntInv(-2) * Sum(Square(Var(0,{D},0)/s - Var(1,{D},1)/s))) * Var(2,1,1)'
aliases = [
    f'Var(0,{D},0)',
    f'Var(1,{D},1)',
    'Var(2,1,1)',
    's = Pm(1)',
]

In [51]:
fn = Genred(formula, aliases, reduction_op='Sum', axis=1, dtype='float32',
            rec_multVar_highdim=None, enable_chunks=True)
variables = [X, W, alpha, torch.tensor([1.0], device=X.device).reshape(-1,1)]

In [56]:
%%timeit
preds2 = fn(*variables)

35.7 ms ± 36.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## batch kernel (by hand)

In [61]:
from falkon.mmv_ops import mmv_cuda
from falkon.kernels import GaussianKernel

In [62]:
kernel = GaussianKernel(1.0)

In [84]:
%%timeit
out = mmv_cuda.fmmv_cuda(X, W, alpha, kernel)
torch.cuda.synchronize()

9.12 ms ± 65.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
