## Importing Necessary Libraries and Setting up Paths

In [1]:
import json
import csv
import time
import sqlite3
import re
import random
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import sqlglot
from sqlglot import parse_one

# In Colab use:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check device availability
print("=" * 60)
print("ENVIRONMENT CHECK")
print("=" * 60)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# if torch.cuda.is_available():
#     DEVICE = torch.device("cuda")
#     print(f"CUDA device: {torch.cuda.get_device_name(0)}")
# elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
#     DEVICE = torch.device("mps")
#     print("Using Apple Silicon MPS")
# else:
#     DEVICE = torch.device("cpu")
#     print("Using CPU")

# print(f"Device selected: {DEVICE}")
# print("=" * 60)

# ============================================================
# File Paths
# ============================================================

# Data paths
TRAIN_JSONL = Path("train_text2sql.jsonl")
VAL_JSONL = Path("val_text2sql.jsonl")
TEST_JSONL = Path("test_hospital_1.jsonl")

# # Database path
# SQLITE_DB = Path("spider_data/database/hospital_1/hospital_1.sqlite")

# Database path for colab
SQLITE_DB = Path("hospital_1.sqlite")

# Save directory for model checkpoints
SAVE_DIR = Path("saved_model")
SAVE_DIR.mkdir(parents=True, exist_ok=True)

# Results
RESULTS_CSV = Path("results_hospital_1_bilstm.csv")

print("\nFILE PATHS:")
print("-" * 60)
print(f"Training data:    {TRAIN_JSONL}")
print(f"Validation data:  {VAL_JSONL}")
print(f"Test data:        {TEST_JSONL}")
print(f"SQLite DB:        {SQLITE_DB}")
print(f"Save directory:   {SAVE_DIR}")
print(f"Results CSV:      {RESULTS_CSV}")
print("-" * 60)

# Verify files exist
print("\nFILE VERIFICATION:")
print("-" * 60)
files_to_check = {
    "Training data": TRAIN_JSONL,
    "Validation data": VAL_JSONL,
    "Test data": TEST_JSONL,
    "SQLite DB": SQLITE_DB,
}

all_exist = True
for name, path in files_to_check.items():
    exists = path.exists()
    status = "✓ EXISTS" if exists else "✗ MISSING"
    print(f"{name:20s}: {status}")
    if not exists:
        all_exist = False

print("-" * 60)

if all_exist:
    print("\n✅ All required files found!")
else:
    print("\n⚠️  WARNING: Some files are missing. Please upload them before proceeding.")

ENVIRONMENT CHECK
PyTorch version: 2.8.0+cu126
CUDA available: True

FILE PATHS:
------------------------------------------------------------
Training data:    train_text2sql.jsonl
Validation data:  val_text2sql.jsonl
Test data:        test_hospital_1.jsonl
SQLite DB:        hospital_1.sqlite
Save directory:   saved_model
Results CSV:      results_hospital_1_bilstm.csv
------------------------------------------------------------

FILE VERIFICATION:
------------------------------------------------------------
Training data       : ✓ EXISTS
Validation data     : ✓ EXISTS
Test data           : ✓ EXISTS
SQLite DB           : ✓ EXISTS
------------------------------------------------------------

✅ All required files found!


## Setting up the Configurations

In [2]:
print("=" * 60)
print("CONFIGURATION PARAMETERS")
print("=" * 60)

# ============================================================
# Training Hyperparameters
# ============================================================

EPOCHS = 15
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
GRAD_CLIP = 0.5
SEED = 42

print("\nTraining Hyperparameters:")
print("-" * 60)
print(f"Epochs:          {EPOCHS}")
print(f"Batch size:      {BATCH_SIZE}")
print(f"Learning rate:   {LEARNING_RATE}")
print(f"Gradient clip:   {GRAD_CLIP}")
print(f"Random seed:     {SEED}")

# ============================================================
# Model Architecture Parameters
# ============================================================

EMB_DIM = 512       # Embedding dimension
HID_DIM = 512       # Hidden dimension
NUM_LAYERS = 2      # Number of LSTM layers in encoder
DROPOUT = 0.2       # Dropout rate

print("\nModel Architecture:")
print("-" * 60)
print(f"Embedding dim:   {EMB_DIM}")
print(f"Hidden dim:      {HID_DIM}")
print(f"Encoder layers:  {NUM_LAYERS}")
print(f"Dropout rate:    {DROPOUT}")

# ============================================================
# Generation Parameters
# ============================================================

MAX_DECODE_LEN = 128        # Maximum SQL tokens to generate
MAX_DECODE_PREVIEW = 64     # Preview length during training

print("\nGeneration Parameters:")
print("-" * 60)
print(f"Max decode length:     {MAX_DECODE_LEN}")
print(f"Preview decode length: {MAX_DECODE_PREVIEW}")

# ============================================================
# Data Processing
# ============================================================

SCHEMA_CONDITION = True     # Prepend schema to question

print("\nData Processing:")
print("-" * 60)
print(f"Schema conditioning: {SCHEMA_CONDITION}")

# ============================================================
# Set Random Seeds for Reproducibility
# ============================================================

random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    # MPS doesn't need special seeding
    pass

print("\n✅ Random seeds set for reproducibility")

CONFIGURATION PARAMETERS

Training Hyperparameters:
------------------------------------------------------------
Epochs:          15
Batch size:      32
Learning rate:   0.0001
Gradient clip:   0.5
Random seed:     42

Model Architecture:
------------------------------------------------------------
Embedding dim:   512
Hidden dim:      512
Encoder layers:  2
Dropout rate:    0.2

Generation Parameters:
------------------------------------------------------------
Max decode length:     128
Preview decode length: 64

Data Processing:
------------------------------------------------------------
Schema conditioning: True

✅ Random seeds set for reproducibility


## Tokenizer and Vocab Definition

In [3]:
print("=" * 60)
print("TOKENIZER AND VOCABULARY")
print("=" * 60)

class VocabTok:
    """
    Simple vocabulary and tokenizer for SQL generation.
    Uses basic regex-based tokenization to split identifiers, numbers, and operators.
    """

    # Special tokens
    PAD = "<pad>"
    UNK = "<unk>"
    BOS = "<bos>"
    EOS = "<eos>"

    def __init__(self):
        """Initialize vocabulary with special tokens."""
        self.stoi = {
            self.PAD: 0,
            self.UNK: 1,
            self.BOS: 2,
            self.EOS: 3
        }
        self.itos = [self.PAD, self.UNK, self.BOS, self.EOS]

    def add_token(self, tok):
        """Add a single token to vocabulary if not present."""
        if tok not in self.stoi:
            self.stoi[tok] = len(self.itos)
            self.itos.append(tok)

    def add_sentence(self, s):
        """Tokenize a sentence and add all tokens to vocabulary."""
        for t in self._basic_tokenize(s):
            self.add_token(t)

    def encode(self, s, add_bos=False, add_eos=False):
        """
        Convert a string to token IDs.

        Args:
            s: Input string
            add_bos: Whether to prepend BOS token
            add_eos: Whether to append EOS token

        Returns:
            List of token IDs
        """
        ids = []
        if add_bos:
            ids.append(self.stoi[self.BOS])

        ids += [self.stoi.get(t, self.stoi[self.UNK]) for t in self._basic_tokenize(s)]

        if add_eos:
            ids.append(self.stoi[self.EOS])

        return ids

    def decode(self, ids):
        """
        Convert token IDs back to string.

        Args:
            ids: List of token IDs

        Returns:
            Decoded string
        """
        return " ".join(self.itos[i] for i in ids if i < len(self.itos))

    def to_json(self):
        """Serialize vocabulary to JSON-compatible dict."""
        return {"itos": self.itos}

    @classmethod
    def from_json(cls, obj):
        """Load vocabulary from JSON object."""
        v = cls()
        v.itos = obj["itos"]
        v.stoi = {t: i for i, t in enumerate(v.itos)}
        return v

    @staticmethod
    def _basic_tokenize(s):
        """
        Basic regex tokenization that splits:
        - Identifiers: [A-Za-z_][A-Za-z_0-9]*
        - Numbers: [0-9]+
        - Multi-char operators: ==, !=, >=, <=
        - Single characters: everything else

        Preserves case for SQL keywords and identifiers.
        """
        return re.findall(r"[A-Za-z_][A-Za-z_0-9]*|[0-9]+|==|!=|>=|<=|[^\s]", s)

    def __len__(self):
        """Return vocabulary size."""
        return len(self.itos)


# ============================================================
# Test the Tokenizer
# ============================================================

print("\nTesting Tokenizer:")
print("-" * 60)

# Create a test tokenizer
test_tok = VocabTok()

# Add some test sentences
test_sentences = [
    "SELECT * FROM Department WHERE DepartmentID = 1 ;",
    "SELECT Name FROM Physician WHERE EmployeeID >= 100 ;"
]

for sent in test_sentences:
    test_tok.add_sentence(sent)
    print(f"Added: {sent}")

print(f"\nVocabulary size: {len(test_tok)}")
print(f"Sample tokens: {test_tok.itos[:20]}")

# Test encoding
test_text = "SELECT Name FROM Department ;"
encoded = test_tok.encode(test_text, add_bos=True, add_eos=True)
decoded = test_tok.decode(encoded)

print(f"\nEncoding test:")
print(f"  Original: {test_text}")
print(f"  Encoded:  {encoded}")
print(f"  Decoded:  {decoded}")

# Test special tokens
print(f"\nSpecial token IDs:")
print(f"  PAD: {test_tok.stoi[VocabTok.PAD]}")
print(f"  UNK: {test_tok.stoi[VocabTok.UNK]}")
print(f"  BOS: {test_tok.stoi[VocabTok.BOS]}")
print(f"  EOS: {test_tok.stoi[VocabTok.EOS]}")

print("\n✅ Tokenizer working correctly!")

TOKENIZER AND VOCABULARY

Testing Tokenizer:
------------------------------------------------------------
Added: SELECT * FROM Department WHERE DepartmentID = 1 ;
Added: SELECT Name FROM Physician WHERE EmployeeID >= 100 ;

Vocabulary size: 18
Sample tokens: ['<pad>', '<unk>', '<bos>', '<eos>', 'SELECT', '*', 'FROM', 'Department', 'WHERE', 'DepartmentID', '=', '1', ';', 'Name', 'Physician', 'EmployeeID', '>=', '100']

Encoding test:
  Original: SELECT Name FROM Department ;
  Encoded:  [2, 4, 13, 6, 7, 12, 3]
  Decoded:  <bos> SELECT Name FROM Department ; <eos>

Special token IDs:
  PAD: 0
  UNK: 1
  BOS: 2
  EOS: 3

✅ Tokenizer working correctly!


## Data Loading & Dataset Class

In [4]:
print("=" * 60)
print("DATA LOADING & DATASET CLASS")
print("=" * 60)

# ============================================================
# Utility Function to Load JSONL Files
# ============================================================

def load_jsonl(path):
    """
    Load JSONL file and return list of dictionaries.

    Args:
        path: Path to JSONL file

    Returns:
        List of dictionaries (one per line)
    """
    if not path or not path.exists():
        return []

    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))

    return rows


# ============================================================
# PyTorch Dataset Class
# ============================================================

class T2SQLDataset(Dataset):
    """
    PyTorch Dataset for Text-to-SQL training.

    Each example contains:
    - Source: Question (optionally with schema)
    - Target: Gold SQL query
    - Schema: Database schema text (for masking)
    """

    def __init__(self, rows, tok, schema_condition=True):
        """
        Args:
            rows: List of data dictionaries
            tok: VocabTok tokenizer instance
            schema_condition: Whether to prepend schema to question
        """
        self.rows = rows
        self.tok = tok
        self.schema_condition = schema_condition

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        """
        Get a single example.

        Returns:
            src: Tensor of source token IDs
            tgt: Tensor of target token IDs (with BOS/EOS)
            schema: Schema text string (for vocab masking)
        """
        r = self.rows[idx]

        # Build source text (question + optional schema)
        if self.schema_condition:
            src_text = f"Schema:\n{r['schema_serialized']}\nQuestion:\n{r['question']}"
        else:
            src_text = r['question']

        # Target SQL
        tgt_text = r['gold_query']

        # Encode to token IDs
        src = torch.tensor(self.tok.encode(src_text), dtype=torch.long)
        tgt = torch.tensor(self.tok.encode(tgt_text, add_bos=True, add_eos=True), dtype=torch.long)

        return src, tgt, r['schema_serialized']


# ============================================================
# Collate Function for Batching
# ============================================================

def collate_fn(batch):
    """
    Collate function to pad sequences in a batch.

    Args:
        batch: List of (src, tgt, schema) tuples

    Returns:
        srcs: Padded source tensor [B, max_src_len]
        tgts: Padded target tensor [B, max_tgt_len]
        schemas: List of schema strings
    """
    srcs, tgts, schemas = zip(*batch)

    # Pad sequences with PAD token (id=0)
    srcs = pad_sequence(srcs, batch_first=True, padding_value=0)
    tgts = pad_sequence(tgts, batch_first=True, padding_value=0)

    return srcs, tgts, list(schemas)


# ============================================================
# Build Vocabulary from Data
# ============================================================

def build_vocab(train_rows, val_rows, test_rows):
    """
    Build vocabulary from training, validation, and test data.

    Args:
        train_rows: Training examples
        val_rows: Validation examples
        test_rows: Test examples (to ensure schema tokens are in vocab)

    Returns:
        VocabTok instance with complete vocabulary
    """
    tok = VocabTok()

    # Add tokens from train + val (schema + question + gold SQL)
    for r in (train_rows + val_rows):
        tok.add_sentence(f"Schema:\n{r['schema_serialized']}\nQuestion:\n{r['question']}")
        tok.add_sentence(r["gold_query"])

    # Add tokens from test schema (both original and lowercase for robustness)
    for r in test_rows:
        schema = r["schema_serialized"]
        tok.add_sentence(schema)
        tok.add_sentence(schema.lower())

    return tok


# ============================================================
# Load Data and Build Vocabulary
# ============================================================

print("\nLoading datasets...")
print("-" * 60)

train_rows = load_jsonl(TRAIN_JSONL)
val_rows = load_jsonl(VAL_JSONL)
test_rows = load_jsonl(TEST_JSONL)

print(f"Training examples:   {len(train_rows)}")
print(f"Validation examples: {len(val_rows)}")
print(f"Test examples:       {len(test_rows)}")

print("\nBuilding vocabulary...")
print("-" * 60)

vocab = build_vocab(train_rows, val_rows, test_rows)

print(f"Vocabulary size: {len(vocab)}")
print(f"Sample vocab (first 30 tokens):")
print(f"  {vocab.itos[:30]}")

# ============================================================
# Create Datasets
# ============================================================

print("\nCreating PyTorch datasets...")
print("-" * 60)

train_dataset = T2SQLDataset(train_rows, vocab, schema_condition=SCHEMA_CONDITION)
val_dataset = T2SQLDataset(val_rows, vocab, schema_condition=SCHEMA_CONDITION)
test_dataset = T2SQLDataset(test_rows, vocab, schema_condition=SCHEMA_CONDITION)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size:   {len(val_dataset)}")
print(f"Test dataset size:  {len(test_dataset)}")

# ============================================================
# Create DataLoaders
# ============================================================

print("\nCreating DataLoaders...")
print("-" * 60)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,  # Batch size 1 for evaluation
    shuffle=False,
    collate_fn=collate_fn
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches:   {len(val_loader)}")
print(f"Test batches:  {len(test_loader)}")

# ============================================================
# Test Data Loading
# ============================================================

print("\nTesting data loading...")
print("-" * 60)

# Get one batch
src_batch, tgt_batch, schema_batch = next(iter(train_loader))

print(f"Source batch shape: {src_batch.shape}")
print(f"Target batch shape: {tgt_batch.shape}")
print(f"Number of schemas:  {len(schema_batch)}")

# Decode first example
print(f"\nFirst example in batch:")
print(f"  Source (decoded): {vocab.decode(src_batch[0].tolist())[:200]}...")
print(f"  Target (decoded): {vocab.decode(tgt_batch[0].tolist())}")

print("\n✅ Data loading working correctly!")

DATA LOADING & DATASET CLASS

Loading datasets...
------------------------------------------------------------
Training examples:   8559
Validation examples: 1034
Test examples:       100

Building vocabulary...
------------------------------------------------------------
Vocabulary size: 7468
Sample vocab (first 30 tokens):
  ['<pad>', '<unk>', '<bos>', '<eos>', 'Schema', ':', 'Database', 'department_management', 'Tables', '-', 'department', '(', 'Department_ID', '*', ',', 'Name', 'Creation', 'Ranking', 'Budget_in_Billions', 'Num_Employees', ')', 'head', 'head_ID', 'name', 'born_state', 'age', 'management', 'department_ID', 'temporary_acting', 'FK']

Creating PyTorch datasets...
------------------------------------------------------------
Train dataset size: 8559
Val dataset size:   1034
Test dataset size:  100

Creating DataLoaders...
------------------------------------------------------------
Train batches: 268
Val batches:   33
Test batches:  100

Testing data loading...
---------

In [5]:
# Diagnostic check
print("\n" + "=" * 60)
print("VALIDATION CHECK")
print("=" * 60)

# Quick sanity check on data
sample_batch = next(iter(train_loader))
src, tgt, schemas = sample_batch
print(f"Source shape: {src.shape}")
print(f"Target shape: {tgt.shape}")
print(f"Source sample (first 50 tokens): {src[0][:50].tolist()}")
print(f"Target sample (first 20 tokens): {tgt[0][:20].tolist()}")
print(f"Max source value: {src.max().item()}")
print(f"Max target value: {tgt.max().item()}")
print(f"Vocab size: {len(vocab)}")
print("=" * 60)


VALIDATION CHECK
Source shape: torch.Size([32, 517])
Target shape: torch.Size([32, 69])
Source sample (first 50 tokens): [4, 5, 6, 5, 2844, 8, 5, 9, 2845, 11, 2846, 13, 14, 2847, 14, 1095, 14, 2318, 14, 158, 20, 9, 2848, 11, 2847, 13, 14, 158, 14, 2849, 20, 9, 2850, 11, 158, 13, 14, 2847, 14, 2851, 20, 9, 2852, 11, 2847, 13, 14, 2853, 14, 158]
Target sample (first 20 tokens): [2, 44, 2862, 46, 2861, 47, 191, 112, 2866, 80, 191, 1255, 2867, 48, 3, 0, 0, 0, 0, 0]
Max source value: 6742
Max target value: 6742
Vocab size: 7468


## Schema-Aware Vocabulary Masking

In [6]:
print("=" * 60)
print("SCHEMA-AWARE VOCABULARY MASKING")
print("=" * 60)

# ============================================================
# SQL Keywords and Tokens
# ============================================================

# Standard SQL keywords (lowercase)
SQL_KEYWORDS = {
    "select", "from", "where", "group", "by", "order", "limit", "having", "distinct", "as",
    "join", "inner", "left", "right", "on", "and", "or", "not", "in", "between", "like",
    "count", "sum", "avg", "min", "max", "asc", "desc", "union", "intersect", "except",
    "case", "when", "then", "end", "exists", "all", "any", "is", "null"
}

# SQL operators and punctuation
SQL_TOKENS = {
    "*", "(", ")", ",", ".", "=", "<", ">", "<=", ">=", "!=", ";", "+", "-", "/", "%", "||"
}

print("\nSQL Keywords defined:")
print(f"  Total keywords: {len(SQL_KEYWORDS)}")
print(f"  Sample: {list(SQL_KEYWORDS)[:10]}")

print("\nSQL Tokens defined:")
print(f"  Total tokens: {len(SQL_TOKENS)}")
print(f"  All tokens: {SQL_TOKENS}")


# ============================================================
# Extract Identifiers from Schema
# ============================================================

def extract_identifiers_from_schema(schema_text):
    """
    Extract table and column names from schema text.

    Args:
        schema_text: Schema string (e.g., from schema_serialized field)

    Returns:
        Set of identifiers (both original case and lowercase)
    """
    names = set()

    for line in schema_text.splitlines():
        # Find all identifiers (alphanumeric + underscore)
        for tok in re.findall(r"[A-Za-z_][A-Za-z_0-9]*", line):
            # Skip common schema metadata words
            if tok.lower() in {"database", "tables", "foreign", "keys"}:
                continue

            # Add both original case and lowercase
            names.add(tok)
            names.add(tok.lower())

    return names


# ============================================================
# Build Vocabulary Mask for Example
# ============================================================

def build_vocab_mask_for_example(tok, schema_text):
    """
    Create a boolean mask indicating which vocabulary tokens are allowed
    for a given schema.

    Allowed tokens:
    - SQL keywords (both cases)
    - SQL operators/punctuation
    - Identifiers from schema (original + lowercase)
    - Digits 0-9
    - Special tokens (BOS, EOS, UNK)
    - PAD is NOT allowed (should never be predicted)

    Args:
        tok: VocabTok instance
        schema_text: Schema string

    Returns:
        Boolean tensor [V] where True = allowed
    """
    allowed = set(SQL_TOKENS)

    # Add SQL keywords in both original and uppercase
    for word in SQL_KEYWORDS:
        allowed.add(word)
        allowed.add(word.upper())

    # Add identifiers from schema
    allowed.update(extract_identifiers_from_schema(schema_text))

    # Add digits
    allowed.update(list("0123456789"))

    # Add special tokens (except PAD)
    allowed.update([VocabTok.BOS, VocabTok.EOS, VocabTok.UNK])

    # Create mask
    mask = torch.zeros(len(tok.itos), dtype=torch.bool)

    for i, token in enumerate(tok.itos):
        if token == VocabTok.PAD:
            mask[i] = False  # Never predict PAD
        elif token in allowed or token.lower() in allowed:
            mask[i] = True
        else:
            mask[i] = False

    return mask


# ============================================================
# Expand Mask with Gold Tokens (for Teacher Forcing)
# ============================================================

def expand_mask_with_gold(vocab_mask_batch, target_batch, pad_idx=0):
    """
    Expand vocabulary mask to include all gold target tokens.
    This ensures teacher forcing doesn't fail during training.

    Args:
        vocab_mask_batch: [B, V] boolean mask
        target_batch: [B, T] target token IDs
        pad_idx: Padding token ID

    Returns:
        Updated vocab_mask_batch with gold tokens allowed
    """
    B, T = target_batch.size()

    for b in range(B):
        # Get unique non-padding gold tokens for this example
        gold_ids = target_batch[b][target_batch[b] != pad_idx].unique()

        # Allow all gold tokens
        vocab_mask_batch[b, gold_ids] = True

    return vocab_mask_batch


# ============================================================
# Test Schema Masking
# ============================================================

print("\n" + "-" * 60)
print("Testing Schema Masking:")
print("-" * 60)

# Get first test example
test_example = test_rows[0]
test_schema = test_example['schema_serialized']
test_question = test_example['question']
test_gold = test_example['gold_query']

print(f"\nTest Example:")
print(f"  Schema:\n{test_schema[:200]}...")
print(f"  Question: {test_question}")
print(f"  Gold SQL: {test_gold}")

# Extract identifiers
identifiers = extract_identifiers_from_schema(test_schema)
print(f"\nExtracted Identifiers ({len(identifiers)} total):")
print(f"  {sorted(list(identifiers))[:20]}...")

# Build mask
mask = build_vocab_mask_for_example(vocab, test_schema)
allowed_count = mask.sum().item()
total_count = len(mask)

print(f"\nVocabulary Mask:")
print(f"  Total vocab size: {total_count}")
print(f"  Allowed tokens:   {allowed_count}")
print(f"  Blocked tokens:   {total_count - allowed_count}")
print(f"  Allowed ratio:    {allowed_count / total_count:.2%}")

# Show some allowed tokens
allowed_tokens = [vocab.itos[i] for i in range(len(vocab.itos)) if mask[i]]
print(f"\nSample allowed tokens (first 30):")
print(f"  {allowed_tokens[:30]}")

# Show some blocked tokens
blocked_tokens = [vocab.itos[i] for i in range(len(vocab.itos)) if not mask[i]]
print(f"\nSample blocked tokens (first 20):")
print(f"  {blocked_tokens[:20]}")

print("\n✅ Schema masking working correctly!")

SCHEMA-AWARE VOCABULARY MASKING

SQL Keywords defined:
  Total keywords: 40
  Sample: ['having', 'right', 'exists', 'order', 'like', 'or', 'intersect', 'between', 'case', 'in']

SQL Tokens defined:
  Total tokens: 17
  All tokens: {'=', '.', '%', ';', '<=', '<', '>=', '!=', ',', '(', '>', '-', '/', '||', ')', '*', '+'}

------------------------------------------------------------
Testing Schema Masking:
------------------------------------------------------------

Test Example:
  Schema:
Database: hospital_1
Tables:
- Physician(EmployeeID*, Name, Position, SSN)
- Department(DepartmentID*, Name, Head)
- Affiliated_With(Physician*, Department, PrimaryAffiliation)
- Procedures(Code*, Nam...
  Question: Which department has the largest number of employees?
  Gold SQL: SELECT name FROM department GROUP BY departmentID ORDER BY count(departmentID) DESC LIMIT 1;

Extracted Identifiers (109 total):
  ['Address', 'Affiliated_With', 'Appointment', 'AppointmentID', 'AssistingNurse', 'Block', 'Blo

## Model Architecture

### Attention Mechanism

In [7]:
print("\n6.1: Luong Attention Mechanism")
print("-" * 60)

class LuongAttention(nn.Module):
    """
    Luong (multiplicative) attention mechanism.
    Computes attention weights over encoder outputs using decoder hidden state.
    """

    def __init__(self, hid_dim):
        """
        Args:
            hid_dim: Hidden dimension size
        """
        super().__init__()
        self.W = nn.Linear(hid_dim, hid_dim, bias=False)

    def forward(self, dec_h, enc_out, mask=None):
        """
        Compute attention context vector.

        Args:
            dec_h: Decoder hidden state [B, H]
            enc_out: Encoder outputs [B, T, H]
            mask: Attention mask [B, T] (True = attend, False = ignore)

        Returns:
            ctx: Context vector [B, H]
            attn: Attention weights [B, T]
        """
        # Compute attention scores: score[i] = enc_out[i] · W(dec_h)
        # [B, T, H] @ [B, H, 1] -> [B, T, 1] -> [B, T]
        score = torch.bmm(enc_out, self.W(dec_h).unsqueeze(2)).squeeze(2)

        # Apply mask (set padded positions to -inf)
        if mask is not None:
            score = score.masked_fill(~mask, -100)  # CHANGED: -1e9 -> -1e4

        # Softmax to get attention weights
        attn = torch.softmax(score, dim=1)  # [B, T]

        # Compute context vector as weighted sum of encoder outputs
        # [B, 1, T] @ [B, T, H] -> [B, 1, H] -> [B, H]
        ctx = torch.bmm(attn.unsqueeze(1), enc_out).squeeze(1)

        return ctx, attn

print("✅ LuongAttention defined")


6.1: Luong Attention Mechanism
------------------------------------------------------------
✅ LuongAttention defined


### Encoder

In [8]:
print("\n6.2: Bidirectional LSTM Encoder")
print("-" * 60)

class Encoder(nn.Module):
    """
    Bidirectional LSTM encoder.
    Processes input sequence and returns hidden states.
    """

    def __init__(self, vocab_size, emb_dim=512, hid_dim=512,
                 num_layers=2, pad_idx=0, dropout=0.2):
        """
        Args:
            vocab_size: Size of vocabulary
            emb_dim: Embedding dimension
            hid_dim: Hidden dimension (will be split for bidirectional)
            num_layers: Number of LSTM layers
            pad_idx: Padding token index
            dropout: Dropout rate
        """
        super().__init__()

        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)

        # Bidirectional LSTM: each direction has hid_dim//2 units
        self.rnn = nn.LSTM(
            emb_dim,
            hid_dim // 2,  # Each direction gets half the hidden dim
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
            bidirectional=True
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        """
        Encode source sequence.

        Args:
            src: Source token IDs [B, T]

        Returns:
            out: Encoder outputs [B, T, H]
            (h, c): Final hidden and cell states
                    h: [2*num_layers, B, H//2]
                    c: [2*num_layers, B, H//2]
        """
        # Embed and apply dropout
        x = self.dropout(self.emb(src))  # [B, T, E]

        # Pass through BiLSTM
        out, (h, c) = self.rnn(x)

        return out, (h, c)

print("✅ Encoder defined")
print(f"   - Input: [B, T] token IDs")
print(f"   - Output: [B, T, {HID_DIM}] encoder outputs")
print(f"   - Hidden: [{2*NUM_LAYERS}, B, {HID_DIM//2}] per direction")


6.2: Bidirectional LSTM Encoder
------------------------------------------------------------
✅ Encoder defined
   - Input: [B, T] token IDs
   - Output: [B, T, 512] encoder outputs
   - Hidden: [4, B, 256] per direction


### Decoder

In [9]:
print("\n6.3: LSTM Decoder with Attention")
print("-" * 60)

class Decoder(nn.Module):
    """
    LSTM decoder with Luong attention.
    Generates output sequence one token at a time.
    """

    def __init__(self, vocab_size, emb_dim=512, hid_dim=512,
                 pad_idx=0, dropout=0.2):
        """
        Args:
            vocab_size: Size of vocabulary
            emb_dim: Embedding dimension
            hid_dim: Hidden dimension
            pad_idx: Padding token index
            dropout: Dropout rate
        """
        super().__init__()

        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)

        # LSTM input: concatenation of embedding + context vector
        self.rnn = nn.LSTM(
            emb_dim + hid_dim,  # Input: [embedding, context]
            hid_dim,
            num_layers=1,  # Single layer decoder
            batch_first=True
        )

        self.attn = LuongAttention(hid_dim)

        # Output projection: [decoder_out, context] -> vocab logits
        self.out = nn.Linear(hid_dim + hid_dim, vocab_size)

        self.dropout = nn.Dropout(dropout)

    def forward(self, trg, enc_out, hidden, enc_mask=None, vocab_mask=None):
        """
        Teacher-forced forward pass (training).

        Args:
            trg: Target sequence [B, T] (includes BOS)
            enc_out: Encoder outputs [B, T_src, H]
            hidden: Initial decoder hidden state (h, c)
            enc_mask: Encoder padding mask [B, T_src]
            vocab_mask: Allowed vocabulary mask [B, V]

        Returns:
            logits: Output logits [B, T-1, V] (predicts trg[:,1:])
        """
        B, T = trg.size()

        # Embed target tokens
        x = self.dropout(self.emb(trg))  # [B, T, E]

        h, c = hidden
        logits = []

        # Decode one step at a time (teacher forcing)
        for t in range(T - 1):  # Predict next token, so stop at T-1
            # Compute attention using current hidden state
            dec_h = h[-1]  # [B, H]
            ctx, _ = self.attn(dec_h, enc_out, mask=enc_mask)  # [B, H]

            # LSTM input: [current_embedding, context_vector]
            rnn_in = torch.cat([x[:, t:t+1, :], ctx.unsqueeze(1)], dim=-1)  # [B, 1, E+H]

            # LSTM step
            o, (h, c) = self.rnn(rnn_in, (h, c))  # o: [B, 1, H]

            # Compute output logits
            logit = self.out(torch.cat([o.squeeze(1), ctx], dim=-1))  # [B, V]

            # Apply vocabulary mask if provided
            if vocab_mask is not None:
                logit = logit.masked_fill(~vocab_mask, -100)  # CHANGED: -1e9 -> -1e4

            logits.append(logit.unsqueeze(1))

        return torch.cat(logits, dim=1)  # [B, T-1, V]

    def step(self, y_prev, enc_out, hidden, enc_mask=None, vocab_mask=None):
        """
        Single decoding step (inference).

        Args:
            y_prev: Previous token [B]
            enc_out: Encoder outputs [B, T_src, H]
            hidden: Previous hidden state (h, c)
            enc_mask: Encoder padding mask [B, T_src]
            vocab_mask: Allowed vocabulary mask [B, V]

        Returns:
            logits: Output logits [B, V]
            next_id: Predicted token [B]
            hidden: Updated hidden state (h, c)
        """
        # Embed previous token
        emb = self.dropout(self.emb(y_prev).unsqueeze(1))  # [B, 1, E]

        # Compute attention
        dec_h = hidden[0][-1]  # [B, H]
        ctx, _ = self.attn(dec_h, enc_out, mask=enc_mask)  # [B, H]

        # LSTM input
        rnn_in = torch.cat([emb, ctx.unsqueeze(1)], dim=-1)  # [B, 1, E+H]

        # LSTM step
        o, hidden = self.rnn(rnn_in, hidden)  # [B, 1, H]

        # Output logits
        logits = self.out(torch.cat([o.squeeze(1), ctx], dim=-1))  # [B, V]

        # Apply vocabulary mask
        if vocab_mask is not None:
            logits = logits.masked_fill(~vocab_mask, -100)  # CHANGED: -1e9 -> -1e4

        # Greedy selection
        next_id = torch.argmax(logits, dim=-1)

        return logits, next_id, hidden

print("✅ Decoder defined")
print(f"   - Forward: Teacher forcing for training")
print(f"   - Step: Greedy decoding for inference")


6.3: LSTM Decoder with Attention
------------------------------------------------------------
✅ Decoder defined
   - Forward: Teacher forcing for training
   - Step: Greedy decoding for inference


### Seq2Seq Model

In [10]:
print("\n6.4: BiLSTM Seq2Seq Model")
print("-" * 60)

class BiLSTMSeq2SQL(nn.Module):
    """
    Complete Seq2Seq model for text-to-SQL generation.
    Combines encoder, decoder, and handles state initialization.
    """

    def __init__(self, vocab_size, pad_idx=0, emb_dim=512,
                 hid_dim=512, num_layers=2, dropout=0.2):
        """
        Args:
            vocab_size: Size of vocabulary
            pad_idx: Padding token index
            emb_dim: Embedding dimension
            hid_dim: Hidden dimension
            num_layers: Number of encoder LSTM layers
            dropout: Dropout rate
        """
        super().__init__()

        self.encoder = Encoder(
            vocab_size, emb_dim, hid_dim,
            num_layers=num_layers,
            pad_idx=pad_idx,
            dropout=dropout
        )

        self.decoder = Decoder(
            vocab_size, emb_dim, hid_dim,
            pad_idx=pad_idx,
            dropout=dropout
        )

        self.pad_idx = pad_idx
        self.hid_dim = hid_dim

        # Projection layers to initialize decoder state from encoder
        self.dec_init_h = nn.Linear(hid_dim, hid_dim)
        self.dec_init_c = nn.Linear(hid_dim, hid_dim)

    def _init_decoder_state(self, enc_h, enc_c):
        """
        Initialize decoder hidden state from encoder final state.
        Combines forward and backward directions.

        Args:
            enc_h: Encoder hidden state [2*num_layers, B, H//2]
            enc_c: Encoder cell state [2*num_layers, B, H//2]

        Returns:
            (h0, c0): Decoder initial state, each [1, B, H]
        """
        # Take last layer, concat forward and backward
        fwd_h = enc_h[-2]  # [B, H//2]
        bwd_h = enc_h[-1]  # [B, H//2]
        fwd_c = enc_c[-2]
        bwd_c = enc_c[-1]

        # Concatenate and project
        h0 = torch.tanh(self.dec_init_h(torch.cat([fwd_h, bwd_h], dim=-1)))  # [B, H]
        c0 = torch.tanh(self.dec_init_c(torch.cat([fwd_c, bwd_c], dim=-1)))  # [B, H]

        return h0.unsqueeze(0), c0.unsqueeze(0)  # [1, B, H]

    def forward(self, src, trg, vocab_mask=None):
        """
        Forward pass (training with teacher forcing).

        Args:
            src: Source sequence [B, T_src]
            trg: Target sequence [B, T_trg] (includes BOS)
            vocab_mask: Vocabulary mask [B, V] or [V]

        Returns:
            logits: Output logits [B, T_trg-1, V]
        """
        # Encode
        enc_out, (h, c) = self.encoder(src)

        # Create encoder padding mask for attention
        enc_mask = (src != self.pad_idx)

        # Initialize decoder state
        h0, c0 = self._init_decoder_state(h, c)

        # Expand vocab mask to batch size if needed
        if vocab_mask is not None:
            if vocab_mask.dim() == 1:
                vocab_mask = vocab_mask.unsqueeze(0).expand(src.size(0), -1)
            vocab_mask = vocab_mask.to(src.device)

        # Decode
        logits = self.decoder(trg, enc_out, (h0, c0),
                             enc_mask=enc_mask, vocab_mask=vocab_mask)

        return logits

    @torch.no_grad()
    def greedy_decode(self, src, bos_id, eos_id, max_len=128, vocab_mask=None):
        """
        Greedy decoding for inference.

        Args:
            src: Source sequence [B, T_src]
            bos_id: BOS token ID
            eos_id: EOS token ID
            max_len: Maximum generation length
            vocab_mask: Vocabulary mask [B, V] or [V]

        Returns:
            Generated token IDs [B, T_gen]
        """
        self.eval()

        # Encode
        enc_out, (h, c) = self.encoder(src)
        enc_mask = (src != self.pad_idx)

        # Initialize decoder state
        h0, c0 = self._init_decoder_state(h, c)

        # Expand vocab mask
        if vocab_mask is not None:
            if vocab_mask.dim() == 1:
                vocab_mask = vocab_mask.unsqueeze(0).expand(src.size(0), -1)
            vocab_mask = vocab_mask.to(src.device)

        # Start with BOS token
        B = src.size(0)
        y = torch.full((B,), bos_id, dtype=torch.long, device=src.device)

        out_ids = []
        hidden = (h0, c0)

        # Generate tokens one by one
        for _ in range(max_len):
            _, y, hidden = self.decoder.step(
                y, enc_out, hidden,
                enc_mask=enc_mask,
                vocab_mask=vocab_mask
            )
            out_ids.append(y.clone())

            # Stop if all sequences generated EOS
            if torch.all(y == eos_id):
                break

        return torch.stack(out_ids, dim=1) if out_ids else \
               torch.zeros(B, 0, dtype=torch.long, device=src.device)

print("✅ BiLSTMSeq2SQL defined")
print(f"   - Vocab size will be: {len(vocab)}")
print(f"   - Embedding dim: {EMB_DIM}")
print(f"   - Hidden dim: {HID_DIM}")
print(f"   - Encoder layers: {NUM_LAYERS}")


# ============================================================
# Instantiate Model
# ============================================================

print("\n" + "-" * 60)
print("Instantiating Model...")
print("-" * 60)

model = BiLSTMSeq2SQL(
    vocab_size=len(vocab),
    pad_idx=0,
    emb_dim=EMB_DIM,
    hid_dim=HID_DIM,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT
)

model = model.to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Statistics:")
print(f"  Total parameters:      {total_params:,}")
print(f"  Trainable parameters:  {trainable_params:,}")
print(f"  Model size (approx):   {total_params * 4 / 1024 / 1024:.2f} MB")

print("\n✅ Model architecture complete!")


6.4: BiLSTM Seq2Seq Model
------------------------------------------------------------
✅ BiLSTMSeq2SQL defined
   - Vocab size will be: 7468
   - Embedding dim: 512
   - Hidden dim: 512
   - Encoder layers: 2

------------------------------------------------------------
Instantiating Model...
------------------------------------------------------------

Model Statistics:
  Total parameters:      22,393,132
  Trainable parameters:  22,393,132
  Model size (approx):   85.42 MB

✅ Model architecture complete!


### Pretraining Diagnostic Check

In [11]:
print("\n" + "=" * 60)
print("PRE-TRAINING DIAGNOSTIC")
print("=" * 60)

# Check if validation loader is working
print(f"Validation loader batches: {len(val_loader)}")
print(f"Validation dataset size: {len(val_dataset)}")

# Try one validation batch
val_src, val_tgt, val_schemas = next(iter(val_loader))
print(f"Val source shape: {val_src.shape}")
print(f"Val target shape: {val_tgt.shape}")
print(f"Val source max: {val_src.max().item()}")
print(f"Val target max: {val_tgt.max().item()}")

# Check for NaN or extreme values
print(f"Val source has NaN: {torch.isnan(val_src).any()}")
print(f"Val target has NaN: {torch.isnan(val_tgt).any()}")

print("=" * 60)


PRE-TRAINING DIAGNOSTIC
Validation loader batches: 33
Validation dataset size: 1034
Val source shape: torch.Size([32, 113])
Val target shape: torch.Size([32, 47])
Val source max: 6762
Val target max: 6764
Val source has NaN: False
Val target has NaN: False


## Training the Model

In [12]:
print("=" * 60)
print("TRAINING THE MODEL")
print("=" * 60)

# ============================================================
# Setup Training Components
# ============================================================

print("\nSetting up training components...")
print("-" * 60)

# Loss function (ignore padding)
criterion = nn.CrossEntropyLoss(ignore_index=0)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler (cosine annealing)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=EPOCHS
)

# Special token IDs
BOS_ID = vocab.stoi[VocabTok.BOS]
EOS_ID = vocab.stoi[VocabTok.EOS]

print(f"✅ Loss function: CrossEntropyLoss (ignore_index=0)")
print(f"✅ Optimizer: AdamW (lr={LEARNING_RATE})")
print(f"✅ Scheduler: CosineAnnealingLR")
print(f"✅ BOS token ID: {BOS_ID}")
print(f"✅ EOS token ID: {EOS_ID}")


# ============================================================
# Training Function
# ============================================================

def run_epoch(data_loader, train_mode=True):
    """
    Run one epoch of training or validation.

    Args:
        data_loader: DataLoader to iterate over
        train_mode: If True, update weights. If False, just evaluate.

    Returns:
        Average loss for the epoch
    """
    model.train(train_mode)

    total_loss = 0.0
    num_batches = 0

    for batch_idx, (src, tgt, schemas) in enumerate(data_loader):
        # Move to device
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        # Build vocabulary mask for each example in batch
        vocab_masks = []
        for schema in schemas:
            mask = build_vocab_mask_for_example(vocab, schema)
            vocab_masks.append(mask)

        vocab_mask_batch = torch.stack(vocab_masks, dim=0).to(DEVICE)  # [B, V]

        # During training, expand mask to include gold tokens (for teacher forcing)
        if train_mode:
            vocab_mask_batch = expand_mask_with_gold(vocab_mask_batch, tgt, pad_idx=0)

        # Forward pass
        # Model returns logits aligned with tgt[:,1:] (predict next token)
        logits = model(src, tgt, vocab_mask=vocab_mask_batch)  # [B, T-1, V]

        # Compute loss
        # Target: tgt[:,1:] (everything after BOS)
        gold = tgt[:, 1:].contiguous().view(-1)  # [B*(T-1)]
        loss = criterion(logits.view(-1, len(vocab)), gold)

        # Check for NaN loss
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"WARNING: NaN/Inf loss detected at batch {batch_idx}")
            continue

        # Backward pass and optimization (only in train mode)
        if train_mode:
            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)

            optimizer.step()

        total_loss += loss.item()
        num_batches += 1

        # Print progress for very first batch
        if batch_idx == 0 and num_batches == 1:
            print(f"  First batch loss: {loss.item():.4f}")

    avg_loss = total_loss / max(1, num_batches)

    # Sanity check
    if avg_loss > 1000:
        print(f"  WARNING: Unusually high average loss: {avg_loss:.2f}")

    return avg_loss

# ============================================================
# Training Loop
# ============================================================

print("\n" + "=" * 60)
print("STARTING TRAINING")
print("=" * 60)

best_val_loss = float('inf')
training_history = {
    'train_loss': [],
    'val_loss': [],
    'learning_rates': []
}

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()

    # Train
    train_loss = run_epoch(train_loader, train_mode=True)

    # Validate
    val_loss = run_epoch(val_loader, train_mode=False)

    # Step scheduler
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]

    # Record history
    training_history['train_loss'].append(train_loss)
    training_history['val_loss'].append(val_loss)
    training_history['learning_rates'].append(current_lr)

    epoch_time = (time.time() - epoch_start_time) / 60

    # ============================================================
    # Preview Generation
    # ============================================================

    with torch.no_grad():
        # Get one batch from validation set for preview
        preview_src, preview_tgt, preview_schemas = next(iter(val_loader))
        preview_src = preview_src[:1].to(DEVICE)  # Take first example
        preview_schema = preview_schemas[0]

        # Build vocab mask
        preview_mask = build_vocab_mask_for_example(vocab, preview_schema)
        preview_mask = preview_mask.unsqueeze(0).to(DEVICE)

        # Generate
        generated_ids = model.greedy_decode(
            preview_src,
            BOS_ID,
            EOS_ID,
            max_len=MAX_DECODE_PREVIEW,
            vocab_mask=preview_mask
        )

        # Decode
        if generated_ids.numel() > 0:
            pred_sql = vocab.decode(generated_ids[0].tolist())
        else:
            pred_sql = "<empty>"

    # ============================================================
    # Logging
    # ============================================================

    print(f"\n[Epoch {epoch:02d}/{EPOCHS}]")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")
    print(f"  LR:         {current_lr:.2e}")
    print(f"  Time:       {epoch_time:.2f} min")
    print(f"  Preview SQL: {pred_sql}")

    # ============================================================
    # Save Best Model
    # ============================================================

    if val_loss < best_val_loss:
        best_val_loss = val_loss

        # Save model checkpoint
        checkpoint = {
            'model': model.state_dict(),
            'vocab': vocab.to_json(),
            'epoch': epoch,
            'val_loss': val_loss,
            'train_loss': train_loss,
            'config': {
                'vocab_size': len(vocab),
                'emb_dim': EMB_DIM,
                'hid_dim': HID_DIM,
                'num_layers': NUM_LAYERS,
                'dropout': DROPOUT
            }
        }

        torch.save(checkpoint, SAVE_DIR / "bilstm_model.pt")

        # Also save vocab separately for easy loading
        with open(SAVE_DIR / "bilstm_vocab.json", "w", encoding="utf-8") as f:
            json.dump(vocab.to_json(), f, ensure_ascii=False, indent=2)

        print(f"  ✅ Saved checkpoint (best val loss: {best_val_loss:.4f})")


# ============================================================
# Training Complete
# ============================================================

print("\n" + "=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)
print(f"\nBest validation loss: {best_val_loss:.4f}")
print(f"Model saved to: {SAVE_DIR / 'bilstm_model.pt'}")
print(f"Vocab saved to: {SAVE_DIR / 'bilstm_vocab.json'}")

# Save training history
history_path = SAVE_DIR / "training_history.json"
with open(history_path, "w") as f:
    json.dump(training_history, f, indent=2)

print(f"Training history saved to: {history_path}")

TRAINING THE MODEL

Setting up training components...
------------------------------------------------------------
✅ Loss function: CrossEntropyLoss (ignore_index=0)
✅ Optimizer: AdamW (lr=0.0001)
✅ Scheduler: CosineAnnealingLR
✅ BOS token ID: 2
✅ EOS token ID: 3

STARTING TRAINING
  First batch loss: 5.0786
  First batch loss: 11.2567

[Epoch 01/15]
  Train Loss: 2.2665
  Val Loss:   19.2808
  LR:         9.89e-05
  Time:       6.44 min
  Preview SQL: SELECT count ( * ) FROM stadium ; <eos>
  ✅ Saved checkpoint (best val loss: 19.2808)
  First batch loss: 1.2082
  First batch loss: 11.0046

[Epoch 02/15]
  Train Loss: 1.0243
  Val Loss:   19.1156
  LR:         9.57e-05
  Time:       6.42 min
  Preview SQL: SELECT count ( * ) FROM stadium ; <eos>
  ✅ Saved checkpoint (best val loss: 19.1156)
  First batch loss: 0.7941
  First batch loss: 10.9505

[Epoch 03/15]
  Train Loss: 0.8213
  Val Loss:   19.1011
  LR:         9.05e-05
  Time:       6.43 min
  Preview SQL: SELECT song_name FROM s

## Evaluation

In [14]:
# ============================================================
# Section 8: Evaluation on hospital_1 Test Set
# ============================================================

print("=" * 60)
print("EVALUATION ON HOSPITAL_1 TEST SET")
print("=" * 60)

# ============================================================
# SQL Utilities
# ============================================================

print("\nSetting up evaluation utilities...")
print("-" * 60)

def canonical_sql(sql_text):
    """
    Normalize SQL to canonical form using sqlglot.

    Args:
        sql_text: SQL query string

    Returns:
        Normalized SQL string, or None if parsing fails
    """
    if not sql_text:
        return None
    try:
        ast = parse_one(sql_text, read="sqlite")
        return ast.sql(dialect="sqlite", pretty=False)
    except Exception:
        return None


def try_execute(conn, sql_text):
    """
    Execute SQL query and return result set.

    Args:
        conn: SQLite connection
        sql_text: SQL query string

    Returns:
        (result_set, error):
            - result_set: Set of tuples (rows), or None if error
            - error: Error message string, or None if success
    """
    try:
        cur = conn.execute(sql_text)
        rows = cur.fetchall()

        # Normalize floats to 6 decimal places
        normalized = []
        for row in rows:
            norm_row = []
            for val in row:
                if isinstance(val, float):
                    norm_row.append(round(val, 6))
                else:
                    norm_row.append(val)
            normalized.append(tuple(norm_row))

        return set(normalized), None

    except Exception as e:
        return None, str(e)


def extract_sql_from_tokens(tok, token_ids):
    """
    Convert token IDs back to SQL string.

    Args:
        tok: VocabTok instance
        token_ids: List of token IDs

    Returns:
        SQL string
    """
    # Decode tokens
    text = " ".join(tok.itos[i] for i in token_ids if i < len(tok.itos))
    text = text.strip()

    # Clean up special tokens
    text = text.replace("<bos>", "").replace("<eos>", "").replace("<pad>", "")
    text = text.strip()

    # Ensure semicolon at end
    if ";" in text:
        text = text.split(";", 1)[0] + ";"

    return text


print("✅ SQL utilities defined")


# ============================================================
# Load Trained Model
# ============================================================

print("\nLoading trained model...")
print("-" * 60)

# Load checkpoint
checkpoint_path = SAVE_DIR / "bilstm_model.pt"
vocab_path = SAVE_DIR / "bilstm_vocab.json"

if not checkpoint_path.exists():
    print(f"❌ ERROR: Model checkpoint not found at {checkpoint_path}")
    print("Please complete training first!")
else:
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)

    # Extract config
    config = checkpoint.get('config', {
        'vocab_size': len(vocab),
        'emb_dim': EMB_DIM,
        'hid_dim': HID_DIM,
        'num_layers': NUM_LAYERS,
        'dropout': DROPOUT
    })

    print(f"✅ Loaded checkpoint from epoch {checkpoint.get('epoch', '?')}")
    print(f"   Train loss: {checkpoint.get('train_loss', 0):.4f}")
    print(f"   Val loss:   {checkpoint.get('val_loss', 0):.4f}")

    # Recreate model with saved config
    eval_model = BiLSTMSeq2SQL(
        vocab_size=config['vocab_size'],
        pad_idx=0,
        emb_dim=config['emb_dim'],
        hid_dim=config['hid_dim'],
        num_layers=config['num_layers'],
        dropout=config['dropout']
    )

    # Load weights
    eval_model.load_state_dict(checkpoint['model'])
    eval_model = eval_model.to(DEVICE)
    eval_model.eval()

    print(f"✅ Model loaded and ready for evaluation")


# ============================================================
# Connect to hospital_1 Database
# ============================================================

print("\nConnecting to database...")
print("-" * 60)

conn = sqlite3.connect(str(SQLITE_DB))
conn.execute("PRAGMA foreign_keys=ON")

print(f"✅ Connected to: {SQLITE_DB}")


# ============================================================
# Evaluation Loop
# ============================================================

print("\n" + "=" * 60)
print("RUNNING EVALUATION")
print("=" * 60)

results = []
n_examples = len(test_dataset)

em_count = 0
ex_count = 0
valid_count = 0
latencies = []

print(f"\nEvaluating on {n_examples} examples from hospital_1...")
print("-" * 60)

for i, (src, tgt, schemas) in enumerate(test_loader, 1):  # CHANGED: schema -> schemas
    # Move to device
    src = src.to(DEVICE)

    # Get example data
    test_ex = test_rows[i - 1]
    question = test_ex['question']
    gold_sql = test_ex['gold_query']

    # Extract schema string from list (ADDED)
    schema = schemas[0] if isinstance(schemas, (list, tuple)) else schemas

    # Build vocabulary mask for this example
    vocab_mask = build_vocab_mask_for_example(vocab, schema)
    vocab_mask = vocab_mask.unsqueeze(0).to(DEVICE)

    # Generate SQL
    start_time = time.time()

    with torch.no_grad():
        generated_ids = eval_model.greedy_decode(
            src,
            BOS_ID,
            EOS_ID,
            max_len=MAX_DECODE_LEN,
            vocab_mask=vocab_mask
        )

    gen_time_ms = (time.time() - start_time) * 1000.0
    latencies.append(gen_time_ms)

    # Extract SQL from generated tokens
    if generated_ids.numel() > 0:
        pred_ids = generated_ids[0].tolist()
        pred_sql_raw = extract_sql_from_tokens(vocab, pred_ids)
    else:
        pred_sql_raw = ""

    # Normalize both predicted and gold SQL
    pred_sql_norm = canonical_sql(pred_sql_raw)
    gold_sql_norm = canonical_sql(gold_sql)

    # ============================================================
    # Compute Metrics
    # ============================================================

    # Exact Match (EM)
    em = int(
        pred_sql_norm is not None and
        gold_sql_norm is not None and
        pred_sql_norm == gold_sql_norm
    )

    # Execution Accuracy (EX) and Valid SQL
    valid = 0
    ex_ok = 0
    error = None

    if pred_sql_norm is not None:
        # Try to execute predicted SQL
        pred_rows, error = try_execute(conn, pred_sql_norm)

        if pred_rows is not None:
            valid = 1  # SQL is valid (executed without error)

            # Execute gold SQL
            gold_rows, gold_error = try_execute(conn, gold_sql_norm or gold_sql)

            if gold_rows is not None:
                # Compare result sets
                ex_ok = int(pred_rows == gold_rows)
            else:
                error = f"Gold SQL failed: {gold_error}"
    else:
        error = "ParseError: Could not parse predicted SQL"

    # Update counters
    em_count += em
    ex_count += ex_ok
    valid_count += valid

    # Store result
    results.append({
        "id": test_ex["id"],
        "question": question,
        "gold_sql": gold_sql,
        "pred_sql_raw": pred_sql_raw,
        "pred_sql_norm": pred_sql_norm or "",
        "em": em,
        "ex": ex_ok,
        "valid_sql": valid,
        "latency_ms": round(gen_time_ms, 2),
        "error": error or ""
    })

    # Progress update
    if i % 10 == 0 or i == n_examples:
        print(f"[{i}/{n_examples}] EM={em_count/i:.3f} EX={ex_count/i:.3f} Valid={valid_count/i:.3f}")


# ============================================================
# Save Results
# ============================================================

print("\nSaving results...")
print("-" * 60)

with open(RESULTS_CSV, "w", newline="", encoding="utf-8") as f:
    if results:
        writer = csv.DictWriter(f, fieldnames=list(results[0].keys()))
        writer.writeheader()
        writer.writerows(results)

print(f"✅ Results saved to: {RESULTS_CSV}")


# ============================================================
# Summary Statistics
# ============================================================

em_rate = em_count / n_examples
ex_rate = ex_count / n_examples
valid_rate = valid_count / n_examples
median_latency = sorted(latencies)[len(latencies) // 2] if latencies else 0

print("\n" + "=" * 60)
print("EVALUATION SUMMARY")
print("=" * 60)
print(f"\nModel: BiLSTM Encoder-Decoder")
print(f"Test Dataset: hospital_1")
print(f"Examples: {n_examples}")
print(f"\nMetrics:")
print(f"  Exact Match (EM):        {em_rate:.3%}")
print(f"  Execution Accuracy (EX): {ex_rate:.3%}")
print(f"  Valid-SQL rate:          {valid_rate:.3%}")
print(f"\nPerformance:")
print(f"  Median generation time:  {median_latency:.1f} ms")
print(f"\nResults saved to: {RESULTS_CSV}")
print("=" * 60)

# Close database connection
conn.close()

print("\n" + "=" * 60)
print("Section 8 Complete!")
print("=" * 60)

EVALUATION ON HOSPITAL_1 TEST SET

Setting up evaluation utilities...
------------------------------------------------------------
✅ SQL utilities defined

Loading trained model...
------------------------------------------------------------
✅ Loaded checkpoint from epoch 3
   Train loss: 0.8213
   Val loss:   19.1011
✅ Model loaded and ready for evaluation

Connecting to database...
------------------------------------------------------------
✅ Connected to: hospital_1.sqlite

RUNNING EVALUATION

Evaluating on 100 examples from hospital_1...
------------------------------------------------------------
[10/100] EM=0.000 EX=0.000 Valid=0.000
[20/100] EM=0.000 EX=0.000 Valid=0.000
[30/100] EM=0.000 EX=0.000 Valid=0.000
[40/100] EM=0.000 EX=0.000 Valid=0.000
[50/100] EM=0.000 EX=0.000 Valid=0.000
[60/100] EM=0.000 EX=0.000 Valid=0.000
[70/100] EM=0.000 EX=0.000 Valid=0.000
[80/100] EM=0.000 EX=0.000 Valid=0.000
[90/100] EM=0.000 EX=0.000 Valid=0.000
[100/100] EM=0.000 EX=0.000 Valid=0.000