In [1]:
import os
import json
import uuid
from transformers import AutoTokenizer, AutoModel
from milvus import default_server
from pymilvus import (
    connections, utility, Collection,
    CollectionSchema, FieldSchema, DataType
)
import torch

In [3]:
default_server.start()

In [10]:
class testing:
    def __init__(self) -> None:
        self.tokenizer=AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

        self.model=AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
        if torch.cuda.is_available():
            self.device = torch.device("cuda:0")  # Use GPU if CUDA is available

        else:
            self.device = torch.device("cpu")
        self.collection = None
        self.collection_name = "chat_demo"
        self.MILVUS_URI = 'http://localhost:19530'
        [self.MILVUS_HOST, self.MILVUS_PORT] = self.MILVUS_URI.split('://')[1].split(':')
        connections.connect(host=self.MILVUS_HOST, port=self.MILVUS_PORT)
        self.collection=Collection(name=self.collection_name)
        self.collection.load()
    def embedding(self,text_data):

        inputs = self.tokenizer(text_data, return_tensors='pt', padding=True, truncation=True)
        if self.device.type == 'cuda':
            inputs = {key: tensor.cuda() for key, tensor in inputs.items()}  # Move tensors to CUDA
        if self.device.type == 'xla':
            inputs = {key: tensor.to(self.device) for key, tensor in inputs.items()}# Move tensors to TPU
        with torch.no_grad():
            # Forward pass through the model
            outputs = self.model(**inputs)


        embeddings = outputs.last_hidden_state.mean(dim=1)  # Assuming you want to use mean pooling

        # Normalize the embeddings if needed
        embeddings=embeddings[0]

        normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)

        return normalized_embeddings
    def search(self,query):
        embedded_vec=self.embedding(query).cpu().numpy()
        
        
        res=self.collection.search(
            data=[embedded_vec],
            anns_field="embedding",
            param={
            'metric_type': 'IP',
            'params': {'nlist': 4096,'nprobe':512}
                    },
            limit=3,

            output_fields=["text"]   )

        text_li=list()
        id_li=list()
        dist_li=list()
        for i, hits in enumerate(res):

            for hit in hits:
                id_li.append(hit.entity.id)
                dist_li.append(hit.entity.distance)
                text_li.append(hit.entity.get("text"))
        data=dict()
        data["id"]=id_li
        data["dist"]=dist_li
        data["text"]=text_li
        return data

In [11]:
tester=testing()

In [12]:
with open('testcases\\two variable testcases.json','r') as file:
    test_data=json.load(file)

In [13]:
print(len(test_data))

1131


In [15]:
p_count=0
crct_count=0
n_count=0
for dictionary in test_data:
    test_res=tester.search(dictionary['query'])
    stat=-1
    for i in test_res['id']:
        if i in dictionary['ans']:
            p_count+=1
            stat=0
            print(dictionary['query'],test_res['text'])
        else:
            n_count+=1
    if stat==0:
        crct_count+=1
    stat=-1
            

In [16]:
print(p_count,n_count,crct_count)

0 0 0


In [10]:
default_server.stop()