In [1]:
import os
import re
import json
from tqdm.auto import tqdm
import pickle
import json

from openai import OpenAI
api_key=json.load(open("api.json"))[0]
client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")

import ollama
model="llama3.1:8b"

In [2]:
intent_instruction='''You are an expert NLU data generation assistant for a campus chatbot.
Your task is to generate {num_query} varied data items for given entity and intent.
**Instructions:**
1. You should generate {num_query} possible queries corresponding to the given entity and intent.
2. The output should be a vaild json snippet like {example}
3. The {num_query} queries should be included in a list.
in json file that is "query":[{num_query} generated] 
4. Each possible query must contain one or some of the given entities.
5. Each possible query must contain one or some of the given intents.
6. You CANNOT add new eneities in the generated query! 
For example:
"intent": "ask_business_location", "generated_query":"What are the business hours of KIMS SALON?"
then a NEW intent "ask_business_time" is added.
7. Output json snippet ONLY. 

**Exapmle:**
{
  "query": ["where is Nasi Kandar restaurant?", "how can i go to Nasi Kandar restaurant?"],
  "intent": "ask_restaurant_location",
  "entities": [
    {"Nasi Kandar restaurant":"restaurant_name"} 
  ]
}
'''

In [16]:
class NLU_data_generator:
    def __init__(self,intent_lib_path=r"data\IE\intent_lib.json"):
        self.data=json.load(open(intent_lib_path,"r",encoding="utf-8"))
        self.categories=list(self.data.keys())
        self.responses={
          category:None for category in self.categories
        }
    
    def create_prompt(self,category,num_query):
        message_set=[
          {"role":"system","content":intent_instruction}
          ]
        for intent in self.data[category][1:]:
          for entity_name in self.data[category][0]:
            prompt=f"""**Your Task:**
            Generate {num_query} new and unique query examples for the intent '{intent}' containing entity {entity_name}:{category}.
            """
            message_set.append({
              "role":"user",
              "content":prompt
            })
        return message_set
    
    def get_response(self,category,num_query):
      message_set=self.create_prompt(category,num_query)
      response_set=[]
      for msg in tqdm(message_set[1:],desc=f"Generating data for {category}"):
        try:
            # response=client.chat.completions.create(
            #     model="deepseek-chat",
            #     messages=[
            #       message_set[0],msg
            #     ],
            #     stream=False
            # ).choices[0].message.content
            response=ollama.chat(model=model, messages=[
                  message_set[0],msg
                ])["message"]["content"]

            response=re.sub(
              pattern=r"```json\n|\n```",
              repl='',
              string=response
            )
            try:
                json.loads(response)
                response_set.append(response)
            except json.JSONDecodeError:
                print(f"Warning: Invalid JSON response for {category}, skipping this item")
                continue
                
        except Exception as e:
            print(f"Error generating response for {category}: {e}")
            continue
            
      self.responses[category]=response_set

    def save_data(self,file_path):
        # 使用写入模式而不是追加模式
        with open(file_path,"w",encoding="utf-8") as f:
            f.write("{\n")
            valid_categories = []
            
            # 先过滤出有有效数据的类别
            for category in self.categories:
                if (self.responses[category] is not None and 
                    len(self.responses[category]) > 0):
                    valid_categories.append(category)
            
            for i, category in enumerate(valid_categories):
                f.write(f'"{category}":[\n')
                
                # 过滤出有效的JSON数据项
                valid_items = []
                for data_item in self.responses[category]:
                    try:
                        # 验证JSON格式
                        json.loads(data_item)
                        valid_items.append(data_item)
                    except json.JSONDecodeError:
                        print(f"Warning: Skipping invalid JSON item in {category}")
                        continue
                
                # 写入有效的数据项
                for j, data_item in enumerate(valid_items):
                    f.write(data_item)
                    if j == len(valid_items) - 1:
                        f.write("\n")
                    else:
                        f.write(",\n")
                
                # 添加类别间的逗号
                if i == len(valid_categories) - 1:
                    f.write("]\n")
                else:
                    f.write("],\n")
                    
            f.write("}")
            
        print(f"数据已保存到 {file_path}")
        print(f"成功保存的类别: {valid_categories}")
        for category in self.categories:
            if category not in valid_categories:
                print(f"警告: {category} 类别没有有效数据")


In [4]:
generator=NLU_data_generator()

In [6]:
for c in tqdm(generator.categories,desc="Generating data"):
    generator.get_response(c,4)
with open("data/generator.pkl", "wb") as f:
    pickle.dump(generator, f)

Generating data:   0%|          | 0/5 [00:00<?, ?it/s]

Generating data for business:   0%|          | 0/39 [00:00<?, ?it/s]

Generating data for restaurant:   0%|          | 0/272 [00:00<?, ?it/s]

Generating data for facility:   0%|          | 0/60 [00:00<?, ?it/s]

Generating data for building:   0%|          | 0/10 [00:00<?, ?it/s]

Generating data for handbook:   0%|          | 0/35 [00:00<?, ?it/s]

FileNotFoundError: [Errno 2] No such file or directory: 'py_code/data/generator.pkl'

In [18]:
with open("data/generator.pkl","rb") as f:
    generator=pickle.load(f)
generator.save_data("data/intent_train_data.json")

数据已保存到 data/intent_train_data.json
成功保存的类别: ['business', 'restaurant', 'facility', 'building', 'handbook']
