In [None]:
!gpustat -cu

In [None]:
%env CUDA_VISIBLE_DEVICES 1

In [None]:
import torch
import faiss
import faiss.contrib.torch_utils

In [None]:
torch.manual_seed(1234)

d = 3 * 7**2
res = faiss.StandardGpuResources()

nb = 10000                      # database size
nq = 10000                       # nb of queries
xb = torch.rand((nb, d), device='cuda')
xb[:, 0] += torch.arange(nb, device='cuda') / 1000.
xq = torch.rand((nq, d), device='cuda')
xq[:, 0] += torch.arange(nq, device='cuda') / 1000.

In [None]:
faiss.

In [None]:
# %%timeit
k = 4
index = faiss.GpuIndexFlatL2(res, d)
index.add(xb)
D, I = index.search(xq, 1)

In [None]:
I.shape

In [None]:
!gpustat -cu

In [None]:
k = 4
index = faiss.GpuIndexFlatL2(res, d)
index.add(xb)

In [None]:
xq = xq.detach().requires_grad_()
opt = torch.optim.SGD([xq], lr=10)
for i in range(100000):
    opt.zero_grad()
    _, I = index.search(xq, k)
    D = torch.nn.functional.mse_loss(xq, xb[I[:, 0]])
    l = D.mean()
    l.backward()
    opt.step()
    if i == 0 or (i + 1) % 1000 == 0:
        print(i, l.item())

In [None]:
def l2_dist(x, y):
    return (x - y).pow(2).sum(dim=-1)

In [None]:
def search(index, xq, k, with_grad=False):
    if with_grad:
        _, I = index.search(xq, k)
        D = l2_dist(xq.unsqueeze(1), xb[I])
    else:
        D, I = index.search(xq, k)
    return D, I

In [None]:
%%timeit
k = 4
with_grad = False
index = faiss.GpuIndexFlatL2(res, d)
index.add(xb)
D, I = search(index, xq, k, with_grad)
# D, I = index.search(xq, k)
# _, I = index.search(xq, k)
# D = l2_dist(xq.unsqueeze(1), xb[I])

In [None]:
k = 1
with_grad = True

index = faiss.GpuIndexFlatL2(res, d)
# index = faiss.index_cpu_to_gpu(index)
index.add(xb)

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
# index = faiss.GpuIndexFlatL2(res, d)
D, I = search(index, xq, k, with_grad)
end.record()
torch.cuda.synchronize()
print('time = %.2fms' % start.elapsed_time(end))

In [None]:
!gpustat -cu

In [None]:
!gpustat -cu

In [None]:
def calc_l2(x, y):
    y = y.transpose(0, 1)
    x2 = x.pow(2).sum(1, keepdims=True)
    y2 = y.pow(2).sum(0, keepdims=True)
    xy = x @ y
    return x2 + y2 - 2 * xy

def calc_l2_(x, y_):
    x2 = x.pow(2).sum(1, keepdims=True)
    y2 = y_.pow(2).sum(0, keepdims=True)
    xy = x @ y_
    return x2 + y2 - 2 * xy

In [None]:
# %%timeit -n1
with_grad = True

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.set_grad_enabled(with_grad):
    dist = calc_l2(xq, xb)
    D, I = dist.min(1)
end.record()
torch.cuda.synchronize()
print('time = %.2fms' % start.elapsed_time(end))

In [None]:
!gpustat -cu