In [95]:
import json
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
import cohere

import pandas as pd
import numpy as np
from tqdm import tqdm
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
from loguru import logger
from utility import milvus_collection_exists, create_milvus_collection
from vector_db import MilvusCollection

from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_openai import OpenAIEmbeddings


In [128]:
FILE = "./context_id.json"
FILE_QUERY = "./data/context_qa_en.json"
COLLECTION_NAME = 'Cohere_embedding'  # Collection name
DIMENSION = 1536  # Embeddings size, cohere embeddings default to 4096 with the large model
COUNT = 25091  # How many questions to embed and insert into Milvus
BATCH_SIZE = 96 # How large of batches to use for embedding and insertion
MILVUS_HOST = 'localhost'  # Milvus server URI
MILVUS_PORT = '19530'
COHERE_API_KEY = 'Y0zL0EiA9HyasDgJxWatSJ0QtjQ14fGU5O6drzWU'  # API key obtained from Cohere
OPENAI_API_KEY="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"
embeddings_model = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)

In [60]:
co = cohere.Client(COHERE_API_KEY)
model="embed-english-v3.0"


In [141]:
###########QUERT PROCESSING############
with open(FILE_QUERY, 'r') as f:
    data_query = json.load(f)


query_list = []
for data in data_query:
    query_list.append(data['question'])
print(len(query_list))

context_id = []
for data in data_query:
    context_id.append(data['context_id'])
print(len(context_id))
    

14656
14656


In [134]:
##############CONTEXT PROCESSING##################
with open(FILE, 'r') as f:
    data = json.load(f)
    

    
# Recursively split json data - If you need to access/manipulate the smaller json chunks
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 512,
    chunk_overlap  = 128,
    length_function = len
)

def preprocess(data, splitter):
    data_split = []
    sub_context_id = 0
    for i in range(len(data)):
        texts = splitter.split_text(data[i]['context'])
        for text in texts:
            dic = {}
            dic['id'] = sub_context_id
            dic['sub_id'] = i
            dic['context'] = text
            sub_context_id+=1
            data_split.append(dic)
        
    return pd.DataFrame(data_split)


data_processed = preprocess(data, text_splitter)
len(data_processed),data_processed[:10]

(25091,
    id  sub_id                                            context
 0   0       0  WikiLeaks () is an international non-profit or...
 1   1       1  The war in Europe concluded with an invasion o...
 2   2       1  atomic bombs on the Japanese cities of Hiroshi...
 3   3       1  were set up by fiat by the Allies and war crim...
 4   4       2  The exact number of Arab casualties is unknown...
 5   5       2  and 200 irregulars. According to Henry Laurens...
 6   6       2  death. According to Laurens, the largest part ...
 7   7       3  As Thomas Hall (2000) notes, "The Sung Empire ...
 8   8       3  between the tributary states and empires, and ...
 9   9       3  in the European subsystem." In Werner Sombart'...)

In [64]:
def embed_query(texts):
    q_embeddings = co.embed(texts=texts,
                model=model,
                input_type="search_query")
    return q_embeddings.embeddings

def embed_doc(texts):
    doc_embeddings = co.embed(texts=texts,
                model=model,
                input_type="search_document")
    return doc_embeddings.embeddings

In [100]:
# Connect to Milvus Database
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
# Remove collection if it already exists
if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)
    
# Create collection which includes the id, title, and embedding.
fields = [
    FieldSchema(name='id', dtype=DataType.INT64,is_primary=True),
    FieldSchema(name='sub_id', dtype=DataType.INT64),
    FieldSchema(name='context_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]

schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)

# Create an IVF_FLAT index for collection.
index_params = {
    'metric_type':'IP',
    'index_type':"IVF_FLAT",
    'params':{"nlist": 1024}
}
collection.create_index(field_name="context_embedding", index_params=index_params)
collection.load()

In [101]:
for batch in tqdm(np.array_split(data_processed, (COUNT/BATCH_SIZE) + 1)):
    contexts = batch['context'].tolist()
    embeddings = embeddings_model.embed_documents(contexts)
    id = batch['id'].tolist()
    sub_id = batch['sub_id'].tolist()
    
    data = [
        {
            'id':id[i],
            'sub_id': sub_id[i],
            'context_embedding': embeddings[i]
        } for i in range(len(id))
    ]

    collection.insert(data=data)

  return bound(*args, **kwds)
  0%|          | 0/262 [00:00<?, ?it/s]

100%|██████████| 262/262 [10:04<00:00,  2.31s/it]  


In [137]:
eb_query = []
for batch in tqdm(np.array_split(query_list, (len(query_list)/256) + 1)):
    embeddings = embeddings_model.embed_documents(batch)
    eb_query.extend(embeddings)

  0%|          | 0/58 [00:00<?, ?it/s]

100%|██████████| 58/58 [02:06<00:00,  2.17s/it]


In [138]:
len(eb_query),len(eb_query[0])

(14656, 1536)

In [153]:
search_params = {
    "metric_type": "IP",
    "params": {
        # search for vectors with a distance greater than 0.8
        "radius": 0.7
    }
}
# print(np.array(eb_query[0]).shape)
results = collection.search([np.array(eb_query[0])], "context_embedding", search_params, 5, output_fields=["sub_id"])
results

['["id: 0, distance: 0.850684404373169, entity: {\'sub_id\': 0}", "id: 2676, distance: 0.8443490266799927, entity: {\'sub_id\': 1434}", "id: 4523, distance: 0.839410126209259, entity: {\'sub_id\': 2497}", "id: 6878, distance: 0.8368374109268188, entity: {\'sub_id\': 3763}", "id: 7960, distance: 0.8250666856765747, entity: {\'sub_id\': 4376}"]']

In [155]:
recall = 0.

for i in range(len(eb_query)):
    
    results = collection.search([np.array(eb_query[i])], "context_embedding", search_params, 15, output_fields=["sub_id"])
    score = {}
    
    for result in results:
        for hit in result:
            cont_id = hit.sub_id
            distance = hit.distance
            if cont_id not in score:
                score[cont_id] = distance 
                
    print(score)
    length = len(score)
    
    #相似度从大到小5个索引
    if length < 6:
        keys = list(score.keys())
    else:
        values = np.array(list(score.values()))
        args = np.argpartition(values, 5)[::-1][:5]
        keys = np.array(list(score.keys()))[args]
    
    print(keys)
    
    cont_id = context_id[i]
    print(f"docs retrieved for query {i}: {keys} (expected: {cont_id})")
    
    if cont_id in keys:
        recall = (i * recall + 1) / (i + 1)
    else:
        recall *= i / (i + 1)
    
    print("recall: ", recall)
    print('\n')
    
    
   


{0: 0.850684404373169, 1434: 0.8443490266799927, 2497: 0.839410126209259, 3763: 0.8368374109268188, 4376: 0.8250666856765747, 4652: 0.8223198652267456, 12342: 0.815341591835022, 63: 0.8127641677856445, 2277: 0.8107476830482483, 6308: 0.8052963018417358, 3905: 0.8039102554321289, 12522: 0.8034877777099609}
[   0 2497 3763 4376 1434]
docs retrieved for query 0: [   0 2497 3763 4376 1434] (expected: 0)
recall:  1.0


{6103: 0.8678044080734253, 6778: 0.8553253412246704, 6059: 0.8539287447929382, 6075: 0.8529354929924011, 1: 0.8440402150154114, 6052: 0.8419313430786133, 4979: 0.8417752981185913, 6374: 0.8395288586616516, 2387: 0.8358622789382935, 10043: 0.8350499868392944, 2770: 0.8337222933769226, 6268: 0.8296891450881958, 2414: 0.8281675577163696, 376: 0.8269612789154053}
[6103 6059 6075    1 6052]
docs retrieved for query 1: [6103 6059 6075    1 6052] (expected: 1)
recall:  1.0


{2: 0.9146761298179626, 11584: 0.8476214408874512, 10165: 0.8427757620811462, 6423: 0.839458703994751, 6608: 