In [2]:
# preprocess.py
# 
# This script processes the ATIS dataset (atis.json) into a unified format
# suitable for classification, tagging, generation, and prompting models.
# Outputs:
#  - question_train.jsonl, question_dev.jsonl, question_test.jsonl
#  - query_train.jsonl,   query_dev.jsonl,   query_test.jsonl
#  - templates.json         # maps template_id -> SQL template (placeholders)
#  - tags_vocab.json        # list of all tag types ("O" + variable names)
#  - default_values.json    # default placeholder values per template for missed tags

import json
import os
import re
from collections import OrderedDict, defaultdict

# ---------- Configuration ----------
INPUT_FILE = 'atis.json'          # path to raw ATIS JSON
OUTPUT_DIR = 'processed'          # output folder for all downstream files
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ---------- Load Raw Data ----------
with open(INPUT_FILE, 'r', encoding='utf-8') as f:
    data = json.load(f)

# ---------- Step 1: Extract Unique SQL Templates ----------
# We select the shortest SQL per question-group, tie-breaking lexicographically.
template_to_id = OrderedDict()
id_to_template = {}
next_template_id = 0

action_records = []  # collects processed sentence-level records

for entry in data:
    sql_group = entry['sql']
    # pick shortest SQL (then alphabetical)
    shortest_sql = min(sql_group, key=lambda s: (len(s), s))
    if shortest_sql not in template_to_id:
        template_to_id[shortest_sql] = next_template_id
        id_to_template[next_template_id] = shortest_sql
        next_template_id += 1
    tid = template_to_id[shortest_sql]
    query_split = entry.get('query-split', '')

    # process each sentence in the group
    for sent in entry['sentences']:
        raw_text = sent['text']
        variables = sent['variables']        # placeholder -> actual value
        question_split = sent.get('question-split', '')

        # ---------- Step 2: Replace placeholders in text ----------
        filled_text = raw_text
        for placeholder, real_val in variables.items():
            # word-boundary replace, e.g. city_name0 -> BOSTON
            filled_text = re.sub(rf"\b{placeholder}\b", real_val, filled_text)

        # ---------- Step 3: Tokenize ----------
        tokens = filled_text.split()  # simple whitespace tokenizer

        # ---------- Step 4: Generate Tag Sequence ----------
        # Tag each token either 'O' or the variable placeholder name
        tags = ['O'] * len(tokens)
        for placeholder, real_val in variables.items():
            val_tokens = real_val.split()
            if not val_tokens:
                continue
            # sliding-window match for multi-token values
            for i in range(len(tokens) - len(val_tokens) + 1):
                if tokens[i:i+len(val_tokens)] == val_tokens:
                    for j in range(len(val_tokens)):
                        tags[i+j] = placeholder

        # ---------- Step 5: Fill SQL Template ----------
        sql_template = shortest_sql
        sql_filled = sql_template
        for placeholder, real_val in variables.items():
            # replace placeholder in SQL, e.g. "city_name0" -> "BOSTON"
            sql_filled = sql_filled.replace(placeholder, real_val)

        # ---------- Collect Record ----------
        record = {
            'text': filled_text,
            'text_tokens': tokens,
            'tags': tags,
            'template_sql': sql_template,
            'template_id': tid,
            'sql_with_vars_filled': sql_filled,
            'variables': variables,
            'question_split': question_split,
            'query_split': query_split
        }
        action_records.append(record)

# ---------- Step 6: Build Tag Vocabulary ----------
all_tags = set()
for rec in action_records:
    all_tags.update(rec['tags'])
tags_vocab = sorted(all_tags)  # e.g. ['O', 'airport_code0', 'city_name0', ...]
with open(os.path.join(OUTPUT_DIR, 'tags_vocab.json'), 'w', encoding='utf-8') as f:
    json.dump(tags_vocab, f, indent=2, ensure_ascii=False)

# ---------- Step 7: Compute Default Values per Template ----------
# For each template_id, choose the first seen value for each placeholder in TRAIN set
default_values = defaultdict(dict)
for rec in action_records:
    if rec['question_split'] == 'train':
        tid = rec['template_id']
        for placeholder, real_val in rec['variables'].items():
            if placeholder not in default_values[tid]:
                default_values[tid][placeholder] = real_val
# Save defaults
with open(os.path.join(OUTPUT_DIR, 'default_values.json'), 'w', encoding='utf-8') as f:
    json.dump(default_values, f, indent=2, ensure_ascii=False)

# ---------- Step 8: Write Out JSONL Splits ----------
splits = ['train', 'dev', 'test']
writers_q = {sp: open(os.path.join(OUTPUT_DIR, f'question_{sp}.jsonl'), 'w', encoding='utf-8') for sp in splits}
writers_g = {sp: open(os.path.join(OUTPUT_DIR, f'query_{sp}.jsonl'), 'w', encoding='utf-8') for sp in splits}

for rec in action_records:
    qsp = rec['question_split']
    gsp = rec['query_split']
    if qsp in writers_q:
        writers_q[qsp].write(json.dumps(rec, ensure_ascii=False) + '\n')
    if gsp in writers_g:
        writers_g[gsp].write(json.dumps(rec, ensure_ascii=False) + '\n')

# close file handles
for f in writers_q.values(): f.close()
for f in writers_g.values(): f.close()

# ---------- Step 9: Save Templates Mapping ----------
with open(os.path.join(OUTPUT_DIR, 'templates.json'), 'w', encoding='utf-8') as f:
    json.dump(id_to_template, f, indent=2, ensure_ascii=False)

print('✔️ Preprocessing done. All outputs in:', OUTPUT_DIR)

✔️ Preprocessing done. All outputs in: processed
