In [None]:
!pip install -q pymilvus towhee gradio

In [2]:
!curl -L https://github.com/towhee-io/examples/releases/download/data/question_answer.csv -O

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100  595k  100  595k    0     0  1613k      0 --:--:-- --:--:-- --:--:-- 1613k


In [3]:
import pandas as pd

df = pd.read_csv('question_answer.csv')
df.head()

Unnamed: 0,id,question,answer
0,0,Is Disability Insurance Required By Law?,Not generally. There are five states that requ...
1,1,Can Creditors Take Life Insurance After ...,If the person who passed away was the one with...
2,2,Does Travelers Insurance Have Renters Ins...,One of the insurance carriers I represent is T...
3,3,Can I Drive A New Car Home Without Ins...,Most auto dealers will not let you drive the c...
4,4,Is The Cash Surrender Value Of Life Ins...,Cash surrender value comes only with Whole Lif...


In [None]:
df.question[6]

'What  Does  AAA  Home  Insurance  Cover?'

In [4]:
id_answer = df.set_index('id')['answer'].to_dict()

In [5]:
id_question = df.set_index('id')['question'].to_dict()

In [None]:
id_question[10]

'What  Does  Medicare  Part  B  Cover?'

In [None]:
id_answer[10]

'Medicare Part B covers the doctor services, outpatient hospital services, medical services and supplies. There is a monthly cost charged to the Social Security check received. There is a deductible and 20% copayments if incurred. In addition you pay all costs for services and supplies not covered by Medicare.'

In [6]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

In [7]:
connections.connect(
    "default",
    uri = 'https://in01-6fc54adbbf19526.gcp-us-west1.vectordb.zillizcloud.com:443',
    user = 'db_admin',
    password = 'ehtdA1234',
    secure = True
)

In [8]:
def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
    FieldSchema(name='id', dtype=DataType.VARCHAR, descrition='ids', max_length=500, is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='reverse image search')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2',
        'index_type':"AUTOINDEX",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection

In [9]:
collection = create_milvus_collection('question_answer', 768)

In [10]:
collection.load()

In [None]:
from towhee.dc2 import pipe, ops
import numpy as np
from towhee.datacollection import DataCollection

insert_pipe = (
    pipe.input('id', 'question', 'answer')
        .map('question', 'vec', ops.text_embedding.dpr(model_name='facebook/dpr-ctx_encoder-single-nq-base'))
        .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
        .map(('id', 'vec'), 'insert_status', ops.ann_insert.milvus_client(host = 'in01-6fc54adbbf19526.gcp-us-west1.vectordb.zillizcloud.com',
                                                                          port = '443',
                                                                          user = 'db_admin',
                                                                          password = 'ehtdA1234', 
                                                                          collection_name='question_answer'))
        .output()
)

In [12]:
import csv
with open('question_answer.csv', encoding='utf-8') as f:
    reader = csv.reader(f)
    next(reader)
    for row in reader:
        insert_pipe(*row)

In [15]:
collection.load()

In [20]:
print('Total number of inserted data is {}.'.format(collection.num_entities))

Total number of inserted data is 1000.


In [17]:
ans_pipe = (
    pipe.input('question')
        .map('question', 'vec', ops.text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base"))
        .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
        .map('vec', 'res', ops.ann_search.milvus_client(host = 'in01-6fc54adbbf19526.gcp-us-west1.vectordb.zillizcloud.com',
                                                        port = '443',
                                                        user = 'db_admin',
                                                        password = 'ehtdA1234', 
                                                        collection_name='question_answer',
                                                        limit=1))
        .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])
        .output('question', 'answer')
)

Cloning the repo: ann-search/milvus-client... Be patient and waiting printing 'Successfully'.
Successfully clone the repo: ann-search/milvus-client.


In [18]:
ans = ans_pipe('Is  Disability  Insurance  Required  By  Law?')

In [19]:
ans = DataCollection(ans)
ans.show()

question,answer
Is Disability Insurance Required By Law?,Not generally. There are five states that require most all employers carry short term disability insurance on their employees. T...


In [None]:
import towhee
def chat(message, history):
    history = history or []
    ans_pipe = (
        pipe.input('question')
            .map('question', 'vec', ops.text_embedding.dpr(model_name="facebook/dpr-ctx_encoder-single-nq-base"))
            .map('vec', 'vec', lambda x: x / np.linalg.norm(x, axis=0))
            .map('vec', 'res', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='question_answer', limit=1))
            .map('res', 'answer', lambda x: [id_answer[int(i[0])] for i in x])
            .output('question', 'answer')
    )

    response = ans_pipe(message).get()[1][0]
    history.append((message, response))
    return history, history

In [None]:
import gradio

collection.load()
chatbot = gradio.Chatbot(color_map=("green", "gray"))
interface = gradio.Interface(
    chat,
    ["text", "state"],
    [chatbot, "state"],
    allow_screenshot=False,
    allow_flagging="never",
)
interface.launch(inline=True, share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://02acafe773b5315c4a.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


