In [1]:
# ðŸ”¹ Cell 1 â€” Imports & Config (Phase-3)
import json
import torch
import torch.nn as nn
import sys
sys.path.append("..")
from torch.utils.data import Dataset, DataLoader

from src.utils import (
    set_seed,
    get_device,
    tokens_to_ids,
    pad_sequence,
    create_attention_mask
)

from src.vocab import PAD, TOKEN2ID
from models.sql_transformer import SQLTransformer

In [6]:
# ðŸ”¹ Cell 2 â€” Phase-3 Config
# =====================
# CONFIG â€” PHASE 3
# =====================
PHASE1_CKPT = "checkpoints/phase1_model.pt"
PHASE2_CKPT = "checkpoints/phase2_model.pt"
PHASE3_CKPT = "checkpoints/phase3_model.pt"

DATASET_PATH = "../data/sql_ast/phase3_groupby_having.json"

EPOCHS = 25
BATCH_SIZE = 16
LR = 3e-4           # ðŸ”½ slightly lower for stability
MAX_LEN = 16        # ðŸ”¼ GROUP BY + HAVING needs longer seq

In [7]:
# ðŸ”¹ Cell 3 â€” Phase-3 Dataset
class Phase3Dataset(Dataset):
    def __init__(self, path):
        with open(path, "r") as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens = self.data[idx]["input_tokens"]

        ids = tokens_to_ids(tokens)
        ids = pad_sequence(ids, MAX_LEN, TOKEN2ID[PAD])
        mask = create_attention_mask(ids, TOKEN2ID[PAD])

        return (
            torch.tensor(ids, dtype=torch.long),
            torch.tensor(mask, dtype=torch.long)
        )

In [8]:
# ðŸ”¹ Cell 4 â€” Precision / Recall / F1
def compute_prf(preds, labels, pad_id):
    tp = fp = fn = 0

    for p_seq, l_seq in zip(preds, labels):
        for p, l in zip(p_seq, l_seq):
            if l == pad_id:
                continue
            if p == l:
                tp += 1
            else:
                fp += 1
                fn += 1

    precision = tp / (tp + fp + 1e-9)
    recall = tp / (tp + fn + 1e-9)
    f1 = 2 * precision * recall / (precision + recall + 1e-9)

    return precision, recall, f1

In [9]:
# ðŸ”¹ Cell 5 â€” Load Dataset & Model (ðŸ”¥ chaining weights)
set_seed(42)
device = get_device()

dataset = Phase3Dataset(DATASET_PATH)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

model = SQLTransformer().to(device)

# ðŸ”¥ IMPORTANT: load Phase-2 weights
model.load_state_dict(
    torch.load(PHASE2_CKPT, map_location=device)
)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=TOKEN2ID[PAD])

print("âœ… Phase-2 weights loaded, ready for Phase-3 training")

âœ… Phase-2 weights loaded, ready for Phase-3 training


In [10]:
# ðŸ”¹ Cell 6 â€” Phase-3 Training Loop
best_f1 = 0.0

for epoch in range(1, EPOCHS + 1):
    model.train()

    total_loss = 0.0
    all_preds = []
    all_labels = []

    for input_ids, attention_mask in loader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        optimizer.zero_grad()

        logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        # Teacher forcing
        logits = logits[:, :-1, :]
        labels = input_ids[:, 1:]

        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            labels.reshape(-1)
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        preds = torch.argmax(logits, dim=-1)
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

    precision, recall, f1 = compute_prf(
        all_preds,
        all_labels,
        pad_id=TOKEN2ID[PAD]
    )

    avg_loss = total_loss / len(loader)

    print(
        f"Epoch {epoch:02d} | "
        f"Loss: {avg_loss:.4f} | "
        f"P: {precision:.4f} | "
        f"R: {recall:.4f} | "
        f"F1: {f1:.4f}"
    )

    # ðŸ’¾ Save best checkpoint
    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), PHASE3_CKPT)
        print("ðŸ’¾ Phase-3 checkpoint saved")

Epoch 01 | Loss: 0.3501 | P: 0.9326 | R: 0.9326 | F1: 0.9326
ðŸ’¾ Phase-3 checkpoint saved
Epoch 02 | Loss: 0.0104 | P: 1.0000 | R: 1.0000 | F1: 1.0000
ðŸ’¾ Phase-3 checkpoint saved
Epoch 03 | Loss: 0.0054 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 04 | Loss: 0.0035 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 05 | Loss: 0.0025 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 06 | Loss: 0.0019 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 07 | Loss: 0.0015 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 08 | Loss: 0.0012 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 09 | Loss: 0.0010 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 10 | Loss: 0.0008 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 11 | Loss: 0.0007 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 12 | Loss: 0.0006 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 13 | Loss: 0.0005 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 14 | Loss: 0.0004 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 15 | Loss: 0.0004 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 16 | Loss: 0.0003 |