In [3]:
from setting import METRIC_TYPE
from vector_db import MilvusCollection
from relational_db import SQLiteDB
from towhee import AutoConfig, AutoPipes
from json import load
from numpy import array, argpartition


collection = MilvusCollection('BART_keywords')
partition = 'en'
sqlite_db = SQLiteDB('context')

config = AutoConfig.load_config('sentence_embedding')
config.model = 'average_word_embeddings_glove.6B.300d'
sentence_embedding = AutoPipes.pipeline('sentence_embedding', config=config)

In [4]:
with open('../data/Query_BART_Keywords.json', 'r') as file:
    queries = load(file)

search_params = {
    "metric_type": METRIC_TYPE,
    "params": {
        # search for vectors with a distance greater than 0.8
        "radius": 0.8
    }
}

recall = 0.

for i, query in enumerate(queries):
    results = collection.search([embedding.get()[0] for embedding in sentence_embedding.batch(query['keywords'])], "embedding", search_params, 5, partition_names=[partition], output_fields=["context_ids"])
    
    score = {}
    
    for result in results:
        for hit in result:
            for context_id in hit.context_ids:
                distance = hit.distance
                if distance in score:
                    score[context_id] += distance
                else:
                    score[context_id] = distance
    
    length = len(score)
    
    if length < 6:
        keys = list(score.keys())
    else:
        values = array(list(score.values()))
        
        args = argpartition(values, 5)[:5]
        keys = array(list(score.keys()))[args]
    
    context_id = query['context_id']
    print(f"docs retrieved for query {i}: {keys} (expected: {context_id})")
    
    if context_id in keys:
        recall = (i * recall + 1) / (i + 1)
    else:
        recall *= i / (i + 1)
    
    print("recall: ", recall)
    print('\n')

# for result in results:
#     print("distance: ", result.distance)
#     entity = result.entity
#     print("found keyword: ", entity.keyword)
#     for i, context_id in enumerate(entity.context_ids):
#         print("==================================================================")
#         print("corresponding context: ", i + 1, ". ", sqlite_db.select(['context'], 'en', f'id = {context_id}')[0][0])
#     print('\n')

docs retrieved for query 0: [0, 1434] (expected: 0)
recall:  1.0

docs retrieved for query 1: [6470 9748 6541 6462 6746] (expected: 1)
recall:  0.5

docs retrieved for query 2: [13723 10634 13514 13362 10733] (expected: 2)
recall:  0.3333333333333333

docs retrieved for query 3: [] (expected: 3)
recall:  0.25

docs retrieved for query 4: [6470 9748 6541 6462 6746] (expected: 4)
recall:  0.2

docs retrieved for query 5: [ 2593   522     6 11327  6209] (expected: 5)
recall:  0.16666666666666669

docs retrieved for query 6: [ 2593   522     6 11327  6209] (expected: 6)
recall:  0.2857142857142857

docs retrieved for query 7: [ 4492  8038  6399 10625  3111] (expected: 7)
recall:  0.25

docs retrieved for query 8: [13723 10634 13514 13362 10733] (expected: 2)
recall:  0.2222222222222222

docs retrieved for query 9: [4376 2497 5251    8 2842] (expected: 8)
recall:  0.3

docs retrieved for query 10: [10426, 1541, 11425, 11785, 11362] (expected: 9)
recall:  0.2727272727272727

docs retrieved f

KeyboardInterrupt: 