# SIFTデータセットを使ったパフォーマンス検証


## データの準備
[SIFT](https://github.com/erikbern/ann-benchmarks/tree/main?tab=readme-ov-file)ベクトルデータセットを利用します。
近傍100件のIDがデータセットに含まれているので、検索結果の精度を測定できます。

In [None]:
!wget http://ann-benchmarks.com/sift-128-euclidean.hdf5

In [30]:
import h5py
import pandas as pd
dataset_path = 'sift-128-euclidean.hdf5'

with h5py.File(dataset_path, 'r') as f:
    def print_hdf5_structure(name, obj):
        print(name)
    f.visititems(print_hdf5_structure)
    distances = f['distances'][:]
    neighbors = f['neighbors'][:]
    train = f['train'][:]
    test = f['test'][:]


distances
neighbors
test
train


In [31]:
# データフレームに変換し、IDを生成
train_df = pd.DataFrame({'vector': [list(vec) for vec in train]})
train_df['id'] = [f"{i}" for i in range(len(train_df))]

test_df = pd.DataFrame({'vector': [list(vec) for vec in test]})
test_df['id'] = [f"img{i}" for i in range(len(test_df))]

In [32]:
distances_df = pd.DataFrame(distances)
neighbors_df = pd.DataFrame(neighbors)

## 精度の計算
[Vertex AI ベクトル検索](https://cloud.google.com/vertex-ai/docs/vector-search/overview?hl=ja)で紹介されているRecallを精度指標とします。

このRecallは、検索結果の中で実際に正しい近傍データであるものの割合です。

例えば、10個の検索結果を得るクエリを実行して9個の正解の最近傍を返した場合に、Recallは`9/10 = 0.9`になります。

$$
\text{Recall} = \frac{\text{Number of relevant documents retrieved in top } k}{k}
$$

このリコールの値は0から1の範囲をとり、1に近いほど、関連するドキュメントを検索できていることを示します。


In [33]:
def calculate_recall_byid(all_result_ids, k=10):
    recalls = []
    num_samples = len(all_result_ids.keys())
    for n in range(num_samples):
        true_ids = set(int(res) for res in neighbors_df.iloc[n].tolist()[:k])
        retrieved_ids = set(all_result_ids[n][:k])
        
        # 正しいIDのセットと取得したIDのセットの交差を使ってリコールを計算
        cnt = len(true_ids.intersection(retrieved_ids))
        recall = cnt / float(k)
        recalls.append(recall)
        # print(f"Recall for test instance {n}: {recall:.4f}")

    average_recall = np.mean(recalls)
    print(f"Average Recall over {num_samples} instances: {average_recall:.4f}")
    return average_recall

## 速度の計測
今回は簡易的な検証のため、専用のロードテストツールではなくJupyter Notebook上で以下のように検証を行います。

In [71]:
import time
import pandas as pd
from tqdm import tqdm
import numpy as np

def perform_search(test_df, search_function, k=10, num_iterations=1, timeout=None):
    percentile_90_list = []
    percentile_99_list = []
    total_execution_times = []
    all_result_ids = {}

    for i in range(num_iterations):
        print(f"Iteration {i+1}")

        search_times = []

        for index, row in tqdm(test_df.iterrows(), desc="Searching", total=test_df.shape[0]):
            vector = row['vector']
            start_time = time.time()
            try:
                response = search_function(vector, timeout)
            except Exception as e:
                print(f"Timeout or error occurred for index {index}: {e}")
                continue

            end_time = time.time()
            search_times.append(end_time - start_time)

            result_ids = [int(res_id) for res_id in response[:k]]  # 上位k件のドキュメントIDを取得
            all_result_ids[index] = result_ids

        df = pd.DataFrame(search_times, columns=['search_time'])

        percentile_90 = df['search_time'].quantile(0.9) * 1000  # ミリ秒に変換
        percentile_99 = df['search_time'].quantile(0.99) * 1000  # ミリ秒に変換

        percentile_90_list.append(percentile_90)
        percentile_99_list.append(percentile_99)
        total_execution_times.append(sum(search_times))

        print("Total execution time for searches: {:.3f} s".format(sum(search_times)))
        print("90th percentile of search times: {:.3f} ms".format(percentile_90))
        print("99th percentile of search times: {:.3f} ms".format(percentile_99))

    average_90 = np.mean(percentile_90_list)
    std_90 = np.std(percentile_90_list)

    average_99 = np.mean(percentile_99_list)
    std_99 = np.std(percentile_99_list)

    average_total_time = np.mean(total_execution_times)
    std_total_time = np.std(total_execution_times)

    print("Average total execution time: {:.3f} s (std: {:.3f})".format(average_total_time, std_total_time))
    print("Average 90th percentile of search times: {:.3f} ms (std: {:.3f})".format(average_90, std_90))
    print("Average 99th percentile of search times: {:.3f} ms (std: {:.3f})".format(average_99, std_99))

    return all_result_ids

# Vald

In [34]:
import grpc
from vald.v1.vald import insert_pb2_grpc
from vald.v1.vald import upsert_pb2_grpc
from vald.v1.vald import search_pb2_grpc
from vald.v1.vald import update_pb2_grpc
from vald.v1.vald import remove_pb2_grpc
from vald.v1.vald import object_pb2_grpc
from vald.v1.vald import index_pb2_grpc
from vald.v1.payload import payload_pb2
import time
from tqdm import tqdm

In [64]:
PORT_DEFAULT = ':8081'
host = "vald-lb-gateway.test-ns.svc.cluster.local"

options = [
    ('grpc.max_metadata_size', 32 * 1024),
]

## create a channel by passing "{host}:{port}"
channel = grpc.insecure_channel(host + PORT_DEFAULT, options=options)

## create stubs for calling RPCs
insertStub = insert_pb2_grpc.InsertStub(channel)
upsertStub = upsert_pb2_grpc.UpsertStub(channel)
updateStub = update_pb2_grpc.UpdateStub(channel)
removeStub = remove_pb2_grpc.RemoveStub(channel)
objectStub = object_pb2_grpc.ObjectStub(channel)
searchStub = search_pb2_grpc.SearchStub(channel)
indexStub = index_pb2_grpc.IndexStub(channel)

insertConfig = payload_pb2.Insert.Config(skip_strict_exist_check=True)
updateConfig = payload_pb2.Update.Config(skip_strict_exist_check=True)
removeConfig = payload_pb2.Remove.Config(skip_strict_exist_check=True)
upsertConfig = payload_pb2.Upsert.Config(skip_strict_exist_check=True)
searchConfig = payload_pb2.Search.Config(num=10, radius=-1, epsilon=0.2)

## Insert

In [36]:
def generatorUpsertStream(df):
    upsert_list = []
    for _, row in tqdm(df.iterrows(), total=df.shape[0], desc="Preparing Upsert Data"):
        v = payload_pb2.Object.Vector(id=row['id'], vector=row['vector'])
        upsert_list.append(payload_pb2.Upsert.Request(vector=v, config=upsertConfig))
    return upsert_list

def chunked_iterable(iterable, chunk_size):
    """Generates chunks of data from an iterable."""
    for i in range(0, len(iterable), chunk_size):
        yield iterable[i:i + chunk_size]

vec = generatorUpsertStream(train_df)

Preparing Upsert Data: 100%|██████████| 1000000/1000000 [01:11<00:00, 14056.39it/s]


In [37]:
chunk_size = 20000

# vecをチャンクに分割して処理
for chunk in tqdm(chunked_iterable(vec, chunk_size), total=(len(vec) + chunk_size - 1) // chunk_size, desc="Upsert Data"):
    for _ in upsertStub.StreamUpsert(iter(chunk)):
        pass

Upsert Data: 100%|██████████| 50/50 [05:42<00:00,  6.86s/it]


### Indexing完了確認

In [38]:
index_size = 1000000
check_interval = 5 # チェック間隔（秒）

start_time = time.time()

while True:
    res = indexStub.IndexInfo(payload_pb2.Empty())
    current_count = res.stored

    print(f"Current index count: {current_count}")

    if current_count >= index_size * 3: # index_replica=3
        end_time = time.time()
        print("Indexing completed in: {:.2f} seconds".format(end_time - start_time))
        break

    time.sleep(check_interval)

Current index count: 3000000
Indexing completed in: 0.01 seconds


## Search

In [72]:
def vald_search_function(vector, _):
    response = searchStub.Search(payload_pb2.Search.Request(vector=vector, config=searchConfig))
    return [int(result.id) for result in response.results]

# Vald
vald_all_result_ids = perform_search(test_df, vald_search_function, k=10, num_iterations=1)

Iteration 1


Searching: 100%|██████████| 10000/10000 [00:48<00:00, 207.01it/s]

Total execution time for searches: 46.998 s
90th percentile of search times: 8.540 ms
99th percentile of search times: 15.398 ms
Average total execution time: 46.998 s (std: 0.000)
Average 90th percentile of search times: 8.540 ms (std: 0.000)
Average 99th percentile of search times: 15.398 ms (std: 0.000)





In [73]:
# Vald
average_recall = calculate_recall_byid(vald_all_result_ids)

Average Recall over 10000 instances: 0.9993


# Opensearch

In [41]:
from opensearchpy import OpenSearch
from opensearchpy.exceptions import ConnectionError
import time
import json

In [42]:
import urllib3
# InsecureRequestWarningを無効化
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

In [43]:
initial_admin_password = '' # set your password

In [44]:
# train_dfのvectorの次元数を取得
vector_dim = len(train_df['vector'][0])
vector_dim

128

In [45]:
client = OpenSearch(
    hosts=[{'host': 'my-third-cluster.opensearch-3.svc.cluster.local', 'port': 9200}],
    http_auth=('admin', initial_admin_password),
    use_ssl=True,
    verify_certs=False
)



## Insert

In [48]:
def index_documents_bulk(df, index_name, bulk_size, max_retries=3, retry_delay=5):
    bulk_data = []

    for _, row in tqdm(df.iterrows(), total=df.shape[0], desc="Indexing Documents"):
        action = {
            "index": {
                "_index": index_name,
                "_id": row['id']
            }
        }
        vector = [float(v) for v in row['vector']]

        document = {
            'vec_id': row['id'],
            'vector': vector
        }

        bulk_data.append(json.dumps(action))
        bulk_data.append(json.dumps(document))

        if len(bulk_data) >= 2 * bulk_size:
            success = False
            retries = 0
            while not success and retries < max_retries:
                try:
                    client.bulk(body="\n".join(bulk_data) + "\n")
                    success = True
                except Exception as e:
                    print(f"Error indexing documents: {e}")
                    retries += 1
                    if retries < max_retries:
                        print(f"Retrying... ({retries}/{max_retries})")
                        time.sleep(retry_delay)
                    else:
                        print("Max retries reached. Skipping this batch.")
            bulk_data = []  # バッファをクリア

    # 残りのデータのリクエスト
    if bulk_data:
        success = False
        retries = 0
        while not success and retries < max_retries:
            try:
                client.bulk(body="\n".join(bulk_data) + "\n")
                success = True
            except Exception as e:
                print(f"Error indexing documents: {e}")
                retries += 1
                if retries < max_retries:
                    print(f"Retrying... ({retries}/{max_retries})")
                    time.sleep(retry_delay)
                else:
                    print("Max retries reached. Skipping final batch.")

    print("All documents have been indexed.")


In [56]:
index_name = 'sift-nmslib-l2-24' # indexは事前に作成しておく
index_documents_bulk(train_df, index_name, 200)

Indexing Documents: 100%|██████████| 1000000/1000000 [03:35<00:00, 4629.99it/s]

All documents have been indexed.





### Indexing完了確認

In [52]:
index_count = 1000000
check_interval = 5 # チェック間隔（秒）

def get_document_count(index_name):
    try:
        response = client.indices.stats(index=index_name)
        doc_count = response['_all']['primaries']['docs']['count']
        return doc_count
    except ConnectionError as e:
        print(f"Connection error: {e}")
        return None


start_time = time.time()

while True:
    doc_count = get_document_count(index_name)
    
    if doc_count is not None:
        print(f"Current document count: {doc_count}")

        if doc_count >= index_count:
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Indexing completed in {elapsed_time:.2f} seconds.")
            break

    time.sleep(check_interval)

Current document count: 1000000
Indexing completed in 0.00 seconds.


## Search

In [67]:
def search(vector, index_name, timeout):
    search_query = {
    "size": 10,
        "query": {
            "knn": {
            "vector": {
                "vector": vector,
                "k": 10,
                "method_parameters" : {
                  "ef_search": 100
                }
            }
            }
        }
    }

    response = client.search(
        index=index_name,
        body=search_query,
        params={},
        request_timeout=timeout
    )
    return response

In [75]:
def open_search_function(vector, timeout):
    response = search(vector, index_name, timeout)
    return [int(res['_id']) for res in response['hits']['hits']]
    
opensearch_all_result_ids = perform_search(test_df, open_search_function, k=10, num_iterations=1, timeout=10)

Iteration 1


Searching: 100%|██████████| 10000/10000 [02:56<00:00, 56.67it/s]

Total execution time for searches: 173.882 s
90th percentile of search times: 19.394 ms
99th percentile of search times: 24.289 ms
Average total execution time: 173.882 s (std: 0.000)
Average 90th percentile of search times: 19.394 ms (std: 0.000)
Average 99th percentile of search times: 24.289 ms (std: 0.000)





In [76]:
average_recall = calculate_recall_byid(opensearch_all_result_ids)

Average Recall over 10000 instances: 0.9989


## 謝辞
データセットは以下ライセンスに基づき使用させていただきました。

http://corpus-texmex.irisa.fr/

データセットを公開いただきましたLaurent Amsaleg様とHervé Jégou様、ANN Benchmarksを公開いただきましたErik Bernhardsson様に感謝を申し上げます。