In [1]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import spacy
import numpy as np
from collections import defaultdict, Counter

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score

In [None]:
# Load the ATIS JSON data
with open("atis.json", "r") as f:
    raw_data = json.load(f)

# with open("geography.json", "r") as f:
#     raw_data = json.load(f)

# Load spaCy tokenizer
nlp = spacy.load("en_core_web_sm")

# Structures to store splits and mappings
question_split = {"train": [], "dev": [], "test": []}
query_split = {"train": [], "dev": [], "test": []}
template_pool = set()

# Process each item, normalize gold SQLs to create template pool
normalized_sqls_per_item = []
for item in raw_data:
    gold_templates = []
    for sql in item["sql"]:
        # Replace all variable values with @var
        template = sql
        for var, val in item["sentences"][0]["variables"].items():
            template = template.replace(f'"{val}"', f'"@{var}"')
            template = template.replace(val, f'"@{var}"')
        gold_templates.append(template)
        template_pool.add(template)
    normalized_sqls_per_item.append((item, gold_templates))


template_list = sorted(template_pool)
template2id = {t: i for i, t in enumerate(template_list)}
id2template = {i: t for t, i in template2id.items()}

# Process entries and assign template IDs
for (item, gold_templates) in normalized_sqls_per_item:
    # Use shortest template to assign template_id
    shortest_sql = sorted(gold_templates, key=lambda x: (len(x), x))[0]
    template_id = template2id[shortest_sql]

    for sent in item["sentences"]:
        raw_text = sent["text"]
        variables = sent["variables"]

        # Replace placeholders in raw_text for model input
        text = raw_text
        for var_name, var_value in variables.items():
            text = text.replace(var_name, var_value)

        # Tokenize
        tokens = [t.text for t in nlp(text)]
        raw_tokens = [t.text for t in nlp(raw_text)]

        # Create tag sequence
        tags = []
        for tok in raw_tokens:
            if tok in variables:
                tags.append(tok)
            else:
                tags.append("O")

        entry = {
            "text": text,
            "template_id": template_id,
            "template_sql": gold_templates,
            "variables": variables,
            "raw_text": raw_text,
            "tokens": tokens,
            "tags": tags,
        }

        # Add to split
        question_split[sent["question-split"]].append(entry)
        query_split[item["query-split"]].append(entry)


In [None]:
token_counter = Counter()
tag_counter = Counter()

# Go through all splits for both question and query versions
for split_name in ["train", "dev", "test"]:
    for entry in question_split[split_name]:
        raw_text = entry["raw_text"]
        tokens = [t.text for t in nlp(raw_text)]

        # Generate tags: placeholder tokens get their name, others are "O"
        tags = []
        for tok in tokens:
            if tok in entry["variables"]:
                tags.append(tok)
                tag_counter[tok] += 1
            else:
                tags.append("O")
                tag_counter["O"] += 1
        token_counter.update(tokens)

        # Store tokenized version and tags for use later
        entry["tokens"] = tokens
        entry["tags"] = tags


In [None]:
# Special tokens
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"

# Token vocab
token2id = {PAD_TOKEN: 0, UNK_TOKEN: 1}
for token in token_counter:
    token2id[token] = len(token2id)

# Tag vocab 
PAD_TAG = "<PAD>"
tag_counter[PAD_TAG] = 1  

tag2id = {}
for tag in tag_counter:
    tag2id[tag] = len(tag2id)


In [34]:
class ATISDataset(Dataset):
    def __init__(self, data, token2id, tag2id):
        self.data = data
        self.token2id = token2id
        self.tag2id = tag2id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        token_ids = [self.token2id.get(tok, self.token2id["<UNK>"]) for tok in entry["tokens"]]
        tag_ids = [self.tag2id[tag] for tag in entry["tags"]]
        template_id = entry["template_id"]

        return {
            "token_ids": torch.tensor(token_ids, dtype=torch.long),
            "tag_ids": torch.tensor(tag_ids, dtype=torch.long),
            "template_id": torch.tensor(template_id, dtype=torch.long),
        }
    
def pad_batch(batch):
    # Find max length in batch
    max_len = max(len(item["token_ids"]) for item in batch)

    # Pad token and tag sequences
    for item in batch:
        pad_len = max_len - len(item["token_ids"])
        item["token_ids"] = F.pad(item["token_ids"], (0, pad_len), value=token2id["<PAD>"])
        item["tag_ids"] = F.pad(item["tag_ids"], (0, pad_len), value=tag2id["<PAD>"])
    # Stack tensors
    token_ids = torch.stack([item["token_ids"] for item in batch])
    tag_ids = torch.stack([item["tag_ids"] for item in batch])
    template_ids = torch.stack([item["template_id"] for item in batch])

    return token_ids, tag_ids, template_ids

train_set = ATISDataset(question_split["train"], token2id, tag2id)
dev_set = ATISDataset(question_split["dev"], token2id, tag2id)
test_set = ATISDataset(question_split["test"], token2id, tag2id)

train_loader = DataLoader(train_set, batch_size=32, shuffle=True, collate_fn=pad_batch)
dev_loader = DataLoader(dev_set, batch_size=32, collate_fn=pad_batch)
test_loader = DataLoader(test_set, batch_size=32, collate_fn=pad_batch)

In [35]:
# Build default values for each variable type (from training data)
from collections import defaultdict, Counter

default_values = defaultdict(str)
var_value_counts = defaultdict(Counter)

for entry in question_split["train"]:
    for var_name, var_val in entry["variables"].items():
        var_value_counts[var_name][var_val] += 1

for var_name, counter in var_value_counts.items():
    default_values[var_name] = counter.most_common(1)[0][0]

In [None]:
class LSTMClassifierTagger(nn.Module):
    def __init__(self, vocab_size, tag_size, template_size, embedding_dim=100, hidden_dim=128, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=token2id["<PAD>"])
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        self.tag_fc = nn.Linear(hidden_dim, tag_size)
        self.template_fc = nn.Linear(hidden_dim, template_size)

    def forward(self, x, lengths):
        embedded = self.embedding(x)
        packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, (hn, cn) = self.lstm(packed)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)

        # Apply dropout
        output = self.dropout(output)
        hn = self.dropout(hn[-1])  

        tag_logits = self.tag_fc(output)
        template_logits = self.template_fc(hn)

        return tag_logits, template_logits

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LSTMClassifierTagger(
    vocab_size=len(token2id),
    tag_size=len(tag2id),
    template_size=len(template2id),
).to(device)

In [39]:
def train(model, loader, optimizer, tag_loss_fn, template_loss_fn, epoch):
    model.train()
    total_loss = 0
    total_template_correct = 0
    total_tag_correct = 0
    total_tokens = 0
    for token_ids, tag_ids, template_ids in loader:
        token_ids = token_ids.to(device)
        tag_ids = tag_ids.to(device)
        template_ids = template_ids.to(device)

        lengths = (token_ids != token2id["<PAD>"]).sum(dim=1)

        optimizer.zero_grad()
        tag_logits, template_logits = model(token_ids, lengths)

        # Tag loss
        tag_logits_reshaped = tag_logits.view(-1, tag_logits.shape[-1])
        tag_ids_reshaped = tag_ids.view(-1)
        tag_loss = tag_loss_fn(tag_logits_reshaped, tag_ids_reshaped)

        # Template loss
        template_loss = template_loss_fn(template_logits, template_ids)

        loss = tag_loss + template_loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # Compute accuracies
        tag_preds = tag_logits.argmax(dim=-1)
        mask = tag_ids != tag2id["<PAD>"]
        total_tag_correct += ((tag_preds == tag_ids) & mask).sum().item()
        total_tokens += mask.sum().item()

        template_preds = template_logits.argmax(dim=-1)
        total_template_correct += (template_preds == template_ids).sum().item()

    avg_loss = total_loss / len(loader)
    tag_acc = total_tag_correct / total_tokens if total_tokens > 0 else 0.0
    template_acc = total_template_correct / len(loader.dataset)

    print(f"Epoch {epoch}: Loss={avg_loss:.4f}, TagAcc={tag_acc:.4f}, TemplateAcc={template_acc:.4f}")

In [40]:
def validate(model, loader, tag_loss_fn, template_loss_fn, epoch, name="Dev"):
    model.eval()
    total_loss = 0
    total_template_correct = 0
    total_tag_correct = 0
    total_tokens = 0
    with torch.no_grad():
        for token_ids, tag_ids, template_ids in loader:
            token_ids = token_ids.to(device)
            tag_ids = tag_ids.to(device)
            template_ids = template_ids.to(device)

            lengths = (token_ids != token2id["<PAD>"]).sum(dim=1)
            tag_logits, template_logits = model(token_ids, lengths)

            tag_logits_reshaped = tag_logits.view(-1, tag_logits.shape[-1])
            tag_ids_reshaped = tag_ids.view(-1)
            tag_loss = tag_loss_fn(tag_logits_reshaped, tag_ids_reshaped)
            template_loss = template_loss_fn(template_logits, template_ids)

            loss = tag_loss + template_loss
            total_loss += loss.item()

            tag_preds = tag_logits.argmax(dim=-1)
            mask = tag_ids != tag2id["<PAD>"]
            total_tag_correct += ((tag_preds == tag_ids) & mask).sum().item()
            total_tokens += mask.sum().item()

            template_preds = template_logits.argmax(dim=-1)
            total_template_correct += (template_preds == template_ids).sum().item()

    avg_loss = total_loss / len(loader)
    tag_acc = total_tag_correct / total_tokens if total_tokens > 0 else 0.0
    template_acc = total_template_correct / len(loader.dataset)

    print(f"{name} Epoch {epoch}: Loss={avg_loss:.4f}, TagAcc={tag_acc:.4f}, TemplateAcc={template_acc:.4f}")

In [41]:
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
tag_loss_fn = nn.CrossEntropyLoss(ignore_index=tag2id["<PAD>"])
template_loss_fn = nn.CrossEntropyLoss()

In [42]:
from collections import Counter

train_templates = [entry["template_id"] for entry in question_split["train"]]
test_templates = [entry["template_id"] for entry in question_split["test"]]

print("Top templates in train:", Counter(train_templates).most_common(5))
print("Top templates in test:", Counter(test_templates).most_common(5))

Top templates in train: [(223, 28), (4, 28), (171, 24), (55, 20), (41, 17)]
Top templates in test: [(223, 12), (4, 12), (55, 11), (204, 11), (41, 9)]


In [43]:
for epoch in range(1, 31):
    train(model, train_loader, optimizer, tag_loss_fn, template_loss_fn, epoch)
    validate(model, dev_loader, tag_loss_fn, template_loss_fn, epoch)

Epoch 1: Loss=7.2729, TagAcc=0.7068, TemplateAcc=0.0801
Dev Epoch 1: Loss=6.1550, TagAcc=0.9149, TemplateAcc=0.1224
Epoch 2: Loss=5.4033, TagAcc=0.9003, TemplateAcc=0.1129
Dev Epoch 2: Loss=5.3423, TagAcc=0.9176, TemplateAcc=0.1224
Epoch 3: Loss=4.5678, TagAcc=0.9541, TemplateAcc=0.1803
Dev Epoch 3: Loss=5.0064, TagAcc=0.9707, TemplateAcc=0.2449
Epoch 4: Loss=4.1134, TagAcc=0.9664, TemplateAcc=0.2240
Dev Epoch 4: Loss=4.7865, TagAcc=0.9707, TemplateAcc=0.2449
Epoch 5: Loss=3.7838, TagAcc=0.9681, TemplateAcc=0.2532
Dev Epoch 5: Loss=4.5137, TagAcc=0.9761, TemplateAcc=0.2857
Epoch 6: Loss=3.4023, TagAcc=0.9740, TemplateAcc=0.3698
Dev Epoch 6: Loss=4.2255, TagAcc=0.9867, TemplateAcc=0.3878
Epoch 7: Loss=3.1197, TagAcc=0.9826, TemplateAcc=0.4080
Dev Epoch 7: Loss=4.0687, TagAcc=0.9894, TemplateAcc=0.3878
Epoch 8: Loss=2.8276, TagAcc=0.9889, TemplateAcc=0.4809
Dev Epoch 8: Loss=3.8619, TagAcc=0.9947, TemplateAcc=0.4286
Epoch 9: Loss=2.5436, TagAcc=0.9917, TemplateAcc=0.5392
Dev Epoch 9: Los

In [44]:
id2template = {i: t for i, t in enumerate(template_list)}  
id2tag = {v: k for k, v in tag2id.items()}
id2token = {v: k for k, v in token2id.items()}

In [None]:
import re
from collections import defaultdict

def fill_variables(template, var_map):
    sql = template
    for var, val in var_map.items():
        sql = re.sub(rf'@{re.escape(var)}\b', f'"{val}"', sql)
    return sql

def normalize_sql(s):
    return " ".join(s.strip().split()).lower()

def simple_tokenize(text):
    return re.findall(r"\b\w+\b", text)

def tag_to_var(tag):
    return tag.lstrip("B-").lstrip("I-").lower()

def evaluate(model, loader, dataset_split, output_file="predicted_sql.txt"):
    model.eval()
    correct = 0
    total = 0
    template_correct = 0
    tag_correct = 0
    tag_total = 0

    with open(output_file, "w") as f:
        with torch.no_grad():
            for i, (token_ids, tag_ids, template_ids) in enumerate(loader):
                token_ids = token_ids.to(device)
                lengths = (token_ids != token2id["<PAD>"]).sum(dim=1)
                tag_logits, template_logits = model(token_ids, lengths)

                tag_preds = tag_logits.argmax(dim=-1).cpu().numpy()
                template_preds = template_logits.argmax(dim=-1).cpu().numpy()
                token_ids = token_ids.cpu().numpy()

                for b in range(token_ids.shape[0]):
                    pred_tags = tag_preds[b]
                    tokens = simple_tokenize(dataset_split[total]["text"])[:len(pred_tags)]
                    pred_template_id = template_preds[b]
                    sql_template = id2template[pred_template_id]

                    # Template classification accuracy
                    if pred_template_id == dataset_split[total]["template_id"]:
                        template_correct += 1

                    # Tagging accuracy
                    gold_tags = [tag2id[tag] for tag in dataset_split[total]["tags"]]
                    min_len = min(len(gold_tags), len(pred_tags))
                    tag_correct += sum(1 for i in range(min_len) if pred_tags[i] == gold_tags[i])
                    tag_total += min_len

                    # Extract predicted variables
                    var_map = {}
                    i = 0
                    while i < len(tokens) and i < len(pred_tags):
                        tag = id2tag[pred_tags[i]]
                        token = tokens[i]

                        if tag != "O":
                            var_name = tag_to_var(tag)
                            merged_tokens = [token]
                            j = i + 1
                            while j < len(tokens):
                                if id2tag[pred_tags[j]] != "O":
                                    break
                                if tokens[j][0].isupper():
                                    merged_tokens.append(tokens[j])
                                    j += 1
                                else:
                                    break
                            if var_name not in var_map:
                                var_map[var_name] = " ".join(merged_tokens)
                            i = j
                        else:
                            i += 1

                    # Fill in the template using predicted variables
                    filled_sql = fill_variables(sql_template, {
                        var: var_map.get(var, default_values[var]) for var in default_values
                    })

                    f.write(filled_sql.strip() + "\n")

                    # Evaluate match with any correct SQL
                    example = dataset_split[total]
                    valid_sqls = []
                    for sql in example["template_sql"]:
                        if isinstance(sql, list):
                            sql = " ".join(sql) if all(len(tok) > 1 for tok in sql) else "".join(sql)
                        elif not isinstance(sql, str):
                            sql = str(sql)
                        sql_filled = sql
                        for var, val in example["variables"].items():
                            sql_filled = re.sub(rf'@{re.escape(var)}\b', f'"{val}"', sql_filled)
                        valid_sqls.append(sql_filled.strip())

                    pred_norm = normalize_sql(filled_sql)
                    valid_norms = [normalize_sql(s) for s in valid_sqls]

                    if pred_norm in valid_norms:
                        correct += 1

                    total += 1

    sql_accuracy = correct / total if total > 0 else 0.0
    template_accuracy = template_correct / total if total > 0 else 0.0
    tag_accuracy = tag_correct / tag_total if tag_total > 0 else 0.0

    print(f"\nFinal SQL Evaluation Accuracy: {sql_accuracy:.4f}")
    print(f"Template Classification Accuracy: {template_accuracy:.4f}")
    print(f"Tagging Accuracy: {tag_accuracy:.4f}")




In [46]:
evaluate(model, test_loader, question_split["test"])


--- Test Example 1 ---
Original Input: what is the biggest city in kansas
Predicted Template ID: 21
SQL Template:
 SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "state_name0" ) AND CITYalias0.STATE_NAME = "state_name0" ;
Predicted Tags: ['O', 'O', 'O', 'O', 'O', 'O', 'state_name0', 'O']
Tokens: ['what', 'is', 'the', 'biggest', 'city', 'in', 'kansas']
Extracted Variables: {'state_name0': 'kansas'}

Predicted Final SQL:
 SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "state_name0" ) AND CITYalias0.STATE_NAME = "state_name0" ;

Correct SQL Options:
→ SELECT CITYalias0.CITY_NAME FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION = ( SELECT MAX( CITYalias1.POPULATION ) FROM CITY AS CITYalias1 WHERE CITYalias1.STATE_NAME = "state_name0" ) 