In [1]:
import os
import json
from typing import Dict, Text, List


class MultiWOZDatabase:
    def __init__(self, database_path: str):
        self.database_path = database_path
        self.DOMAINS = [
            'restaurant',
            'hotel',
            'attraction',
            'train',
            'taxi',
            'police',
            'hospital'
        ]
        self.database_data = {}
        self.database_keys = {}
        self._load_data()
        self._print_structure_examples()

    def _load_data(self):
        """加载数据库，并将字段名转为小写，提取所有字段 key。"""
        for domain in self.DOMAINS:
            file_path = os.path.join(self.database_path, f"{domain}_db.json")
            with open(file_path, "r") as f:
                self.database_data[domain] = json.load(f)

            self.database_keys[domain] = set()

            if domain == 'taxi':
                # taxi 是 dict
                self.database_data[domain] = {
                    k.lower(): v for k, v in self.database_data[domain].items()
                }
                self.database_keys[domain].update(self.database_data[domain].keys())
            else:
                for i, item in enumerate(self.database_data[domain]):
                    # 所有字段转为小写
                    self.database_data[domain][i] = {
                        k.lower(): v for k, v in item.items()
                    }
                    self.database_keys[domain].update(self.database_data[domain][i].keys())

    def _print_structure_examples(self):
        """打印每个 domain 的示例数据和字段。"""
        for domain in self.DOMAINS:
            print(f"=== {domain.upper()} ===")
            if domain == 'taxi':
                print("Sample data (taxi domain is a dict):")
                for k, v in list(self.database_data[domain].items())[:1]:
                    print(f"{k}: {v}")
            else:
                print("First item in data:")
                print(self.database_data[domain][0])

            print("Extracted keys:")
            print(self.database_keys[domain])
            print("\n")

    def query(self, domain: Text, constraints: Dict[Text, Text]) -> List[Dict]:
        """
        返回指定 domain 中满足所有 constraints 的实体列表。

        参数：
            domain:      查询的领域名，如 'hotel'、'restaurant'。
            constraints: 键值对形式的硬约束，如 {'area': 'north', 'parking': 'yes'}

        返回：
            满足条件的实体（字典）组成的列表
        """
        results = []

        if domain not in self.database_data:
            print(f"[Warning] Domain '{domain}' not in database.")
            return results

        entities = self.database_data[domain]

        if domain == "taxi":
            results = [entities]  # 直接返回所有
        else:
            for entity in entities:
                match = True
                for key, value in constraints.items():
                    key = key.lower()
                    entity_value = entity.get(key, "").lower()
                    if isinstance(entity_value, list):
                        if value.lower() not in [v.lower() for v in entity_value]:
                            match = False
                            break
                    else:
                        if value.lower() != entity_value:
                            match = False
                            break
                if match:
                    results.append(entity)

        return results

In [3]:
database_path = "./multiwoz_database"
database = MultiWOZDatabase(database_path)

=== RESTAURANT ===
First item in data:
{'address': 'Regent Street City Centre', 'area': 'centre', 'food': 'italian', 'id': '19210', 'introduction': 'Pizza hut is a large chain with restaurants nationwide offering convenience pizzas pasta and salads to eat in or take away', 'location': [52.20103, 0.126023], 'name': 'pizza hut city centre', 'phone': '01223323737', 'postcode': 'cb21ab', 'pricerange': 'cheap', 'type': 'restaurant'}
Extracted keys:
{'name', 'id', 'pricerange', 'area', 'introduction', 'signature', 'type', 'address', 'postcode', 'food', 'location', 'phone'}


=== HOTEL ===
First item in data:
{'address': '124 tenison road', 'area': 'east', 'internet': 'yes', 'parking': 'no', 'id': '0', 'location': [52.1963733, 0.1987426], 'name': 'a and b guest house', 'phone': '01223315702', 'postcode': 'cb12dp', 'price': {'double': '70', 'family': '90', 'single': '50'}, 'pricerange': 'moderate', 'stars': '4', 'takesbookings': 'yes', 'type': 'guesthouse'}
Extracted keys:
{'parking', 'price',

In [6]:
from datasets import load_dataset

dataset = load_dataset('multi_woz_v22')
data = dataset['train']

In [7]:
import json
from collections import defaultdict

DOMAINS = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital']
TOTAL = 50

def iterate_dialogues(data, database, context_size=3):
    domain_counts = {d: 0 for d in DOMAINS}
    
    for dialog in data:
        # print(dialog)
        dialogue_id = dialog['dialogue_id'].split('.')[0].lower()
        domain_gt = dialog['services'][0] if len(dialog['services']) > 0 else ''

        if domain_gt == '' or domain_gt == 'bus':
            continue
        
        if domain_counts[domain_gt] >= TOTAL:
            continue
        domain_counts[domain_gt] += 1
        
        last_state = {}

        for tn in range(0, len(dialog['turns']['utterance']), 2):
            # 构建上下文
            context = [
                f"Customer: {t}" if i % 2 == 0 else f"Assistant: {t}"
                for i, t in enumerate(dialog['turns']['utterance'][:tn+1])
            ]

            # 获取当前轮次对话状态
            state = dialog['turns']['frames'][tn]['state']
            if not state:
                state = {}
            else:
                state = state[0]['slots_values']
                state = {
                    k: v[0]
                    for k, v in zip(state['slots_values_name'], state['slots_values_list'])
                }

            # 结构化当前状态
            new_state = {}
            for sl, val in state.items():
                domain, name = sl.split('-')
                if domain not in new_state:
                    new_state[domain] = {}
                new_state[domain][name] = val

            # 提取状态更新
            state_update = {}
            for domain, domain_state in new_state.items():
                for slot, value in domain_state.items():
                    if slot not in last_state.get(domain, {}) or last_state[domain][slot] != value:
                        if domain not in state_update:
                            state_update[domain] = {}
                        state_update[domain][slot] = value

            last_state = new_state

            # 数据库查询结果
            database_results = {
                domain: len(database.query(domain, domain_state))
                for domain, domain_state in new_state.items()
            }

            # 构建 turn 对象
            turn = {
                'page_content': '\n'.join(context[-context_size:]),
                'question': dialog['turns']['utterance'][tn],
                'gt_state': last_state,
                'dialogue_id': dialogue_id,
                'metadata': {
                    'domain': domain_gt,
                    'state': state_update,
                    'full_state': last_state,
                    'context': '\n'.join(context[-6:]),
                    'response': dialog['turns']['utterance'][tn + 1],
                    'database': database_results
                }
            }

            yield turn

for tn in iterate_dialogues(data, database):
    print(json.dumps(tn, indent=2, ensure_ascii=False))
    break

{
  "page_content": "Customer: i need a place to dine in the center thats expensive",
  "question": "i need a place to dine in the center thats expensive",
  "gt_state": {
    "restaurant": {
      "area": "centre",
      "pricerange": "expensive"
    }
  },
  "dialogue_id": "pmul4398",
  "metadata": {
    "domain": "restaurant",
    "state": {
      "restaurant": {
        "area": "centre",
        "pricerange": "expensive"
      }
    },
    "full_state": {
      "restaurant": {
        "area": "centre",
        "pricerange": "expensive"
      }
    },
    "context": "Customer: i need a place to dine in the center thats expensive",
    "response": "I have several options for you; do you prefer African, Asian, or British food?",
    "database": {
      "restaurant": 33
    }
  }
}


In [9]:
import pickle
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from tqdm import tqdm

# Use HuggingFace English model
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

docs = []
for turn in tqdm(iterate_dialogues(data, database), desc="Processing dialogues", unit="turns"):
        doc = Document(page_content=turn['page_content'],
                       metadata=turn['metadata'])
        docs.append(doc)

# Build FAISS vector database
faiss_vs = FAISS.from_documents(documents=docs, embedding=embeddings)

# Save to local file
with open("multiwoz-context-db.vec", "wb") as f:
    pickle.dump(faiss_vs, f)

print("FAISS vector database has been saved as multiwoz-context-db.vec") 

  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Processing dialogues: 1914turns [00:14, 131.87turns/s] 


FAISS vector database has been saved as multiwoz-context-db.vec
