In [1]:
import pandas as pd
import json
import random
import re
import os
from pathlib import Path
from transformers import pipeline

## Load all CSV files into DataFrames

In [2]:
project_root = str(Path().cwd().resolve().parent)
raw_data_dir = os.path.join(project_root, "data", "raw")
entities_dir = os.path.join(raw_data_dir, "entities")
intents_dir = os.path.join(raw_data_dir, "intents")

In [3]:
# --- Load Entity Value Lists ---
try:
    practitioner_df = pd.read_csv(os.path.join(entities_dir, "practitioner_name.csv"))
    practitioner_list = practitioner_df['text'].tolist()

    appointment_type_df = pd.read_csv(os.path.join(entities_dir, "appointment_type.csv"))
    appointment_type_list = appointment_type_df['text'].tolist()
    
    # Load the pre-generated appointment IDs
    appointment_id_df = pd.read_csv(os.path.join(entities_dir, "appointment_id.csv"))
    appointment_id_list = appointment_id_df['text'].tolist()
    
    print(f"Loaded {len(practitioner_list)} practitioners.")
    print(f"Loaded {len(appointment_type_list)} appointment types.")
    print(f"Loaded {len(appointment_id_list)} pre-generated appointment IDs.")

except FileNotFoundError as e:
    print(f"Error loading entity CSVs: {e}")
    print("Please ensure you have run id_generator.py and all entity files are in place.")


# --- Load Intent Template DataFrames ---
intent_dfs = {}
intent_files = [
    "schedule", "reschedule", "cancel", "query_avail",
    "greeting", "bye", "positive_reply", "negative_reply", "oos"
]

for intent_name in intent_files:
    try:
        path = os.path.join(intents_dir, f"{intent_name}.csv")
        intent_dfs[intent_name] = pd.read_csv(path)
        print(f"Loaded '{intent_name}.csv' with {len(intent_dfs[intent_name])} rows.")
    except FileNotFoundError:
        print(f"Warning: Could not find '{intent_name}.csv'. Skipping.")

Loaded 149 practitioners.
Loaded 147 appointment types.
Loaded 150 pre-generated appointment IDs.
Loaded 'schedule.csv' with 139 rows.
Loaded 'reschedule.csv' with 143 rows.
Loaded 'cancel.csv' with 142 rows.
Loaded 'query_avail.csv' with 140 rows.
Loaded 'greeting.csv' with 150 rows.
Loaded 'bye.csv' with 150 rows.
Loaded 'positive_reply.csv' with 150 rows.
Loaded 'negative_reply.csv' with 150 rows.
Loaded 'oos.csv' with 150 rows.


## Data Augmentation
This is the new, core section. We define a function, `augment_templates`, that uses a pre-trained language model to create new sentence variations.

- How it works:

	1. It initializes a `fill-mask` pipeline from `transformers`, which is good at predicting words in context.
	2. The function iterates through each word of a template sentence.
	3. It skips placeholders (like `{practitioner_name}`) to ensure they are preserved.
	4. It replaces one word at a time with a `[MASK]` token.
	5. It asks the language model to predict the best words to fill that mask.
	6. It creates a new template using a suitable prediction, effectively paraphrasing the original.

We then apply this function to the templates for our complex intents (`schedule`, `cancel`, etc.).

In [4]:
# --- Initialize Augmentation Pipeline ---
unmasker = pipeline('fill-mask', model='distilbert-base-uncased')

# --- Define Placeholders and Cleanup Function ---
known_placeholders = ["{practitioner_name}", "{appointment_type}", "{appointment_id}"]

def cleanup_template(text):
    """
    Sanitizes an augmented template to fix corrupted placeholders.
    """
    # 1. Fix missing closing braces
    for ph in known_placeholders:
        # Find instances like "{practitioner_name" (missing the closing brace)
        corrupted_ph = ph[:-1] # e.g., "{practitioner_name"
        if corrupted_ph in text and ph not in text:
            text = text.replace(corrupted_ph, ph)
            
    # 2. Remove any stray opening braces that are not part of a known placeholder
    # This finds any "{" that is not followed by a known entity name
    text = re.sub(r'\{(?!practitioner_name|appointment_type|appointment_id)', '', text)
    
    return text

def augment_templates(templates, num_augmentations_per_template=3):
    """
    Augments a list of templates using contextual word replacement
    and cleans up any resulting placeholder corruption.
    """
    augmented_templates = list(templates) # Start with all original templates
    
    for template in templates:
        # This regex splits the text by spaces while keeping placeholders intact
        words_and_placeholders = re.findall(r'\{[a-zA-Z_]+\}|\S+', template)
        
        potential_indices = [
            i for i, item in enumerate(words_and_placeholders) 
            if item not in known_placeholders
        ]
        
        if not potential_indices:
            continue
            
        for _ in range(num_augmentations_per_template):
            idx_to_replace = random.choice(potential_indices)
            original_word = words_and_placeholders[idx_to_replace]
            
            temp_list = words_and_placeholders.copy()
            temp_list[idx_to_replace] = unmasker.tokenizer.mask_token
            masked_text = ' '.join(temp_list)
            
            predictions = [
                pred['token_str'] for pred in unmasker(masked_text, top_k=5)
                if '##' not in pred['token_str'] and pred['token_str'].lower() != original_word.lower()
            ]
            
            if predictions:
                new_word = random.choice(predictions)
                new_template_list = words_and_placeholders.copy()
                new_template_list[idx_to_replace] = new_word
                
                # Create the augmented sentence
                augmented_sentence = ' '.join(new_template_list)
                
                # **Apply the cleanup function**
                cleaned_sentence = cleanup_template(augmented_sentence)
                
                augmented_templates.append(cleaned_sentence)
                    
    return list(set(augmented_templates))

# --- Apply Augmentation to Complex Intents ---
print("\\n--- Augmenting Complex Intents ---")
complex_intents = ["schedule", "reschedule", "cancel", "query_avail"]

for intent in complex_intents:
    if intent in intent_dfs:
        print(f"Augmenting intent: '{intent}'...")
        original_df = intent_dfs[intent]
        original_templates = original_df['text'].tolist()
        
        augmented_templates = augment_templates(original_templates)
        
        augmented_df = pd.DataFrame(augmented_templates, columns=['text'])
        augmented_df['intent'] = intent
        
        intent_dfs[intent] = augmented_df
        print(f"  -> Original: {len(original_templates)} templates | Augmented: {len(augmented_df)} templates")

Device set to use cpu


\n--- Augmenting Complex Intents ---
Augmenting intent: 'schedule'...
  -> Original: 139 templates | Augmented: 534 templates
Augmenting intent: 'reschedule'...
  -> Original: 143 templates | Augmented: 554 templates
Augmenting intent: 'cancel'...
  -> Original: 142 templates | Augmented: 548 templates
Augmenting intent: 'query_avail'...
  -> Original: 140 templates | Augmented: 536 templates


## Process Intents and Inject Entities

In [5]:
final_dataset = []

# --- Process Simple Intents ---
simple_intents = ["greeting", "bye", "positive_reply", "negative_reply", "oos"]
for intent_name in simple_intents:
    if intent_name in intent_dfs:
        for index, row in intent_dfs[intent_name].iterrows():
            final_dataset.append({"text": row['text'], "intent": row['intent'], "entities": []})
print(f"Processed {len(final_dataset)} entries from simple intents.")

# --- Process Complex Intents ---
entity_map = {
    'practitioner_name': practitioner_list,
    'appointment_type': appointment_type_list,
    'appointment_id': appointment_id_list
}

for intent_name in complex_intents:
    if intent_name in intent_dfs:
        for index, row in intent_dfs[intent_name].iterrows():
            template = row['text']
            injected_text = template
            entities = []
            placeholders = re.findall(r"\{(.+?)\}", template)
            
            for placeholder in placeholders:
                # The logic is now generalized for all entities
                if placeholder in entity_map:
                    value_to_inject = random.choice(entity_map[placeholder])
                    start_index = injected_text.find("{" + placeholder + "}")
                    if start_index != -1:
                        injected_text = injected_text.replace("{" + placeholder + "}", value_to_inject, 1)
                        end_index = start_index + len(value_to_inject)
                        entities.append({"start": start_index, "end": end_index, "label": placeholder})
                else:
                    print(f"Warning: Found unknown placeholder '{placeholder}'")

            final_dataset.append({
                "text": injected_text,
                "intent": row['intent'],
                "entities": sorted(entities, key=lambda e: e['start'])
            })
print(f"Processing complete. Total entries in dataset: {len(final_dataset)}.")

Processed 750 entries from simple intents.
Processing complete. Total entries in dataset: 2922.


## Shuffle and Save the final JSONL dataset

In [7]:
random.shuffle(final_dataset)
print("Dataset shuffled.")

output_path_jsonl = os.path.join(raw_data_dir, "hasd.jsonl")
with open(output_path_jsonl, 'w', encoding='utf-8') as f:
    for entry in final_dataset:
        f.write(json.dumps(entry, ensure_ascii=False) + '\n')

print(f"✅ Success! Dataset correctly saved.")

Dataset shuffled.
✅ Success! Dataset correctly saved.
