# Healthcare SQL Agent - Model-First MVP## End-to-End Training and Inference PipelineThis notebook implements a complete pipeline with **validation-focused improvements**:1. Generate synthetic T-SQL training dataset with **ID vault** (extracts ALL identifiers)2. Train a BPE tokenizer with **special tokens** (placeholders, sentinel, keywords)3. Train a tiny decoder-only model with **loss masking** (only SQL tokens supervised)4. Run inference with **robust SQL extraction** and **sentinel-based stopping**5. Validate with **strict rules** (SELECT-only, placeholder integrity, schema checks)6. Evaluate on 100 questions to measure **pass rate** (target: ≥90%)### Key Improvements:- **ID Vault**: Extracts patient IDs, department IDs, provider IDs, years (1900-2100), and dates- **SQL Extraction**: Handles model echoing prompt by extracting only the SQL statement- **Sentinel Token**: `</SQL>` marker stops generation and helps parse output- **Loss Masking**: Training loss computed only on SQL tokens (not prompt/schema)- **Special Tokens**: Placeholders (`__ID_1__`, `__YEAR_1__`), sentinel, keywords preserved- **Enhanced Validation**: Checks placeholder integrity, forbids DML/DDL, enforces SELECT-only

## 1. Install Dependencies

In [None]:
!pip install -q torch transformers tokenizers datasets tqdm jsonlines safetensors accelerate

## 2. Imports and Setup

In [None]:
import json
import random
import re
import os
from pathlib import Path
from typing import List, Dict, Tuple
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tqdm import tqdm
import math

# Set random seeds
random.seed(42)
torch.manual_seed(42)

# Create directories
os.makedirs('data', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('artifacts/tokenizer', exist_ok=True)

print("Setup complete!")

## 3. Example Schema Definition

In [None]:
# Define example healthcare schema
EXAMPLE_SCHEMA = {
    "schema_name": "healthcare_analytics",
    "tables": [
        {
            "name": "Patients",
            "columns": [
                {"name": "PatientID", "type": "INT", "pk": True},
                {"name": "FirstName", "type": "VARCHAR(50)"},
                {"name": "LastName", "type": "VARCHAR(50)"},
                {"name": "DateOfBirth", "type": "DATE"},
                {"name": "Gender", "type": "VARCHAR(10)"},
                {"name": "InsuranceProvider", "type": "VARCHAR(100)"}
            ]
        },
        {
            "name": "Visits",
            "columns": [
                {"name": "VisitID", "type": "INT", "pk": True},
                {"name": "PatientID", "type": "INT", "fk": "Patients.PatientID"},
                {"name": "VisitDate", "type": "DATE"},
                {"name": "DepartmentID", "type": "INT"},
                {"name": "ProviderID", "type": "INT"},
                {"name": "VisitType", "type": "VARCHAR(50)"},
                {"name": "TotalCharge", "type": "DECIMAL(10,2)"}
            ]
        },
        {
            "name": "Departments",
            "columns": [
                {"name": "DepartmentID", "type": "INT", "pk": True},
                {"name": "DepartmentName", "type": "VARCHAR(100)"},
                {"name": "Location", "type": "VARCHAR(100)"}
            ]
        },
        {
            "name": "Providers",
            "columns": [
                {"name": "ProviderID", "type": "INT", "pk": True},
                {"name": "ProviderName", "type": "VARCHAR(100)"},
                {"name": "Specialty", "type": "VARCHAR(100)"},
                {"name": "DepartmentID", "type": "INT", "fk": "Departments.DepartmentID"}
            ]
        },
        {
            "name": "Diagnoses",
            "columns": [
                {"name": "DiagnosisID", "type": "INT", "pk": True},
                {"name": "VisitID", "type": "INT", "fk": "Visits.VisitID"},
                {"name": "ICDCode", "type": "VARCHAR(20)"},
                {"name": "DiagnosisDescription", "type": "VARCHAR(255)"}
            ]
        }
    ]
}

# Save schema
with open('data/example_schema.json', 'w') as f:
    json.dump(EXAMPLE_SCHEMA, f, indent=2)

print("Schema created with", len(EXAMPLE_SCHEMA['tables']), "tables")

## 4. Dataset Generator

**IMPROVED**: Enhanced with:
- **ID Vault**: `extract_placeholders()` now extracts ALL numeric identifiers (1-6 digits)
- **Year Detection**: Recognizes years (1900-2100) based on context ("in 2023", "during 2022")
- **Date Extraction**: Handles ISO dates (YYYY-MM-DD)
- **SQL Extraction**: `extract_sql_from_completion()` extracts clean SQL from model output
- **Sentinel Token**: Appends `</SQL>` to all training samples for reliable stopping

In [None]:
def generate_schema_compact_text(schema: Dict) -> str:    """Generate compact schema text for prompts."""    lines = []    for table in schema['tables']:        cols = ', '.join([c['name'] for c in table['columns']])        lines.append(f"{table['name']} ( {cols} )")    return '; '.join(lines)def extract_placeholders(question: str) -> Tuple[str, Dict[str, str]]:    """Extract IDs, years, and dates, replace with placeholders."""    id_map = {}    id_counter = 1    year_counter = 1    date_counter = 1        # First pass: Extract ISO dates (YYYY-MM-DD)    def replace_date(match):        nonlocal date_counter        value = match.group(0)        placeholder = f"__DATE_{date_counter}__"        id_map[placeholder] = value        date_counter += 1        return placeholder        question_clean = re.sub(r'\b\d{4}-\d{2}-\d{2}\b', replace_date, question)        # Second pass: Extract years (1900-2100) with context clues    # Match 4-digit numbers that look like years    def replace_year_or_id(match):        nonlocal id_counter, year_counter        value = match.group(0)        num = int(value)                # Check if it's a year (1900-2100 range) and has year context        # Look at surrounding words for year indicators        start_pos = match.start()        end_pos = match.end()        context_before = question_clean[max(0, start_pos-20):start_pos].lower()        context_after = question_clean[end_pos:min(len(question_clean), end_pos+20)].lower()                year_indicators = ['in ', 'year', 'during', 'for ', 'since', 'until']        is_year_context = any(ind in context_before for ind in year_indicators) or any(ind in context_after for ind in year_indicators)                if 1900 <= num <= 2100 and is_year_context:            placeholder = f"__YEAR_{year_counter}__"            year_counter += 1        else:            placeholder = f"__ID_{id_counter}__"            id_counter += 1                id_map[placeholder] = value        return placeholder        # Match 1-6 digit numbers    question_clean = re.sub(r'\b\d{1,6}\b', replace_year_or_id, question_clean)        return question_clean, id_mapdef extract_sql_from_completion(completion: str) -> str:    """Extract only the SQL statement from model completion."""    # Find first SELECT (case-insensitive)    upper_completion = completion.upper()    select_pos = upper_completion.find('SELECT')        if select_pos == -1:        # No SELECT found, return as-is for validation to fail        return completion.strip()        # Extract from SELECT onwards    sql_part = completion[select_pos:]        # Look for sentinel </SQL> first    if '</SQL>' in sql_part:        end_pos = sql_part.find('</SQL>')        sql_part = sql_part[:end_pos].strip()    else:        # Look for first semicolon        semicolon_pos = sql_part.find(';')        if semicolon_pos != -1:            sql_part = sql_part[:semicolon_pos + 1].strip()        # Ensure it ends with semicolon    if not sql_part.endswith(';'):        sql_part = sql_part + ';'        return sql_partdef generate_dataset_samples(schema: Dict, num_samples: int) -> List[Dict]:    """Generate synthetic SQL training samples."""    samples = []    schema_text = generate_schema_compact_text(schema)        templates = [        # COUNT templates        {            "question": "How many visits did patient {patient_id} have in department {dept_id}?",            "sql": "SELECT COUNT(*) FROM Visits WHERE PatientID = {patient_id_ph} AND DepartmentID = {dept_id_ph};"        },        {            "question": "Count total visits in {year}",            "sql": "SELECT COUNT(*) FROM Visits WHERE YEAR(VisitDate) = {year_ph};"        },        # SUM templates        {            "question": "What is the total charge for patient {patient_id}?",            "sql": "SELECT SUM(TotalCharge) FROM Visits WHERE PatientID = {patient_id_ph};"        },        {            "question": "Total charges for department {dept_id} in {year}",            "sql": "SELECT SUM(TotalCharge) FROM Visits WHERE DepartmentID = {dept_id_ph} AND YEAR(VisitDate) = {year_ph};"        },        # AVG templates        {            "question": "What is the average charge per visit for provider {provider_id}?",            "sql": "SELECT AVG(TotalCharge) FROM Visits WHERE ProviderID = {provider_id_ph};"        },        # GROUP BY templates        {            "question": "Show visit counts by department for patient {patient_id}",            "sql": "SELECT DepartmentID, COUNT(*) FROM Visits WHERE PatientID = {patient_id_ph} GROUP BY DepartmentID;"        },        {            "question": "Show monthly visit counts in {year}",            "sql": "SELECT MONTH(VisitDate) AS Month, COUNT(*) FROM Visits WHERE YEAR(VisitDate) = {year_ph} GROUP BY MONTH(VisitDate);"        },        {            "question": "Total charges by provider in department {dept_id}",            "sql": "SELECT ProviderID, SUM(TotalCharge) FROM Visits WHERE DepartmentID = {dept_id_ph} GROUP BY ProviderID;"        },        # JOIN templates        {            "question": "List all visits with patient names for patient {patient_id}",            "sql": "SELECT V.VisitID, P.FirstName, P.LastName, V.VisitDate FROM Visits V JOIN Patients P ON V.PatientID = P.PatientID WHERE V.PatientID = {patient_id_ph};"        },        {            "question": "Show visit counts by department name",            "sql": "SELECT D.DepartmentName, COUNT(V.VisitID) FROM Visits V JOIN Departments D ON V.DepartmentID = D.DepartmentID GROUP BY D.DepartmentName;"        },    ]        for i in range(num_samples):        template = random.choice(templates)                # Generate random IDs        patient_id = random.randint(1000, 9999)        dept_id = random.randint(10, 99)        provider_id = random.randint(100, 999)        year = random.randint(2020, 2024)                # Fill template        question = template['question'].format(            patient_id=patient_id,            dept_id=dept_id,            provider_id=provider_id,            year=year        )                # Extract placeholders (now handles all IDs properly)        question_clean, id_map = extract_placeholders(question)                # Build SQL with placeholders - need to match the order in id_map        # The id_map keys are ordered by appearance in the question        id_keys = list(id_map.keys())                # Map template variables to placeholder positions        sql_template = template['sql']        if '{patient_id_ph}' in sql_template and '{dept_id_ph}' in sql_template:            # Two IDs: patient first, then department            sql = sql_template.format(patient_id_ph=id_keys[0], dept_id_ph=id_keys[1])        elif '{dept_id_ph}' in sql_template and '{year_ph}' in sql_template:            # Department and year            sql = sql_template.format(dept_id_ph=id_keys[0], year_ph=id_keys[1])        elif '{patient_id_ph}' in sql_template:            sql = sql_template.format(patient_id_ph=id_keys[0])        elif '{dept_id_ph}' in sql_template:            sql = sql_template.format(dept_id_ph=id_keys[0])        elif '{provider_id_ph}' in sql_template:            sql = sql_template.format(provider_id_ph=id_keys[0])        elif '{year_ph}' in sql_template:            sql = sql_template.format(year_ph=id_keys[0])        else:            # No placeholders needed (e.g., "Show visit counts by department name")            sql = sql_template                # Add sentinel token        sql = sql.strip()        if not sql.endswith(';'):            sql = sql + ';'        sql = sql + ' </SQL>'                # Create ID map text        id_map_text = ' , '.join([f"{k} = {v}" for k, v in id_map.items()]) if id_map else "None"                # Create training record        record = {            "schema_id": schema['schema_name'],            "schema_text": schema_text,            "question": question_clean,            "id_map": id_map_text,            "sql": sql        }                samples.append(record)        return samples# Generate datasetsprint("Generating training dataset...")train_samples = generate_dataset_samples(EXAMPLE_SCHEMA, 5000)print("Generating validation dataset...")val_samples = generate_dataset_samples(EXAMPLE_SCHEMA, 200)# Save datasetswith open('data/train.jsonl', 'w') as f:    for sample in train_samples:        f.write(json.dumps(sample) + '\n')with open('data/val.jsonl', 'w') as f:    for sample in val_samples:        f.write(json.dumps(sample) + '\n')print(f"Generated {len(train_samples)} training samples")print(f"Generated {len(val_samples)} validation samples")print("\nExample sample:")print(json.dumps(train_samples[0], indent=2))print("\nID extraction test:")test_q = "How many visits did patient 5432 have in department 25?"clean, id_map = extract_placeholders(test_q)print(f"Original: {test_q}")print(f"Clean: {clean}")print(f"ID Map: {id_map}")

## 5. Train BPE Tokenizer

**IMPROVED**: Added special tokens:
- **Section markers**: `SCHEMA:`, `QUESTION:`, `ID_MAP:`, `SQL:` (never split)
- **Sentinel**: `</SQL>` (single token for reliable stopping)
- **Placeholders**: `__ID_1__` through `__ID_64__`, `__YEAR_1__` through `__YEAR_8__`, `__DATE_1__` through `__DATE_16__`
- **SQL Keywords**: SELECT, FROM, WHERE, etc. (prevents splitting common SQL terms)

In [None]:
def format_training_text(sample: Dict) -> str:    """Format sample into training text with markers."""    return f"""SCHEMA : {sample['schema_text']}QUESTION : {sample['question']}ID_MAP : {sample['id_map']}SQL : {sample['sql']}"""# Prepare training text for tokenizerprint("Preparing tokenizer training data...")tokenizer_training_file = 'data/tokenizer_train.txt'with open(tokenizer_training_file, 'w') as f:    for sample in train_samples:        f.write(format_training_text(sample) + '\n\n')# Build special tokens listspecial_tokens = ["[PAD]", "[UNK]", "[BOS]", "[EOS]"]# Add section markersspecial_tokens.extend(["SCHEMA", ":", "SCHEMA:", "QUESTION", "QUESTION:", "ID_MAP", "ID_MAP:", "SQL", "SQL:"])# Add sentinelspecial_tokens.append("</SQL>")# Add placeholder tokens (ID, DATE, YEAR)for i in range(1, 65):    special_tokens.append(f"__ID_{i}__")for i in range(1, 17):    special_tokens.append(f"__DATE_{i}__")for i in range(1, 9):    special_tokens.append(f"__YEAR_{i}__")# Add common SQL keywords as special tokens to prevent splittingsql_keywords = ["SELECT", "FROM", "WHERE", "JOIN", "GROUP", "BY", "ORDER", "COUNT", "SUM", "AVG",                 "MIN", "MAX", "AS", "AND", "OR", "ON", "INNER", "LEFT", "RIGHT", "OUTER"]special_tokens.extend(sql_keywords)print(f"Training tokenizer with {len(special_tokens)} special tokens...")# Train tokenizertokenizer = Tokenizer(BPE(unk_token="[UNK]"))tokenizer.pre_tokenizer = Whitespace()trainer = BpeTrainer(    vocab_size=8000,    min_frequency=2,    special_tokens=special_tokens)tokenizer.train([tokenizer_training_file], trainer)# Save tokenizertokenizer.save('artifacts/tokenizer/tokenizer.json')print("Tokenizer trained and saved!")print(f"Vocab size: {tokenizer.get_vocab_size()}")# Test tokenizertest_text = "SELECT COUNT(*) FROM Visits WHERE PatientID = __ID_1__;"encoded = tokenizer.encode(test_text)print(f"\nTest encoding: {test_text}")print(f"Tokens: {encoded.tokens}")# Test special tokenstest_text2 = "SQL : SELECT * FROM Visits ; </SQL>"encoded2 = tokenizer.encode(test_text2)print(f"\nTest special tokens: {test_text2}")print(f"Tokens: {encoded2.tokens}")

## 6. Define Decoder-Only Transformer Model

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        
    def forward(self, x, mask=None):
        # Self-attention with residual
        attn_out, _ = self.attention(x, x, x, attn_mask=mask, need_weights=False)
        x = self.norm1(x + attn_out)
        # FFN with residual
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 512, n_layers: int = 8, 
                 n_heads: int = 8, max_seq_len: int = 512, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, dropout) for _ in range(n_layers)
        ])
        
        # Output
        self.norm = nn.LayerNorm(d_model)
        self.output = nn.Linear(d_model, vocab_size)
        
    def forward(self, input_ids, attention_mask=None):
        batch_size, seq_len = input_ids.shape
        
        # Create causal mask
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1)
        causal_mask = causal_mask.masked_fill(causal_mask == 1, float('-inf'))
        
        # Embeddings
        positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)
        x = self.token_embedding(input_ids) + self.position_embedding(positions)
        x = self.dropout(x)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x, mask=causal_mask)
        
        x = self.norm(x)
        logits = self.output(x)
        
        return logits

# Model config (tiny for MVP - ~50M params)
MODEL_CONFIG = {
    'vocab_size': tokenizer.get_vocab_size(),
    'd_model': 512,
    'n_layers': 8,
    'n_heads': 8,
    'max_seq_len': 512,
    'dropout': 0.1
}

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DecoderOnlyTransformer(**MODEL_CONFIG).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model created with {total_params:,} parameters ({total_params/1e6:.1f}M)")
print(f"Device: {device}")

## 7. Dataset and DataLoader

**IMPROVED**: Implements loss masking:
- Returns `{'input_ids': ..., 'labels': ...}` instead of just input_ids
- Masks prompt tokens (SCHEMA, QUESTION, ID_MAP) with `-100` in labels
- Training loss computed **only on SQL tokens** after `SQL:` marker
- Prevents model from learning to echo the prompt unnecessarily

In [None]:
class SQLDataset(Dataset):    def __init__(self, jsonl_path: str, tokenizer: Tokenizer, max_length: int = 512):        self.tokenizer = tokenizer        self.max_length = max_length        self.samples = []                # Load samples        with open(jsonl_path, 'r') as f:            for line in f:                self.samples.append(json.loads(line))        def __len__(self):        return len(self.samples)        def __getitem__(self, idx):        sample = self.samples[idx]                # Format text        text = format_training_text(sample)                # Tokenize        encoded = self.tokenizer.encode(text)        tokens = encoded.ids                # Find the position of "SQL :" to mask everything before it        # We want to compute loss only on SQL tokens        text_before_sql = f"""SCHEMA : {sample['schema_text']}QUESTION : {sample['question']}ID_MAP : {sample['id_map']}SQL : """        encoded_before = self.tokenizer.encode(text_before_sql)        mask_until = len(encoded_before.ids)                # Create labels (for loss computation)        labels = tokens.copy()                # Mask prompt tokens (set to -100 so they're ignored in loss)        for i in range(min(mask_until, len(labels))):            labels[i] = -100                # Truncate or pad        if len(tokens) > self.max_length:            tokens = tokens[:self.max_length]            labels = labels[:self.max_length]        else:            pad_len = self.max_length - len(tokens)            tokens = tokens + [0] * pad_len  # 0 is [PAD]            labels = labels + [-100] * pad_len  # Mask padding                return {            'input_ids': torch.tensor(tokens, dtype=torch.long),            'labels': torch.tensor(labels, dtype=torch.long)        }# Create datasetstrain_dataset = SQLDataset('data/train.jsonl', tokenizer, max_length=MODEL_CONFIG['max_seq_len'])val_dataset = SQLDataset('data/val.jsonl', tokenizer, max_length=MODEL_CONFIG['max_seq_len'])# Create dataloaderstrain_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)print(f"Train dataset: {len(train_dataset)} samples")print(f"Val dataset: {len(val_dataset)} samples")print(f"Batches per epoch: {len(train_loader)}")# Test loss maskingprint("\nTesting loss masking...")sample_batch = next(iter(train_loader))print(f"Input shape: {sample_batch['input_ids'].shape}")print(f"Labels shape: {sample_batch['labels'].shape}")print(f"Masked tokens (label=-100): {(sample_batch['labels'][0] == -100).sum().item()}")print(f"Unmasked tokens: {(sample_batch['labels'][0] != -100).sum().item()}")

## 8. Training Loop

**IMPROVED**: Updated to use loss masking:
- Accepts `batch['input_ids']` and `batch['labels']` from dataset
- Uses `ignore_index=-100` in CrossEntropyLoss to skip masked tokens
- Only supervises SQL generation, not prompt reproduction

In [None]:
def train_epoch(model, train_loader, optimizer, device, epoch):    model.train()    total_loss = 0        pbar = tqdm(train_loader, desc=f"Epoch {epoch}")    for batch_idx, batch in enumerate(pbar):        input_ids = batch['input_ids'].to(device)        labels = batch['labels'].to(device)                # Forward pass        logits = model(input_ids)                # Shift for next-token prediction        shift_logits = logits[:, :-1, :].contiguous()        shift_labels = labels[:, 1:].contiguous()                # Calculate loss (ignore_index=-100 handles both padding and masked prompt tokens)        loss_fct = nn.CrossEntropyLoss(ignore_index=-100)        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))                # Backward pass        optimizer.zero_grad()        loss.backward()        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)        optimizer.step()                total_loss += loss.item()        pbar.set_postfix({'loss': f'{loss.item():.4f}'})        return total_loss / len(train_loader)def validate(model, val_loader, device):    model.eval()    total_loss = 0        with torch.no_grad():        for batch in val_loader:            input_ids = batch['input_ids'].to(device)            labels = batch['labels'].to(device)                        logits = model(input_ids)                        shift_logits = logits[:, :-1, :].contiguous()            shift_labels = labels[:, 1:].contiguous()                        loss_fct = nn.CrossEntropyLoss(ignore_index=-100)            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))                        total_loss += loss.item()        return total_loss / len(val_loader)# Training setupoptimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.01)num_epochs = 3  # Short training for MVPprint("Starting training...\n")for epoch in range(1, num_epochs + 1):    train_loss = train_epoch(model, train_loader, optimizer, device, epoch)    val_loss = validate(model, val_loader, device)        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")        # Save checkpoint    checkpoint = {        'epoch': epoch,        'model_state_dict': model.state_dict(),        'optimizer_state_dict': optimizer.state_dict(),        'train_loss': train_loss,        'val_loss': val_loss,        'model_config': MODEL_CONFIG    }    torch.save(checkpoint, f'checkpoints/checkpoint_epoch_{epoch}.pt')    print(f"Checkpoint saved: checkpoint_epoch_{epoch}.pt\n")print("Training complete!")

## 9. Inference and Generation

**IMPROVED**: Robust generation with:
- **Sentinel stopping**: Stops when `</SQL>` is generated
- **SQL extraction**: Uses `extract_sql_from_completion()` to extract clean SQL
- **Prompt echo handling**: Strips SCHEMA/QUESTION/ID_MAP if model echoes them
- **Proper formatting**: Matches training format with spacing (e.g., `SCHEMA : ...`)

In [None]:
def generate_sql(model, tokenizer, schema_text, question, id_map_text,                  max_length=256, temperature=0.1, device='cuda'):    """Generate SQL from a question using the trained model."""    model.eval()        # Build prompt with proper spacing to match training format    prompt = f"""SCHEMA : {schema_text}QUESTION : {question}ID_MAP : {id_map_text}SQL : """        # Encode prompt    encoded = tokenizer.encode(prompt)    input_ids = torch.tensor([encoded.ids], dtype=torch.long).to(device)        # Generate    generated_ids = input_ids[0].tolist()        # Get sentinel token ID if it exists    sentinel_token = "</SQL>"    try:        sentinel_encoded = tokenizer.encode(sentinel_token)        sentinel_id = sentinel_encoded.ids[0] if len(sentinel_encoded.ids) > 0 else None    except:        sentinel_id = None        with torch.no_grad():        for step in range(max_length):            # Get current sequence            current_ids = torch.tensor([generated_ids], dtype=torch.long).to(device)                        # Forward pass            logits = model(current_ids)                        # Get next token logits            next_token_logits = logits[0, -1, :] / temperature                        # Greedy sampling            next_token_id = torch.argmax(next_token_logits).item()                        # Add to sequence            generated_ids.append(next_token_id)                        # Decode to check for stopping conditions            decoded_text = tokenizer.decode(generated_ids)            completion = decoded_text[len(prompt):]                        # Stop if we hit sentinel            if '</SQL>' in completion:                break                        # Stop if we hit semicolon and have generated reasonable SQL            if ';' in completion and step > 10:                # Check if it looks like SQL after semicolon                after_semi = completion.split(';', 1)[1].strip()                if not after_semi or len(after_semi) < 10:                    break                        # Stop on EOS token            if next_token_id == 3:  # 3 is [EOS]                break        # Decode full text    generated_text = tokenizer.decode(generated_ids)        # Extract SQL part using the extraction function    completion = generated_text[len(prompt):]    sql = extract_sql_from_completion(completion)        return sqlprint("Inference function ready!")

## 10. Validation Gates

**IMPROVED**: Enhanced validation:
- **Placeholder integrity**: Checks `__ID_*__`, `__YEAR_*__`, `__DATE_*__` against ID map
- **Strict DML/DDL blocking**: Rejects INSERT, UPDATE, DELETE, MERGE, DROP, ALTER, CREATE, TRUNCATE, EXEC, GRANT, REVOKE
- **Single statement**: Must have exactly one semicolon
- **SELECT-only**: Must start with SELECT (case-insensitive)
- **Schema checks**: Validates known table names appear in SQL

In [None]:
def validate_sql(sql: str, schema: Dict, id_map: Dict[str, str]) -> Tuple[bool, List[str]]:    """Validate generated SQL against strict rules."""    errors = []        # Clean up SQL    sql = sql.strip()        # Rule 1: Must be exactly one statement (count semicolons)    if sql.count(';') != 1:        errors.append("Must contain exactly one statement ending with ';'")        # Rule 2: Must end with semicolon    if not sql.endswith(';'):        errors.append("Must end with ';'")        # Rule 3: Must be SELECT only (case-insensitive)    sql_upper = sql.upper()    if not sql_upper.startswith('SELECT'):        errors.append("Must start with SELECT")        # Rule 4: Block DML/DDL keywords    forbidden_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE', 'DROP', 'ALTER',                          'CREATE', 'TRUNCATE', 'EXEC', 'EXECUTE', 'GRANT', 'REVOKE']    for keyword in forbidden_keywords:        if keyword in sql_upper:            errors.append(f"Forbidden keyword: {keyword}")        # Rule 5: Schema-known tables only (basic check)    table_names = [t['name'] for t in schema['tables']]    found_known_table = False    for table in table_names:        if table.upper() in sql_upper:            found_known_table = True            break        if not found_known_table:        errors.append("No known tables found in SQL")        # Rule 6: Placeholder integrity - check for ID, DATE, YEAR placeholders    placeholders_in_sql = re.findall(r'__(?:ID|DATE|YEAR)_\d+__', sql)    for ph in placeholders_in_sql:        if ph not in id_map:            errors.append(f"Unknown placeholder: {ph}")        return len(errors) == 0, errorsdef reinject_ids(sql: str, id_map: Dict[str, str]) -> str:    """Replace placeholders with actual IDs."""    for placeholder, value in id_map.items():        sql = sql.replace(placeholder, value)    return sqlprint("Validation functions ready!")

## 11. End-to-End Demo with 3 Example Questions

**IMPROVED**: Uses complete enhanced pipeline:
- Enhanced ID extraction (catches department IDs, years, etc.)
- Robust SQL extraction from model output
- Improved validation with all new rules

In [None]:
# Load latest checkpointlatest_checkpoint = 'checkpoints/checkpoint_epoch_3.pt'if os.path.exists(latest_checkpoint):    checkpoint = torch.load(latest_checkpoint, map_location=device)    model.load_state_dict(checkpoint['model_state_dict'])    print(f"Loaded checkpoint from epoch {checkpoint['epoch']}\n")else:    print("No checkpoint found, using untrained model\n")schema_text = generate_schema_compact_text(EXAMPLE_SCHEMA)# Example questionstest_questions = [    "How many visits did patient 5432 have in department 25?",    "What is the total charge for patient 7890?",    "Show monthly visit counts in 2023"]print("=" * 80)print("GENERATING SQL FOR TEST QUESTIONS")print("=" * 80)passed = 0for i, question in enumerate(test_questions, 1):    print(f"\n[Question {i}]")    print(f"Original: {question}")        # Extract placeholders (enhanced to catch all IDs)    question_clean, id_map = extract_placeholders(question)    print(f"Clean: {question_clean}")    print(f"ID Map: {id_map}")        # Generate SQL (with sentinel support)    id_map_text = ' , '.join([f"{k} = {v}" for k, v in id_map.items()]) if id_map else "None"    generated_sql = generate_sql(model, tokenizer, schema_text, question_clean, id_map_text, device=device)        print(f"\nGenerated SQL (with placeholders):")    print(f"  {generated_sql}")        # Validate (with enhanced rules)    is_valid, errors = validate_sql(generated_sql, EXAMPLE_SCHEMA, id_map)        if is_valid:        print("✓ Validation: PASSED")        passed += 1                # Reinject IDs        final_sql = reinject_ids(generated_sql, id_map)        print(f"\nFinal SQL (with real IDs):")        print(f"  {final_sql}")    else:        print("✗ Validation: FAILED")        for error in errors:            print(f"  - {error}")        print("-" * 80)print(f"\n" + "=" * 80)print(f"RESULTS: {passed}/{len(test_questions)} questions passed validation")print("=" * 80)if passed >= 2:    print("\n✓ SUCCESS: Basic acceptance criteria met (>=2/3 passed)")else:    print("\n✗ INCOMPLETE: Need more training or better templates")

## 11.5. Extended Evaluation (100 Test Questions)

Run evaluation on 100 held-out questions to measure pass rate and identify common failure patterns.

In [None]:
# Generate 100 test questions
print("Generating 100 test questions...")
test_samples = generate_dataset_samples(EXAMPLE_SCHEMA, 100)

# Run evaluation
print("Running evaluation...\n")
results = {
    'passed': 0,
    'failed': 0,
    'failure_categories': {
        'must_start_with_select': 0,
        'missing_semicolon': 0,
        'unknown_placeholder': 0,
        'multiple_statements': 0,
        'forbidden_keyword': 0,
        'no_known_tables': 0,
        'other': 0
    },
    'examples': {
        'passed': [],
        'failed': []
    }
}

for i, sample in enumerate(tqdm(test_samples, desc="Evaluating")):
    question_clean = sample['question']
    id_map_text = sample['id_map']
    
    # Parse ID map back to dict
    id_map = {}
    if id_map_text != "None":
        for pair in id_map_text.split(' , '):
            k, v = pair.split(' = ')
            id_map[k] = v
    
    # Generate SQL
    try:
        generated_sql = generate_sql(model, tokenizer, schema_text, question_clean, id_map_text, device=device)
        
        # Validate
        is_valid, errors = validate_sql(generated_sql, EXAMPLE_SCHEMA, id_map)
        
        if is_valid:
            results['passed'] += 1
            if len(results['examples']['passed']) < 5:
                results['examples']['passed'].append({
                    'question': question_clean,
                    'sql': generated_sql
                })
        else:
            results['failed'] += 1
            
            # Categorize failure
            categorized = False
            for error in errors:
                if 'start with SELECT' in error:
                    results['failure_categories']['must_start_with_select'] += 1
                    categorized = True
                elif 'semicolon' in error.lower():
                    results['failure_categories']['missing_semicolon'] += 1
                    categorized = True
                elif 'Unknown placeholder' in error:
                    results['failure_categories']['unknown_placeholder'] += 1
                    categorized = True
                elif 'one statement' in error:
                    results['failure_categories']['multiple_statements'] += 1
                    categorized = True
                elif 'Forbidden keyword' in error:
                    results['failure_categories']['forbidden_keyword'] += 1
                    categorized = True
                elif 'known tables' in error:
                    results['failure_categories']['no_known_tables'] += 1
                    categorized = True
            
            if not categorized:
                results['failure_categories']['other'] += 1
            
            if len(results['examples']['failed']) < 5:
                results['examples']['failed'].append({
                    'question': question_clean,
                    'sql': generated_sql,
                    'errors': errors
                })
    except Exception as e:
        results['failed'] += 1
        results['failure_categories']['other'] += 1
        if len(results['examples']['failed']) < 5:
            results['examples']['failed'].append({
                'question': question_clean,
                'sql': 'ERROR',
                'errors': [str(e)]
            })

# Calculate pass rate
pass_rate = (results['passed'] / len(test_samples)) * 100

# Print results
print("\n" + "=" * 80)
print("EVALUATION RESULTS")
print("=" * 80)
print(f"Total questions: {len(test_samples)}")
print(f"Passed: {results['passed']}")
print(f"Failed: {results['failed']}")
print(f"Pass rate: {pass_rate:.1f}%")
print("\n" + "-" * 80)
print("FAILURE CATEGORIES:")
for category, count in results['failure_categories'].items():
    if count > 0:
        print(f"  {category.replace('_', ' ').title()}: {count}")

print("\n" + "-" * 80)
print("EXAMPLE PASSED QUERIES (first 3):")
for i, ex in enumerate(results['examples']['passed'][:3], 1):
    print(f"\n  {i}. Q: {ex['question']}")
    print(f"     SQL: {ex['sql']}")

print("\n" + "-" * 80)
print("EXAMPLE FAILED QUERIES (first 3):")
for i, ex in enumerate(results['examples']['failed'][:3], 1):
    print(f"\n  {i}. Q: {ex['question']}")
    print(f"     SQL: {ex['sql']}")
    print(f"     Errors: {', '.join(ex['errors'])}")

print("\n" + "=" * 80)
if pass_rate >= 90:
    print("✓✓✓ EXCELLENT: Pass rate >= 90% - Target achieved!")
elif pass_rate >= 70:
    print("✓✓ GOOD: Pass rate >= 70% - Getting close!")
elif pass_rate >= 50:
    print("✓ ACCEPTABLE: Pass rate >= 50% - Needs improvement")
else:
    print("✗ NEEDS WORK: Pass rate < 50% - More training needed")
print("=" * 80)

# Save results to JSON
with open('data/eval_results.json', 'w') as f:
    json.dump(results, f, indent=2)
print("\nResults saved to data/eval_results.json")

## 12. Summary and Next Steps

In [None]:
print("""\n
=================================================================================
MVP PIPELINE COMPLETE!
=================================================================================

What we built:
✓ Example healthcare schema with 5 tables
✓ Dataset generator with 5,000+ training and 200+ validation samples
✓ BPE tokenizer trained on SQL corpus (vocab size: 8,000)
✓ Tiny decoder-only transformer (~50M parameters)
✓ Training loop with checkpointing
✓ Inference with prompt formatting
✓ Validation gates (SELECT-only, schema checks, placeholders)
✓ ID placeholder vault and reinjection
✓ End-to-end demo with 3 test questions

Files created:
- data/example_schema.json
- data/train.jsonl (5,000 samples)
- data/val.jsonl (200 samples)
- artifacts/tokenizer/tokenizer.json
- checkpoints/checkpoint_epoch_*.pt

Next steps to scale:
1. Increase dataset size to 200k-800k samples
2. Add more SQL template diversity (JOINs, subqueries, HAVING, etc.)
3. Scale model to ~300M parameters
4. Train for more epochs with learning rate scheduling
5. Add evaluation harness with held-out test set
6. Fine-tune hyperparameters (learning rate, batch size, etc.)
7. Implement beam search or nucleus sampling for generation
8. Add more sophisticated schema validation

=================================================================================
""")