**import依赖包，准备好要使用的辅助函数**

In [None]:
import os
import re
import json
import openai
import requests
from tqdm import tqdm
from dotenv import load_dotenv, find_dotenv
from weaviate import Client
from weaviate.util import generate_uuid5

load_dotenv(find_dotenv())
openai.api_key = os.getenv('OPENAI_API_KEY')

def get_embedding(text, model="text-embedding-ada-002"):
    response = openai.Embedding.create(
        model=model,
        input=text
    )
    return response['data'][0]['embedding']

def get_completion(prompt, model="gpt-3.5-turbo"):
    messages = [{"role": "user", "content": prompt}]
    response = openai.ChatCompletion.create(
        model=model,
        messages=messages,
        temperature=0,
    )
    return response.choices[0].message["content"]

def get_completion_instruct(prompt, model="gpt-3.5-turbo-instruct"):
    response = openai.Completion.create(
        model=model,
        prompt=prompt,
        max_tokens=500,
        temperature=0
    )
    return response.choices[0].text

**读取原始json数据，抽取所需字段为入库准备**

In [None]:
url = 'https://raw.githubusercontent.com/thu-coai/CrossWOZ/master/data/crosswoz/database/hotel_db.json'
if not os.path.exists('hotel_db.json'):
    print("Downloading file...")
    response = requests.get(url)
    with open('hotel_db.json', 'wb') as file:
        file.write(response.content)
    print("Download complete!")
else:
    print("File already exists.")

In [None]:
# 读取数据，保留其中所需字段
with open('hotel_db.json','r') as f:
    items = json.load(f)

keymap = {
    '名称'     : 'name',
    '酒店类型' : 'type',
    '地址'     : 'address',
    '地铁'     : 'subway',
    '电话'     : 'phone',
    '价格'     : 'price',
    '评分'     : 'rating',
    '酒店设施' : 'facilities'
}

hotels = []
for id, item in enumerate(items):
    hotel = {}
    for k in item[1].keys():
        if k in keymap:
            hotel[keymap[k]] = item[1][k]
        if k == '价格':
            if not item[1][k]:
                hotel[keymap[k]] = -1.0
            else:
                hotel[keymap[k]] = float(item[1][k])
        if k == '酒店设施':
            hotel[keymap[k]] = f"酒店提供的设施:{';'.join(item[1][k])}"
        if k == '名称' or k == '地址':
            hotel['_'+keymap[k]] = ' '.join(re.findall(r'[\dA-Za-z\-]+|\w', item[1][k]))
    hotel['hotel_id'] = id
    hotels.append(hotel)
print(hotels[10])
with open('hotel.json', 'w') as f:
    f.write(json.dumps(hotels, ensure_ascii=False, indent=2))

**数据库选择**

1. pinecone

支持tokenizer ids的sparse向量，用于做keyword检索

2. milvus

Feature Roadmap
https://github.com/milvus-io/milvus/discussions/
https://wiki.lfaidata.foundation/display/MIL/Feature+plans

"Hybrid search with BM25 and vector"计划在3.0+版本，即最早明年初release

3. weaviate

支持BM25，但不能与向量检索同时查询，可以搜两次然后交集


**连接数据库，处理字段并写入**

In [None]:
# 1. 连接Weaviate
client = Client(
    url="http://localhost:8080",
    additional_headers={"X-OpenAI-Api-Key":os.getenv("OPENAI_API_KEY")}
)
client.schema.delete_class("Hotel")
# 2. 建表，创建schema
schema = {
  "classes": [
    {
      "class": "Sentence",
      "description": "written text, a sentence",
      "properties": [
        { "dataType": ["number"], "description": "id of hotel", "name": "hotel_id" },
        {
          "dataType": ["text"],
          "description": "name of hotel",
          "name": "_name", #分词过用于搜索的
          "indexSearchable": True,
          "tokenization": "whitespace",
          "moduleConfig": {
            "text2vec-contextionary": { "skip": True }
          },
        },
        {
          "dataType": ["text"],
          "description": "type of hotel",
          "name": "name",
          "indexSearchable": False,
          "moduleConfig": {
            "text2vec-contextionary": { "skip": True }
          },
        },
        {
          "dataType": ["text"],
          "description": "type of hotel",
          "name": "type",
          "indexSearchable": False,
          "moduleConfig": {
            "text2vec-contextionary": { "skip": True }
          },
        },
        {
          "dataType": ["text"],
          "description": "address of hotel",
          "name": "_address", #分词过用于搜索的
          "indexSearchable": True,
          "tokenization": "whitespace",
          "moduleConfig": {
            "text2vec-contextionary": { "skip": True }
          },
        },
        {
          "dataType": ["text"],
          "description": "type of hotel",
          "name": "address",
          "indexSearchable": False,
          "moduleConfig": {
            "text2vec-contextionary": { "skip": True }
          },
        },
        {
          "dataType": ["text"],
          "description": "nearby subway",
          "name": "subway",
          "indexSearchable": False,
          "moduleConfig": {
            "text2vec-contextionary": { "skip": True }
          },
        },
        {
          "dataType": ["text"],
          "description": "phone of hotel",
          "name": "phone",
          "indexSearchable": False,
          "moduleConfig": {
            "text2vec-contextionary": { "skip": True }
          },
        },
        { "dataType": ["number"], "description": "price of hotel",   "name": "price" },
        { "dataType": ["number"], "description": "rating of hotel",  "name": "rating" },
        {
          "dataType": ["text"],
          "description": "facilities provided",
          "name": "facilities",
          "indexSearchable": True,
          "moduleConfig": {
            "text2vec-contextionary": { "skip": False }
          },
        },
      ],
      "vectorizer": "text2vec-openai",
      "moduleConfig": {
        "text2vec-openai": {
          "vectorizeClassName": False,
          "model": "ada",
          "modelVersion": "002",
          "type": "text"
        },
      },
    }
  ]
}
client.schema.create(schema) # 单class创建也可用client.schema.create_class(schema)

# 插入数据
client.batch.configure(batch_size=4, dynamic=True)

for hotel in tqdm(hotels[:100]):
    client.batch.add_data_object(
        data_object=hotel,
        class_name="Hotel",
        uuid=generate_uuid5(hotel, "Hotel")
    )
client.batch.flush()

**拼装prompt，调用ChatGPT完成NLU任务，抽取到结构化信息**

In [None]:
# 拼装prompt完成NLU任务
def nlu(input_text):
    instruction = """
    你的任务是识别用户对酒店的选择条件
    酒店包含8个属性，分别是：名称(name)、酒店类型(type)、地址(address)、地铁(subway)、电话(phone)、价格(price)、评分(rating)、酒店设
施(facilities)。其中酒店类型的取值只有以下四种：豪华型, 经济型, 舒适型, 高档型
    """

    # 输出描述
    output_format = """
    以JSON格式输出，包含字段如下（不要编造此外的其他字段，未明确提及不需输出）
      - name: string类型
      - type: string类型，取值范围：'豪华型', '经济型', '舒适型', '高档型'
      - address: string类型
      - subway: string类型
      - phone: string类型
      - facilities: string类型
      - price.range.low: float类型，取值范围大于0
      - price.range.high: float类型，取值范围大于0
      - rating.range.low: float类型，取值范围[0,5]
      - rating.range.high: float类型，取值范围[0,5]
      - sort.ordering: string类型，排序的顺序，取值范围：'ascend', 'descend'
      - sort.slot: string类型，用于排序的属性字段，取值范围：'price', 'rating'
    """

    examples = """
    我想订一个400元以内评分高的酒店：
    {"price.range.high":400,"sort.slot":"rating","sort.ordering":"descend"}
    有经济型或者舒适型的酒店嘛，便宜点的：
    {"type":["豪华型","舒适型"],"sort.slot":"price","sort.ordering","ascend"}
    给我找个价格500元内，评分高于4分的酒店：
    {"price.range.high":500,"rating.range.low":4}"}
    订一家200到400元的酒店吧：
    {"price.range.low":200,"price.range.high":400}
    """

    prompt = f"""
    {instruction}

    {output_format}

    examples:
    {examples}

    user input：
    {input_text}

    """

    result = get_completion_instruct(prompt)
    #print(result)
    result = json.loads(result)
    #print(json.dumps(result, ensure_ascii=False))
    return result

In [None]:
client = Client(
    url="http://localhost:8080",
    additional_headers={"X-OpenAI-Api-Key":os.getenv("OPENAI_API_KEY")}
)

In [None]:
def rrf(rankings, k=60):
    scores = dict()
    for ranking in rankings:
        for i, doc in enumerate(ranking):
            doc_id = doc['hotel_id'] if isinstance(doc, dict) else doc # modify this line to handle other types
            if doc_id not in scores:
                scores[doc_id] = 0
            scores[doc_id] += 1 / (k + i)
    return sorted(scores.items(), key=lambda x: x[1], reverse=True)

# a = [{'hotel_id':'a','addr':'b'},{'hotel_id':'c','addr':'d'},{'hotel_id':'e','addr':'f'},{'hotel_id':'g','addr':'h'}]
# b = [{'hotel_id':'a','addr':'b'},{'hotel_id':'g','addr':'h'},{'hotel_id':'c','addr':'d'},{'hotel_id':'e','addr':'f'}]
# print(rrf([a,b]))

In [None]:
def search(state, output_fields=["hotel_id","name","type","rating","price"], limit=10):
    candidates = []
    ##################### assemble filters ###########################
    filters = []
    keys = ['type','price.range.low','price.range.high','rating.range.low','rating.range.hight']
    if any(key in state for key in keys):
        if 'type' in state:
            filters.append({"path": ["type"], "operator": "Equal", "valueString": state['type']})
        if 'price.range.low' in state:
            filters.append({"path": ["price"], "operator": "GreaterThan", "valueNumber": state['price.range.low']})
        if 'price.range.high' in state:
            filters.append({"path": ["price"], "operator": "LessThan", "valueNumber": state['price.range.high']})
        if 'rating.range.low' in state:
            filters.append({"path": ["rating"], "operator": "GreaterThan", "valueNumber": state['rating.range.low']})
        if 'rating.range.high' in state:
            filters.append({"path": ["rating"], "operator": "LessThan", "valueNumber": state['rating.range.high']})
    if (len(filters)) == 1:
        filters = filters[0]
    elif len(filters) > 1:
        filters = {"operator":"And","operands":filters}
    ####################### vector search ###############################
    if 'facilities' in state:
        query = client.query.get("Hotel",output_fields)
        query = query.with_near_text({"concepts": [f"酒店提供:{state['facilities']}"]})
        if filters:
            query = query.with_where(filters)
        query = query.with_limit(limit)
        result = query.do()
        # candidates = candidates + [item for item in result['data']['Get']['Hotel'] if item not in candidates]
        candidates = rrf([candidates, result['data']['Get']['Hotel']])
    ####################### keyword search ##############################
    if 'name' in state:
        text = ' '.join(re.findall(r'[\dA-Za-z\-]+|\w', state['name']))
        query = client.query.get("Hotel",output_fields)
        query = query.with_bm25(query=text, properties= ["_name"])
        if filters:
            query = query.with_where(filters)
        query = query.with_limit(limit)
        result = query.do()
        candidates = rrf([candidates, result['data']['Get']['Hotel']])
    if 'address' in state:
        text = ' '.join(re.findall(r'[\dA-Za-z\-]+|\w', state['address']))
        query = client.query.get("Hotel",output_fields)
        query = query.with_bm25(query=text, properties= ["_address"])
        if filters:
            query = query.with_where(filters)
        query = query.with_limit(limit)
        result = query.do()
        candidates = rrf([candidates, result['data']['Get']['Hotel']])
    ######################## condition search ############################
    if not candidates:
        print("--- 字段搜索未命中，仅返回filter过滤结果 ---")
        query = client.query.get("Hotel",output_fields)
        if filters:
            query = query.with_where(filters)
        query = query.with_limit(limit)
        result = query.do()
        candidates = result['data']['Get']['Hotel']
    ############################ sort ###################################
    if 'sort.slot' in state:
        if state['sort.ordering'] == 'descend':
            candidates = sorted(candidates, key=lambda x: x[state['sort.slot']], reverse=True)
        else:
            candidates = sorted(candidates, key=lambda x: x[state['sort.slot']])
    return candidates

In [None]:
search({'type':'舒适型'})

In [None]:
class DialogManager:
    def __init__(self):
        self.state = {}

    def update_state(self, new_state):
        conflict_keys = []
        for key, value in new_state.items():
            if key in self.state and self.state[key] != value:
                conflict_keys.append(key)
        for key in conflict_keys:
            del self.state[key]
        self.state.update(new_state)
        if 'price.range.low' in self.state and 'price.range.high' in self.state:
            if self.state['price.range.low'] > self.state['price.range.high']:
                del self.state['price.range.low']
                del self.state['price.range.high']
        if 'rating.range.low' in self.state and 'rating.range.high' in self.state:
            if self.state['rating.range.low'] > self.state['rating.range.high']:
                del self.state['rating.range.low']
                del self.state['rating.range.high']

    def get_state(self):
        return self.state

# dm = DialogManager()
# dm.update_state({})
# print(dm.get_state())
# dm.update_state({"price.range.high": 400, "sort.slot": "rating", "sort.ordering": "descend"})
# print(dm.get_state())
# dm.update_state({"type": ["豪华型", "舒适型"], "sort.slot": "price", "sort.ordering": "ascend"})
# print(dm.get_state())
# dm.update_state({"price.range.high": 100, "rating.range.low": 4})
# print(dm.get_state())
# dm.update_state({"price.range.low": 200, "price.range.high": 400})
# print(dm.get_state())
# dm.update_state({"price.range.low": 500, "price.range.high": 400})
# print(dm.get_state())

In [None]:
# queries = ["我要订酒店","有800以内的豪华酒店吗，评分不要太低噢","要有室内泳池的"]
# query = "订一个有300到600元的酒店，评分高一点的"
# query = "给我搜一下那个格林豪泰，评分高的那家"
# query = "评分高于4，价格低于500的酒店"
query = "我要订个在顺义的酒店"

state = nlu(query)
print(state)
print(json.dumps(search(state),ensure_ascii=False,indent=2))

In [None]:
import gradio as gr
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

dm = DialogManager()
def chatbot_for_hotel(input_text):
    state = nlu(input_text)
    dm.update_state(state)
    candidates = search(dm.get_state())
    data = {}
    for d in candidates:
        for key, value in d.items():
            if key not in data:
                data[key] = []
            data[key].append(value)
    if data == {}:
        data = {"hotel":[]}
    df = pd.DataFrame(data)
    return df

iface = gr.Interface(fn=chatbot_for_hotel, inputs="text", outputs="dataframe")

iface.launch()