In [1]:
import os
import json
import uuid
from transformers import AutoTokenizer, AutoModel
from milvus import default_server
from pymilvus import (
    connections, utility, Collection,
    CollectionSchema, FieldSchema, DataType
)
import torch

In [2]:
class VectorSearchWrapper:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        if torch.cuda.is_available():
            self.device = torch.device("cuda:0")  # Use GPU if CUDA is available

        else:
            self.device = torch.device("cpu")
        self.EMBED_MODEL = 'sentence-transformers/all-mpnet-base-v2'
        self.DIM = 768
        self.json_file_path='combined files\\cleaned_and_combined_hyd.json'
        self.collection = None
        self.collection_name = "chat_demo"
        self.MILVUS_URI = 'http://localhost:19530'
        [self.MILVUS_HOST, self.MILVUS_PORT] = self.MILVUS_URI.split('://')[1].split(':')
        self.result=""
        self.embeddings=dict()
        self.history=list()
        self.run()
    def initialize_model(self):
        # Load the pre-trained model and tokenizer
        model_name = self.EMBED_MODEL
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
    def embedding(self,text_data):

        inputs = self.tokenizer(text_data, return_tensors='pt', padding=True, truncation=True)
        if self.device.type == 'cuda':
            inputs = {key: tensor.cuda() for key, tensor in inputs.items()}  # Move tensors to CUDA
        if self.device.type == 'xla':
            inputs = {key: tensor.to(self.device) for key, tensor in inputs.items()}# Move tensors to TPU
        with torch.no_grad():
            # Forward pass through the model
            outputs = self.model(**inputs)


        embeddings = outputs.last_hidden_state.mean(dim=1)  # Assuming you want to use mean pooling

        # Normalize the embeddings if needed
        embeddings=embeddings[0]

        normalized_embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)

        return normalized_embeddings
    
    def preprocess_and_embed(self):
        embedded_list = {}
        with open(self.json_file_path, 'r') as file:
            json_data = json.load(file)

        # Preprocess and embed each JSON entry
        for entry in json_data:
            curr_dict = {}
            address=entry['Location']
            terms = [term.strip() for term in address.split(',')]
            if len(terms)>4:
                replacable=', '.join(terms[-4:])
            else:
                replacable=address
            
            entry['text data']=entry['text data'].replace(address,replacable)
            text_data = entry["text data"].replace(f'Name: {entry["Name"]}','')
            entry_id = entry.get("Id", None)  # Extract the ID from JSON (assuming it has an "id" field)
            
            if text_data not in embedded_list:  # Avoid duplicate embeddings
                embedding = self.embedding(text_data)
                embedded_list[entry_id] = {
                    "embedding": embedding,
                    "text": text_data
                }

        return embedded_list
    def create_collection(self,collection_name):
        connections.connect(host=self.MILVUS_HOST, port=self.MILVUS_PORT)

        has_collection = utility.has_collection(collection_name)

        if has_collection:
            utility.drop_collection(collection_name)
        # Create collection
        fields = [
            FieldSchema(name='id', dtype=DataType.INT64, is_primary=True,auto_id=False),
            FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=self.DIM),
            FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=1000)
        ]
        schema = CollectionSchema(
            fields=fields,
            description="Towhee demo",
            enable_dynamic_field=True
        )
        collection = Collection(name=collection_name, schema=schema)

        # Change index here if you want to accelerate search
        index_params = {
            'metric_type': 'L2',
            'index_type': 'IVF_FLAT',
            'params': {'nlist': 8192,'nprobe':512}
        }
        collection.create_index(
            field_name='embedding',
            index_params=index_params
        )

        return collection

    def prepare_data(self):
        data = []
        embedded_data = self.embeddings

        for entry_id, entry in embedded_data.items():
            entity = {
                'id': entry_id,  # Use the provided ID
                'embedding': entry['embedding'],
                'text': entry['text']
            }
            data.append(entity)

        return data
    def insert_data(self,collection_name):
        self.collection = self.create_collection(collection_name)

        # Prepare data
        data_to_insert = self.prepare_data()
        # Insert data into the collection
        self.collection.insert(data_to_insert)
    def run(self,collection_name=None):
        if collection_name is None:
            collection_name = self.collection_name

        # Initialize model asynchronously
        self.initialize_model()
        

        # Preprocess and embed data asynchronously
        self.embeddings =self.preprocess_and_embed()

        # Insert data asynchronously
        self.insert_data(collection_name)
        
        
    def search_milvus(self,query):
        embedded_vec=self.embedding(query).cpu().numpy()
        collection=Collection(name=self.collection_name)
        collection.load()
        res=collection.search(
            data=[embedded_vec],
            anns_field="embedding",
            param={
            'metric_type': 'IP',
            'params': {'nlist': 4096}
                    },
            limit=3,

            output_fields=["text"]   )

        text_li=list()
        id_li=list()
        dist_li=list()
        for i, hits in enumerate(res):

            for hit in hits:
                id_li.append(hit.entity.id)
                dist_li.append(hit.entity.distance)
                text_li.append(hit.entity.get("text"))
        data=dict()
        data["id"]=id_li
        data["dist"]=dist_li
        data["text"]=text_li
        return data


In [3]:
default_server.start()

In [4]:
chatbot=VectorSearchWrapper()

In [5]:
data=chatbot.search_milvus(query="CBSE schools")

RPC error: [search], <MilvusException: (code=0, message=fail to search on all shard leaders, err=All attempts results:
attempt #1:code: UnexpectedError, error: fail to Search, QueryNode ID=93, reason=collection:452604578851989266, metric type not match: expected=L2, actual=IP: invalid parameter
attempt #2:code: UnexpectedError, error: fail to Search, QueryNode ID=93, reason=collection:452604578851989266, metric type not match: expected=L2, actual=IP: invalid parameter
attempt #3:code: UnexpectedError, error: fail to Search, QueryNode ID=93, reason=collection:452604578851989266, metric type not match: expected=L2, actual=IP: invalid parameter
attempt #4:code: UnexpectedError, error: fail to Search, QueryNode ID=93, reason=collection:452604578851989266, metric type not match: expected=L2, actual=IP: invalid parameter
attempt #5:code: UnexpectedError, error: fail to Search, QueryNode ID=93, reason=collection:452604578851989266, metric type not match: expected=L2, actual=IP: invalid parame

MilvusException: <MilvusException: (code=0, message=fail to search on all shard leaders, err=All attempts results:
attempt #1:code: UnexpectedError, error: fail to Search, QueryNode ID=93, reason=collection:452604578851989266, metric type not match: expected=L2, actual=IP: invalid parameter
attempt #2:code: UnexpectedError, error: fail to Search, QueryNode ID=93, reason=collection:452604578851989266, metric type not match: expected=L2, actual=IP: invalid parameter
attempt #3:code: UnexpectedError, error: fail to Search, QueryNode ID=93, reason=collection:452604578851989266, metric type not match: expected=L2, actual=IP: invalid parameter
attempt #4:code: UnexpectedError, error: fail to Search, QueryNode ID=93, reason=collection:452604578851989266, metric type not match: expected=L2, actual=IP: invalid parameter
attempt #5:code: UnexpectedError, error: fail to Search, QueryNode ID=93, reason=collection:452604578851989266, metric type not match: expected=L2, actual=IP: invalid parameter
)>

In [6]:
print(data)

{'id': [280, 701, 948], 'dist': [0.7203517556190491, 0.7094092965126038, 0.708120584487915], 'text': ['Name: ~Category: Public Schools~Location: Survey No .1, Seetharampuram, Beside Kalyani Theatre, Old Bowenpally, Hasmathpet, Secunderabad - 500011~Faculty: ~Sports: Athletics, Carroms, Chess, Karate, Skating, Yoga~Amenities: Transport, Laboratory, Smart Classrooms, Computers Facility, Library~Board: CBSE~Years: -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10~Fee: -1~Since: Not Available~Strength: Not Available~', 'Name: ~Category: Public Schools~Location: Sri Sai Baba Temple Rd, M.I.G.H colony, MIGH Colony, Walker Town, Bhoiguda, Secunderabad, Telangana 500025~Faculty: ~Sports: ~Amenities: Transport~Board: CBSE~Years: -1, 0, 1, 2, 3, 4, 5, 6, 7, 8~Fee: -1~Since: 2020~Strength: Not Available~', 'Name: ~Category: Other~Location: Dilsukhnagar Public School, Badangpet, Telangana 500058, India~Faculty: ~Sports: ~Amenities: Transport~Board: CBSE, State Board~Years: 0~Fee: 34000~Since: Not Available~Str

In [7]:
default_server.stop()
