In [1]:
# Cell 1 â€” Imports & Config
import json
import torch
import torch.nn as nn
import sys
sys.path.append("..")

from torch.utils.data import Dataset, DataLoader
from models.sql_transformer import SQLTransformer
from src.utils import (
    tokens_to_ids,
    pad_sequence,
    create_attention_mask,
    set_seed,
    get_device
)
from src.vocab import PAD, TOKEN2ID



In [10]:
# =====================
# CONFIG
# =====================
DATASET_PATH = "../data/sql_ast/phase1_simple_select.json"
CHECKPOINT_PATH = "checkpoints/phase1_model.pt"

EPOCHS = 20
BATCH_SIZE = 16
LR = 1e-3
MAX_LEN = 10


In [11]:
# Cell 1 â€” Imports & Config
class Phase1Dataset(Dataset):
    def __init__(self, path):
        with open(path, "r") as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens = self.data[idx]["input_tokens"]

        ids = tokens_to_ids(tokens)
        ids = pad_sequence(ids, MAX_LEN, TOKEN2ID[PAD])
        mask = create_attention_mask(ids, TOKEN2ID[PAD])

        return (
            torch.tensor(ids, dtype=torch.long),
            torch.tensor(mask, dtype=torch.long)
        )


In [12]:
# ðŸ”¹ Cell 3 â€” Metrics (Precision / Recall / F1)
def compute_prf(preds, labels, pad_id):
    tp = fp = fn = 0

    for p_seq, l_seq in zip(preds, labels):
        for p, l in zip(p_seq, l_seq):
            if l == pad_id:
                continue
            if p == l:
                tp += 1
            else:
                fp += 1
                fn += 1

    precision = tp / (tp + fp + 1e-9)
    recall = tp / (tp + fn + 1e-9)
    f1 = 2 * precision * recall / (precision + recall + 1e-9)

    return precision, recall, f1


In [13]:
# ðŸ”¹ Cell 4 â€” Setup (Device, DataLoader, Model)
set_seed(42)
device = get_device()

dataset = Phase1Dataset(DATASET_PATH)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

model = SQLTransformer().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=TOKEN2ID[PAD])


In [14]:
# ðŸ”¹ Cell 5 â€” Training Loop (Notebook-Safe)
best_f1 = 0.0

for epoch in range(1, EPOCHS + 1):
    model.train()

    total_loss = 0.0
    all_preds = []
    all_labels = []

    for input_ids, attention_mask in loader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        optimizer.zero_grad()

        logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        # Teacher forcing
        logits = logits[:, :-1, :]
        labels = input_ids[:, 1:]

        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            labels.reshape(-1)
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        preds = torch.argmax(logits, dim=-1)
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

    precision, recall, f1 = compute_prf(
        all_preds,
        all_labels,
        pad_id=TOKEN2ID[PAD]
    )

    avg_loss = total_loss / len(loader)

    print(
        f"Epoch {epoch:02d} | "
        f"Loss: {avg_loss:.4f} | "
        f"P: {precision:.4f} | "
        f"R: {recall:.4f} | "
        f"F1: {f1:.4f}"
    )

    # Save best checkpoint
    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), CHECKPOINT_PATH)
        print("ðŸ’¾ Checkpoint saved")


Epoch 01 | Loss: 1.9022 | P: 0.6464 | R: 0.6464 | F1: 0.6464
ðŸ’¾ Checkpoint saved
Epoch 02 | Loss: 0.2598 | P: 1.0000 | R: 1.0000 | F1: 1.0000
ðŸ’¾ Checkpoint saved
Epoch 03 | Loss: 0.0931 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 04 | Loss: 0.0553 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 05 | Loss: 0.0384 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 06 | Loss: 0.0291 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 07 | Loss: 0.0230 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 08 | Loss: 0.0190 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 09 | Loss: 0.0162 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 10 | Loss: 0.0143 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 11 | Loss: 0.0127 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 12 | Loss: 0.0114 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 13 | Loss: 0.0105 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 14 | Loss: 0.0096 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 15 | Loss: 0.0090 | P: 1.0000 | R: 1.0000 | F1: 1.0000
Epoch 16 | Loss: 0.0084 | P: 1.0000 | R: 

In [15]:
# ðŸ”¹ Cell 6 â€” Load Checkpoint Later (for Phase-2)
model = SQLTransformer().to(device)
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
model.eval()


SQLTransformer(
  (embedding): Embedding(48, 128, padding_idx=0)
  (pos_embedding): Embedding(512, 128)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=256, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc_out): Linear(in_features=128, out_features=48, bias=True)
)