In [1]:

# üîπ Cell 1 ‚Äî Imports & Config (Phase-4 JOIN)
import json
import torch
import torch.nn as nn
import sys
import os
sys.path.append("..")

from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from src.utils import (

    tokens_to_ids,
    pad_sequence,
    create_attention_mask,
    get_allowed_tokens
)

from src.vocab import PAD, TOKEN2ID, ID2TOKEN, UNK
from models.sql_transformer import SQLTransformer

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# ‚úÖ CHECKPOINT PATH (as per your folder structure)
CHECKPOINT_DIR = "notebooks/checkpoints/phase4_join"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [4]:

# üîπ Cell 2 ‚Äî Phase-4 Config
# =========================
# CONFIG ‚Äî PHASE 4 (JOIN)
# =========================

PHASE3_CKPT = "checkpoints/phase3_model.pt"
PHASE4_CKPT = "checkpoints/phase4_model.pt"

PHASE4_PATH = "../data/sql_ast/phase4_join.json"

with open(PHASE4_PATH, "r") as f:
    phase4_data = json.load(f)

print("Total Phase-4 samples:", len(phase4_data))

#EPOCHS = 30
#BATCH_SIZE = 16
#LR = 3e-4

# JOIN adds: JOIN TABLE ON COL COL
#MAX_LEN = 40

Total Phase-4 samples: 2000


In [5]:
# üîπ Cell 3 ‚Äî Decoder Training Sample Builder (Phase-4)

def prepare_phase4_sample(sample):
    """
    Converts phase4_join.json entry into decoder input / label pairs.
    
    Fixes:
    ‚úî Proper teacher forcing (shifted labels)
    ‚úî PAD tokens masked with -100 (ignored by loss & metrics)
    ‚úî Safe TOKEN2ID lookup
    """

    tokens = sample["input_tokens"]

    # Convert tokens ‚Üí ids safely
    token_ids = [
        TOKEN2ID.get(t, TOKEN2ID["<UNK>"])
        for t in tokens
    ]

    # Teacher forcing
    input_ids = torch.tensor(token_ids[:-1], dtype=torch.long)
    labels = torch.tensor(token_ids[1:], dtype=torch.long)

    # üî• IMPORTANT FIX:
    # Ignore PAD tokens in loss & metrics
    labels[labels == TOKEN2ID[PAD]] = -100

    return {
        "input_ids": input_ids,
        "labels": labels
    }

In [6]:
# üîπ Cell 4 ‚Äî Phase-4 Dataset
class Phase4JoinDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return prepare_phase4_sample(self.data[idx])

In [7]:
# üîπ Cell ‚Äî Phase-4 Collate Function (PAD + MASK)

def phase4_collate_fn(batch):
    """
    Pads variable-length Phase-4 samples.
    Ensures DataLoader can stack tensors safely.
    """

    input_ids = [item["input_ids"] for item in batch]
    labels = [item["labels"] for item in batch]

    max_len = max(x.size(0) for x in input_ids)

    padded_inputs = []
    padded_labels = []

    for inp, lab in zip(input_ids, labels):
        pad_len = max_len - inp.size(0)

        padded_inputs.append(
            torch.cat([
                inp,
                torch.full((pad_len,), TOKEN2ID[PAD], dtype=torch.long)
            ])
        )

        padded_labels.append(
            torch.cat([
                lab,
                torch.full((pad_len,), -100, dtype=torch.long)
            ])
        )

    return {
        "input_ids": torch.stack(padded_inputs),
        "labels": torch.stack(padded_labels)
    }

In [8]:
# üîπ üîπ Cell 5 ‚Äî DataLoade
BATCH_SIZE = 16

train_dataset = Phase4JoinDataset(phase4_data)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=phase4_collate_fn   # üî• THIS FIXES THE CRASH
)

In [9]:
# üîπ Cell 6 ‚Äî Grammar-Masked Loss (ACTUALLY CORRECT)

'''def phase4_loss(logits, input_ids, labels, allowed_token_fn):
    """
    Grammar-constrained loss for Phase-4 JOIN training.

    ‚úî Grammar derived from input_ids (decoder history)
    ‚úî PAD-only steps skipped
    ‚úî No inf / NaN possible
    ‚úî Phase-4 safe
    """

    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")

    B, T, V = logits.size()
    total_loss = 0.0
    valid_steps = 0

    for t in range(T):

        # skip positions where all labels are ignored
        if torch.all(labels[:, t] == -100):
            continue

        step_logits = logits[:, t, :]   # (B, V)

        step_masks = []

        for b in range(B):
            # üî• USE input_ids, NOT labels
            ids = input_ids[b, :t].tolist()
            tokens_so_far = [ID2TOKEN[i] for i in ids]

            allowed = allowed_token_fn(
                tokens_so_far=tokens_so_far,
                schema_tables=None,
                schema_columns=None
            )

            if not allowed:
                mask = torch.zeros(V, device=logits.device)
            else:
                mask = torch.full((V,), float("-inf"), device=logits.device)
                mask[list(allowed)] = 0.0

            step_masks.append(mask)

        step_mask = torch.stack(step_masks, dim=0)  # (B, V)
        masked_logits = step_logits + step_mask

        step_loss = loss_fn(masked_logits, labels[:, t])
        total_loss += step_loss
        valid_steps += B

    return total_loss / max(valid_steps, 1)'''
import torch
import torch.nn.functional as F

def phase4_loss(logits, input_ids, labels, allowed_token_fn):
    """
    Final optimized loss for Phase-4.
    Corrects the indexing and ensures stable reduction.
    """
    B, T, V = logits.size()
    
    # Use a slightly smaller penalty (-1000.0) for better gradient flow
    # while still effectively zeroing out the probability.
    mask = torch.zeros((B, T, V), device=logits.device)
    
    for b in range(B):
        # We start at t=0 to predict labels[0]
        for t in range(T):
            if labels[b, t] == -100:
                continue
            
            # üî• CRITICAL FIX: To predict label at index 't', the model 
            # has seen tokens from 0 up to 't' in teacher forcing.
            ids = input_ids[b, :t+1].tolist()
            tokens_so_far = [ID2TOKEN.get(i, UNK) for i in ids]
            
            allowed = allowed_token_fn(
                tokens_so_far=tokens_so_far,
                schema_tables=None,
                schema_columns=None
            )
            
            if allowed:
                # Use -1000.0 instead of -1e4 or -1e9 for numerical safety
                m = torch.full((V,), -1000.0, device=logits.device)
                for token_id in allowed:
                    if token_id < V: # Safety check for vocab size
                        m[token_id] = 0.0
                mask[b, t, :] = m

    # Apply mask
    masked_logits = logits + mask
    
    # Flatten: (B*T, V) and (B*T)
    # CrossEntropyLoss with ignore_index=-100 handles the mean calculation correctly.
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
    
    return loss_fn(masked_logits.view(-1, V), labels.view(-1))




In [10]:
#üîπ Cell 7 ‚Äî Metric Helpers (NEW)
def compute_prf(preds, labels, ignore_index=-100):
    preds = preds.view(-1)
    labels = labels.view(-1)

    mask = labels != ignore_index
    preds = preds[mask]
    labels = labels[mask]

    tp = (preds == labels).sum().item()
    fp = (preds != labels).sum().item()
    fn = fp  # token-level symmetric

    precision = tp / (tp + fp + 1e-9)
    recall = tp / (tp + fn + 1e-9)
    f1 = 2 * precision * recall / (precision + recall + 1e-9)

    return precision, recall, f1

In [11]:
# üîπ Cell ‚Äî Load Phase-3 model with Vocab Surgery
model = SQLTransformer().to(device)

# 1. Load the checkpoint state
checkpoint_state = torch.load(PHASE3_CKPT, map_location=device)

# 2. Get current model's state dict
model_state = model.state_dict()

# 3. Identify the layers that changed size
# These are the ones causing the "Size Mismatch"
mismatched_layers = ["embedding.weight", "fc_out.weight", "fc_out.bias"]

for name, param in checkpoint_state.items():
    if name in mismatched_layers:
        print(f"Surgery on layer: {name}")
        old_weight = param
        new_weight = model_state[name]
        
        # Copy the old weights (0 to 47) into the new weight tensor (0 to 48)
        # The 49th index (for <AGG>) will remain randomly initialized
        if len(old_weight.shape) > 1: # For Weights (Matrices)
            new_weight[:old_weight.shape[0], :] = old_weight
        else: # For Biases (Vectors)
            new_weight[:old_weight.shape[0]] = old_weight
            
        model_state[name] = new_weight
    else:
        # For all other layers (Transformer blocks), just copy directly
        model_state[name] = param

# 4. Load the modified state dict into the model
model.load_state_dict(model_state)

print(f"‚úÖ Surgery Complete: Phase-3 weights (size 48) adapted to Phase-4 model (size 49)")

Surgery on layer: embedding.weight
Surgery on layer: fc_out.weight
Surgery on layer: fc_out.bias
‚úÖ Surgery Complete: Phase-3 weights (size 48) adapted to Phase-4 model (size 49)


In [12]:
def verify_grammar_with_data(dataset, allowed_fn):
    conflicts = 0
    for i in range(len(dataset)):
        sample = dataset[i]
        tokens = sample["input_tokens"] # Full sequence from JSON
        
        for t in range(len(tokens)-1):
            so_far = tokens[:t+1]
            target = tokens[t+1]
            allowed_ids = allowed_fn(so_far)
            target_id = TOKEN2ID.get(target)
            
            if target_id not in allowed_ids:
                conflicts += 1
                # print(f"Sample {i} | Error at step {t}: '{target}' is BLOCKED after {so_far[-3:]}")
                # break 
    print(f"Total Grammar Conflicts: {conflicts}")

verify_grammar_with_data(phase4_data, get_allowed_tokens)

Total Grammar Conflicts: 0


In [13]:
# üîπ Cell 8 ‚Äî Freeze Phase-1/2/3 (CORRECT)

# for name, param in model.named_parameters():
#     if not name.startswith("fc_out"):
#         param.requires_grad = False
#     else:
#         param.requires_grad = True

# üîπ Cell 8 ‚Äî Unfreeze more of the model
# We unfreeze the last block + the output head so it can learn JOIN context
for name, param in model.named_parameters():
    if "layers.1" in name or "fc_out" in name: # Adjust 'layers.1' to your last block index
        param.requires_grad = True
    else:
        param.requires_grad = False
# üîπ Cell 9 ‚Äî Optimizer
optimizer = AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=3e-4
)

# üîπ Cell 10 ‚Äî Improved Training Loop
EPOCHS = 10 
model.train()

for epoch in range(EPOCHS):
    total_loss, steps = 0.0, 0
    total_p, total_r, total_f1 = 0.0, 0.0, 0.0
    
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        
        logits = model(input_ids)
        loss = phase4_loss(logits, input_ids, labels, get_allowed_tokens)
        
        # Check for NaN
        if torch.isnan(loss):
            print("‚ö†Ô∏è NaN Loss detected! Skipping batch.")
            continue

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Metrics
        preds = torch.argmax(logits, dim=-1)
        p, r, f1 = compute_prf(preds, labels)
        
        total_loss += loss.item()
        total_p += p
        total_r += r
        total_f1 += f1
        steps += 1

    # ‚úÖ REPORT AVERAGE LOSS (total_loss / steps)
    avg_loss = total_loss / steps
    print(f"Epoch {epoch+1:02d} | Avg Loss: {avg_loss:.4f} | P: {total_p/steps:.4f} | F1: {total_f1/steps:.4f}")

Epoch 01 | Avg Loss: 0.1308 | P: 0.5365 | F1: 0.5365
Epoch 02 | Avg Loss: 0.0035 | P: 0.6776 | F1: 0.6776
Epoch 03 | Avg Loss: 0.0015 | P: 0.7048 | F1: 0.7048
Epoch 04 | Avg Loss: 0.0009 | P: 0.7164 | F1: 0.7164
Epoch 05 | Avg Loss: 0.0007 | P: 0.7277 | F1: 0.7277
Epoch 06 | Avg Loss: 0.0005 | P: 0.7322 | F1: 0.7322
Epoch 07 | Avg Loss: 0.0005 | P: 0.7362 | F1: 0.7362
Epoch 08 | Avg Loss: 0.0003 | P: 0.7363 | F1: 0.7363
Epoch 09 | Avg Loss: 0.0003 | P: 0.7423 | F1: 0.7423
Epoch 10 | Avg Loss: 0.0002 | P: 0.7451 | F1: 0.7451


In [None]:
# # üîπ Cell 9 ‚Äî Optimizer
# optimizer = AdamW(
#     filter(lambda p: p.requires_grad, model.parameters()),
#     lr=3e-4
# )

In [None]:
# # üîπ Cell 10 ‚Äî Phase-4 Training Loop (WITH METRICS)

# EPOCHS = 6

# model.to(device)
# model.train()

# for epoch in range(EPOCHS):
#     total_loss = 0.0
#     total_p, total_r, total_f1 = 0.0, 0.0, 0.0
#     steps = 0

#     for batch in train_loader:
#         optimizer.zero_grad()

#         input_ids = batch["input_ids"].to(device)
#         labels = batch["labels"].to(device)

#         logits = model(input_ids)

#         loss = phase4_loss(
#         logits=logits,
#         input_ids=input_ids,
#         labels=labels,
#         allowed_token_fn=get_allowed_tokens
#     )

#         loss.backward()
#         # üîπ Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
#         optimizer.step()

#         # ---- metrics ----
#         preds = torch.argmax(logits, dim=-1)
#         p, r, f1 = compute_prf(preds, labels)

#         total_loss += loss.item()
#         total_p += p
#         total_r += r
#         total_f1 += f1
#         steps += 1

#     # ---- epoch summary ----
#     print(
#         f"Epoch {epoch+1}/{EPOCHS} | "
#         f"Loss: {total_loss:.4f} | "
#         f"P: {total_p/steps:.4f} | "
#         f"R: {total_r/steps:.4f} | "
#         f"F1: {total_f1/steps:.4f}"
#     )

#     # ---- save checkpoint ----
#     ckpt_path = os.path.join(
#         CHECKPOINT_DIR,
#         f"phase4_join_epoch_{epoch+1}.pt"
#     )
#     torch.save(model.state_dict(), ckpt_path)
#     print(f"üíæ Saved checkpoint ‚Üí {ckpt_path}")

In [None]:
#üîπ Cell 11 ‚Äî Save Model
