In [1]:
import tiledb
from tiledb.cloud import client
import tiledb.vector_search as vs
from tiledb.vector_search.utils import *

import numpy as np
import random
import sklearn
import string

In [2]:
!cd /tmp && wget https://github.com/TileDB-Inc/TileDB-Vector-Search/releases/download/0.0.1/siftsmall.tgz
!cd /tmp && tar xf siftsmall.tgz

--2023-10-06 11:03:33--  https://github.com/TileDB-Inc/TileDB-Vector-Search/releases/download/0.0.1/siftsmall.tgz
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/627523373/b1990696-797c-4876-86c9-24cb101f7922?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20231006%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20231006T110333Z&X-Amz-Expires=300&X-Amz-Signature=7be26420dc408c0519e72dbc3ced4d62439d4fbefd61b40ccba35a28cb3422fa&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=627523373&response-content-disposition=attachment%3B%20filename%3Dsiftsmall.tgz&response-content-type=application%2Foctet-stream [following]
--2023-10-06 11:03:33--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/627523373/b1990696-797c-4876-86c9-24cb101f792

In [3]:
def delete_if_exists(uri):
    try:
        group = tiledb.Group(uri, "m")
    except tiledb.TileDBError as err:
        message = str(err)
        if "group does not exist" in message:
            return
        else:
            raise err
    group.delete()


In [4]:
namespace=client.default_user().username
random_suffix = "".join(random.choices(string.ascii_letters, k=10))

# Use this in staging notebook
# index_uri = f"tiledb://TileDB-Inc/s3://tiledb-unittest/groups/unit-tests/vector_search/{namespace}/sift10k_flat"
# ivf_index_uri = f"tiledb://TileDB-Inc/s3://tiledb-unittest/groups/unit-tests/vector_search/{namespace}/sift10k_ivf_flat"

# Use this for local tests
index_uri = f"tiledb://{namespace}/s3://tiledb-unittest/groups/unit-tests/vector_search/{namespace}/sift10k_flat_{random_suffix}"
ivf_index_uri = f"tiledb://{namespace}/s3://tiledb-unittest/groups/unit-tests/vector_search/{namespace}/sift10k_ivf_flat_{random_suffix}"

source_uri = "/tmp/siftsmall_base.fvecs"

delete_if_exists(index_uri)
delete_if_exists(ivf_index_uri)

In [5]:
flat_index = vs.ingest(
    index_type = "FLAT",
    index_uri = index_uri,
    source_uri = source_uri,
)

In [6]:
ivf_flat_index = vs.ingest(
    index_type="IVF_FLAT",
    source_uri=source_uri,
    index_uri=ivf_index_uri,
)

In [7]:
# Get query vectors with ground truth
query_vectors = load_fvecs("/tmp/siftsmall_query.fvecs")
ground_truth = load_ivecs("/tmp/siftsmall_groundtruth.ivecs")

In [8]:
def accuracy(result, gt):
    found = 0
    total = 0
    i = 0
    for r in result:
        total += len(r)
        found += len(np.intersect1d(r, gt[i]))
        i += 1
    return found / total

In [9]:
# Return the 100 most similar vectors to the query vectors with FLAT
result_d, result_i = flat_index.query(query_vectors, k=100)
ac = accuracy(result_i, ground_truth)
print(f"Accuracy: {ac}")
assert ac == 1.0



Accuracy: 1.0


In [10]:
# Return the 100 most similar vectors to the query vectors with IVF_FLAT
# (you can set the nprobe parameter)
result_ivf_d, result_ivf_i = ivf_flat_index.query(query_vectors, nprobe=10, k=100)
ac = accuracy(result_ivf_i, ground_truth)
print(f"Accuracy: {ac}")
assert ac >= 0.85

Accuracy: 0.9204


In [None]:
# Test distributed query
result_ivf_d, result_ivf_i = ivf_flat_index.query(query_vectors, nprobe=10, k=100, mode=tiledb.cloud.dag.Mode.BATCH, num_partitions=2)
ac = accuracy(result_ivf_i, ground_truth)
print(f"Accuracy: {ac}")
assert ac >= 0.85

In [None]:
delete_if_exists(index_uri)
delete_if_exists(ivf_index_uri)