In [2]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import spacy
import numpy as np
from collections import defaultdict, Counter

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score

In [42]:
#Load your dataset
with open("atis.json", "r") as f:
    raw_data = json.load(f)

# with open("geography.json", "r") as f:
#     raw_data = json.load(f)

# Load spaCy tokenizer
nlp = spacy.load("en_core_web_sm")

# Structures to store splits and mappings
question_split = {"train": [], "dev": [], "test": []}
query_split = {"train": [], "dev": [], "test": []}
template_pool = set()

# Process each item, normalize gold SQLs to create template pool
normalized_sqls_per_item = []
for item in raw_data:
    gold_templates = []
    for sql in item["sql"]:
        template = sql
        for var, val in item["sentences"][0]["variables"].items():
            template = template.replace(f'"{val}"', f'"@{var}"')
            template = template.replace(val, f'"@{var}"')
        gold_templates.append(template)
        template_pool.add(template)
    normalized_sqls_per_item.append((item, gold_templates))

template_list = sorted(template_pool)
template2id = {t: i for i, t in enumerate(template_list)}
id2template = {i: t for t, i in template2id.items()}

# Process entries and assign template IDs
for (item, gold_templates) in normalized_sqls_per_item:
    shortest_sql = sorted(gold_templates, key=lambda x: (len(x), x))[0]
    template_id = template2id[shortest_sql]

    for sent in item["sentences"]:
        raw_text = sent["text"]
        variables = sent["variables"]

        # Replace placeholders in raw_text for model input
        text = raw_text
        for var_name, var_value in variables.items():
            text = text.replace(var_name, var_value)

        # Tokenize
        tokens = [t.text for t in nlp(text)]

        # Create tag sequence
        tags = []
        for tok in tokens:
            matched = False
            for var_name, var_value in variables.items():
                if tok == var_value or tok in var_value.split():
                    tags.append(var_name)
                    matched = True
                    break
            if not matched:
                tags.append("O")

        # Fill in gold SQL (for seq2seq generation target)
        gold_sql = shortest_sql
        for var_name, var_value in variables.items():
            gold_sql = gold_sql.replace(f"@{var_name}", f'"{var_value}"')

        entry = {
            "text": text,
            "template_id": template_id,
            "template_sql": gold_templates,
            "variables": variables,
            "raw_text": raw_text,
            "tokens": tokens,
            "tags": tags,
            "gold_sql": gold_sql.strip(),
        }

        # Add to split
        question_split[sent["question-split"]].append(entry)
        query_split[item["query-split"]].append(entry)

In [43]:
token_counter = Counter()
tag_counter = Counter()

for entry in question_split["train"]:  
    text = entry["text"]
    tokens = [t.text for t in nlp(text)]

    tags = []
    for tok in tokens:
        matched = False
        for var_name, var_val in entry["variables"].items():
            if tok == var_val or tok in var_val.split():
                tags.append(var_name)
                tag_counter[var_name] += 1
                matched = True
                break
        if not matched:
            tags.append("O")
            tag_counter["O"] += 1

    token_counter.update(tokens)
    entry["tokens"] = tokens
    entry["tags"] = tags

In [44]:

for split in ["dev", "test"]:
    for entry in question_split[split]:
        text = entry["text"]
        tokens = [t.text for t in nlp(text)]
        tags = []
        for tok in tokens:
            matched = False
            for var_name, var_val in entry["variables"].items():
                if tok == var_val or tok in var_val.split():
                    tags.append(var_name)
                    matched = True
                    break
            if not matched:
                tags.append("O")
        entry["tokens"] = tokens
        entry["tags"] = tags

In [45]:
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"

# Token vocab
token2id = {PAD_TOKEN: 0, UNK_TOKEN: 1}
for token in token_counter:
    token2id[token] = len(token2id)

# Tag vocab 
PAD_TAG = "<PAD>"
tag_counter[PAD_TAG] = 1  

# tag2id = {}
# for tag in tag_counter:
#     tag2id[tag] = len(tag2id)

tag2id = {"<PAD>": 0, "O": 1}
for tag in tag_counter:
    if tag not in tag2id:
        tag2id[tag] = len(tag2id)

In [46]:
class Seq2SeqDataset(Dataset):
    def __init__(self, data, token2id, sql_token2id):
        self.data = data
        self.token2id = token2id
        self.sql_token2id = sql_token2id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        # Encoder input
        token_ids = [self.token2id.get(tok, self.token2id["<UNK>"]) for tok in entry["tokens"]]
        # Decoder target (shortest SQL template filled with values)
        sql = sorted(entry["template_sql"], key=lambda x: (len(x), x))[0]
        for var, val in entry["variables"].items():
            sql = sql.replace(f"@{var}", f'"{val}"')
        sql_tokens = re.findall(r"\w+|[^\s\w]", sql)
        target_ids = [self.sql_token2id.get(tok, self.sql_token2id["<UNK>"]) for tok in sql_tokens]

        return {
            "token_ids": torch.tensor(token_ids, dtype=torch.long),
            "target_ids": torch.tensor(target_ids, dtype=torch.long),
        }

def pad_seq2seq_batch(batch):
    max_input_len = max(len(item["token_ids"]) for item in batch)
    max_target_len = max(len(item["target_ids"]) for item in batch)

    for item in batch:
        item["token_ids"] = F.pad(item["token_ids"], (0, max_input_len - len(item["token_ids"])), value=token2id["<PAD>"])
        item["target_ids"] = F.pad(item["target_ids"], (0, max_target_len - len(item["target_ids"])), value=sql_token2id["<PAD>"])

    token_ids = torch.stack([item["token_ids"] for item in batch])
    target_ids = torch.stack([item["target_ids"] for item in batch])
    return token_ids, target_ids

In [47]:
import re
from collections import Counter

# Special tokens
SQL_PAD = "<PAD>"
SQL_UNK = "<UNK>"
SQL_START = "<START>"
SQL_END = "<END>"

# Step 1: Collect all SQL tokens from training set
sql_token_counter = Counter()

for entry in question_split["train"]:
    sql = sorted(entry["template_sql"], key=lambda x: (len(x), x))[0]
    for var, val in entry["variables"].items():
        sql = sql.replace(f"@{var}", f'"{val}"')
    # Basic tokenization: split by words and punctuation
    sql_tokens = re.findall(r"\w+|[^\s\w]", sql)
    sql_token_counter.update(sql_tokens)

# Step 2: Create vocab dict
sql_token2id = {
    SQL_PAD: 0,
    SQL_UNK: 1,
    SQL_START: 2,
    SQL_END: 3
}

for token in sql_token_counter:
    sql_token2id[token] = len(sql_token2id)

# Step 3: Reverse map for inference (optional)
id2sql_token = {i: tok for tok, i in sql_token2id.items()}

In [48]:
seq2seq_train_set = Seq2SeqDataset(question_split["train"], token2id, sql_token2id)
seq2seq_dev_set = Seq2SeqDataset(question_split["dev"], token2id, sql_token2id)
seq2seq_test_set = Seq2SeqDataset(question_split["test"], token2id, sql_token2id)

seq2seq_train_loader = DataLoader(seq2seq_train_set, batch_size=32, shuffle=True, collate_fn=pad_seq2seq_batch)
seq2seq_dev_loader = DataLoader(seq2seq_dev_set, batch_size=32, collate_fn=pad_seq2seq_batch)
seq2seq_test_loader = DataLoader(seq2seq_test_set, batch_size=32, collate_fn=pad_seq2seq_batch)
seq2seq_query_test_set = Seq2SeqDataset(query_split["test"], token2id, sql_token2id)
seq2seq_query_test_loader = DataLoader(seq2seq_query_test_set, batch_size=32, collate_fn=pad_seq2seq_batch)


In [49]:
class LSTMEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=token2id["<PAD>"])
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, input_seq, lengths):
        embedded = self.embedding(input_seq)
        packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, (h_n, c_n) = self.lstm(packed)
        encoder_outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)

        return encoder_outputs, (h_n, c_n)


In [50]:
class LSTMDecoderWithAttention(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=sql_token2id["<PAD>"])
        self.lstm = nn.LSTM(embedding_dim + hidden_dim, hidden_dim, batch_first=True)
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward_step(self, input_token, hidden, encoder_outputs):
        embedded = self.embedding(input_token)  # (B, 1, E)

        # Attention: compute dot-product attention weights
        h_t = hidden[0][-1]  # (B, H)
        attn_scores = torch.bmm(encoder_outputs, h_t.unsqueeze(2)).squeeze(2)  # (B, T)
        attn_weights = torch.softmax(attn_scores, dim=1)  # (B, T)

        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)  # (B, 1, H)
        lstm_input = torch.cat([embedded, context], dim=2)  # (B, 1, E+H)

        output, hidden = self.lstm(lstm_input, hidden)  # output: (B, 1, H)
        output = self.out(output)  # (B, 1, V)

        return output, hidden, attn_weights

    def forward(self, trg, hidden, encoder_outputs):
        outputs = []
        input_token = trg[:, 0].unsqueeze(1)  # <START>

        for t in range(1, trg.size(1)):
            output, hidden, _ = self.forward_step(input_token, hidden, encoder_outputs)
            outputs.append(output)
            input_token = trg[:, t].unsqueeze(1)

        outputs = torch.cat(outputs, dim=1)  # (B, T-1, V)
        return outputs

In [51]:
class Seq2SeqWithAttention(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg):
        batch_size = src.size(0)
        trg_len = trg.size(1)
        vocab_size = self.decoder.out.out_features

        outputs = torch.zeros(batch_size, trg_len - 1, vocab_size).to(self.device)

        lengths = (src != token2id["<PAD>"]).sum(dim=1)
        embedded = self.encoder.embedding(src)
        packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, (h, c) = self.encoder.lstm(packed)
        encoder_outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)

        output = self.decoder(trg, (h, c), encoder_outputs)
        return output

In [52]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = LSTMEncoder(len(token2id), embedding_dim=100, hidden_dim=128)
decoder = LSTMDecoderWithAttention(len(sql_token2id), embedding_dim=100, hidden_dim=128)
model = Seq2SeqWithAttention(encoder, decoder, device).to(device)


In [53]:
def train_seq2seq(model, loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0

    for token_ids, target_ids in loader:
        token_ids = token_ids.to(device)
        target_ids = target_ids.to(device)

        optimizer.zero_grad()

        decoder_input = target_ids[:, :-1]
        decoder_target = target_ids[:, 1:]

        output = model(token_ids, decoder_input)  # [B, T-1, V]
        output = output.reshape(-1, output.shape[-1])
        target = decoder_target.reshape(-1)

        # Sanity check: prevent shape mismatch crash
        min_len = min(output.size(0), target.size(0))
        loss = loss_fn(output[:min_len], target[:min_len])

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch}: Loss={avg_loss:.4f}")


In [54]:
def validate_seq2seq(model, loader, loss_fn, epoch, name="Dev"):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for token_ids, target_ids in loader:
            token_ids = token_ids.to(device)
            target_ids = target_ids.to(device)

            decoder_input = target_ids[:, :-1]
            decoder_target = target_ids[:, 1:]

            output = model(token_ids, decoder_input)  # [B, T-1, V]
            output = output.reshape(-1, output.shape[-1])
            target = decoder_target.reshape(-1)

            min_len = min(output.size(0), target.size(0))
            loss = loss_fn(output[:min_len], target[:min_len])
            total_loss += loss.item()

            pred = output.argmax(dim=-1)
            correct += (pred[:min_len] == target[:min_len]).sum().item()
            total += min_len

    avg_loss = total_loss / len(loader)
    acc = correct / total if total > 0 else 0.0
    print(f"{name} Epoch {epoch}: Loss={avg_loss:.4f}, Accuracy={acc:.4f}")


In [55]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss(ignore_index=sql_token2id["<PAD>"])

for epoch in range(1, 11):
    train_seq2seq(model, seq2seq_train_loader, loss_fn, optimizer, epoch)
    validate_seq2seq(model, seq2seq_dev_loader, loss_fn, epoch)

Epoch 1: Loss=3.5441
Dev Epoch 1: Loss=3.1747, Accuracy=0.0833
Epoch 2: Loss=3.0440
Dev Epoch 2: Loss=3.0186, Accuracy=0.0824
Epoch 3: Loss=2.9198
Dev Epoch 3: Loss=2.9331, Accuracy=0.0823
Epoch 4: Loss=2.8472
Dev Epoch 4: Loss=2.8838, Accuracy=0.0825
Epoch 5: Loss=2.7977
Dev Epoch 5: Loss=2.8429, Accuracy=0.0828
Epoch 6: Loss=2.7571
Dev Epoch 6: Loss=2.8146, Accuracy=0.0822
Epoch 7: Loss=2.7238
Dev Epoch 7: Loss=2.7928, Accuracy=0.0830
Epoch 8: Loss=2.6978
Dev Epoch 8: Loss=2.7715, Accuracy=0.0829
Epoch 9: Loss=2.6848
Dev Epoch 9: Loss=2.7579, Accuracy=0.0832
Epoch 10: Loss=2.6646
Dev Epoch 10: Loss=2.7463, Accuracy=0.0833


In [56]:

def normalize_sql(s):
    return " ".join(s.strip().split()).lower()

def generate_sql_from_model_with_attention(model, loader, output_file="seq2seq_attn_test_predictions.txt"):
    model.eval()
    predictions = []
    golds = []
    correct = 0
    total = 0

    with open(output_file, "w") as f:
        with torch.no_grad():
            for token_ids, target_ids in loader:
                token_ids = token_ids.to(device)
                lengths = (token_ids != token2id["<PAD>"]).sum(dim=1)
                batch_size = token_ids.size(0)

                # Encode
                embedded = model.encoder.embedding(token_ids)
                packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
                packed_output, (h, c) = model.encoder.lstm(packed)
                encoder_outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)

                # Decode
                input_token = torch.full((batch_size, 1), sql_token2id["<START>"], dtype=torch.long, device=device)
                outputs = []

                for _ in range(100):
                    output, (h, c), _ = model.decoder.forward_step(input_token, (h, c), encoder_outputs)
                    logits = output[:, -1, :]
                    pred_token = logits.argmax(dim=-1, keepdim=True)
                    outputs.append(pred_token)
                    input_token = pred_token

                output_seqs = torch.cat(outputs, dim=1).cpu().numpy()

                for i in range(batch_size):
                    pred_tokens = output_seqs[i].tolist()
                    if sql_token2id["<END>"] in pred_tokens:
                        pred_tokens = pred_tokens[:pred_tokens.index(sql_token2id["<END>"])]
                    pred_sql = normalize_sql(" ".join(id2sql_token[tok] for tok in pred_tokens))

                    gold_tokens = target_ids[i].tolist()
                    if sql_token2id["<END>"] in gold_tokens:
                        gold_tokens = gold_tokens[:gold_tokens.index(sql_token2id["<END>"])]
                    gold_sql = normalize_sql(" ".join(id2sql_token[tok] for tok in gold_tokens))

                    predictions.append(pred_sql)
                    golds.append(gold_sql)
                    f.write(pred_sql + "\n")

                    if pred_sql == gold_sql:
                        correct += 1
                    total += 1

    acc = correct / total if total > 0 else 0.0
    print(f"[Test Seq2Seq+Attention] SQL Generation Accuracy: {acc:.4f}")
    return predictions, golds


In [57]:
print("[Evaluation on Question Split]")
preds_question_attention, golds_question_attention = generate_sql_from_model_with_attention(
    model, 
    seq2seq_test_loader, 
    output_file="seq2seq_question_preds.txt"
)

print("[Evaluation on Query Split]")
preds_query_attention, golds_query_attention = generate_sql_from_model_with_attention(
    model, 
    seq2seq_query_test_loader,  
    output_file="seq2seq_query_preds.txt"
)


[Evaluation on Question Split]
[Test Seq2Seq+Attention] SQL Generation Accuracy: 0.0000
[Evaluation on Query Split]
[Test Seq2Seq+Attention] SQL Generation Accuracy: 0.0000
