In [None]:
# Intallation and Imports

# OpenAI (v0.28), Pandas, and tqdm
!pip -q install "openai==0.28" pandas tqdm 

import os, json, gzip, time, difflib
import pandas as pd
from tqdm import tqdm
from typing import Dict, Any, List, Tuple
from getpass import getpass
import openai
import random
import copy

print("OpenAI SDK version:", openai.__version__)

In [None]:
# API Key Request

openai.api_key = getpass("Enter your OpenAI API key (will not echo): ")

In [None]:
# Configuration

FILE_PATH         = "DATA/abcd_v1.1.json"   
ONTOLOGY_PATH     = "DATA/ontology.json" # optional; enrich label sets if present
PRIMARY_MODEL     = "gpt-4o" 
FALLBACK_MODELS   = ["gpt-4o", "gpt-3.5-turbo-0125"]
MAX_SAMPLES  = 250 # Cap for cost. Use 250 at max
REQUEST_DELAY_SEC = 0.7

In [None]:
# Helper Functions

def call_chat_model_safely(messages, model):
    try:
        resp = openai.ChatCompletion.create(model=model, messages=messages, temperature=0)
        return resp["choices"][0]["message"].get("content","") if resp.get("choices") else ""
    except Exception as e:
        print(f"[chat:{model}] error: {e}")
        return "" 

def random_seeded_examples(dataset, size=250, seed=42):
    random.seed(seed)
    return random.sample(dataset, size)

def load_json(path: str):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)
    
def convo_to_transcript(convo: Dict[str,Any]) -> str:
    orig = convo.get("original", [])
    return " ".join([f"{sp}: {tx}" for sp, tx in orig])

def try_parse_json(text: str):
    if not text: return None
    text = text.strip()
    try:
        return json.loads(text)
    except Exception:
        s, e = text.find("{"), text.rfind("}")
        if s != -1 and e != -1 and e > s:
            cand = text[s:e+1]
            try:
                return json.loads(cand)
            except Exception:
                return None
    return None

def closest_label(pred: str, choices: List[str], cutoff: float = 0.6) -> str:
    if not pred or not choices:
        return ""
    if pred in choices:
        return pred
    # Case-insensitive exact first
    lowmap = {c.lower(): c for c in choices}
    if pred.lower() in lowmap:
        return lowmap[pred.lower()]
    # Fuzzy match to nearest valid label
    best = difflib.get_close_matches(pred, choices, n=1, cutoff=cutoff)
    return best[0] if best else ""

In [None]:
# Load Dataset & Split into Train, Dev, and Test

abcd = load_json(FILE_PATH)
train_split = (abcd.get("train", []) or [])
dev_split = (abcd.get("dev", []) or [])
test_split = (abcd.get("test", []) or [])
print(f"Train examples: {len(train_split)} | Dev examples: {len(dev_split)} | Test examples: {len(test_split)}")

# Gather 250 random seeded examples of train_split, dev_split, and test_split
# Use for consistent sampling during prompt testing, validation, and final evaluation

train_seeded_examples = random_seeded_examples(train_split)
dev_seeded_examples = random_seeded_examples(dev_split)
test_seeded_examples = random_seeded_examples(test_split)
print(f"Seeded Train examples: {len(train_seeded_examples)} | Seeded Dev examples: {len(dev_seeded_examples)} | Seeded Test examples: {len(test_seeded_examples)}")

In [None]:
# Build label sets

SCHEMA = {
  "personal": {"customer_name":"","email":"","member_level":"","phone":"","username":""},
  "order": {"street_address":"","full_address":"","city":"","num_products":"","order_id":"",
            "packaging":"","payment_method":"","products":"[]","purchase_date":"","state":"","zip_code":""},
  "product": {"names":[],"amounts":[]},
  "flow": "",
  "subflow": ""
}

def extract_unique_values_from_conversations(dataset: List[Dict[str, Any]]) -> Dict[str, Any]:
    # Takes a list of dictionaries(the dataset) and returns one dictionary 
    # Easy access to unique values of each feature

    unique_values = {
        dictionary: (
            {subdictionary: set() for subdictionary in subdict_content}
            if isinstance(subdict_content, dict) else set()
        )
        for dictionary, subdict_content in SCHEMA.items() # Gathered the dictionary getters from SCHEMA
    }

    # Iterate through all conversations and collect values
    for example in dataset:
        scenario = example.get("scenario", {})

        for dictionary, subdict_content in SCHEMA.items():
            # Handle nested dictionaries (personal, order, product)
            if isinstance(subdict_content, dict): 
                subdict = scenario.get(dictionary, {})
                for subdictionary in subdict_content.keys():
                    value = subdict.get(subdictionary)
                    if value not in (None, ""):
                        # If value is a list, add each item individually
                        if isinstance(value, list):
                            unique_values[dictionary][subdictionary].update(value)
                        else:
                            unique_values[dictionary][subdictionary].add(value)
            # Handle non-nested dictionary (flow, subflow)
            else:
                value = scenario.get(dictionary)
                if value not in (None, ""):
                    unique_values[dictionary].add(value)

    return unique_values

label_sets = extract_unique_values_from_conversations(train_split)

def enrich_with_ontology(unique_values: Dict[str, Any], ontology_path: str) -> Dict[str, Any]:
    try:
        with open(ontology_path, "r", encoding="utf-8") as f:
            ontology = json.load(f)

        def walk_and_add(item, path=[]):
            # If it's a dictionary, dive into each key-value pair
            if isinstance(item, dict):
                for key, value in item.items():
                    walk_and_add(value, path + [key])
            # If it's a list, check each element
            elif isinstance(item, list):
                for element in item:
                    walk_and_add(element, path)
            # If it's a string, add it to matching feature sets
            elif isinstance(item, str):
                for section, content in unique_values.items():
                    if isinstance(content, dict):
                        for field_name in content.keys():
                            if field_name in path or section in path:
                                unique_values[section][field_name].add(item)
                    else:
                        if section in path:
                            unique_values[section].add(item)

        walk_and_add(ontology)
        print(f"[ontology] Successfully enriched features with values from {ontology_path}")

    except Exception as e:
        print(f"[ontology] Could not parse {ontology_path}: {e}")

    return unique_values

label_sets = enrich_with_ontology(label_sets, ONTOLOGY_PATH)

In [None]:
# Metadata extractor

# Full list of features. Only use the features you are responsible for
label_opts = {
"customer_name": f"- Valid customer_name labels (pick exactly one, copy verbatim): {label_sets['personal']['customer_name']}\n",
"email": f"- Valid email labels (pick exactly one, copy verbatim): {label_sets['personal']['email']}\n",
"memeber_level": f"- Valid member_level labels (pick exactly one, copy verbatim): {label_sets['personal']['member_level']}\n",
"phone": f"- Valid phone labels (pick exactly one, copy verbatim): {label_sets['personal']['phone']}\n",
"username": f"- Valid username labels (pick exactly one, copy verbatim): {label_sets['personal']['username']}\n",
"street_address": f"- Valid street_address labels (pick exactly one, copy verbatim): {label_sets['order']['street_address']}\n",
"full_address": f"- Valid full_address labels (pick exactly one, copy verbatim): {label_sets['order']['full_address']}\n",
"city": f"- Valid city labels (pick exactly one, copy verbatim): {label_sets['order']['city']}\n",
"num_products": f"- Valid num_products labels (pick exactly one, copy verbatim): {label_sets['order']['num_products']}\n",
"order_id": f"- Valid order_id labels (pick exactly one, copy verbatim): {label_sets['order']['order_id']}\n",
"packaging": f"- Valid packaging labels (pick exactly one, copy verbatim): {label_sets['order']['packaging']}\n",
"payment_method": f"- Valid payment_method labels (pick exactly one, copy verbatim): {label_sets['order']['payment_method']}\n",
"products": f"- Valid products labels (pick exactly one, copy verbatim): {label_sets['order']['products']}\n",
"purchase_date": f"- Valid purchase_date labels (pick exactly one, copy verbatim): {label_sets['order']['purchase_date']}\n",
"state": f"- Valid state labels (pick exactly one, copy verbatim): {label_sets['order']['state']}\n",
"zip_code": f"- Valid zip_code labels (pick exactly one, copy verbatim): {label_sets['order']['zip_code']}\n",
"names": f"- Valid names labels (pick exactly one, copy verbatim): {label_sets['product']['names']}\n",
"amounts": f"- Valid amounts labels (pick exactly one, copy verbatim): {label_sets['product']['amounts']}\n",
"flow": f"- Valid flow labels (pick exactly one, copy verbatim): {label_sets['flow']}\n",
"subflow": f"- Valid subflow labels (pick exactly one, copy verbatim): {label_sets['subflow']}\n"    
}
fewshot_examples = """
EXAMPLES:

Example 1:
Transcript:
\"\"\"
customer: Hi! I need to return an item, can you help me with that?
agent: sure, may I have your name please?
customer: I got the wrong size.
\"\"\"
Correct flow: product_defect
Correct subflow: return_size

Example 2:
Transcript:
\"\"\"
customer: just wanted to check on the status of a refund
agent: everything in order, soon I will indicate the status of your refund.
customer: how much long till it is refunded
\"\"\"
Correct flow: product_defect
Correct subflow: refund_status

Example 3:
Transcript:
\"\"\"
customer: I've got a promo code and I want to know when they expire.
agent: sure! let me check that.
agent: Ok, all promo codes expire after 7 days without fail.
\"\"\"
Correct flow: storewide_query
Correct subflow: timing_4
"""

def extract_metadata_from_transcript(transcript: str, opts: list, label_sets: dict=label_sets) -> dict:    
    selected_instr = ''.join(label_opts[opt] for opt in opts if opt in label_opts)

    # Add task-specific guidance for subflow / flow
    classification_guidelines = ""
    if "subflow" in opts:
        classification_guidelines += (
        "SUBFLOW CLASSIFICATION GUIDELINES:\n"
        "- Base the subflow ONLY on the CUSTOMER'S main goal, not on the agent's mistakes or speculation.\n"
        "- Ignore greetings, small talk, politeness, apologies, and generic empathy.\n"
        "- If multiple issues are mentioned, choose the one the customer repeats or emphasizes most.\n"
        "- If the customer is changing or cancelling something BEFORE delivery, prefer subflows related to\n"
        "  change-address, modify-order, or cancel-order.\n"
        "- If the customer reports a problem AFTER delivery (missing item, damaged item, wrong size, stain,\n"
        "  refund questions), choose the corresponding post-delivery subflow.\n"
        "- Always choose EXACTLY ONE subflow label. Do NOT output multiple labels.\n\n"

        "SUBFLOW KEYWORD MAPPING RULES:\n"
        "- refund_status → Use when the customer asks WHEN a refund will be processed, the timing of a refund,\n"
        "  or the progress of a refund.\n"
        "- refund_update → Use when the customer already HAS a refund in progress and wants an UPDATE.\n"
        "- return_size → Wrong size, need a different size, too small, too large, incorrect fit.\n"
        "- timing_4 → Promo code timing questions, expiration dates, offer deadlines.\n"
        "- status_due_amount → When the customer asks HOW MUCH they owe or the billing amount.\n"
        "- status_due_date → When the customer asks WHEN a bill, subscription, or payment is due.\n"
        "- manage_change_address → When the customer wants to update or correct their shipping address.\n"
        "- manage_upgrade → When the customer wants to upgrade or modify an order BEFORE shipment.\n"
        "- status_credit_missing → When the customer is missing expected credits or account adjustments.\n"
        "- status_delivery_time → Questions about WHEN an order will arrive or delivery timing.\n"
        "- shopping_cart → Issues adding/removing items, checkout problems, or cart not updating.\n\n"

        "FINE-GRAINED LABEL SELECTION RULE:\n"
        "- When labels share the same prefix (e.g., jacket_how_1 vs jacket_how_4), choose the variant that\n"
        "  BEST MATCHES the CUSTOMER'S specific question.\n"
        "- Do NOT default automatically to the `_1` or `_status` version.\n\n"
    )


    if "flow" in opts:
        classification_guidelines += (
            "FLOW CLASSIFICATION GUIDELINES:\n"
            "- The flow is the high-level category of the interaction (e.g., order_issue, subscription, etc.).\n"
            "- Pick the single most appropriate flow that best groups the customer’s main request.\n\n"
        )

    label_instr = (
        "CLASSIFICATION CONSTRAINTS:\n"
        + classification_guidelines
        + selected_instr +
        "- Do NOT invent new labels. Use only the valid labels above.\n"
        "- If uncertain, pick the single most likely label.\n"
    )

    prompt = (
        "Convert the following customer-support dialog into structured metadata.\n\n"
        f"{fewshot_examples}\n"
        f"{label_instr}\n"
        "OUTPUT RULES:\n"
        "- Return STRICT JSON only (no prose, no markdown).\n"
        "- Use this exact schema and field types:\n"
        f"{json.dumps(SCHEMA, separators=(',', ':'))}\n"
        "- For fields you are not sure about, you may leave them as \"\" or [].\n"
        "- Your TOP PRIORITY is to correctly classify 'subflow' (and 'flow' if requested).\n\n"
        f"Dialog transcript:\n{transcript}\n"
    )

    messages = [
        {
            "role": "system",
            "content": (
                "You are an information extraction and classification model. "
                "Always return valid JSON that exactly matches the schema. "
                "Do not add explanations or comments."
            ),
        },
        {"role": "user", "content": prompt},
    ]
    

    # Try primary and fallback models
    models_to_try = [PRIMARY_MODEL] + [m for m in FALLBACK_MODELS if m != PRIMARY_MODEL]

    for model_name in models_to_try:
        content = call_chat_model_safely(messages, model_name)
        data = try_parse_json(content)
        if isinstance(data, dict):
            # Deep copy schema to ensure fresh structure
            out = copy.deepcopy(SCHEMA)
            # Populate with model output
            for key, value in data.items():
                out[key] = value

        # Post-process: normalize labels using closest_label()
            for key, value in out.items():
                if key in label_sets and isinstance(value, str):
                    out[key] = closest_label(value, label_sets[key])
                elif key in label_sets and isinstance(value, list):
                    out[key] = [closest_label(v, label_sets[key]) for v in value]

            return out
        
        if content:
            print(f"[warn:{model_name}] unparsable output (first 160 chars): {content[:160]}")
        time.sleep(REQUEST_DELAY_SEC)

    # If no model produced valid JSON, return empty schema
    return copy.deepcopy(SCHEMA)

In [None]:
# Keep excess or experimental prompt/extraction code here (not executed). CLEAR EVERYTHING BELOW THIS LINE BEFORE PUSHING TO MAIN!!!

In [None]:
# Essential prompt testing setup functions

# Dataframe creator
def build_dataframe(examples: list, features_to_include: list, max_samples: int=MAX_SAMPLES):
    rows = []
    features_to_include = set(features_to_include) if features_to_include else None # A set looks up the dictionary keys faster than a list does

    for example in examples[:max_samples]:

        scenario = example.get("scenario", {})
        # Flatten scenario keys
        personal = scenario.get("personal", {})
        order = scenario.get("order", {})
        product = scenario.get("product", {})

        features = {
            "customer_name": personal.get("customer_name"),
            "email": personal.get("email"),
            "member_level": personal.get("member_level"),
            "phone": personal.get("phone"),
            "username": personal.get("username"),
            "street_address": order.get("street_address"),
            "full_address": order.get("full_address"),
            "city": order.get("city"),
            "num_products": order.get("num_products"),
            "order_id": order.get("order_id"),
            "packaging": order.get("packaging"),
            "payment_method": order.get("payment_method"),
            "products": order.get("products"),
            "purchase_date": order.get("purchase_date"),
            "state": order.get("state"),
            "zip_code": order.get("zip_code"),
            "names": product.get("names"),
            "amounts": product.get("amounts"),
            "flow": scenario.get("flow"),
            "subflow": scenario.get("subflow"),
        }

        # This code keeps only the features you want in the features dictionary
        if features_to_include:
            features = {key: value for key, value in features.items() if key in features_to_include}
        rows.append({
            "convo_id": example.get("convo_id",""),
            **features,
            "transcript": convo_to_transcript(example),
        })
    
    return pd.DataFrame(rows)

# Make predictions (extraction)
def extraction(df, feature_options : list):
    predicted_feature_values = []
    for transcript in tqdm(df["transcript"], desc="Extracting Metadata"):
        predicted_feature_values.append(extract_metadata_from_transcript(transcript, feature_options))
        time.sleep(REQUEST_DELAY_SEC)  
    
    # Makes feature names appear cleanly (example: "email" instead of "personal.email")
    extracted = []
    for section in ["personal", "order", "product", "flow", "subflow"]:
        if section in ["flow", "subflow"]:
            extract = pd.DataFrame({f"extracted_{section}": [item[section] for item in predicted_feature_values]})
        else:
            extract = pd.json_normalize([item[section] for item in predicted_feature_values]).add_prefix("extracted_") 
        extracted.append(extract)
    extracted_df = pd.concat(extracted, axis=1).reset_index(drop=True)
    
    final_df = pd.concat([df.reset_index(drop=True), extracted_df[[f"extracted_{f}" for f in feature_options]]], axis=1)

    return final_df

# Measures extraction accuracy per feature
def accuracy(df, feature_options: list):
    for feature in feature_options:
        gt = df[feature].astype(str).fillna("")
        ex = df[f"extracted_{feature}"].astype(str).fillna("")
        acc = (gt == ex).mean()
        print(f"{feature} accuracy: {acc: .2%}")

In [None]:
# ============================================================
# 1️. TRAINING EVALUATION
# ============================================================
# Purpose: Test your extraction function (extract_metadata_from_transcript)
#          on training examples to refine prompt behavior and schema.
# Notes:
# - Run this cell often while tuning prompts.
# - Expect to iterate and modify the extraction prompt/function here.
# ============================================================

feature_options = ["subflow", "flow"] # Include the features you are supposed to test (ex. ["flow", "subflow"]). CLEAR VALUE BEFORE PUSHING FILE TO MAIN!!!
examples = 50 # Can use variable as 3rd arugment for build_dataframe (default is 250). You do not have to rerun prior code cells

df = build_dataframe(train_seeded_examples, feature_options, examples)
print("DataFrame shape:", df.shape)

final_df = extraction(df, feature_options)
pd.set_option("display.max_colwidth", 25)
display(final_df.head(250))

accuracy(final_df, feature_options)


In [None]:
# ============================================================
# 2️. DEVELOPMENT EVALUATION
# ============================================================
# Purpose: Evaluate your *current best* prompt on development examples.
# Notes:
# - Dev data is unseen during tuning.
# - Run this cell only after training accuracy stabilizes.
# - Checks whether your prompt generalizes well.
# ============================================================

df = build_dataframe(dev_seeded_examples, feature_options, examples)
print("DataFrame shape:", df.shape)

final_df = extraction(df, feature_options)
pd.set_option("display.max_colwidth", 25)
display(final_df.head(250))

accuracy(final_df, feature_options)

In [None]:
# ============================================================
# 3️. FINAL TEST EVALUATION
# ============================================================
# Purpose: Final evaluation on test data after dev results look good.
# Notes:
# - Use this cell only once prompt tuning is complete.
# - Keep it clean and reproducible.
# ============================================================

df = build_dataframe(test_seeded_examples, feature_options, examples)
print("DataFrame shape:", df.shape)

final_df = extraction(df, feature_options)
pd.set_option("display.max_colwidth", 25)
display(final_df.head(250))

accuracy(final_df, feature_options)