In [None]:
pip install -q pymilvus towhee==1.1.3 gradio

In [1]:
import pandas as pd

df = pd.read_csv('question_answer.csv')
df.head()

Unnamed: 0,id,question,answer
0,0,Is Disability Insurance Required By Law?,Not generally. There are five states that requ...
1,1,Can Creditors Take Life Insurance After ...,If the person who passed away was the one with...
2,2,Does Travelers Insurance Have Renters Ins...,One of the insurance carriers I represent is T...
3,3,Can I Drive A New Car Home Without Ins...,Most auto dealers will not let you drive the c...
4,4,Is The Cash Surrender Value Of Life Ins...,Cash surrender value comes only with Whole Lif...


In [2]:
id_answer = df.set_index('id')['answer'].to_dict()

In [None]:
pip install wget

In [None]:
pip install --upgrade milvus

In [None]:
pip install pymilvus==2.4

In [30]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility,MilvusClient

connections.connect("default", host="127.0.0.1", port="19530")
client=MilvusClient(uri="http://localhost:19530")
def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
    FieldSchema(name='id', dtype=DataType.VARCHAR, descrition='ids', max_length=500, is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='reverse image search')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2',
        'index_type':"IVF_FLAT",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    
    return collection

collection = create_milvus_collection('question_answer', 768)
client.load_collection(
    collection_name="question_answer",
    replica_number=1 # Number of replicas to create on query nodes. Max value is 1 for Milvus Standalone, and no greater than `queryNode.replicas` for Milvus Cluster.
)      
#print(collection)

2024-05-05 00:27:09,195 - 4543446528 - milvus_client.py-milvus_client:641 - DEBUG: Created new connection using: 1c8c5bcd254341b09f2c3c869ae8ae16


In [33]:
%%time
from towhee import pipe, ops
import numpy as np
from towhee.datacollection import DataCollection

insert_pipe = (
    pipe.input('id', 'question', 'answer')
        .map('question', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))
        .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
        .map(('id', 'vec'), 'insert_status', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer'))
        .output()
)

import csv
with open('question_answer.csv', encoding='utf-8') as f:
    reader = csv.reader(f)
    next(reader)
    allRows=[]
    for row in reader:
        allRows.append(row)
    res=insert_pipe.batch(allRows)


CPU times: user 2min 57s, sys: 2min 25s, total: 5min 22s
Wall time: 38.5 s


In [38]:
print('Total number of inserted data is {}.'.format(collection.num_entities))
print('Total number of inserted data is {}.'.format(len(res)))

Total number of inserted data is 0.
Total number of inserted data is 1000.


In [39]:
%%time
collection.load()
ans_pipe = (
    pipe.input('question')
        .map('question', 'vec', ops.text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base"))
        .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
        .map('vec', 'res', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer', limit=1))
        .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])
        .output('question', 'answer')
)


ans = ans_pipe('Is  Disability  Insurance  Required  By  Law?')
ans = DataCollection(ans)
ans.show()

question,answer
Is Disability Insurance Required By Law?,Not generally. There are five states that require most all employers carry short term disability insurance on their employees. T...


CPU times: user 1.83 s, sys: 902 ms, total: 2.73 s
Wall time: 1.62 s


In [40]:
ans[0]['answer']

['Not generally. There are five states that require most all employers carry short term disability insurance on their employees. These states are: California, Hawaii, New Jersey, New York, and Rhode Island. Besides this mandatory short term disability law, there is no other legislative imperative for someone to purchase or be covered by disability insurance.']

In [42]:
def chat(message, history):
    history = history or []
    ans_pipe = (
        pipe.input('question')
            .map('question', 'vec', ops.text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base"))
            .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
            .map('vec', 'res', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer', limit=1))
            .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])
            .output('question', 'answer')
    )

    response = ans_pipe(message).get()[1][0]
    history.append((message, response))
    return history, history

In [55]:
import gradio as gr

collection.load()
chatbot = gr.Chatbot()
interface = gradio.Interface(
    chat,
    ["text", "state"],
    [chatbot, "state"],
    allow_flagging="never",
)
interface.launch(inline=True, share=True)

Running on local URL:  http://127.0.0.1:7861
Running on public URL: https://c55fab2513d699eacd.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


