### 1.Data_Preprocessing for generation and LLM

In [1]:
import json
import re
import os

os.mkdir('cache')
os.mkdir('evaluation_results')
os.mkdir('seq2seq_results')
os.mkdir('vocab')
def load_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)
# pick up the shortest sql
def pick_shortest_sql(sql_list):
    queries = [q.strip() for q in sql_list if q.strip()]
    min_len = min(len(q) for q in queries)
    shortest_queries = [q for q in queries if len(q) == min_len]
    shortest_queries.sort()
    return shortest_queries[0]
# build templates, unique id
def build_templates(data, split_type):
    template_id_map = {}
    templates = {}
    default_values = {}
    current_id = 0

    for entry in data:
        split_value = entry.get(split_type + "-split")
        if split_value != "train":
            continue

        shortest_query = pick_shortest_sql(entry["sql"])
        if shortest_query not in template_id_map:
            template_id_map[shortest_query] = current_id
            templates[current_id] = shortest_query
            default_val_dict = {}

            for sent in entry.get("sentences", []):
                if sent.get("question-split") == "train":
                    variables = sent["variables"]
                    default_val_dict.update(variables)
                    break

            if not default_val_dict and "variables" in entry:
                for var_info in entry["variables"]:
                    default_val_dict[var_info["name"]] = var_info["example"]

            default_values[current_id] = default_val_dict
            current_id += 1

    return templates, default_values
# replace placeholder with viriables
def replace_placeholders(text, variables):
    placeholder_spans = sorted(
        ((match.start(), var)
         for var in variables
         for match in re.finditer(r'\b'+re.escape(var)+r'\b', text)),
        key=lambda x: x[0]
    )
    new_text = text
    for _, var in placeholder_spans:
        val = variables[var]
        new_text = re.sub(r'\b'+re.escape(var)+r'\b', val, new_text, count=1)
    return new_text

def process_sentence(text, variables, template_id=None, template_sql=None):
    full_text = replace_placeholders(text, variables)
    full_sql = template_sql
    if template_sql:
        for var_name, val in variables.items():
            pattern = re.compile(r'(?<=\")'+re.escape(var_name)+r'(?=\")')
            full_sql = pattern.sub(val, full_sql)

    return {
        "text": full_text,
        "variables": variables,
        "template_id": template_id,
        "sql": full_sql
    }
# train split and query split
def split_data(data, templates, split_type):
    splits = {"train": [], "dev": [], "test": []}
    template_sql_to_id = {sql: tid for tid, sql in templates.items()}

    for entry in data:
        shortest_query = pick_shortest_sql(entry["sql"])
        template_in_train = shortest_query in template_sql_to_id
        template_id = template_sql_to_id.get(shortest_query)
        template_sql = shortest_query

        entry_split = entry.get("query-split")

        for sent in entry["sentences"]:
            if split_type == "question":
                qsplit = sent["question-split"]
            else:
                qsplit = entry_split

            sample = process_sentence(
                sent["text"], sent["variables"],
                template_id if template_in_train else None,
                template_sql
            )

            splits[qsplit].append(sample)

    return splits
# save tentative results
def save_jsonl(filename, data_list):
    with open(filename, 'w', encoding='utf-8') as f:
        for item in data_list:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

if __name__ == "__main__":
    data = load_data("./sources/atis.json")

    # Question split
    question_templates, question_defaults = build_templates(data, "question")
    question_splits = split_data(data, question_templates, "question")

    # Query split
    query_templates, query_defaults = build_templates(data, "query")
    query_splits = split_data(data, query_templates, "query")

    # Save splits
    for split in ["train", "dev", "test"]:
        save_jsonl(f"./datasets/question_{split}.jsonl", question_splits[split])
        save_jsonl(f"./datasets/query_{split}.jsonl", query_splits[split])

    # Save templates and defaults
    with open("./datasets/question_templates.json", "w", encoding='utf-8') as f:
        json.dump(question_templates, f, ensure_ascii=False, indent=2)
    with open("./datasets/question_defaults.json", "w", encoding='utf-8') as f:
        json.dump(question_defaults, f, ensure_ascii=False, indent=2)

    with open("./datasets/query_templates.json", "w", encoding='utf-8') as f:
        json.dump(query_templates, f, ensure_ascii=False, indent=2)
    with open("./datasets/query_defaults.json", "w", encoding='utf-8') as f:
        json.dump(query_defaults, f, ensure_ascii=False, indent=2)

    print("Data processing complete.")
    print("Question split sizes:", {k: len(v) for k, v in question_splits.items()})
    print("Query split sizes:", {k: len(v) for k, v in query_splits.items()})

Data processing complete.
Question split sizes: {'train': 4347, 'dev': 486, 'test': 447}
Query split sizes: {'train': 4812, 'dev': 121, 'test': 347}


## 2. Generation tasks

### 2.1 Data preparation for generation tasks

In [2]:
import json
import os
import pickle
import spacy
from typing import List, Tuple
from collections import Counter

PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"

# File paths
train_question_file = "./datasets/question_train.jsonl"
train_query_file = "./datasets/query_train.jsonl"
dev_question_file = "./datasets/question_dev.jsonl"
dev_query_file = "./datasets/query_dev.jsonl"
test_question_file = "./datasets/question_test.jsonl"
test_query_file = "./datasets/query_test.jsonl"

# Vocab and cache paths
vocab_dir = "./vocab"
cache_dir = "./cache"
os.makedirs(vocab_dir, exist_ok=True)
os.makedirs(cache_dir, exist_ok=True)

question_vocab_file = os.path.join(vocab_dir, "question_vocab.pkl")
query_vocab_file = os.path.join(vocab_dir, "query_vocab.pkl")

# Cache paths
train_question_cache = os.path.join(cache_dir, 'train_question_tokenized.pkl')
dev_question_cache = os.path.join(cache_dir, 'dev_question_tokenized.pkl')
test_question_cache = os.path.join(cache_dir, 'test_question_tokenized.pkl')

train_query_cache = os.path.join(cache_dir, 'train_query_tokenized.pkl')
dev_query_cache = os.path.join(cache_dir, 'dev_query_tokenized.pkl')
test_query_cache = os.path.join(cache_dir, 'test_query_tokenized.pkl')

print("Loading spaCy model...")
nlp = spacy.load("en_core_web_sm", disable=["ner", "parser", "lemmatizer"])

# Data loading
def load_data(filepath: str) -> List[Tuple[str, str]]:
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            entry = json.loads(line)
            data.append((entry['text'], entry['sql']))
    print(f"Loaded {len(data)} examples from {filepath}")
    return data

train_question_data = load_data(train_question_file)
dev_question_data = load_data(dev_question_file)
test_question_data = load_data(test_question_file)

train_query_data = load_data(train_query_file)
dev_query_data = load_data(dev_query_file)
test_query_data = load_data(test_query_file)

# Complete vocab construction and tokenization

# Special tokens
special_tokens = [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]

# Function to build vocabulary
def build_vocab(data_files: List[str], vocab_file: str):
    token_counter = Counter()

    for file in data_files:
        with open(file, 'r', encoding='utf-8') as f:
            for line in f:
                entry = json.loads(line)
                # Tokenize text with spaCy for question
                doc_question = nlp(entry['text'].strip())
                token_counter.update([token.text.lower() for token in doc_question])

                # Tokenize SQL by whitespace
                tokens_query = entry['sql'].strip().split()
                token_counter.update([token.lower() for token in tokens_query])

    # Build final vocab dict
    vocab = {token: idx for idx, token in enumerate(special_tokens)}

    # Add tokens from data to vocab
    for token, _ in token_counter.most_common():
        if token not in vocab:
            vocab[token] = len(vocab)

    # Save vocab
    with open(vocab_file, 'wb') as vf:
        pickle.dump(vocab, vf)

    print(f"Vocab of size {len(vocab)} created and saved at {vocab_file}")
    return vocab

# Build unified vocab
all_files = [train_question_file, dev_question_file, test_question_file,
             train_query_file, dev_query_file, test_query_file]

unified_vocab = build_vocab(all_files, os.path.join(vocab_dir, 'unified_vocab.pkl'))

def tokenize_and_cache(data_file: str, cache_file: str, vocab: dict):
    tokenized_data = []

    with open(data_file, 'r', encoding='utf-8') as f:
        for line in f:
            entry = json.loads(line)

            # Tokenize question text
            question_tokens = [vocab.get(token.text.lower(), vocab[UNK_TOKEN])
                               for token in nlp(entry['text'].strip())]

            # Tokenize SQL query
            sql_tokens = [vocab.get(token.lower(), vocab[UNK_TOKEN])
                          for token in entry['sql'].strip().split()]

            tokenized_data.append((question_tokens, sql_tokens))

    # Save tokenized data to cache file
    with open(cache_file, 'wb') as cf:
        pickle.dump(tokenized_data, cf)

    print(f"Tokenized {len(tokenized_data)} samples and cached to {cache_file}")

# Tokenization and caching for all 6 files
tokenize_and_cache(train_question_file, train_question_cache, unified_vocab)
tokenize_and_cache(dev_question_file, dev_question_cache, unified_vocab)
tokenize_and_cache(test_question_file, test_question_cache, unified_vocab)

tokenize_and_cache(train_query_file, train_query_cache, unified_vocab)
tokenize_and_cache(dev_query_file, dev_query_cache, unified_vocab)
tokenize_and_cache(test_query_file, test_query_cache, unified_vocab)

Loading spaCy model...
Loaded 4347 examples from ./datasets/question_train.jsonl
Loaded 486 examples from ./datasets/question_dev.jsonl
Loaded 447 examples from ./datasets/question_test.jsonl
Loaded 4812 examples from ./datasets/query_train.jsonl
Loaded 121 examples from ./datasets/query_dev.jsonl
Loaded 347 examples from ./datasets/query_test.jsonl
Vocab of size 1467 created and saved at ./vocab\unified_vocab.pkl
Tokenized 4347 samples and cached to ./cache\train_question_tokenized.pkl
Tokenized 486 samples and cached to ./cache\dev_question_tokenized.pkl
Tokenized 447 samples and cached to ./cache\test_question_tokenized.pkl
Tokenized 4812 samples and cached to ./cache\train_query_tokenized.pkl
Tokenized 121 samples and cached to ./cache\dev_query_tokenized.pkl
Tokenized 347 samples and cached to ./cache\test_query_tokenized.pkl


In [3]:
print("\n=== Question Split Examples ===")
for i in range(3):
    print(train_question_data[i])

print("\n=== Query Split Examples ===")
for i in range(3):
    print(train_query_data[i])


=== Question Split Examples ===
('list all the flights that arrive at MKE from various cities', 'SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT AS AIRPORTalias0 , FLIGHT AS FLIGHTalias0 WHERE AIRPORTalias0.AIRPORT_CODE = "MKE" AND FLIGHTalias0.TO_AIRPORT = AIRPORTalias0.AIRPORT_CODE ;')
('what flights from any city land at MKE', 'SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT AS AIRPORTalias0 , FLIGHT AS FLIGHTalias0 WHERE AIRPORTalias0.AIRPORT_CODE = "MKE" AND FLIGHTalias0.TO_AIRPORT = AIRPORTalias0.AIRPORT_CODE ;')
('show me the flights into DAL', 'SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT AS AIRPORTalias0 , FLIGHT AS FLIGHTalias0 WHERE AIRPORTalias0.AIRPORT_CODE = "DAL" AND FLIGHTalias0.TO_AIRPORT = AIRPORTalias0.AIRPORT_CODE ;')

=== Query Split Examples ===
('list all the flights that arrive at MKE from various cities', 'SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT AS AIRPORTalias0 , FLIGHT AS FLIGHTalias0 WHERE AIRPORTalias0.AIRPORT_CODE = "MKE" AND FL

In [4]:
with open('./cache/train_question_tokenized.pkl', 'rb') as file:
    data = pickle.load(file)
print(data)

[([108, 99, 36, 35, 121, 233, 207, 494, 9, 1012, 607], [13, 16, 18, 9, 111, 6, 135, 7, 15, 6, 25, 14, 102, 4, 427, 5, 27, 4, 102, 17]), ([54, 35, 9, 277, 8, 902, 207, 494], [13, 16, 18, 9, 111, 6, 135, 7, 15, 6, 25, 14, 102, 4, 427, 5, 27, 4, 102, 17]), ([61, 56, 36, 35, 412, 486], [13, 16, 18, 9, 111, 6, 135, 7, 15, 6, 25, 14, 102, 4, 466, 5, 27, 4, 102, 17]), ([61, 56, 36, 35, 182, 207, 486], [13, 16, 18, 9, 111, 6, 135, 7, 15, 6, 25, 14, 102, 4, 466, 5, 27, 4, 102, 17]), ([108, 99, 36, 35, 121, 233, 207, 494], [13, 16, 18, 9, 111, 6, 135, 7, 15, 6, 25, 14, 102, 4, 427, 5, 27, 4, 102, 17]), ([108, 99, 36, 182, 35, 207, 494], [13, 16, 18, 9, 111, 6, 135, 7, 15, 6, 25, 14, 102, 4, 427, 5, 27, 4, 102, 17]), ([54, 35, 902, 207, 494], [13, 16, 18, 9, 111, 6, 135, 7, 15, 6, 25, 14, 102, 4, 427, 5, 27, 4, 102, 17]), ([61, 56, 36, 35, 26, 486], [13, 16, 18, 9, 111, 6, 135, 7, 15, 6, 25, 14, 102, 4, 466, 5, 27, 4, 102, 17]), ([108, 99, 36, 714, 207, 494], [13, 16, 18, 9, 111, 6, 135, 7, 15, 6

In [5]:
with open('./vocab/unified_vocab.pkl', 'rb') as file:
    data = pickle.load(file)

print(data)

{'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3, '=': 4, 'and': 5, 'as': 6, ',': 7, 'city': 8, 'from': 9, 'airport_service': 10, ')': 11, '(': 12, 'select': 13, 'where': 14, 'flight': 15, 'distinct': 16, ';': 17, 'flightalias0.flight_id': 18, 'cityalias0.city_code': 19, 'cityalias0.city_name': 20, 'cityalias0': 21, 'airport_servicealias0.city_code': 22, 'airport_servicealias0.airport_code': 23, 'airport_servicealias0': 24, 'flightalias0': 25, 'to': 26, 'flightalias0.to_airport': 27, 'flightalias0.from_airport': 28, 'cityalias1.city_code': 29, 'cityalias1.city_name': 30, 'airport_servicealias1.city_code': 31, 'airport_servicealias1.airport_code': 32, 'cityalias1': 33, 'airport_servicealias1': 34, 'flights': 35, 'the': 36, 'in': 37, 'days': 38, 'month_number0': 39, 'year0': 40, 'day_number0': 41, 'date_day': 42, 'flightalias0.flight_days': 43, 'daysalias0.day_name': 44, 'daysalias0.days_code': 45, 'flightalias0.departure_time': 46, 'on': 47, 'fare': 48, 'daysalias0': 49, 'date_dayalias0.

### 2.2 LSTM

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import pickle
import os
import sqlglot

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Directory paths
cache_dir = "./cache"
vocab_dir = "./vocab"

# Special tokens
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"

# Load vocab
with open(os.path.join(vocab_dir, "unified_vocab.pkl"), 'rb') as f:
    vocab = pickle.load(f)

idx_to_token = {idx: tok for tok, idx in vocab.items()}
vocab_size = len(vocab)

# Load cached data
def load_cached_data(filename):
    with open(os.path.join(cache_dir, filename), 'rb') as f:
        return pickle.load(f)

train_q_data = load_cached_data("train_question_tokenized.pkl")
dev_q_data = load_cached_data("dev_question_tokenized.pkl")
test_q_data = load_cached_data("test_question_tokenized.pkl")

train_s_data = load_cached_data("train_query_tokenized.pkl")
dev_s_data = load_cached_data("dev_query_tokenized.pkl")
test_s_data = load_cached_data("test_query_tokenized.pkl")

# SQL normalization function
def normalize_sql(sql_str):
    return ' '.join(sql_str.lower().strip().rstrip(';').split())

# LSTM Seq2Seq without Attention
class LSTMSeq2Seq(nn.Module):
    def __init__(self, vocab_size, embed_size=128, hidden_size=256):
        super().__init__()
        self.hidden_size = hidden_size
        
        self.encoder_embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab[PAD_TOKEN])
        self.encoder_lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)

        self.decoder_embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab[PAD_TOKEN])
        self.decoder_lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc_out = nn.Linear(hidden_size, vocab_size)

    def forward(self, src, tgt):
        # Encoder
        embedded_src = self.encoder_embedding(src)
        _, (hidden, cell) = self.encoder_lstm(embedded_src)

        # Decoder
        embedded_tgt = self.decoder_embedding(tgt)
        decoder_output, _ = self.decoder_lstm(embedded_tgt, (hidden, cell))

        # Output
        outputs = self.fc_out(decoder_output)
        return outputs

def batchify(data, batch_size):
    random.shuffle(data)
    batches = []
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        src_batch, tgt_input_batch, tgt_output_batch = [], [], []
        for src_seq, tgt_seq in batch:
            src_batch.append(src_seq)
            tgt_input_batch.append([vocab[SOS_TOKEN]] + tgt_seq)
            tgt_output_batch.append(tgt_seq + [vocab[EOS_TOKEN]])

        src_max_len = max(len(seq) for seq in src_batch)
        tgt_max_len = max(len(seq) for seq in tgt_input_batch)

        src_batch_padded = [seq + [vocab[PAD_TOKEN]]*(src_max_len - len(seq)) for seq in src_batch]
        tgt_input_batch_padded = [seq + [vocab[PAD_TOKEN]]*(tgt_max_len - len(seq)) for seq in tgt_input_batch]
        tgt_output_batch_padded = [seq + [vocab[PAD_TOKEN]]*(tgt_max_len - len(seq)) for seq in tgt_output_batch]

        batches.append((torch.tensor(src_batch_padded, device=device),
                        torch.tensor(tgt_input_batch_padded, device=device),
                        torch.tensor(tgt_output_batch_padded, device=device)))
    return batches

def train_seq2seq(model, train_data, dev_data, epochs=3, batch_size=64, lr=1e-3):
    criterion = nn.CrossEntropyLoss(ignore_index=vocab[PAD_TOKEN])
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        batches = batchify(train_data, batch_size)

        for src, tgt_in, tgt_out in batches:
            optimizer.zero_grad()
            output = model(src, tgt_in)
            loss = criterion(output.view(-1, vocab_size), tgt_out.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(batches)
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")

        if dev_data:
            acc = evaluate(model, dev_data)
            print(f"Dev Accuracy: {acc*100:.2f}%")


def infer(model, src, max_len=100):
    model.eval()
    src_tensor = torch.tensor([src], device=device)
    with torch.no_grad():
        embedded_src = model.encoder_embedding(src_tensor)
        _, (hidden, cell) = model.encoder_lstm(embedded_src)

        tgt_idx = vocab[SOS_TOKEN]
        tgt_indices = []
        
        for _ in range(max_len):
            tgt_tensor = torch.tensor([[tgt_idx]], device=device)
            embedded_tgt = model.decoder_embedding(tgt_tensor)
            output, (hidden, cell) = model.decoder_lstm(embedded_tgt, (hidden, cell))
            logits = model.fc_out(output.squeeze(1))
            tgt_idx = logits.argmax(1).item()

            if tgt_idx in (vocab[EOS_TOKEN], vocab[PAD_TOKEN]):
                break
            tgt_indices.append(tgt_idx)

    return ' '.join(idx_to_token[idx] for idx in tgt_indices)

def evaluate(model, dataset):
    model.eval()
    correct = 0
    with torch.no_grad():
        for src, tgt in dataset:
            pred_sql = infer(model, src)
            gold_sql = ' '.join(idx_to_token[idx] for idx in tgt)
            try:
                if sqlglot.parse_one(normalize_sql(pred_sql)).sql() == sqlglot.parse_one(normalize_sql(gold_sql)).sql():
                    correct += 1
            except sqlglot.errors.ParseError:
                continue
    return correct / len(dataset)

# Instantiate model without Attention
model_lstm = LSTMSeq2Seq(vocab_size, embed_size=128, hidden_size=256).to(device)

Training and evaluation Question Split
print("Training LSTM without Attention (Question Split):")
# train_seq2seq(model_lstm, train_q_data, dev_q_data, epochs=30, batch_size=64, lr=1e-3)
q_acc = evaluate(model_lstm, test_q_data)
print(f"Question Split Test Accuracy: {q_acc*100:.2f}%")

Instantiate a new model for Query Split for fair comparison
model_lstm_query = LSTMSeq2Seq(vocab_size, embed_size=128, hidden_size=256).to(device)

# Training and evaluation Query Split
print("\nTraining LSTM without Attention (Query Split):")
train_seq2seq(model_lstm_query, train_s_data, dev_s_data, epochs=15, batch_size=64, lr=1e-3)
s_acc = evaluate(model_lstm_query, test_s_data)
print(f"Query Split Test Accuracy: {s_acc*100:.2f}%")

Training LSTM without Attention (Question Split):
Question Split Test Accuracy: 0.22%

Training LSTM without Attention (Query Split):
Epoch 1/15 - Loss: 3.0011
Dev Accuracy: 0.00%
Epoch 2/15 - Loss: 1.1864
Dev Accuracy: 0.00%
Epoch 3/15 - Loss: 0.7299
Dev Accuracy: 0.00%
Epoch 4/15 - Loss: 0.5358
Dev Accuracy: 0.00%
Epoch 5/15 - Loss: 0.4396
Dev Accuracy: 0.00%
Epoch 6/15 - Loss: 0.3793
Dev Accuracy: 0.00%
Epoch 7/15 - Loss: 0.3366
Dev Accuracy: 0.00%
Epoch 8/15 - Loss: 0.3065
Dev Accuracy: 0.00%


KeyboardInterrupt: 

### 2.3 LSTM Encoder-Decoder with Attention

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import pickle
import os
import sqlglot

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Directory paths
cache_dir = "./cache"
vocab_dir = "./vocab"

# Special tokens
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"

# Load vocab
with open(os.path.join(vocab_dir, "unified_vocab.pkl"), 'rb') as f:
    vocab = pickle.load(f)

idx_to_token = {idx: tok for tok, idx in vocab.items()}
vocab_size = len(vocab)

# Load cached data
def load_cached_data(filename):
    with open(os.path.join(cache_dir, filename), 'rb') as f:
        return pickle.load(f)

train_q_data = load_cached_data("train_question_tokenized.pkl")
dev_q_data = load_cached_data("dev_question_tokenized.pkl")
test_q_data = load_cached_data("test_question_tokenized.pkl")

train_s_data = load_cached_data("train_query_tokenized.pkl")
dev_s_data = load_cached_data("dev_query_tokenized.pkl")
test_s_data = load_cached_data("test_query_tokenized.pkl")

# SQL normalization function
def normalize_sql(sql_str):
    return ' '.join(sql_str.lower().strip().rstrip(';').split())

# LSTM Seq2Seq with Attention
class LSTMAttentionSeq2Seq(nn.Module):
    def __init__(self, vocab_size, embed_size=128, hidden_size=256):
        super().__init__()
        self.hidden_size = hidden_size

        self.encoder_embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab[PAD_TOKEN])
        self.encoder_lstm = nn.LSTM(embed_size, hidden_size, batch_first=True, bidirectional=True)

        self.decoder_embedding = nn.Embedding(vocab_size, embed_size, padding_idx=vocab[PAD_TOKEN])
        self.decoder_lstm = nn.LSTM(embed_size + hidden_size*2, hidden_size, batch_first=True)

        self.attn_linear = nn.Linear(hidden_size*3, hidden_size)
        self.fc_out = nn.Linear(hidden_size, vocab_size)

    def forward(self, src, tgt):
        # Encoder
        embedded_src = self.encoder_embedding(src)
        enc_output, (hidden, cell) = self.encoder_lstm(embedded_src)
        
        # Concatenate forward and backward states
        hidden = torch.cat((hidden[0], hidden[1]), dim=1).unsqueeze(0)
        cell = torch.cat((cell[0], cell[1]), dim=1).unsqueeze(0)

        # Decoder
        embedded_tgt = self.decoder_embedding(tgt)
        batch_size, tgt_len, _ = embedded_tgt.shape
        outputs = []

        dec_hidden, dec_cell = hidden, cell
        for t in range(tgt_len):
            dec_input_t = embedded_tgt[:, t:t+1, :]

            # Attention scores (Luong style)
            attn_weights = torch.bmm(dec_hidden.permute(1, 0, 2), enc_output.permute(0, 2, 1))
            attn_weights = torch.softmax(attn_weights, dim=-1)
            context_vector = torch.bmm(attn_weights, enc_output)

            # Combine context with input
            dec_input_combined = torch.cat([dec_input_t, context_vector], dim=2)
            output, (dec_hidden, dec_cell) = self.decoder_lstm(dec_input_combined, (dec_hidden, dec_cell))

            # Generate token
            output = output.squeeze(1)
            prediction = self.fc_out(torch.tanh(self.attn_linear(torch.cat([output, context_vector.squeeze(1)], dim=1))))
            outputs.append(prediction.unsqueeze(1))

        outputs = torch.cat(outputs, dim=1)
        return outputs
        
def batchify(data, batch_size):
    random.shuffle(data)
    batches = []
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        src_batch, tgt_input_batch, tgt_output_batch = [], [], []
        for src_seq, tgt_seq in batch:
            src_batch.append(src_seq)
            tgt_input_batch.append([vocab[SOS_TOKEN]] + tgt_seq)
            tgt_output_batch.append(tgt_seq + [vocab[EOS_TOKEN]])

        src_max_len = max(len(seq) for seq in src_batch)
        tgt_max_len = max(len(seq) for seq in tgt_input_batch)

        src_batch_padded = [seq + [vocab[PAD_TOKEN]]*(src_max_len - len(seq)) for seq in src_batch]
        tgt_input_batch_padded = [seq + [vocab[PAD_TOKEN]]*(tgt_max_len - len(seq)) for seq in tgt_input_batch]
        tgt_output_batch_padded = [seq + [vocab[PAD_TOKEN]]*(tgt_max_len - len(seq)) for seq in tgt_output_batch]

        batches.append((torch.tensor(src_batch_padded, device=device),
                        torch.tensor(tgt_input_batch_padded, device=device),
                        torch.tensor(tgt_output_batch_padded, device=device)))
    return batches

def train_seq2seq(model, train_data, dev_data, epochs=3, batch_size=64, lr=1e-3):
    criterion = nn.CrossEntropyLoss(ignore_index=vocab[PAD_TOKEN])
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        batches = batchify(train_data, batch_size)

        for src, tgt_in, tgt_out in batches:
            optimizer.zero_grad()
            output = model(src, tgt_in)
            loss = criterion(output.view(-1, vocab_size), tgt_out.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(batches)
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")

        if dev_data:
            acc = evaluate(model, dev_data)
            print(f"Dev Accuracy: {acc*100:.2f}%")

def infer(model, src, max_len=100):
    model.eval()
    src_tensor = torch.tensor([src], device=device)
    with torch.no_grad():
        embedded_src = model.encoder_embedding(src_tensor)
        enc_output, (hidden, cell) = model.encoder_lstm(embedded_src)
        hidden = torch.cat((hidden[0], hidden[1]), dim=1).unsqueeze(0)
        cell = torch.cat((cell[0], cell[1]), dim=1).unsqueeze(0)

        tgt_idx = vocab[SOS_TOKEN]
        tgt_indices = []

        dec_hidden, dec_cell = hidden, cell
        for _ in range(max_len):
            tgt_tensor = torch.tensor([[tgt_idx]], device=device)
            embedded_tgt = model.decoder_embedding(tgt_tensor)

            attn_weights = torch.bmm(dec_hidden.permute(1, 0, 2), enc_output.permute(0, 2, 1))
            attn_weights = torch.softmax(attn_weights, dim=-1)
            context_vector = torch.bmm(attn_weights, enc_output)

            dec_input_combined = torch.cat([embedded_tgt, context_vector], dim=2)
            output, (dec_hidden, dec_cell) = model.decoder_lstm(dec_input_combined, (dec_hidden, dec_cell))
            output = output.squeeze(1)
            prediction = model.fc_out(torch.tanh(model.attn_linear(torch.cat([output, context_vector.squeeze(1)], dim=1))))
            tgt_idx = prediction.argmax(1).item()

            if tgt_idx in (vocab[EOS_TOKEN], vocab[PAD_TOKEN]):
                break
            tgt_indices.append(tgt_idx)

    return ' '.join(idx_to_token[idx] for idx in tgt_indices)

def evaluate(model, dataset):
    model.eval()
    correct = 0
    with torch.no_grad():
        for src, tgt in dataset:
            pred_sql = infer(model, src)
            gold_sql = ' '.join(idx_to_token[idx] for idx in tgt)
            try:
                if sqlglot.parse_one(normalize_sql(pred_sql)).sql() == sqlglot.parse_one(normalize_sql(gold_sql)).sql():
                    correct += 1
            except sqlglot.errors.ParseError:
                continue
    return correct / len(dataset)

# Instantiate model with Attention
model_lstm_attn = LSTMAttentionSeq2Seq(vocab_size, embed_size=128, hidden_size=256).to(device)

# Training and evaluation Question Split
print("Training LSTM with Attention (Question Split):")
train_seq2seq(model_lstm_attn, train_q_data, dev_q_data, epochs=10, batch_size=64, lr=1e-3)
q_acc = evaluate(model_lstm_attn, test_q_data)
print(f"Question Split Test Accuracy: {q_acc*100:.2f}%")

# Instantiate a new model for Query Split
model_lstm_attn_query = LSTMAttentionSeq2Seq(vocab_size, embed_size=128, hidden_size=256).to(device)

# Training and evaluation Query Split
print("\nTraining LSTM with Attention (Query Split):")
train_seq2seq(model_lstm_attn_query, train_s_data, dev_s_data, epochs=10, batch_size=64, lr=1e-3)
s_acc = evaluate(model_lstm_attn_query, test_s_data)
print(f"Query Split Test Accuracy: {s_acc*100:.2f}%")

Training LSTM with Attention (Question Split):


RuntimeError: Expected hidden[0] size (1, 64, 256), got [1, 64, 512]

### 2.4 Transformer with attention

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import pickle
import os
import sqlglot

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Directory paths
cache_dir = "./cache"
vocab_dir = "./vocab"

# Special tokens
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"

# Load vocab
with open(os.path.join(vocab_dir, "unified_vocab.pkl"), 'rb') as f:
    vocab = pickle.load(f)

idx_to_token = {idx: tok for tok, idx in vocab.items()}

# Load cached data
def load_cached_data(filename):
    with open(os.path.join(cache_dir, filename), 'rb') as f:
        return pickle.load(f)

# Load question split data
train_q_data = load_cached_data("train_question_tokenized.pkl")
dev_q_data = load_cached_data("dev_question_tokenized.pkl")
test_q_data = load_cached_data("test_question_tokenized.pkl")

# Load query split data
train_s_data = load_cached_data("train_query_tokenized.pkl")
dev_s_data = load_cached_data("dev_query_tokenized.pkl")
test_s_data = load_cached_data("test_query_tokenized.pkl")

vocab_size = len(vocab)

# use Positional Encoding sin 
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=512):
        super().__init__()
        self.encoding = torch.zeros(max_len, embed_size, device=device, dtype=torch.float32)
        pos = torch.arange(0, max_len, device=device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2, device=device).float() * (-torch.log(torch.tensor(10000.0, device=device)) / embed_size))
        self.encoding[:, 0::2] = torch.sin(pos * div_term)
        self.encoding[:, 1::2] = torch.cos(pos * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)]

# SQL normalization function, to prevent the impact of format
def normalize_sql(sql_str):
    return ' '.join(sql_str.lower().strip().rstrip(';').split())

# Transformer model with Mask and Positional Encoding
class TransformerSeq2Seq(nn.Module):
    def __init__(self, input_vocab_size, output_vocab_size, embed_size=128, num_heads=8,
                 num_encoder_layers=3, num_decoder_layers=3, ff_hidden_size=512,
                 dropout=0.1, max_seq_len=512):
        super().__init__()
        self.embed_size = embed_size
        
        padding_idx = vocab[PAD_TOKEN]
        assert padding_idx < input_vocab_size, f"PAD_TOKEN index ({padding_idx}) out of range (input_vocab_size {input_vocab_size})"
        assert padding_idx < output_vocab_size, f"PAD_TOKEN index ({padding_idx}) out of range (output_vocab_size {output_vocab_size})"

        self.src_embed = nn.Embedding(input_vocab_size, embed_size, padding_idx=padding_idx)
        self.tgt_embed = nn.Embedding(output_vocab_size, embed_size, padding_idx=padding_idx)
        self.pos_enc = PositionalEncoding(embed_size, max_seq_len)
        self.transformer = nn.Transformer(
            d_model=embed_size, nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=ff_hidden_size, dropout=dropout,
            batch_first=True
        )
        self.fc_out = nn.Linear(embed_size, output_vocab_size)

    def generate_mask(self, size):
        mask = torch.triu(torch.ones(size, size), diagonal=1).bool().to(device)
        return mask

    def forward(self, src, tgt):
        src_emb = self.pos_enc(self.src_embed(src))
        tgt_emb = self.pos_enc(self.tgt_embed(tgt))
        src_pad_mask = (src == vocab[PAD_TOKEN])
        tgt_pad_mask = (tgt == vocab[PAD_TOKEN])
        tgt_mask = self.generate_mask(tgt.size(1))

        out = self.transformer(
            src_emb, tgt_emb,
            src_key_padding_mask=src_pad_mask,
            tgt_key_padding_mask=tgt_pad_mask,
            tgt_mask=tgt_mask
        )
        return self.fc_out(out)

def batchify(data, batch_size):
    random.shuffle(data)
    batches = []
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        src_batch, tgt_input_batch, tgt_output_batch = [], [], []
        for src_seq, tgt_seq in batch:
            src_batch.append(src_seq)
            tgt_input_batch.append([vocab[SOS_TOKEN]] + tgt_seq)
            tgt_output_batch.append(tgt_seq + [vocab[EOS_TOKEN]])

        src_max_len = max(len(seq) for seq in src_batch)
        tgt_max_len = max(len(seq) for seq in tgt_input_batch)

        src_batch_padded = [seq + [vocab[PAD_TOKEN]] * (src_max_len - len(seq)) for seq in src_batch]
        tgt_input_batch_padded = [seq + [vocab[PAD_TOKEN]] * (tgt_max_len - len(seq)) for seq in tgt_input_batch]
        tgt_output_batch_padded = [seq + [vocab[PAD_TOKEN]] * (tgt_max_len - len(seq)) for seq in tgt_output_batch]

        batches.append((torch.tensor(src_batch_padded, device=device),
                        torch.tensor(tgt_input_batch_padded, device=device),
                        torch.tensor(tgt_output_batch_padded, device=device)))
    return batches

def train_transformer(model, train_data, dev_data, epochs=10, batch_size=64, lr=1e-3):
    criterion = nn.CrossEntropyLoss(ignore_index=vocab[PAD_TOKEN])
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        batches = batchify(train_data, batch_size)

        for src, tgt_in, tgt_out in batches:
            optimizer.zero_grad()
            output = model(src, tgt_in)
            loss = criterion(output.view(-1, output.size(-1)), tgt_out.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(batches)
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")

        if dev_data:
            dev_acc = evaluate(model, dev_data)
            print(f"Dev Accuracy: {dev_acc*100:.2f}%")

def infer(model, src, max_len=100):
    model.eval()
    src_tensor = torch.tensor([src], device=device)
    tgt_indices = [vocab[SOS_TOKEN]]

    with torch.no_grad():
        for _ in range(max_len):
            tgt_tensor = torch.tensor([tgt_indices], device=device)
            output = model(src_tensor, tgt_tensor)
            next_token_logits = output[0, -1, :]
            next_token_idx = next_token_logits.argmax().item()

            if next_token_idx in (vocab[EOS_TOKEN], vocab[PAD_TOKEN]):
                break
            tgt_indices.append(next_token_idx)

    return ' '.join([idx_to_token[idx] for idx in tgt_indices[1:]])

def evaluate(model, dataset):
    model.eval()
    correct = 0
    with torch.no_grad():
        for src, tgt in dataset:
            pred_sql = infer(model, src)
            gold_sql = ' '.join(idx_to_token[idx] for idx in tgt)
            pred_norm = normalize_sql(pred_sql)
            gold_norm = normalize_sql(gold_sql)
            try:
                pred_parsed = sqlglot.parse_one(pred_norm).sql()
                gold_parsed = sqlglot.parse_one(gold_norm).sql()
                if pred_parsed == gold_parsed:
                    correct += 1
            except sqlglot.errors.ParseError:
                continue  

    return correct / len(dataset)

# Instantiate model
model = TransformerSeq2Seq(
    input_vocab_size=vocab_size,
    output_vocab_size=vocab_size,
    embed_size=128,
    num_heads=4,
    num_encoder_layers=1,
    num_decoder_layers=1,
    ff_hidden_size=256,
    dropout=0.1,
    max_seq_len=512
).to(device)

# Question split training and evaluation
print("Training on Question Split:")
train_transformer(model, train_q_data, dev_q_data, epochs=10, batch_size=64, lr=1e-3)
question_acc = evaluate(model, test_q_data)
print(f"Question Split Test Accuracy: {question_acc*100:.2f}%")

# Query split training and evaluation (instantiate a new model for fairness)
model_query = TransformerSeq2Seq(
    input_vocab_size=vocab_size,
    output_vocab_size=vocab_size,
    embed_size=128,
    num_heads=4,
    num_encoder_layers=1,
    num_decoder_layers=1,
    ff_hidden_size=256,
    dropout=0.1,
    max_seq_len=512
).to(device)

print("\nTraining on Query Split:")
train_transformer(model_query, train_s_data, dev_s_data, epochs=10, batch_size=64, lr=1e-3)
query_acc = evaluate(model_query, test_s_data)
print(f"Query Split Test Accuracy: {query_acc*100:.2f}%")

Training on Question Split:


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch 1/10 - Loss: 2.9358


  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


## 3. LLM tasks

### 3.1 Build prompts

In [6]:
import json
import random
import pandas as pd
import re
import os

random.seed(42)

# load schema provided by professor
def load_schema(csv_file_path):
    schema_df = pd.read_csv(csv_file_path)
    schema_df.columns = [col.strip().lower() for col in schema_df.columns]

    schema_prompt = "Database Schema:\n"
    current_table = ""
    columns = []

    for _, row in schema_df.iterrows():
        table = str(row['table name']).strip()
        column = str(row['field name']).strip()

        if table in ("-", "nan") or column in ("-", "nan"):
            continue

        if table != current_table:
            if current_table:
                schema_prompt += f"- {current_table}({', '.join(columns)})\n"
            current_table = table
            columns = [column]
        else:
            columns.append(column)

    if current_table and columns:
        schema_prompt += f"- {current_table}({', '.join(columns)})\n"

    return schema_prompt

# replace 
def replace_sql_placeholders(sql, variables):
    for var_name, val in variables.items():
        sql = re.sub(f'"{var_name}"', f'"{val}"', sql)
    return sql

# build prompts
def build_prompt_samples(train_data, test_data, schema_prompt, shot_count):
    prompts = []
    
    if shot_count > 0:
        examples = random.sample(train_data, shot_count)
    else:
        examples = []

    for test_item in test_data:
        prompt_parts = [schema_prompt]

        # add examples for different size of shots
        for ex in examples:
            example_sql = replace_sql_placeholders(ex['sql'], ex['variables'])
            prompt_parts.append(f"Question: {ex['text']}\nSQL: {example_sql}")

        test_question = test_item['text']
        gold_sql = replace_sql_placeholders(test_item['sql'], test_item['variables'])
        prompt_parts.append(f"Question: {test_question}\nSQL:")

        final_prompt = "\n\n".join(prompt_parts)

        prompts.append({
            "prompt": final_prompt,
            "gold_sql": gold_sql
        })

    return prompts

# save
def save_jsonl(file_path, data):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

if __name__ == '__main__':
    schema_file = './sources/atis-schema.csv'

    splits = ['question', 'query']

    for split in splits:
        train_file = f'./datasets/{split}_train.jsonl'
        test_file = f'./datasets/{split}_test.jsonl'

        schema_prompt = load_schema(schema_file)

        train_data = load_jsonl(train_file)
        test_data = load_jsonl(test_file)

        # Zero-shot
        zero_shot_prompts = build_prompt_samples(train_data, test_data, schema_prompt, shot_count=0)
        save_jsonl(f'./llm_prompts/{split}_zero_shot_test.jsonl', zero_shot_prompts)
        print(f"{split.capitalize()} zero-shot prompts saved.")

        # Few-shot (5 examples)
        few_shot_prompts = build_prompt_samples(train_data, test_data, schema_prompt, shot_count=5)
        save_jsonl(f'./llm_prompts/{split}_few_shot_test.jsonl', few_shot_prompts)
        print(f"{split.capitalize()} few-shot prompts saved.")

        # Many-shot (40 examples)
        many_shot_prompts = build_prompt_samples(train_data, test_data, schema_prompt, shot_count=40)
        save_jsonl(f'./llm_prompts/{split}_many_shot_test.jsonl', many_shot_prompts)
        print(f"{split.capitalize()} many-shot prompts saved.")

Question zero-shot prompts saved.
Question few-shot prompts saved.
Question many-shot prompts saved.
Query zero-shot prompts saved.
Query few-shot prompts saved.
Query many-shot prompts saved.


### 3.2 Use openai API for generation tasks

In [8]:
import json
import re
import pandas as pd
from openai import OpenAI
from dotenv import load_dotenv
import os
from sqlglot import parse_one

load_dotenv()
client = OpenAI()

# add schema into prompts
def load_schema(csv_file_path: str) -> str:
    schema_df = pd.read_csv(csv_file_path)
    schema_df.columns = [col.strip().lower() for col in schema_df.columns]

    schema_str = "Your task is to generate an SQL query based on the following database schema:\n"
    current_table = ""
    columns = []

    for _, row in schema_df.iterrows():
        table = str(row['table name']).strip()
        column = str(row['field name']).strip()

        if table in ("-", "nan") or column in ("-", "nan"):
            continue

        if table != current_table:
            if current_table:
                schema_str += f"- {current_table}({', '.join(columns)})\n"
            current_table = table
            columns = [column]
        else:
            columns.append(column)

    if current_table and columns:
        schema_str += f"- {current_table}({', '.join(columns)})\n"

    return schema_str

def load_prompt_data(file_path: str):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            ex = json.loads(line.strip())
            data.append((ex['prompt'], ex['gold_sql']))
    return data

def query_openai(prompt):
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": prompt}]
    )
    return response.choices[0].message.content.strip()

# extract sql from raw response of LLM
def extract_sql(text):
    match = re.search(r'```sql\n(.*?)\n```', text, re.DOTALL | re.IGNORECASE)
    return match.group(1).strip() if match else text.strip()

# SQL nomalization
def normalize_sql(sql):
    return re.sub(r'\s+', ' ', sql.strip().lower())

# use sqlglot to evaluate 
def sql_semantic_similarity(sql1, sql2):
    try:
        parsed1 = parse_one(sql1, read='mysql')
        parsed2 = parse_one(sql2, read='mysql')
        return parsed1 == parsed2
    except Exception as e:
        print(f"Semantic parse error: {e}")
        return False

# evaluation
def evaluate_llm(test_data, result_file):
    correct_exact = 0
    correct_semantic = 0
    total = len(test_data)

    with open(result_file, 'w', encoding='utf-8') as f:
        for i, (prompt, gold_sql) in enumerate(test_data):
            try:
                raw_response = query_openai(prompt)
                pred_sql = extract_sql(raw_response)

                exact_match = normalize_sql(pred_sql) == normalize_sql(gold_sql)
                semantic_match = sql_semantic_similarity(pred_sql, gold_sql)

                correct_exact += int(exact_match)
                correct_semantic += int(semantic_match)

                result = (
                    f"\nExample [{i + 1}/{total}]\n"
                    f"Prompt:\n{prompt}\n\n"
                    f"Standard SQL:\n{gold_sql}\n"
                    f"Raw LLM Response:\n{raw_response}\n"
                    f"Extracted Predicted SQL:\n{pred_sql}\n"
                    f"Exact match: {exact_match}\n"
                    f"Semantic match: {semantic_match}\n"
                    + "-" * 50 + "\n"
                )

                print(result)
                f.write(result)

            except Exception as e:
                error_message = f"Error at example {i + 1}: {e}\n"
                print(error_message)
                f.write(error_message)

        exact_acc = correct_exact / total
        semantic_acc = correct_semantic / total

        summary = (
            f"\nFinal Exact Match Accuracy: {exact_acc:.3f}\n"
            f"Final Semantic Match Accuracy: {semantic_acc:.3f}\n"
        )
        print(summary)
        f.write(summary)

    return exact_acc, semantic_acc

if __name__ == "__main__":
    schema_info = load_schema('./sources/atis-schema.csv')
    prompt_dir = './llm_prompts/'

    results_dir = "./evaluation_results/"
    os.makedirs(results_dir, exist_ok=True)

    splits = ['question', 'query']
    shots = ['many','few', 'zero']

    for split in splits:
        for shot in shots:
            prompt_file = os.path.join(prompt_dir, f'{split}_{shot}_shot_test.jsonl')
            test_data = load_prompt_data(prompt_file)

            result_file = os.path.join(results_dir, f'{split}_{shot}_shot_results.txt')
            print(f"\nEvaluating {split.capitalize()} split - {shot.capitalize()}-shot...")

            evaluate_llm(test_data, result_file)


Evaluating Question split - Many-shot...

Example [1/279]
Prompt:
Database Schema:
- AIRCRAFT(AIRCRAFT_CODE, AIRCRAFT_DESCRIPTION, MANUFACTURER, BASIC_TYPE, ENGINES, PROPULSION, WIDE_BODY, WING_SPAN, LENGTH, WEIGHT, CAPACITY, PAY_LOAD, CRUISING_SPEED, RANGE_MILES, PRESSURIZED)
- AIRLINE(AIRLINE_CODE, AIRLINE_NAME, NOTE)
- AIRPORT(AIRPORT_CODE, AIRPORT_NAME, AIRPORT_LOCATION, STATE_CODE, COUNTRY_NAME, TIME_ZONE_CODE, MINIMUM_CONNECT_TIME)
- AIRPORT_SERVICE(CITY_CODE, AIRPORT_CODE, MILES_DISTANT, DIRECTION, MINUTES_DISTANT)
- CITY(CITY_CODE, CITY_NAME, STATE_CODE, COUNTRY_NAME, TIME_ZONE_CODE)
- CLASS_OF_SERVICE(BOOKING_CLASS, RANK, CLASS_DESCRIPTION)
- CODE_DESCRIPTION(CODE, DESCRIPTION)
- COMPARTMENT_CLASS(COMPARTMENT, CLASS_TYPE)
- DATE_DAY(MONTH_NUMBER, DAY_NUMBER, YEAR, DAY_NAME)
- DAYS(DAYS_CODE, DAY_NAME)
- DUAL_CARRIER(MAIN_AIRLINE, LOW_FLIGHT_NUMBER, HIGH_FLIGHT_NUMBER, DUAL_AIRLINE, SERVICE_NAME)
- EQUIPMENT_SEQUENCE(AIRCRAFT_CODE_SEQUENCE, AIRCRAFT_CODE)
- FARE(FARE_ID, FROM_

KeyboardInterrupt: 