In [3]:
import json, os, re, hashlib, concurrent.futures as cf
from typing import List, Dict, Any, Tuple
from dotenv import load_dotenv
from tqdm import tqdm
from openai import OpenAI

"""-------------------------------------------------------------
0. 환경 설정
-------------------------------------------------------------"""
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    raise RuntimeError("OPENAI_API_KEY not found in environment.")
client = OpenAI(api_key=api_key)

"""-------------------------------------------------------------
1. 온톨로지 · 정규화 규칙
-------------------------------------------------------------"""
SLOT_ORDER = [
    "area", "price", "food", "people",
    "day", "time", "booking_ref"
]

SYNONYM_MAP: Dict[str, str] = {
    "centre": "centre", "center": "centre",
    "cheap": "cheap", "inexpensive": "cheap",
    "north": "north", "south": "south",
    "expensive": "expensive", "moderate": "moderate"
}

OPEN_VAL_MAP: Dict[str, str] = {
    r"^\+?\d{5,}$":      "phone_any",
    r"^[A-Z0-9]{6,}$":   "ref_any"
}

def canonicalize(raw: str) -> str:
    """slot1=val1|slot2=val2 문자열을 슬롯 순서에 맞춰 canonical form으로."""
    slot_dict = {s: "none" for s in SLOT_ORDER}
    for pair in raw.split("|"):
        if "=" not in pair:
            continue
        k, v = [x.strip().lower() for x in pair.split("=", 1)]
        if k not in slot_dict:
            continue
        v = SYNONYM_MAP.get(v, v) or "none"
        for pat, token in OPEN_VAL_MAP.items():
            if re.match(pat, v):
                v = token
                break
        slot_dict[k] = v
    return "|".join(f"{k}={slot_dict[k]}" for k in SLOT_ORDER)

def hash_state(canonical: str) -> str:
    return hashlib.sha256(canonical.encode()).hexdigest()

"""-------------------------------------------------------------
2. GPT 시스템 프롬프트
   * 반드시 top-level 에서 {"turns": [...]} 형태로 반환하도록 요구
-------------------------------------------------------------"""
SYSTEM_PROMPT = """
You are “MDP-Annotator-v2”, an expert annotator for task-oriented dialogues.
For each turn in the dialogue, extract MDP elements strictly from the text.

Return **exactly** the following JSON structure (no extra keys, no prose):
{
  "turns": [
    {   // one object per dialogue turn
      "turn_idx": <int>,
      "speaker": "user"|"system",
      "utterance": <string>,
      "state_before": {<slot>: <value>, ...},
      "action": [ {"type": <string>, "slot": <string>, "value": <string>} ],
      "state_after":  {<slot>: <value>, ...},
      "transition": {
        "slots_added":  {<slot>: <value>, ...},
        "slots_removed": [<slot>, ...],
        "goal_progress": "partial"|"complete"|"unchanged",
        "is_valid_transition": true|false,
        "justification": <string>
      },
      "reward": {
        "score": -1|0|1,
        "justification": <string>
      }
    }
  ]
}

***Strict rules***
- Do NOT invent information not present in the dialogue text.
- Output valid JSON only. No markdown, no comments.
"""

USER_PROMPT_TMPL = """
DIALOGUE HISTORY:

{dialogue_text}

TASK:
Annotate each turn according to the definitions and format in the system prompt.
Return ONLY the JSON.
"""

"""-------------------------------------------------------------
3. 보정·검증 유틸
-------------------------------------------------------------"""
def safe_normalize_turn(turn: Any) -> Dict[str, Any] | None:
    """턴 객체가 dict 형태인지 확인하고 필드 누락·잘못된 타입을 보정한다."""
    if not isinstance(turn, dict):
        return None
    turn.setdefault("state_before", {})
    turn.setdefault("state_after", {})
    if not isinstance(turn["state_before"], dict):
        turn["state_before"] = {}
    if not isinstance(turn["state_after"], dict):
        turn["state_after"] = {}
    return turn

"""-------------------------------------------------------------
4. LLM 호출 및 주석 함수
-------------------------------------------------------------"""

def join_utterances(dialogue: Dict[str, Any]) -> str:
    return "\n".join(f"{t['speaker'].upper()}: {t['utterance']}" for t in dialogue["turns"])

def annotate_dialogue(idx: int, dialogue: Dict[str, Any]) -> Dict[str, Any]:
    text_block = join_utterances(dialogue)
    user_prompt = USER_PROMPT_TMPL.format(dialogue_text=text_block)

    try:
        resp = client.chat.completions.create(
            model="gpt-4o",
            temperature=0.0,  # 최대한 deterministic
            response_format={"type": "json_object"},
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_prompt}
            ]
        )
        raw_json = json.loads(resp.choices[0].message.content)

        # --- turns 배열 확보 ---
        if isinstance(raw_json, dict) and "turns" in raw_json and isinstance(raw_json["turns"], list):
            turns_list = raw_json["turns"]
        elif isinstance(raw_json, list):
            turns_list = raw_json  # LLM이 바로 list 반환한 경우 (fallback)
        else:
            return {"error": "Un-parsable top-level JSON", "dialogue_idx": idx}

        parsed: Dict[str, Any] = {}
        parse_errors: Dict[int, str] = {}

        for i, raw_turn in enumerate(turns_list, 1):
            turn = safe_normalize_turn(raw_turn)
            if turn is None:
                parse_errors[i] = f"invalid type {type(raw_turn).__name__}"
                continue

            # --- state 해시 계산 ---
            canon_prev = canonicalize("|".join(f"{k}={v}" for k, v in turn["state_before"].items()))
            canon_next = canonicalize("|".join(f"{k}={v}" for k, v in turn["state_after"].items()))
            turn["state_id_prev"] = hash_state(canon_prev)
            turn["state_id_next"] = hash_state(canon_next)
            if isinstance(turn.get("transition"), dict):
                turn["transition"]["prev_state_id"] = turn["state_id_prev"]
                turn["transition"]["next_state_id"] = turn["state_id_next"]

            parsed[f"turn_{i}"] = turn

        if parse_errors:
            parsed["_parse_errors"] = parse_errors
        return parsed

    except Exception as e:
        return {"error": str(e), "dialogue_idx": idx, "prompt": text_block[:300]}

"""-------------------------------------------------------------
5. 병렬 주석
-------------------------------------------------------------"""

def parallel_annotate(dialogues: List[Dict[str, Any]], num_threads: int = 10) -> Dict[int, Dict]:
    results: Dict[int, Dict] = {}
    with cf.ThreadPoolExecutor(max_workers=num_threads) as exe:
        futures = {exe.submit(annotate_dialogue, idx, dlg): idx for idx, dlg in enumerate(dialogues)}
        for fut in tqdm(cf.as_completed(futures), total=len(futures), desc="Annotating"):
            idx = futures[fut]
            try:
                results[idx] = fut.result()
            except Exception as exc:
                results[idx] = {"error": str(exc)}
    return results

"""-------------------------------------------------------------
6. 데이터 로드 & 메인
-------------------------------------------------------------"""

def load_multiwoz_json(path: str) -> List[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data if isinstance(data, list) else list(data.values())

if __name__ == "__main__":
    FILE_PATH   = "dataset/train/dialogues_001.json"
    NUM_THREADS = 20
    OUTPUT_FILE = "results_dialogue001.json"

    dialogues = load_multiwoz_json(FILE_PATH)
    print(f"원본 대화 수: {len(dialogues)}")

    annotated = parallel_annotate(dialogues, NUM_THREADS)

    with open(OUTPUT_FILE, "w", encoding="utf-8") as fp:
        json.dump(annotated, fp, ensure_ascii=False, indent=2)

    print(f"완료: {OUTPUT_FILE} 저장")


원본 대화 수: 512


Annotating: 100%|██████████| 512/512 [10:17<00:00,  1.21s/it]


완료: results_dialogue001.json 저장
