# Healthcare SQL Agent - Model-First MVP
## End-to-End Training and Inference Pipeline

This notebook implements a complete pipeline:
1. Generate synthetic T-SQL training dataset
2. Train a BPE tokenizer
3. Train a tiny decoder-only model from scratch
4. Run inference with validation
5. Demo with 3 example questions

## 1. Install Dependencies

In [None]:
!pip install -q torch transformers tokenizers datasets tqdm jsonlines safetensors accelerate

## 2. Imports and Setup

In [None]:
import json
import random
import re
import os
from pathlib import Path
from typing import List, Dict, Tuple
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tqdm import tqdm
import math

# Set random seeds
random.seed(42)
torch.manual_seed(42)

# Create directories
os.makedirs('data', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('artifacts/tokenizer', exist_ok=True)

print("Setup complete!")

## 3. Example Schema Definition

In [None]:
# Define example healthcare schema
EXAMPLE_SCHEMA = {
    "schema_name": "healthcare_analytics",
    "tables": [
        {
            "name": "Patients",
            "columns": [
                {"name": "PatientID", "type": "INT", "pk": True},
                {"name": "FirstName", "type": "VARCHAR(50)"},
                {"name": "LastName", "type": "VARCHAR(50)"},
                {"name": "DateOfBirth", "type": "DATE"},
                {"name": "Gender", "type": "VARCHAR(10)"},
                {"name": "InsuranceProvider", "type": "VARCHAR(100)"}
            ]
        },
        {
            "name": "Visits",
            "columns": [
                {"name": "VisitID", "type": "INT", "pk": True},
                {"name": "PatientID", "type": "INT", "fk": "Patients.PatientID"},
                {"name": "VisitDate", "type": "DATE"},
                {"name": "DepartmentID", "type": "INT"},
                {"name": "ProviderID", "type": "INT"},
                {"name": "VisitType", "type": "VARCHAR(50)"},
                {"name": "TotalCharge", "type": "DECIMAL(10,2)"}
            ]
        },
        {
            "name": "Departments",
            "columns": [
                {"name": "DepartmentID", "type": "INT", "pk": True},
                {"name": "DepartmentName", "type": "VARCHAR(100)"},
                {"name": "Location", "type": "VARCHAR(100)"}
            ]
        },
        {
            "name": "Providers",
            "columns": [
                {"name": "ProviderID", "type": "INT", "pk": True},
                {"name": "ProviderName", "type": "VARCHAR(100)"},
                {"name": "Specialty", "type": "VARCHAR(100)"},
                {"name": "DepartmentID", "type": "INT", "fk": "Departments.DepartmentID"}
            ]
        },
        {
            "name": "Diagnoses",
            "columns": [
                {"name": "DiagnosisID", "type": "INT", "pk": True},
                {"name": "VisitID", "type": "INT", "fk": "Visits.VisitID"},
                {"name": "ICDCode", "type": "VARCHAR(20)"},
                {"name": "DiagnosisDescription", "type": "VARCHAR(255)"}
            ]
        }
    ]
}

# Save schema
with open('data/example_schema.json', 'w') as f:
    json.dump(EXAMPLE_SCHEMA, f, indent=2)

print("Schema created with", len(EXAMPLE_SCHEMA['tables']), "tables")

## 4. Dataset Generator

In [None]:
def generate_schema_compact_text(schema: Dict) -> str:
    """Generate compact schema text for prompts."""
    lines = []
    for table in schema['tables']:
        cols = ', '.join([c['name'] for c in table['columns']])
        lines.append(f"{table['name']}({cols})")
    return '; '.join(lines)

def extract_placeholders(question: str) -> Tuple[str, Dict[str, str]]:
    """Extract IDs and dates, replace with placeholders."""
    id_map = {}
    counter = 1
    
    # Extract numeric IDs (pattern: number after 'ID', 'id', or standalone numbers in context)
    def replace_id(match):
        nonlocal counter
        value = match.group(0)
        placeholder = f"__ID_{counter}__"
        id_map[placeholder] = value
        counter += 1
        return placeholder
    
    # Replace numeric patterns (simple heuristic for MVP)
    question_clean = re.sub(r'\b\d{4,}\b', replace_id, question)
    
    # Extract dates
    date_counter = 1
    def replace_date(match):
        nonlocal date_counter
        value = match.group(0)
        placeholder = f"__DATE_{date_counter}__"
        id_map[placeholder] = value
        date_counter += 1
        return placeholder
    
    question_clean = re.sub(r'\d{4}-\d{2}-\d{2}', replace_date, question_clean)
    
    return question_clean, id_map

def generate_dataset_samples(schema: Dict, num_samples: int) -> List[Dict]:
    """Generate synthetic SQL training samples."""
    samples = []
    schema_text = generate_schema_compact_text(schema)
    
    templates = [
        # COUNT templates
        {
            "question": "How many visits did patient {patient_id} have in department {dept_id}?",
            "sql": "SELECT COUNT(*) FROM Visits WHERE PatientID = {patient_id_ph} AND DepartmentID = {dept_id_ph};"
        },
        {
            "question": "Count total visits in {year}",
            "sql": "SELECT COUNT(*) FROM Visits WHERE YEAR(VisitDate) = {year_ph};"
        },
        # SUM templates
        {
            "question": "What is the total charge for patient {patient_id}?",
            "sql": "SELECT SUM(TotalCharge) FROM Visits WHERE PatientID = {patient_id_ph};"
        },
        {
            "question": "Total charges for department {dept_id} in {year}",
            "sql": "SELECT SUM(TotalCharge) FROM Visits WHERE DepartmentID = {dept_id_ph} AND YEAR(VisitDate) = {year_ph};"
        },
        # AVG templates
        {
            "question": "What is the average charge per visit for provider {provider_id}?",
            "sql": "SELECT AVG(TotalCharge) FROM Visits WHERE ProviderID = {provider_id_ph};"
        },
        # GROUP BY templates
        {
            "question": "Show visit counts by department for patient {patient_id}",
            "sql": "SELECT DepartmentID, COUNT(*) FROM Visits WHERE PatientID = {patient_id_ph} GROUP BY DepartmentID;"
        },
        {
            "question": "Show monthly visit counts in {year}",
            "sql": "SELECT MONTH(VisitDate) AS Month, COUNT(*) FROM Visits WHERE YEAR(VisitDate) = {year_ph} GROUP BY MONTH(VisitDate);"
        },
        {
            "question": "Total charges by provider in department {dept_id}",
            "sql": "SELECT ProviderID, SUM(TotalCharge) FROM Visits WHERE DepartmentID = {dept_id_ph} GROUP BY ProviderID;"
        },
        # JOIN templates
        {
            "question": "List all visits with patient names for patient {patient_id}",
            "sql": "SELECT V.VisitID, P.FirstName, P.LastName, V.VisitDate FROM Visits V JOIN Patients P ON V.PatientID = P.PatientID WHERE V.PatientID = {patient_id_ph};"
        },
        {
            "question": "Show visit counts by department name",
            "sql": "SELECT D.DepartmentName, COUNT(V.VisitID) FROM Visits V JOIN Departments D ON V.DepartmentID = D.DepartmentID GROUP BY D.DepartmentName;"
        },
    ]
    
    for i in range(num_samples):
        template = random.choice(templates)
        
        # Generate random IDs
        patient_id = random.randint(1000, 9999)
        dept_id = random.randint(10, 99)
        provider_id = random.randint(100, 999)
        year = random.randint(2020, 2024)
        
        # Fill template
        question = template['question'].format(
            patient_id=patient_id,
            dept_id=dept_id,
            provider_id=provider_id,
            year=year
        )
        
        # Extract placeholders
        question_clean, id_map = extract_placeholders(question)
        
        # Build SQL with placeholders
        sql = template['sql'].format(
            patient_id_ph="__ID_1__" if '{patient_id_ph}' in template['sql'] else '',
            dept_id_ph="__ID_2__" if '{dept_id_ph}' in template['sql'] else "__ID_1__",
            provider_id_ph="__ID_1__",
            year_ph="__ID_1__" if '{year_ph}' in template['sql'] else "__ID_2__"
        )
        
        # Create ID map text
        id_map_text = ', '.join([f"{k}={v}" for k, v in id_map.items()]) if id_map else "None"
        
        # Create training record
        record = {
            "schema_id": schema['schema_name'],
            "schema_text": schema_text,
            "question": question_clean,
            "id_map": id_map_text,
            "sql": sql
        }
        
        samples.append(record)
    
    return samples

# Generate datasets
print("Generating training dataset...")
train_samples = generate_dataset_samples(EXAMPLE_SCHEMA, 5000)

print("Generating validation dataset...")
val_samples = generate_dataset_samples(EXAMPLE_SCHEMA, 200)

# Save datasets
with open('data/train.jsonl', 'w') as f:
    for sample in train_samples:
        f.write(json.dumps(sample) + '\n')

with open('data/val.jsonl', 'w') as f:
    for sample in val_samples:
        f.write(json.dumps(sample) + '\n')

print(f"Generated {len(train_samples)} training samples")
print(f"Generated {len(val_samples)} validation samples")
print("\nExample sample:")
print(json.dumps(train_samples[0], indent=2))

## 5. Train BPE Tokenizer

In [None]:
def format_training_text(sample: Dict) -> str:
    """Format sample into training text with markers."""
    return f"""SCHEMA: {sample['schema_text']}
QUESTION: {sample['question']}
ID_MAP: {sample['id_map']}
SQL: {sample['sql']}"""

# Prepare training text for tokenizer
print("Preparing tokenizer training data...")
tokenizer_training_file = 'data/tokenizer_train.txt'
with open(tokenizer_training_file, 'w') as f:
    for sample in train_samples:
        f.write(format_training_text(sample) + '\n\n')

# Train tokenizer
print("Training BPE tokenizer...")
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()

trainer = BpeTrainer(
    vocab_size=8000,
    min_frequency=2,
    special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"]
)

tokenizer.train([tokenizer_training_file], trainer)

# Save tokenizer
tokenizer.save('artifacts/tokenizer/tokenizer.json')
print("Tokenizer trained and saved!")
print(f"Vocab size: {tokenizer.get_vocab_size()}")

# Test tokenizer
test_text = "SELECT COUNT(*) FROM Visits WHERE PatientID = __ID_1__;"
encoded = tokenizer.encode(test_text)
print(f"\nTest encoding: {test_text}")
print(f"Tokens: {encoded.tokens}")

## 6. Define Decoder-Only Transformer Model

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        
    def forward(self, x, mask=None):
        # Self-attention with residual
        attn_out, _ = self.attention(x, x, x, attn_mask=mask, need_weights=False)
        x = self.norm1(x + attn_out)
        # FFN with residual
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 512, n_layers: int = 8, 
                 n_heads: int = 8, max_seq_len: int = 512, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, dropout) for _ in range(n_layers)
        ])
        
        # Output
        self.norm = nn.LayerNorm(d_model)
        self.output = nn.Linear(d_model, vocab_size)
        
    def forward(self, input_ids, attention_mask=None):
        batch_size, seq_len = input_ids.shape
        
        # Create causal mask
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1)
        causal_mask = causal_mask.masked_fill(causal_mask == 1, float('-inf'))
        
        # Embeddings
        positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)
        x = self.token_embedding(input_ids) + self.position_embedding(positions)
        x = self.dropout(x)
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x, mask=causal_mask)
        
        x = self.norm(x)
        logits = self.output(x)
        
        return logits

# Model config (tiny for MVP - ~50M params)
MODEL_CONFIG = {
    'vocab_size': tokenizer.get_vocab_size(),
    'd_model': 512,
    'n_layers': 8,
    'n_heads': 8,
    'max_seq_len': 512,
    'dropout': 0.1
}

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DecoderOnlyTransformer(**MODEL_CONFIG).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model created with {total_params:,} parameters ({total_params/1e6:.1f}M)")
print(f"Device: {device}")

## 7. Dataset and DataLoader

In [None]:
class SQLDataset(Dataset):
    def __init__(self, jsonl_path: str, tokenizer: Tokenizer, max_length: int = 512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.samples = []
        
        # Load samples
        with open(jsonl_path, 'r') as f:
            for line in f:
                self.samples.append(json.loads(line))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Format text
        text = format_training_text(sample)
        
        # Tokenize
        encoded = self.tokenizer.encode(text)
        tokens = encoded.ids
        
        # Truncate or pad
        if len(tokens) > self.max_length:
            tokens = tokens[:self.max_length]
        else:
            tokens = tokens + [0] * (self.max_length - len(tokens))  # 0 is [PAD]
        
        return torch.tensor(tokens, dtype=torch.long)

# Create datasets
train_dataset = SQLDataset('data/train.jsonl', tokenizer, max_length=MODEL_CONFIG['max_seq_len'])
val_dataset = SQLDataset('data/val.jsonl', tokenizer, max_length=MODEL_CONFIG['max_seq_len'])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

print(f"Train dataset: {len(train_dataset)} samples")
print(f"Val dataset: {len(val_dataset)} samples")
print(f"Batches per epoch: {len(train_loader)}")

## 8. Training Loop

In [None]:
def train_epoch(model, train_loader, optimizer, device, epoch):
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for batch_idx, input_ids in enumerate(pbar):
        input_ids = input_ids.to(device)
        
        # Forward pass
        logits = model(input_ids)
        
        # Shift for next-token prediction
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous()
        
        # Calculate loss
        loss_fct = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(train_loader)

def validate(model, val_loader, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for input_ids in val_loader:
            input_ids = input_ids.to(device)
            
            logits = model(input_ids)
            
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = input_ids[:, 1:].contiguous()
            
            loss_fct = nn.CrossEntropyLoss(ignore_index=0)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            
            total_loss += loss.item()
    
    return total_loss / len(val_loader)

# Training setup
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.01)
num_epochs = 3  # Short training for MVP

print("Starting training...\n")
for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_loader, optimizer, device, epoch)
    val_loss = validate(model, val_loader, device)
    
    print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
    
    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'model_config': MODEL_CONFIG
    }
    torch.save(checkpoint, f'checkpoints/checkpoint_epoch_{epoch}.pt')
    print(f"Checkpoint saved: checkpoint_epoch_{epoch}.pt\n")

print("Training complete!")

## 9. Inference and Generation

In [None]:
def generate_sql(model, tokenizer, schema_text, question, id_map_text, 
                 max_length=256, temperature=0.1, device='cuda'):
    """Generate SQL from a question using the trained model."""
    model.eval()
    
    # Build prompt
    prompt = f"""SCHEMA: {schema_text}
QUESTION: {question}
ID_MAP: {id_map_text}
SQL: """
    
    # Encode prompt
    encoded = tokenizer.encode(prompt)
    input_ids = torch.tensor([encoded.ids], dtype=torch.long).to(device)
    
    # Generate
    generated_ids = input_ids[0].tolist()
    
    with torch.no_grad():
        for _ in range(max_length):
            # Get current sequence
            current_ids = torch.tensor([generated_ids], dtype=torch.long).to(device)
            
            # Forward pass
            logits = model(current_ids)
            
            # Get next token logits
            next_token_logits = logits[0, -1, :] / temperature
            
            # Greedy sampling
            next_token_id = torch.argmax(next_token_logits).item()
            
            # Add to sequence
            generated_ids.append(next_token_id)
            
            # Check for EOS or semicolon
            decoded_text = tokenizer.decode(generated_ids)
            if ';' in decoded_text[len(prompt):] or next_token_id == 3:  # 3 is [EOS]
                break
    
    # Decode
    generated_text = tokenizer.decode(generated_ids)
    
    # Extract SQL part
    if 'SQL:' in generated_text:
        sql_part = generated_text.split('SQL:')[1].strip()
        # Clean up
        if ';' in sql_part:
            sql_part = sql_part.split(';')[0] + ';'
        return sql_part
    
    return generated_text

print("Inference function ready!")

## 10. Validation Gates

In [None]:
def validate_sql(sql: str, schema: Dict, id_map: Dict[str, str]) -> Tuple[bool, List[str]]:
    """Validate generated SQL against strict rules."""
    errors = []
    
    # Rule 1: Must be exactly one statement
    if sql.count(';') != 1:
        errors.append("Must contain exactly one statement ending with ';'")
    
    # Rule 2: Must end with semicolon
    if not sql.strip().endswith(';'):
        errors.append("Must end with ';'")
    
    # Rule 3: Must be SELECT only
    sql_upper = sql.upper()
    if not sql_upper.strip().startswith('SELECT'):
        errors.append("Must start with SELECT")
    
    # Rule 4: Block DML/DDL keywords
    forbidden_keywords = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'TRUNCATE', 'EXEC', 'EXECUTE']
    for keyword in forbidden_keywords:
        if keyword in sql_upper:
            errors.append(f"Forbidden keyword: {keyword}")
    
    # Rule 5: Schema-known tables only (basic check)
    table_names = [t['name'] for t in schema['tables']]
    for table in table_names:
        if table.upper() in sql_upper:
            # Found at least one known table
            break
    else:
        errors.append("No known tables found in SQL")
    
    # Rule 6: Placeholder integrity
    placeholders_in_sql = re.findall(r'__ID_\d+__|__DATE_\d+__', sql)
    for ph in placeholders_in_sql:
        if ph not in id_map:
            errors.append(f"Unknown placeholder: {ph}")
    
    return len(errors) == 0, errors

def reinject_ids(sql: str, id_map: Dict[str, str]) -> str:
    """Replace placeholders with actual IDs."""
    for placeholder, value in id_map.items():
        sql = sql.replace(placeholder, value)
    return sql

print("Validation functions ready!")

## 11. End-to-End Demo with 3 Example Questions

In [None]:
# Load latest checkpoint
latest_checkpoint = 'checkpoints/checkpoint_epoch_3.pt'
if os.path.exists(latest_checkpoint):
    checkpoint = torch.load(latest_checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']}\n")
else:
    print("No checkpoint found, using untrained model\n")

schema_text = generate_schema_compact_text(EXAMPLE_SCHEMA)

# Example questions
test_questions = [
    "How many visits did patient 5432 have in department 25?",
    "What is the total charge for patient 7890?",
    "Show monthly visit counts in 2023"
]

print("=" * 80)
print("GENERATING SQL FOR TEST QUESTIONS")
print("=" * 80)

passed = 0
for i, question in enumerate(test_questions, 1):
    print(f"\n[Question {i}]")
    print(f"Original: {question}")
    
    # Extract placeholders
    question_clean, id_map = extract_placeholders(question)
    print(f"Clean: {question_clean}")
    print(f"ID Map: {id_map}")
    
    # Generate SQL
    id_map_text = ', '.join([f"{k}={v}" for k, v in id_map.items()]) if id_map else "None"
    generated_sql = generate_sql(model, tokenizer, schema_text, question_clean, id_map_text, device=device)
    
    print(f"\nGenerated SQL (with placeholders):")
    print(f"  {generated_sql}")
    
    # Validate
    is_valid, errors = validate_sql(generated_sql, EXAMPLE_SCHEMA, id_map)
    
    if is_valid:
        print("✓ Validation: PASSED")
        passed += 1
        
        # Reinject IDs
        final_sql = reinject_ids(generated_sql, id_map)
        print(f"\nFinal SQL (with real IDs):")
        print(f"  {final_sql}")
    else:
        print("✗ Validation: FAILED")
        for error in errors:
            print(f"  - {error}")
    
    print("-" * 80)

print(f"\n" + "=" * 80)
print(f"RESULTS: {passed}/3 questions passed validation")
print("=" * 80)

if passed >= 2:
    print("\n✓ SUCCESS: MVP acceptance criteria met (>=2/3 passed)")
else:
    print("\n✗ INCOMPLETE: Need more training or better templates")

## 12. Summary and Next Steps

In [None]:
print("""\n
=================================================================================
MVP PIPELINE COMPLETE!
=================================================================================

What we built:
✓ Example healthcare schema with 5 tables
✓ Dataset generator with 5,000+ training and 200+ validation samples
✓ BPE tokenizer trained on SQL corpus (vocab size: 8,000)
✓ Tiny decoder-only transformer (~50M parameters)
✓ Training loop with checkpointing
✓ Inference with prompt formatting
✓ Validation gates (SELECT-only, schema checks, placeholders)
✓ ID placeholder vault and reinjection
✓ End-to-end demo with 3 test questions

Files created:
- data/example_schema.json
- data/train.jsonl (5,000 samples)
- data/val.jsonl (200 samples)
- artifacts/tokenizer/tokenizer.json
- checkpoints/checkpoint_epoch_*.pt

Next steps to scale:
1. Increase dataset size to 200k-800k samples
2. Add more SQL template diversity (JOINs, subqueries, HAVING, etc.)
3. Scale model to ~300M parameters
4. Train for more epochs with learning rate scheduling
5. Add evaluation harness with held-out test set
6. Fine-tune hyperparameters (learning rate, batch size, etc.)
7. Implement beam search or nucleus sampling for generation
8. Add more sophisticated schema validation

=================================================================================
""")