<a href="https://colab.research.google.com/github/AndrewMichael2020/my-gpt-2-2/blob/main/colab_train_4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Healthcare "SQL Agent" Model â€” Colab Training Notebook

This notebook trains a small decoder-only model **from scratch** to generate **one T-SQL SELECT** statement from a schema-aware question.

Key fixes:
- Deterministic ID vault for all identifiers
- Hard `</SQL>` sentinel
- Loss masking by construction (prompt masked, SQL supervised)
- Token-based completion slicing


In [1]:
# !pip -q install -U tokenizers==0.22.1


In [2]:
import os, re, json, time, random
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace


In [3]:
\
# 1) Repro + Config
SEED = 1337
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

DATA_DIR = "/content/data_sql_agent"
ART_DIR  = "/content/artifacts_sql_agent"
CKPT_DIR = "/content/checkpoints_sql_agent"

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(ART_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

TRAIN_JSONL = os.path.join(DATA_DIR, "train.jsonl")
VAL_JSONL   = os.path.join(DATA_DIR, "val.jsonl")
TOKENIZER_PATH = os.path.join(ART_DIR, "tokenizer.json")

N_TRAIN = 5000
N_VAL   = 200

MAX_LEN = 384
BATCH_SIZE = 8
LR = 2e-4
EPOCHS = 3
WARMUP_STEPS = 200
GRAD_CLIP = 1.0

D_MODEL = 384
N_HEAD  = 6
N_LAYER = 8
D_FF    = 4 * D_MODEL
DROPOUT = 0.1

GEN_MAX_NEW_TOKENS = 160
TEMPERATURE = 0.0

SQL_END = "</SQL>"


Device: cuda


In [4]:
\
# 2) Schema + Templates
SCHEMA_ID = "healthcare_analytics"
SCHEMA_TEXT = (
    "Patients ( PatientID, FirstName, LastName, DateOfBirth, Gender, InsuranceProvider ); "
    "Visits ( VisitID, PatientID, VisitDate, DepartmentID, ProviderID, VisitType, TotalCharge ); "
    "Departments ( DepartmentID, DepartmentName, Location ); "
    "Providers ( ProviderID, ProviderName, Specialty, DepartmentID ); "
    "Diagnoses ( DiagnosisID, VisitID, ICDCode, DiagnosisDescription )"
)

QUESTION_TEMPLATES = [
    ("How many visits did patient {PID} have in department {DID}?",
     "SELECT COUNT(*) FROM Visits WHERE PatientID = {PID} AND DepartmentID = {DID};"),
    ("What is the total charge for patient {PID}?",
     "SELECT SUM(TotalCharge) FROM Visits WHERE PatientID = {PID};"),
    ("Show monthly visit counts in {YEAR}",
     "SELECT MONTH(VisitDate) AS [Month], COUNT(*) AS VisitCount "
     "FROM Visits WHERE YEAR(VisitDate) = {YEAR} GROUP BY MONTH(VisitDate) ORDER BY [Month];"),
    ("List providers in department {DID} ordered by name",
     "SELECT ProviderID, ProviderName FROM Providers WHERE DepartmentID = {DID} ORDER BY ProviderName;"),
    ("Count visits by department in {YEAR}",
     "SELECT DepartmentID, COUNT(*) AS VisitCount FROM Visits "
     "WHERE YEAR(VisitDate) = {YEAR} GROUP BY DepartmentID ORDER BY VisitCount DESC;"),
]

def rand_id(min_v: int, max_v: int) -> int:
    return random.randint(min_v, max_v)

def rand_year() -> int:
    return random.randint(2018, 2024)


In [5]:
\
# 3) ID Vault (Aggressive, Deterministic)
def extract_placeholders(question: str) -> Tuple[str, Dict[str, str]]:
    id_map: Dict[str, str] = {}
    id_counter = 1
    date_counter = 1
    year_counter = 1
    q = question

    date_pat = re.compile(r"\b\d{4}-\d{2}-\d{2}\b")

    def repl_date(m):
        nonlocal date_counter
        value = m.group(0)
        ph = f"__DATE_{date_counter}__"
        id_map[ph] = value
        date_counter += 1
        return ph

    q = date_pat.sub(repl_date, q)

    int_pat = re.compile(r"\b\d{1,6}\b")

    def looks_like_year(num_str: str, text: str, start_idx: int) -> bool:
        try:
            n = int(num_str)
        except:
            return False
        if not (1900 <= n <= 2100):
            return False
        left = max(0, start_idx - 18)
        window = text[left:start_idx].lower()
        triggers = [" in ", " year", " during", " for ", " within", " by "]
        return any(t in window for t in triggers)

    def repl_int(m):
        nonlocal id_counter, year_counter
        value = m.group(0)
        start = m.start()
        if looks_like_year(value, q, start):
            ph = f"__YEAR_{year_counter}__"
            id_map[ph] = value
            year_counter += 1
            return ph
        ph = f"__ID_{id_counter}__"
        id_map[ph] = value
        id_counter += 1
        return ph

    q = int_pat.sub(repl_int, q)
    return q, id_map

def id_map_to_text(id_map: Dict[str, str]) -> str:
    def idx(ph: str):
        m = re.search(r"_(\d+)__", ph)
        return int(m.group(1)) if m else 0
    items = sorted(id_map.items(), key=lambda kv: (kv[0].split("_")[1], idx(kv[0])))
    return " ; ".join([f"{k} = {v}" for k, v in items]) if items else ""

def finalize_sql(sql: str) -> str:
    sql = sql.strip()
    if not sql.endswith(";"):
        sql = sql.rstrip(";") + ";"
    return sql + " " + SQL_END

q0 = "How many visits did patient 5432 have in department 25?"
clean0, map0 = extract_placeholders(q0)
print("Original:", q0)
print("Clean:", clean0)
print("ID Map:", map0)


Original: How many visits did patient 5432 have in department 25?
Clean: How many visits did patient __ID_1__ have in department __ID_2__?
ID Map: {'__ID_1__': '5432', '__ID_2__': '25'}


In [6]:
\
# 4) Generate dataset (JSONL)
def render_template(q_tmpl: str, sql_tmpl: str) -> str:
    pid = str(rand_id(1000, 9999))
    did = str(rand_id(1, 99))
    yr  = str(rand_year())

    q = q_tmpl.format(PID=pid, DID=did, YEAR=yr)
    sql = sql_tmpl.format(PID=pid, DID=did, YEAR=yr)

    clean_q, id_map = extract_placeholders(q)

    val_to_ph = {}
    for ph, val in id_map.items():
        val_to_ph.setdefault(val, ph)

    sql_ph = sql
    for val, ph in val_to_ph.items():
        sql_ph = re.sub(rf"\b{re.escape(val)}\b", ph, sql_ph)

    sample = {
        "schema_id": SCHEMA_ID,
        "schema_text": SCHEMA_TEXT,
        "question": clean_q,
        "id_map": id_map_to_text(id_map),
        "sql": finalize_sql(sql_ph),
    }
    return json.dumps(sample, ensure_ascii=False)

def write_jsonl(path: str, n: int):
    with open(path, "w", encoding="utf-8") as f:
        for _ in range(n):
            q_tmpl, sql_tmpl = random.choice(QUESTION_TEMPLATES)
            f.write(render_template(q_tmpl, sql_tmpl) + "\n")

print("Generating training dataset...")
write_jsonl(TRAIN_JSONL, N_TRAIN)
print("Generating validation dataset...")
write_jsonl(VAL_JSONL, N_VAL)

print(f"Generated {N_TRAIN} training samples")
print(f"Generated {N_VAL} validation samples")

with open(TRAIN_JSONL, "r", encoding="utf-8") as f:
    ex = json.loads(next(f))
print("Example sample:")
print(json.dumps(ex, indent=2))


Generating training dataset...
Generating validation dataset...
Generated 5000 training samples
Generated 200 validation samples
Example sample:
{
  "schema_id": "healthcare_analytics",
  "schema_text": "Patients ( PatientID, FirstName, LastName, DateOfBirth, Gender, InsuranceProvider ); Visits ( VisitID, PatientID, VisitDate, DepartmentID, ProviderID, VisitType, TotalCharge ); Departments ( DepartmentID, DepartmentName, Location ); Providers ( ProviderID, ProviderName, Specialty, DepartmentID ); Diagnoses ( DiagnosisID, VisitID, ICDCode, DiagnosisDescription )",
  "question": "Count visits by department in __YEAR_1__",
  "id_map": "__YEAR_1__ = 2020",
  "sql": "SELECT DepartmentID, COUNT(*) AS VisitCount FROM Visits WHERE YEAR(VisitDate) = __YEAR_1__ GROUP BY DepartmentID ORDER BY VisitCount DESC; </SQL>"
}


In [7]:
\
# 5) Tokenizer training (BPE) with special tokens
def format_training_text(sample: Dict) -> str:
    allowed = []
    for ph in re.findall(r"__\\w+_\\d+__", sample['id_map']):
        allowed.append(ph)
    allowed_txt = ' '.join(allowed)
    rules = 'RULES: Output must be exactly one T-SQL SELECT statement. Start with SELECT. End with a semicolon. Use only schema tables/columns. Use placeholders exactly as provided. Append </SQL> at the end.'
    return (
        f"{rules}\n"
        f"SCHEMA: {sample['schema_text']}\n"
        f"QUESTION: {sample['question']}\n"
        f"ID_MAP: {sample['id_map']}\n"
        f"ALLOWED_PLACEHOLDERS: {allowed_txt}\n"
        f"SQL: {sample['sql']}"
    )

def iter_training_texts(jsonl_path: str):
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            s = json.loads(line)
            yield format_training_text(s)

special_tokens = [
    "[PAD]", "[UNK]", "[BOS]", "[EOS]",
    "SCHEMA:", "QUESTION:", "ID_MAP:", "SQL:",
    SQL_END,
    "__ID_", "__DATE_", "__YEAR_",
]
for i in range(1, 65):
    special_tokens.append(f"__ID_{i}__")
for i in range(1, 17):
    special_tokens.append(f"__DATE_{i}__")
for i in range(1, 9):
    special_tokens.append(f"__YEAR_{i}__")

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()

trainer = BpeTrainer(vocab_size=2000, min_frequency=2, special_tokens=special_tokens, show_progress=True)

texts = list(iter_training_texts(TRAIN_JSONL))
print(f"Training tokenizer with {len(special_tokens)} special tokens...")
tokenizer.train_from_iterator(texts, trainer=trainer)
tokenizer.save(TOKENIZER_PATH)
print("Tokenizer saved:", TOKENIZER_PATH)
print("Vocab size:", tokenizer.get_vocab_size())

t1 = "SELECT COUNT(*) FROM Visits WHERE PatientID = __ID_1__;"
print("Tokens:", tokenizer.encode(t1).tokens)

t2 = "SQL: SELECT * FROM Visits; </SQL>"
print("Tokens:", tokenizer.encode(t2).tokens)
print("SQL: tokens:", tokenizer.encode("SQL:").tokens, tokenizer.encode("SQL:").ids)


Training tokenizer with 100 special tokens...
Tokenizer saved: /content/artifacts_sql_agent/tokenizer.json
Vocab size: 981
Tokens: ['SELECT', 'COUNT', '(*)', 'FROM', 'Visits', 'WHERE', 'PatientID', '=', '__ID_1__', ';']
Tokens: ['SQL:', 'SELECT', '*', 'FROM', 'Visits', ';', '</SQL>']
SQL: tokens: ['SQL:'] [7]


In [8]:
\
# 6) Dataset with loss masking BY CONSTRUCTION
PAD_ID = tokenizer.token_to_id("[PAD]")
SQL_END_ID = tokenizer.encode(SQL_END).ids[0]
assert PAD_ID is not None and SQL_END_ID is not None

def build_prompt_text(sample: Dict) -> str:
    allowed = []
    for ph in re.findall(r"__\\w+_\\d+__", sample['id_map']):
        allowed.append(ph)
    allowed_txt = ' '.join(allowed)
    rules = 'RULES: Output must be exactly one T-SQL SELECT statement. Start with SELECT. End with a semicolon. Use only schema tables/columns. Use placeholders exactly as provided. Append </SQL> at the end.'
    return (
        f"{rules}\n"
        f"SCHEMA: {sample['schema_text']}\n"
        f"QUESTION: {sample['question']}\n"
        f"ID_MAP: {sample['id_map']}\n"
        f"ALLOWED_PLACEHOLDERS: {allowed_txt}\n"
        f"SQL:"
    )

class SQLDataset(Dataset):
    def __init__(self, jsonl_path: str, tokenizer: Tokenizer, max_len: int):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.samples = []
        with open(jsonl_path, "r", encoding="utf-8") as f:
            for line in f:
                self.samples.append(json.loads(line))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        prompt = build_prompt_text(s)
        sql = " " + s["sql"].strip()

        prompt_ids = self.tokenizer.encode(prompt).ids
        sql_ids = self.tokenizer.encode(sql).ids

        input_ids = (prompt_ids + sql_ids)[:self.max_len]
        labels = ([-100] * len(prompt_ids) + sql_ids[:])[:self.max_len]

        if len(input_ids) < self.max_len:
            pad_len = self.max_len - len(input_ids)
            input_ids += [PAD_ID] * pad_len
            labels += [-100] * pad_len

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
        }

def collate(batch):
    x = torch.stack([b["input_ids"] for b in batch], dim=0)
    y = torch.stack([b["labels"] for b in batch], dim=0)
    return x, y

train_ds = SQLDataset(TRAIN_JSONL, tokenizer, MAX_LEN)
val_ds   = SQLDataset(VAL_JSONL, tokenizer, MAX_LEN)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)

x, y = next(iter(train_loader))
first = (y[0] != -100).nonzero(as_tuple=True)[0][0].item()
print("First supervised label idx:", first)
# print("Decoded near boundary:", tokenizer.decode(x[0].tolist(, skip_special_tokens=False)[max(0, first-10):first+25]))


First supervised label idx: 116


In [9]:
\
# 7) Decoder-only model (TransformerEncoder + causal mask)
class GPTSmall(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_head: int, n_layer: int, d_ff: int, dropout: float, max_len: int):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_head,
            dim_feedforward=d_ff,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.blocks = nn.TransformerEncoder(enc_layer, num_layers=n_layer)
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.max_len = max_len

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        B, T = input_ids.shape
        pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0).expand(B, T)
        x = self.tok_emb(input_ids) + self.pos_emb(pos)
        x = self.dropout(x)
        causal = torch.triu(torch.ones(T, T, device=input_ids.device, dtype=torch.bool), diagonal=1)
        x = self.blocks(x, mask=causal)
        x = self.ln_f(x)
        return self.head(x)

vocab_size = tokenizer.get_vocab_size()
model = GPTSmall(vocab_size, D_MODEL, N_HEAD, N_LAYER, D_FF, DROPOUT, MAX_LEN).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,} ({n_params/1e6:.2f}M)")


Model parameters: 15,097,344 (15.10M)




In [10]:
\
# 8) Training loop + checkpointing + warmup
def get_lr(step: int, base_lr: float, warmup: int) -> float:
    if step < warmup:
        return base_lr * (step / max(1, warmup))
    return base_lr

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, betas=(0.9, 0.95), weight_decay=0.1)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

def save_ckpt(epoch: int, global_step: int):
    path = os.path.join(CKPT_DIR, f"checkpoint_epoch_{epoch}.pt")
    torch.save({
        "epoch": epoch,
        "global_step": global_step,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }, path)
    print("Checkpoint saved:", path)

def load_latest_ckpt() -> Tuple[int, int]:
    files = [f for f in os.listdir(CKPT_DIR) if f.startswith("checkpoint_epoch_") and f.endswith(".pt")]
    if not files:
        return 0, 0
    files.sort(key=lambda x: int(re.search(r"checkpoint_epoch_(\d+)\.pt", x).group(1)))
    latest = files[-1]
    ckpt = torch.load(os.path.join(CKPT_DIR, latest), map_location=DEVICE)
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    print("Loaded checkpoint:", latest)
    return ckpt["epoch"], ckpt["global_step"]

@torch.no_grad()
def eval_loss(loader: DataLoader) -> float:
    model.eval()
    total = 0.0
    n = 0
    for input_ids, labels in loader:
        input_ids = input_ids.to(DEVICE)
        labels = labels.to(DEVICE)
        with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
            logits = model(input_ids)
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        total += loss.item()
        n += 1
    model.train()
    return total / max(1, n)

start_epoch, global_step = load_latest_ckpt()

print("Starting training...")
for epoch in range(start_epoch + 1, EPOCHS + 1):
    model.train()
    t0 = time.time()
    running = 0.0
    n = 0

    for batch_idx, (input_ids, labels) in enumerate(train_loader, start=1):
        global_step += 1
        lr_now = get_lr(global_step, LR, WARMUP_STEPS)
        for pg in optimizer.param_groups:
            pg["lr"] = lr_now

        input_ids = input_ids.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
            logits = model(input_ids)
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()

        running += loss.item()
        n += 1

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch} step {batch_idx}/{len(train_loader)} loss={running/n:.4f} lr={lr_now:.2e}")

    train_loss = running / max(1, n)
    val_loss = eval_loss(val_loader)
    dt = time.time() - t0
    print(f"\nEpoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f} (time {dt:.1f}s)")
    save_ckpt(epoch, global_step)

print("Training complete!")


  scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))
  with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):


Starting training...
Epoch 1 step 100/625 loss=3.4958 lr=1.00e-04
Epoch 1 step 200/625 loss=1.7778 lr=2.00e-04
Epoch 1 step 300/625 loss=1.1859 lr=2.00e-04
Epoch 1 step 400/625 loss=0.8897 lr=2.00e-04
Epoch 1 step 500/625 loss=0.7119 lr=2.00e-04
Epoch 1 step 600/625 loss=0.5932 lr=2.00e-04


  with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):



Epoch 1: Train Loss = 0.5695, Val Loss = 0.0000 (time 20.8s)
Checkpoint saved: /content/checkpoints_sql_agent/checkpoint_epoch_1.pt
Epoch 2 step 100/625 loss=0.0000 lr=2.00e-04
Epoch 2 step 200/625 loss=0.0000 lr=2.00e-04
Epoch 2 step 300/625 loss=0.0000 lr=2.00e-04
Epoch 2 step 400/625 loss=0.0009 lr=2.00e-04
Epoch 2 step 500/625 loss=0.0008 lr=2.00e-04
Epoch 2 step 600/625 loss=0.0009 lr=2.00e-04

Epoch 2: Train Loss = 0.0009, Val Loss = 0.0000 (time 20.1s)
Checkpoint saved: /content/checkpoints_sql_agent/checkpoint_epoch_2.pt
Epoch 3 step 100/625 loss=0.0000 lr=2.00e-04
Epoch 3 step 200/625 loss=0.0000 lr=2.00e-04
Epoch 3 step 300/625 loss=0.0000 lr=2.00e-04
Epoch 3 step 400/625 loss=0.0000 lr=2.00e-04
Epoch 3 step 500/625 loss=0.0000 lr=2.00e-04
Epoch 3 step 600/625 loss=0.0000 lr=2.00e-04

Epoch 3: Train Loss = 0.0000, Val Loss = 0.0000 (time 19.9s)
Checkpoint saved: /content/checkpoints_sql_agent/checkpoint_epoch_3.pt
Training complete!


In [11]:
\
# 9) Inference: token slicing + sentinel stop + extraction + validation
def build_prompt(schema_text: str, clean_question: str, id_map_text: str) -> str:
    rules = 'RULES: Output must be exactly one T-SQL SELECT statement. Start with SELECT. End with a semicolon. Use only schema tables/columns. Use placeholders exactly as provided. Append </SQL> at the end.'
    allowed = ' '.join(re.findall(r"__\\w+_\\d+__", id_map_text))
    return (
        f"{rules}\n"
        f"SCHEMA: {schema_text}\n"
        f"QUESTION: {clean_question}\n"
        f"ID_MAP: {id_map_text}\n"
        f"ALLOWED_PLACEHOLDERS: {allowed}\n"
        f"SQL:"
    )

@torch.no_grad()
def generate_ids(prompt_text: str, max_new_tokens: int = GEN_MAX_NEW_TOKENS) -> Tuple[List[int], int]:
    model.eval()
    prompt_ids = tokenizer.encode(prompt_text).ids
    input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=DEVICE)

    allowed_ph = set(re.findall(r"__\w+_\d+__", prompt_text))
    allowed_ph_ids = set()
    for ph in allowed_ph:
        enc = tokenizer.encode(ph)
        if len(enc.ids) == 1:
            allowed_ph_ids.add(enc.ids[0])

    # Gather all known placeholder token ids (they are single-token special tokens in this setup)
    all_ph_ids = []
    vocab = tokenizer.get_vocab()
    for tok, tid in vocab.items():
        if re.fullmatch(r"__\w+_\d+__", tok):
            all_ph_ids.append(tid)
    all_ph_ids = torch.tensor(sorted(all_ph_ids), device=DEVICE, dtype=torch.long) if all_ph_ids else None

    SEMI_ID = tokenizer.encode(";").ids[0]

    for _ in range(max_new_tokens):
        if input_ids.size(1) > MAX_LEN:
            input_ids = input_ids[:, -MAX_LEN:]
        logits = model(input_ids)
        next_logits = logits[:, -1, :].clone()

        # Constrain: forbid placeholders not in allowed set
        if all_ph_ids is not None and all_ph_ids.numel() > 0:
            forbid = []
            for tid in all_ph_ids.tolist():
                if tid not in allowed_ph_ids:
                    forbid.append(tid)
            if forbid:
                next_logits[:, torch.tensor(forbid, device=DEVICE, dtype=torch.long)] = -1e9

        if TEMPERATURE and TEMPERATURE > 0:
            probs = torch.softmax(next_logits / TEMPERATURE, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
        else:
            next_id = torch.argmax(next_logits, dim=-1, keepdim=True)

        tok = int(next_id.item())
        input_ids = torch.cat([input_ids, next_id], dim=1)

        # Stop at semicolon, then append </SQL> for consistency
        if tok == SEMI_ID:
            input_ids = torch.cat([input_ids, torch.tensor([[SQL_END_ID]], device=DEVICE, dtype=torch.long)], dim=1)
            break

        if tok == SQL_END_ID:
            break

    return input_ids[0].tolist(), len(prompt_ids)

def extract_sql(completion_text: str) -> str:
    if SQL_END in completion_text:
        completion_text = completion_text.split(SQL_END, 1)[0]
    m = re.search(r"\bSELECT\b", completion_text, flags=re.IGNORECASE)
    if not m:
        return ""
    tail = completion_text[m.start():]
    semi = tail.find(";")
    if semi == -1:
        return ""
    return tail[:semi+1].strip()

DISALLOWED = re.compile(r"\b(INSERT|UPDATE|DELETE|MERGE|DROP|ALTER|CREATE|EXEC|GRANT|REVOKE|TRUNCATE)\b", re.IGNORECASE)
KNOWN_TABLES = {"Patients","Visits","Departments","Providers","Diagnoses"}

def validate_sql(sql: str, id_map: Dict[str,str]) -> Tuple[bool, List[str]]:
    errs = []
    s = sql.strip()
    if not s:
        errs.append("No SQL extracted.")
        return False, errs

    if not re.match(r"^SELECT\b", s, flags=re.IGNORECASE):
        errs.append("Must start with SELECT.")
    if not s.endswith(";"):
        errs.append("Must end with ';'.")
    if ";" in s[:-1]:
        errs.append("Must contain exactly one statement ending with ';'.")
    if DISALLOWED.search(s):
        errs.append("Contains disallowed keyword.")

    if re.search(r"=\s*;", s):
        errs.append("Empty predicate value (found '= ;').")
    if re.search(r"=\s*(AND|OR|GROUP\b|ORDER\b|HAVING\b|$)", s, flags=re.IGNORECASE):
        errs.append("Empty predicate value (found '= AND/OR/GROUP/ORDER').")

    used = re.findall(r"__\w+_\d+__", s)
    used_set = set(used)

    for ph in used_set:
        if ph not in id_map:
            errs.append(f"Unknown placeholder: {ph}")

    required = set(id_map.keys())
    missing = sorted(list(required - used_set))
    if missing:
        errs.append(f"Missing required placeholder(s): {', '.join(missing)}")

    for ph in used_set:
        if used.count(ph) > 3:
            errs.append(f"Placeholder repeated too many times: {ph}")

    if not any(re.search(rf"\b{t}\b", s) for t in KNOWN_TABLES):
        errs.append("No known tables found in SQL.")

    return len(errs) == 0, errs

def render_ids(sql: str, id_map: Dict[str,str]) -> str:
    out = sql
    for ph in sorted(id_map.keys(), key=len, reverse=True):
        out = out.replace(ph, id_map[ph])
    return out

def generate_sql(question: str) -> Dict:
    clean_q, id_map = extract_placeholders(question)
    id_map_text = id_map_to_text(id_map)
    prompt = build_prompt(SCHEMA_TEXT, clean_q, id_map_text)

    all_ids, prompt_len = generate_ids(prompt)
    completion_ids = all_ids[prompt_len:]  # token-based slicing
    completion_text = tokenizer.decode(completion_ids, skip_special_tokens=False)

    sql_ph = extract_sql(completion_text)
    ok, errs = validate_sql(sql_ph, id_map)
    sql_final = render_ids(sql_ph, id_map) if ok else ""

    return {
        "original": question,
        "clean": clean_q,
        "id_map": id_map,
        "completion": completion_text,
        "sql_placeholders": sql_ph,
        "valid": ok,
        "errors": errs,
        "sql_final": sql_final,
    }

tests = [
    "How many visits did patient 5432 have in department 25?",
    "What is the total charge for patient 7890?",
    "Show monthly visit counts in 2023",
]

for i, q in enumerate(tests, 1):
    r = generate_sql(q)
    print(f"\n[Q{i}] {q}")
    print("Clean:", r["clean"])
    print("ID Map:", r["id_map"])
    print("SQL (placeholders):", r["sql_placeholders"])
    print("Valid:", r["valid"])
    if not r["valid"]:
        print("Errors:", r["errors"])
        print("Completion preview:", r["completion"][:220])
    else:
        print("SQL (final):", r["sql_final"])



[Q1] How many visits did patient 5432 have in department 25?
Clean: How many visits did patient __ID_1__ have in department __ID_2__?
ID Map: {'__ID_1__': '5432', '__ID_2__': '25'}
SQL (placeholders): SELECT COUNT (*) FROM Visits WHERE PatientID = __ID_1__ AND DepartmentID = __ID_2__ ;
Valid: True
SQL (final): SELECT COUNT (*) FROM Visits WHERE PatientID = 5432 AND DepartmentID = 25 ;

[Q2] What is the total charge for patient 7890?
Clean: What is the total charge for patient __ID_1__?
ID Map: {'__ID_1__': '7890'}
SQL (placeholders): SELECT SUM ( TotalCharge ) FROM Visits WHERE PatientID = __ID_1__ ;
Valid: True
SQL (final): SELECT SUM ( TotalCharge ) FROM Visits WHERE PatientID = 7890 ;

[Q3] Show monthly visit counts in 2023
Clean: Show monthly visit counts in __YEAR_1__
ID Map: {'__YEAR_1__': '2023'}
SQL (placeholders): SELECT MONTH ( VisitDate ) AS [ Month ], COUNT (*) AS VisitCount FROM Visits WHERE YEAR ( VisitDate ) = __YEAR_1__ GROUP BY MONTH ( VisitDate ) ORDER BY [ Month ];


In [12]:
\
# 10) Pass-rate evaluation (synthetic)
def sample_questions(n: int = 100) -> List[str]:
    out = []
    for _ in range(n):
        q_tmpl, _ = random.choice(QUESTION_TEMPLATES)
        pid = str(rand_id(1000, 9999))
        did = str(rand_id(1, 99))
        yr  = str(rand_year())
        out.append(q_tmpl.format(PID=pid, DID=did, YEAR=yr))
    return out

def eval_pass_rate(n: int = 100):
    qs = sample_questions(n)
    passed = 0
    reasons = {}
    for q in qs:
        r = generate_sql(q)
        if r["valid"]:
            passed += 1
        else:
            key = r["errors"][0] if r["errors"] else "Unknown"
            reasons[key] = reasons.get(key, 0) + 1
    print(f"Pass rate: {passed}/{n} = {passed/n:.1%}")
    for k, v in sorted(reasons.items(), key=lambda kv: kv[1], reverse=True)[:10]:
        print(f"- {k}: {v}")

eval_pass_rate(100)


Pass rate: 100/100 = 100.0%


In [15]:
# # === Interactive Q&A cell ===
# # Ask any question; it will print the cleaned question, placeholders, and final SQL.

# def ask(q: str):
#     r = generate_sql(q)
#     print("\nQUESTION:", r["original"])
#     print("CLEAN:", r["clean"])
#     print("ID MAP:", r["id_map"])
#     print("SQL (placeholders):", r["sql_placeholders"])
#     print("VALID:", r["valid"])
#     if r["valid"]:
#         print("SQL (final):", r["sql_final"])
#     else:
#         print("ERRORS:", r["errors"])
#         print("COMPLETION (preview):", r["completion"][:300])

# while True:
#     q = input("\nAsk a question (blank to stop): ").strip()
#     if not q:
#         break
#     ask(q)


In [14]:
# === Compact the trained model (fp16 + optional int8 dynamic quantization) ===
# fp16: smaller + fast on GPU (use model.half()).
# int8 dynamic: much smaller + faster on CPU (Linear layers quantized).

import os, torch
from pathlib import Path

OUT_DIR = "/content/compact_sql_agent"
os.makedirs(OUT_DIR, exist_ok=True)

def file_mb(p):
    return Path(p).stat().st_size / (1024**2)

# 1) Save FP16 checkpoint (best for GPU inference)
fp16_path = os.path.join(OUT_DIR, "model_fp16.pt")
model_fp16 = model.to("cpu").eval().half()
torch.save({"model_state_dict": model_fp16.state_dict(), "vocab_size": tokenizer.get_vocab_size()}, fp16_path)
print("Saved FP16:", fp16_path, f"({file_mb(fp16_path):.2f} MB)")

# 2) Save INT8 dynamic-quantized checkpoint (best for CPU inference)
# Note: runs on CPU; do NOT move this quantized model to CUDA.
from torch.ao.quantization import quantize_dynamic

int8_path = os.path.join(OUT_DIR, "model_int8_dynamic.pt")
model_fp32 = model.to("cpu").eval().float()
model_int8 = quantize_dynamic(model_fp32, {torch.nn.Linear}, dtype=torch.qint8)
torch.save({"model_state_dict": model_int8.state_dict(), "vocab_size": tokenizer.get_vocab_size()}, int8_path)
print("Saved INT8 dynamic:", int8_path, f"({file_mb(int8_path):.2f} MB)")

print("\nHow to use:")
print("- GPU: load fp16 state_dict into GPTSmall(...).half().to('cuda')")
print("- CPU: load int8 state_dict into quantized model (keep on CPU)")


Saved FP16: /content/compact_sql_agent/model_fp16.pt (28.83 MB)
Saved INT8 dynamic: /content/compact_sql_agent/model_int8_dynamic.pt (29.56 MB)

How to use:
- GPU: load fp16 state_dict into GPTSmall(...).half().to('cuda')
- CPU: load int8 state_dict into quantized model (keep on CPU)


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  model_int8 = quantize_dynamic(model_fp32, {torch.nn.Linear}, dtype=torch.qint8)
