In [1]:
import os
import json
import re
import random
import pickle
import argparse
from typing import List, Dict, Set
from tqdm.auto import tqdm
from collections import Counter

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import transformers
from transformers import (
    T5TokenizerFast, 
    T5ForConditionalGeneration, 
    T5Config, 
    GenerationConfig,
    get_linear_schedule_with_warmup
)
import numpy as np

# Check for GPU
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {DEVICE}")

# --- CONFIGURATION FOR SCRATCH TRAINING ---
class ScratchConfig:
    # Data paths
    data_dir = "data"
    schema_meta_path = "schema_meta.json" # Ensure this is in your working dir or data dir
    
    # Model architecture
    model_name = "google-t5/t5-small"
    
    # Training Hyperparameters (Optimized for Scratch Training)
    learning_rate = 5e-4        # Higher LR is needed for scratch training
    weight_decay = 0.01
    batch_size = 16             # Adjust based on your GPU memory
    test_batch_size = 32
    max_n_epochs = 60           # Needs many more epochs than fine-tuning
    num_warmup_epochs = 5
    patience_epochs = 10        # Wait longer for improvements
    
    # Output files
    experiment_name = "t5_scratch_ec"
    output_dir = "scratch_results"
    
config = ScratchConfig()
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs("records", exist_ok=True) # For saving submission pkl files
os.makedirs("results", exist_ok=True) # For saving submission sql files

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
# --- SCHEMA LOADING & HELPERS ---

def load_schema_meta(path):
    # Try loading from current dir, then data dir
    if not os.path.exists(path):
        path = os.path.join(config.data_dir, path)
    
    if not os.path.exists(path):
        print(f"WARNING: Schema meta file not found at {path}. Pruning will be disabled.")
        return {}, {}, {}
        
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data.get("ents", {}), data.get("defaults", {}), data.get("links", {})

ENTS, DEFAULTS, LINKS = load_schema_meta(config.schema_meta_path)

# Build Phrase->Table Lexicon for Pruning
PHRASE2TABLE = {}
for table, meta in DEFAULTS.items():
    phrase = meta["utt"].strip().lower()
    PHRASE2TABLE.setdefault(phrase, set()).add(table)
for table, cols in ENTS.items():
    for col, colmeta in cols.items():
        phrase = colmeta["utt"].strip().lower()
        PHRASE2TABLE.setdefault(phrase, set()).add(table)

def detect_tables(question, max_tables=8):
    """Heuristically find relevant tables + 1-hop neighbors"""
    q = question.lower()
    candidates = set()
    
    # 1. Direct Match
    for phrase, tables in PHRASE2TABLE.items():
        if phrase in q:
            candidates.update(tables)
            
    if not candidates: return set(ENTS.keys()) # Fallback
    
    # 2. 1-Hop Expansion (Foreign Keys)
    expanded = set(candidates)
    for t in candidates:
        # Outgoing links
        for neigh in LINKS.get(t, {}): expanded.add(neigh)
        # Incoming links
        for other, links in LINKS.items():
            if t in links: expanded.add(other)
            
    # 3. Cap size
    if len(expanded) > max_tables:
        return set(sorted(list(candidates))[:max_tables])
    return expanded

def serialize_schema(tables):
    parts = []
    for t in sorted(tables):
        if t in ENTS:
            cols = ", ".join(list(ENTS[t].keys())[:6]) # Limit cols per table
            parts.append(f"{t}({cols})")
    return "Tables:\n" + ",\n".join(parts)

def normalize_text(s):
    return " ".join(s.strip().split())

def build_input(nl_question):
    """Prefix + Pruned Schema + Question"""
    tables = detect_tables(nl_question)
    schema_str = serialize_schema(tables)
    return f"translate English to SQL.\n{schema_str}\n\nQuestion: {nl_question}\nSQL:"

def build_target(sql):
    """Normalize and Uppercase"""
    return normalize_text(sql).upper()

def read_lines(filepath):
    path = os.path.join(config.data_dir, filepath)
    with open(path, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f if line.strip()]

In [3]:
# --- CUSTOM DATASET WITH VOCAB EXPANSION ---

# Initialize Tokenizer
tokenizer = T5TokenizerFast.from_pretrained(config.model_name)

# 1. Define SQL Vocabulary
SQL_KEYWORDS = [
    'SELECT', 'FROM', 'WHERE', 'GROUP', 'BY', 'ORDER', 'HAVING', 'LIMIT', 
    'JOIN', 'ON', 'AS', 'DISTINCT', 'COUNT', 'MAX', 'MIN', 'AVG', 'SUM',
    'AND', 'OR', 'NOT', 'IN', 'LIKE', 'BETWEEN', 'IS', 'NULL',
    'INTERSECT', 'UNION', 'EXCEPT', 'DESC', 'ASC'
]

# 2. Extract Schema Items
SCHEMA_ITEMS = []
for table, cols in ENTS.items():
    SCHEMA_ITEMS.append(table)
    SCHEMA_ITEMS.extend(cols.keys())
SCHEMA_ITEMS.extend(['=', '>', '<', '>=', '<=', '!=', '(', ')', ','])
NEW_TOKENS = sorted(list(set(SQL_KEYWORDS + SCHEMA_ITEMS)))

# 3. Add to Tokenizer
print(f"Adding {len(NEW_TOKENS)} new special tokens to tokenizer...")
tokenizer.add_special_tokens({'additional_special_tokens': NEW_TOKENS})
PAD_IDX = tokenizer.pad_token_id

class T5ScratchDataset(Dataset):
    def __init__(self, split):
        self.split = split
        self.max_enc = 512
        self.max_dec = 256
        self.data = self.load_data()
        
    def load_data(self):
        nl = read_lines(f"{self.split}.nl")
        if self.split != 'test':
            sql = read_lines(f"{self.split}.sql")
        else:
            sql = [""] * len(nl)
            
        processed = []
        print(f"Processing {self.split} set...")
        for q, s in tqdm(zip(nl, sql), total=len(nl)):
            enc_txt = build_input(q)
            dec_txt = build_target(s) if self.split != 'test' else ""
            
            enc = tokenizer(enc_txt, truncation=True, max_length=self.max_enc)
            dec = tokenizer(dec_txt, truncation=True, max_length=self.max_dec)
            
            processed.append({
                'enc_ids': enc.input_ids,
                'enc_mask': enc.attention_mask,
                'dec_ids': dec.input_ids
            })
        return processed
        
    def __len__(self): return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# --- COLLATE FUNCTIONS ---

def collate_fn(batch):
    enc_ids = pad_sequence([torch.tensor(x['enc_ids']) for x in batch], batch_first=True, padding_value=PAD_IDX)
    enc_mask = pad_sequence([torch.tensor(x['enc_mask']) for x in batch], batch_first=True, padding_value=0)
    
    dec_ids = pad_sequence([torch.tensor(x['dec_ids']) for x in batch], batch_first=True, padding_value=PAD_IDX)
    
    # Teacher forcing inputs: Shift right
    dec_in = torch.zeros_like(dec_ids)
    dec_in[:, 1:] = dec_ids[:, :-1]
    dec_in[:, 0] = PAD_IDX
    
    # Targets: Ignore padding in loss
    labels = dec_ids.clone()
    labels[labels == PAD_IDX] = -100
    
    return enc_ids, enc_mask, dec_in, labels

def test_collate_fn(batch):
    enc_ids = pad_sequence([torch.tensor(x['enc_ids']) for x in batch], batch_first=True, padding_value=PAD_IDX)
    enc_mask = pad_sequence([torch.tensor(x['enc_mask']) for x in batch], batch_first=True, padding_value=0)
    
    # Initial decoder input for generation
    bs = enc_ids.size(0)
    dec_in = torch.full((bs, 1), PAD_IDX, dtype=torch.long)
    
    return enc_ids, enc_mask, dec_in

# Create Dataloaders
train_ds = T5ScratchDataset("train")
dev_ds = T5ScratchDataset("dev")
test_ds = T5ScratchDataset("test")

train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
dev_loader = DataLoader(dev_ds, batch_size=config.test_batch_size, collate_fn=collate_fn)
test_loader = DataLoader(test_ds, batch_size=config.test_batch_size, collate_fn=test_collate_fn)



Adding 165 new special tokens to tokenizer...
Processing train set...


100%|██████████| 4225/4225 [00:02<00:00, 1420.71it/s]


Processing dev set...


100%|██████████| 466/466 [00:00<00:00, 1407.62it/s]


Processing test set...


100%|██████████| 432/432 [00:00<00:00, 3101.25it/s]


In [4]:
def initialize_model_scratch():
    print("Initializing T5-Small from SCRATCH (Random Weights)...")
    # Load config only, not weights
    t5_config = T5Config.from_pretrained(config.model_name)
    model = T5ForConditionalGeneration(t5_config)
    
    # RESIZE EMBEDDINGS for new SQL tokens
    model.resize_token_embeddings(len(tokenizer))
    
    model.to(DEVICE)
    return model

model = initialize_model_scratch()

# Setup Optimizer & Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
total_steps = len(train_loader) * config.max_n_epochs
warmup_steps = len(train_loader) * config.num_warmup_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

print(f"Model ready. Total params: {sum(p.numel() for p in model.parameters())}")

Initializing T5-Small from SCRATCH (Random Weights)...
Model ready. Total params: 60564480


In [5]:
def save_files(queries, sql_path, record_path):
    """Helper to save submission files"""
    with open(sql_path, 'w') as f:
        for q in queries:
            f.write(q + "\n")
    
    # Create dummy records file to satisfy submission format
    # (The real eval happens on gradescope, but we need the .pkl file)
    records = [[] for _ in queries] 
    with open(record_path, 'wb') as f:
        pickle.dump(records, f)
    print(f"Saved to {sql_path} and {record_path}")

def eval_epoch(model, loader):
    model.eval()
    total_loss = 0
    gen_sql = []
    
    gen_cfg = GenerationConfig(
        max_length=256, 
        pad_token_id=PAD_IDX, 
        eos_token_id=tokenizer.eos_token_id,
        num_beams=4 # Beam search helps significantly
    )
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Eval"):
            enc_ids, enc_mask, dec_in, labels = batch
            enc_ids, enc_mask = enc_ids.to(DEVICE), enc_mask.to(DEVICE)
            dec_in, labels = dec_in.to(DEVICE), labels.to(DEVICE)
            
            # Loss
            outputs = model(input_ids=enc_ids, attention_mask=enc_mask, decoder_input_ids=dec_in)
            loss = nn.CrossEntropyLoss(ignore_index=-100)(outputs.logits.view(-1, outputs.logits.size(-1)), labels.view(-1))
            total_loss += loss.item()
            
            # Generation (limit usage to save time if needed, but full dev recommended)
            preds = model.generate(input_ids=enc_ids, attention_mask=enc_mask, generation_config=gen_cfg)
            decoded = tokenizer.batch_decode(preds, skip_special_tokens=True)
            gen_sql.extend(decoded)
            
    return total_loss / len(loader), gen_sql

def train_epoch(model, loader):
    model.train()
    total_loss = 0
    
    for batch in tqdm(loader, desc="Train"):
        optimizer.zero_grad()
        enc_ids, enc_mask, dec_in, labels = batch
        enc_ids, enc_mask = enc_ids.to(DEVICE), enc_mask.to(DEVICE)
        dec_in, labels = dec_in.to(DEVICE), labels.to(DEVICE)
        
        outputs = model(input_ids=enc_ids, attention_mask=enc_mask, decoder_input_ids=dec_in)
        loss = nn.CrossEntropyLoss(ignore_index=-100)(outputs.logits.view(-1, outputs.logits.size(-1)), labels.view(-1))
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        
    return total_loss / len(loader)

In [6]:
best_loss = float('inf')
patience_counter = 0
best_model_path = os.path.join(config.output_dir, "best_model.pt")

print(f"Starting training for {config.max_n_epochs} epochs...")

for epoch in range(config.max_n_epochs):
    train_loss = train_epoch(model, train_loader)
    val_loss, val_queries = eval_epoch(model, dev_loader)
    
    print(f"Epoch {epoch+1}/{config.max_n_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    # Early Stopping & Saving
    if val_loss < best_loss:
        best_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        print(f"  -> New Best Model Saved!")
        
        # Optional: Print a few examples to check syntax
        print(f"  Sample Gen: {val_queries[0]}")
    else:
        patience_counter += 1
        if patience_counter >= config.patience_epochs:
            print("Early stopping triggered.")
            break

print("Training Complete.")

Starting training for 60 epochs...


Train: 100%|██████████| 265/265 [00:37<00:00,  7.06it/s]
Eval: 100%|██████████| 15/15 [01:25<00:00,  5.72s/it]


Epoch 1/60 | Train Loss: 2.8717 | Val Loss: 0.7064
  -> New Best Model Saved!
  Sample Gen: FLIGHT_1.FLIGHT_ID FLIGHT FLIGHT_1 AIRP T_SERVICE AIRP T_SERVICE_SERVICE_1 CITY CITY_1 AIRP T_SERVICE AIRP T_SERVICE_2 CITY CITY CITY_2 FLIGHT_1. _AIRP T AIRP T_SERVICE_SERVICE_1.AIRP T_1.AIRP T_CODE AIRP T_SERVICE_SERVICE_1.CITY_1.CITY_1.CITY_CODE CITY_1.CITY_1.CITY_1.CITY_1.CITY_1.CITY_1.CITY_1.CITY_1.CITY_1.CITY_NAME'FLIGHT_1.TO_AIRP T AIRP T CITY_2.CITY_2.CITY_2.CITY_2.CITY_2.CITY_2.CITY_2.CITY_2.CITY_2.CITY_2.CITY_2.CITY_2.CITY


Train: 100%|██████████| 265/265 [00:36<00:00,  7.21it/s]
Eval: 100%|██████████| 15/15 [01:21<00:00,  5.46s/it]


Epoch 2/60 | Train Loss: 0.5075 | Val Loss: 0.3094
  -> New Best Model Saved!
  Sample Gen: FLIGHT_1.FLIGHT_ID FLIGHT FLIGHT_1 AIRP T_SERVICE AIRP T_SERVICE_1 CITY CITY_1 AIRP T_SERVICE AIRP T_SERVICE_2 CITY CITY_2 FLIGHT_1. _AIRP T AIRP T_SERVICE_1.AIRP T_CODE AIRP T_SERVICE_1.CITY_CODE CITY_1.CITY_CODE CITY_1.CITY_NAME'FLIGHT_1.TO_AIRP T AIRP T_SERVICE_2.AIRP T_CODE AIRP T_SERVICE_2.CITY_CODE CITY_2.CITY_CODE CITY_2.CITY_NAME '


Train: 100%|██████████| 265/265 [00:36<00:00,  7.19it/s]
Eval: 100%|██████████| 15/15 [01:24<00:00,  5.67s/it]


Epoch 3/60 | Train Loss: 0.2745 | Val Loss: 0.1994
  -> New Best Model Saved!
  Sample Gen: FLIGHT_1.FLIGHT_ID FLIGHT FLIGHT_1 AIRP T_SERVICE AIRP T_SERVICE_1 CITY CITY_1 AIRP T_SERVICE AIRP T_SERVICE_2 CITY CITY_2 DAYS DAYS_1 DATE_DAY DATE_DAY_1 FLIGHT_1.AIRL E_CODE'FLIGHT_1. _AIRP T AIRP T_SERVICE_1.AIRP T_CODE AIRP T_SERVICE_1.CITY_CODE CITY_1.CITY_CODE CITY_1.CITY_NAME 'BOST'FLIGHT_1.TO_AIRP T AIRP T_SERVICE_2.AIRP T_CODE AIRP T_SERVICE_2.CITY_CODE CITY_2.CITY_CODE CITY_2.CITY_NAME 'BOST'FL


Train: 100%|██████████| 265/265 [00:36<00:00,  7.23it/s]
Eval: 100%|██████████| 15/15 [01:23<00:00,  5.60s/it]


Epoch 4/60 | Train Loss: 0.2009 | Val Loss: 0.1641
  -> New Best Model Saved!
  Sample Gen: FLIGHT_1.FLIGHT_ID FLIGHT FLIGHT_1 AIRP T_SERVICE AIRP T_SERVICE_1 CITY CITY_1 AIRP T_SERVICE AIRP T_SERVICE_2 CITY CITY_2 DAYS DAYS_1 DATE_DAY DATE_DAY_1 FLIGHT_1. _AIRP T AIRP T_SERVICE_1.AIRP T_CODE AIRP T_SERVICE_1.CITY_CODE CITY_1.CITY_CODE CITY_1.CITY_NAME 'DENVER' FLIGHT_1.TO_AIRP T AIRP T_SERVICE_2.AIRP T_CODE AIRP T_SERVICE_2.CITY_CODE CITY_2.CITY_CODE CITY_2.CITY_NAME 'DENVER' FLIGHT_1.FLIGHT_DAYS DAYS


Train: 100%|██████████| 265/265 [00:37<00:00,  7.16it/s]
Eval: 100%|██████████| 15/15 [01:22<00:00,  5.51s/it]


Epoch 5/60 | Train Loss: 0.1689 | Val Loss: 0.1583
  -> New Best Model Saved!
  Sample Gen: FLIGHT_1.FLIGHT_ID FLIGHT FLIGHT_1 AIRP T_SERVICE AIRP T_SERVICE_1 CITY CITY_1 AIRP T_SERVICE AIRP T_SERVICE_2 CITY CITY_2 DAYS DAYS_1 DATE_DAY DATE_DAY_1 FLIGHT_1. _AIRP T AIRP T_SERVICE_1.AIRP T_CODE AIRP T_SERVICE_1.CITY_CODE CITY_1.CITY_CODE CITY_1.CITY_NAME 'DENVER' FLIGHT_1.TO_AIRP T AIRP T_SERVICE_2.AIRP T_CODE AIRP T_SERVICE_2.CITY_CODE CITY_2.CITY_CODE CITY_2.CITY_NAME 'DENVER' FLIGHT_1.FLIGHT_DAYS DAYS_1.


Train: 100%|██████████| 265/265 [00:37<00:00,  7.15it/s]
Eval: 100%|██████████| 15/15 [01:23<00:00,  5.57s/it]


Epoch 6/60 | Train Loss: 0.1786 | Val Loss: 0.1297
  -> New Best Model Saved!
  Sample Gen: FLIGHT_1.FLIGHT_ID FLIGHT FLIGHT_1 AIRP T_SERVICE AIRP T_SERVICE_1 CITY CITY_1 AIRP T_SERVICE AIRP T_SERVICE_2 CITY CITY_2 DAYS DAYS_1 DATE_DAY DATE_DAY_1 FLIGHT_1. _AIRP T AIRP T_SERVICE_1.AIRP T_CODE AIRP T_SERVICE_1.CITY_CODE CITY_1.CITY_CODE CITY_1.CITY_NAME 'DENVER' FLIGHT_1.TO_AIRP T AIRP T_SERVICE_2.AIRP T_CODE AIRP T_SERVICE_2.CITY_CODE CITY_2.CITY_CODE CITY_NAME 'DENVER'


Train: 100%|██████████| 265/265 [00:37<00:00,  7.15it/s]
Eval: 100%|██████████| 15/15 [01:22<00:00,  5.53s/it]


Epoch 7/60 | Train Loss: 0.1481 | Val Loss: 0.1209
  -> New Best Model Saved!
  Sample Gen: FLIGHT_1.FLIGHT_ID FLIGHT FLIGHT_1 AIRP T_SERVICE AIRP T_SERVICE_1 CITY CITY_1 AIRP T_SERVICE AIRP T_SERVICE_2 CITY CITY_2 DAYS DAYS_1 DATE_DAY DATE_DAY_1 FLIGHT_1. _AIRP T AIRP T_SERVICE_1.AIRP T_CODE AIRP T_SERVICE_1.CITY_CODE CITY_1.CITY_CODE CITY_1.CITY_NAME 'DENVER' FLIGHT_1.TO_AIRP T AIRP T_SERVICE_2.AIRP T_CODE AIRP T_SERVICE_2.CITY_CODE CITY_2.CITY_CODE CITY_2.CITY_NAME 'PHILADELPHIA' FLIGHT_1.FLIGHT_DAYS DAY


Train: 100%|██████████| 265/265 [00:36<00:00,  7.20it/s]
Eval: 100%|██████████| 15/15 [01:24<00:00,  5.65s/it]


Epoch 8/60 | Train Loss: 0.1517 | Val Loss: 0.1289


Train: 100%|██████████| 265/265 [00:36<00:00,  7.24it/s]
Eval: 100%|██████████| 15/15 [01:25<00:00,  5.67s/it]


Epoch 9/60 | Train Loss: 0.1607 | Val Loss: 0.1594


Train: 100%|██████████| 265/265 [00:36<00:00,  7.22it/s]
Eval: 100%|██████████| 15/15 [01:26<00:00,  5.75s/it]


Epoch 10/60 | Train Loss: 0.1899 | Val Loss: 0.1644


Train: 100%|██████████| 265/265 [00:37<00:00,  7.14it/s]
Eval: 100%|██████████| 15/15 [01:26<00:00,  5.76s/it]


Epoch 11/60 | Train Loss: 0.2071 | Val Loss: 0.1725


Train: 100%|██████████| 265/265 [00:36<00:00,  7.25it/s]
Eval: 100%|██████████| 15/15 [01:24<00:00,  5.62s/it]


Epoch 12/60 | Train Loss: 0.2150 | Val Loss: 0.1594


Train: 100%|██████████| 265/265 [00:37<00:00,  7.14it/s]
Eval: 100%|██████████| 15/15 [01:22<00:00,  5.52s/it]


Epoch 13/60 | Train Loss: 0.2033 | Val Loss: 0.1555


Train: 100%|██████████| 265/265 [00:37<00:00,  7.07it/s]
Eval: 100%|██████████| 15/15 [01:23<00:00,  5.56s/it]


Epoch 14/60 | Train Loss: 0.2188 | Val Loss: 0.1687


Train: 100%|██████████| 265/265 [00:37<00:00,  7.11it/s]
Eval: 100%|██████████| 15/15 [01:26<00:00,  5.75s/it]


Epoch 15/60 | Train Loss: 0.2436 | Val Loss: 0.2301


Train: 100%|██████████| 265/265 [00:37<00:00,  7.12it/s]
Eval: 100%|██████████| 15/15 [01:23<00:00,  5.54s/it]


Epoch 16/60 | Train Loss: 0.2696 | Val Loss: 0.1884


Train: 100%|██████████| 265/265 [00:37<00:00,  7.14it/s]
Eval: 100%|██████████| 15/15 [01:14<00:00,  4.96s/it]

Epoch 17/60 | Train Loss: 0.2790 | Val Loss: 0.2174
Early stopping triggered.
Training Complete.





In [7]:
# Load Best Model
print("Loading best model for inference...")
model.load_state_dict(torch.load(best_model_path))
model.eval()

test_queries = []
gen_cfg = GenerationConfig(
    max_length=256, 
    pad_token_id=PAD_IDX, 
    eos_token_id=tokenizer.eos_token_id,
    num_beams=4
)

print("Generating predictions on TEST set...")
with torch.no_grad():
    for batch in tqdm(test_loader):
        enc_ids, enc_mask, _ = batch
        enc_ids, enc_mask = enc_ids.to(DEVICE), enc_mask.to(DEVICE)
        
        preds = model.generate(input_ids=enc_ids, attention_mask=enc_mask, generation_config=gen_cfg)
        decoded = tokenizer.batch_decode(preds, skip_special_tokens=True)
        test_queries.extend(decoded)

# Define output paths per assignment spec
final_sql_path = "results/t5_ft_experiment_ec_test.sql"
final_pkl_path = "records/t5_ft_experiment_ec_test.pkl"

# Save
save_files(test_queries, final_sql_path, final_pkl_path)

print("\nDONE!")
print(f"Submission files generated:\n1. {final_sql_path}\n2. {final_pkl_path}")

Loading best model for inference...
Generating predictions on TEST set...


100%|██████████| 14/14 [01:14<00:00,  5.34s/it]

Saved to results/t5_ft_experiment_ec_test.sql and records/t5_ft_experiment_ec_test.pkl

DONE!
Submission files generated:
1. results/t5_ft_experiment_ec_test.sql
2. records/t5_ft_experiment_ec_test.pkl



