In [1]:
import os
import json
import random
import torch
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration
from difflib import get_close_matches
from tqdm import tqdm

# ======================= Dataset =========================
class TextToSQLDataset(Dataset):
    def __init__(self, pairs, tokenizer, max_input_length=512, max_output_length=128):
        self.pairs = pairs
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_output_length = max_output_length

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        input_text, target_text = self.pairs[idx]

        input_encoding = self.tokenizer(
            input_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_input_length,
            return_tensors="pt"
        )

        target_encoding = self.tokenizer(
            target_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_output_length,
            return_tensors="pt"
        )

        labels = target_encoding["input_ids"]
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "input_ids": input_encoding["input_ids"].squeeze(),
            "attention_mask": input_encoding["attention_mask"].squeeze(),
            "labels": labels.squeeze()
        }

# ======================= Data Utils ===========================
def load_spider_data(data_path):
    with open(os.path.join(data_path, 'train_spider.json')) as f:
        train_data = json.load(f)
    with open(os.path.join(data_path, 'train_others.json')) as f:
        train_data2 = json.load(f)
    with open(os.path.join(data_path, 'dev.json')) as f:
        dev_data = json.load(f)
    with open(os.path.join(data_path, 'tables.json')) as f:
        tables = json.load(f)
        
    train_data = train_data + train_data2
    return train_data, dev_data, tables

def split_train_val(train_data, val_ratio=0.25):
    db_to_examples = defaultdict(list)
    for ex in train_data:
        db_to_examples[ex['db_id']].append(ex)

    db_ids = list(db_to_examples.keys())
    random.shuffle(db_ids)
    split_idx = int(len(db_ids) * (1 - val_ratio))
    train_db_ids = set(db_ids[:split_idx])
    val_db_ids = set(db_ids[split_idx:])

    new_train = [ex for db in train_db_ids for ex in db_to_examples[db]]
    new_val = [ex for db in val_db_ids for ex in db_to_examples[db]]
    return new_train, new_val

def get_schema_dict(tables):
    schema_dict = {}
    for db in tables:
        db_id = db['db_id']
        table_names = db['table_names_original']
        column_names = db['column_names_original']
        columns = defaultdict(list)
        for table_idx, col_name in column_names:
            if table_idx >= 0:
                columns[table_names[table_idx]].append(col_name)
        schema_dict[db_id] = {'tables': table_names, 'columns': dict(columns)}
    return schema_dict

def serialize_schema(db_id, schema_dict):
    schema = schema_dict[db_id]
    parts = []
    for table in schema['tables']:
        cols = schema['columns'].get(table, [])
        table_str = f"[{table}] " + ", ".join(cols)
        parts.append(table_str)
    return " | ".join(parts)

def prepare_input_output(data, schema_dict):
    pairs = []
    for item in data:
        question = item['question']
        db_id = item['db_id']
        schema = serialize_schema(db_id, schema_dict)
        input_text = f"translate English to SQL: {question} <schema> {schema}"
        output_text = item['query']
        pairs.append((input_text, output_text))
    return pairs

def prepare_dataloaders(train_pairs, val_pairs, tokenizer, batch_size=8):
    train_dataset = TextToSQLDataset(train_pairs, tokenizer)
    val_dataset = TextToSQLDataset(val_pairs, tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    return train_loader, val_loader

# ======================= Training ========================
def train(model, train_loader, val_loader, optimizer, device, tokenizer, num_epochs=10, save_path="/kaggle/working/t5_text2sql.pt"):
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 30)

        for step, batch in enumerate(train_loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_train_loss += loss.item()

            if (step + 1) % 20 == 0 or (step + 1) == len(train_loader):
                print(f"[Batch {step + 1}/{len(train_loader)}] Train Loss: {loss.item():.4f}")

        avg_train_loss = total_train_loss / len(train_loader)
        print(f"Average Train Loss: {avg_train_loss:.4f}")

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for step, batch in enumerate(val_loader):
                batch = {k: v.to(device) for k, v in batch.items()}
                val_loss = model(**batch).loss.item()
                total_val_loss += val_loss

                if (step + 1) % 10 == 0 or (step + 1) == len(val_loader):
                    print(f"[Val Batch {step + 1}/{len(val_loader)}] Val Loss: {val_loss:.4f}")

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Average Val Loss: {avg_val_loss:.4f}")

    torch.save(model.state_dict(), save_path)
    print(f"\n✅ Model saved to {save_path}")

# ======================= Inference & Postprocessing ========================
from difflib import get_close_matches
import re

def apply_postprocessing(pred_sql, question, db_id, schema_dict):
    # Extract schema info
    tables = list(schema_dict[db_id]['columns'].keys())
    columns = [col for cols in schema_dict[db_id]['columns'].values() for col in cols]

    # 1. COLUMN/TABLE NAME CORRECTION
    def correct_names(text, valid_names):
        tokens = set(re.findall(r'\b\w+\b', text))
        for token in tokens:
            match = get_close_matches(token, valid_names, n=1, cutoff=0.85)
            if match and token != match[0]:
                text = re.sub(rf'\b{re.escape(token)}\b', match[0], text, flags=re.IGNORECASE)
        return text

    pred_sql = correct_names(pred_sql, tables + columns)

    # 2. TABLE ALIAS CORRECTION
    aliases = {}
    alias_matches = re.findall(r'(FROM|JOIN)\s+(\w+)\s+AS\s+(\w+)', pred_sql, re.IGNORECASE)
    for _, table, alias in alias_matches:
        aliases[alias] = table

    for alias, table in aliases.items():
        if table in schema_dict[db_id]['columns']:
            valid_cols = [col for col in schema_dict[db_id]['columns'][table]]
            for col in valid_cols:
                wrong_pattern = rf'\b{alias}\d*\.{col}\b'
                correct_pattern = f'{alias}.{col}'
                pred_sql = re.sub(wrong_pattern, correct_pattern, pred_sql)

    # 3. REPLACING <unk> WITH < OR >
    if '<unk>' in pred_sql:
        comparative_less = {"less", "smaller", "fewer", "below", "under"}
        comparative_more = {"more", "greater", "higher", "above", "over"}

        question_tokens = question.lower().split()
        unk_match = re.search(r'(\w+)\s*<unk>', pred_sql)
        if unk_match:
            col_name = unk_match.group(1)
            col_tokens = col_name.lower().split('_')
            position = -1
            for i in range(len(question_tokens) - len(col_tokens) + 1):
                if question_tokens[i:i + len(col_tokens)] == col_tokens:
                    position = i
                    break

            context_window = question_tokens[max(0, position-2):position+len(col_tokens)+2]
            if any(word in comparative_less for word in context_window):
                pred_sql = pred_sql.replace('<unk>', '<')
            elif any(word in comparative_more for word in context_window):
                pred_sql = pred_sql.replace('<unk>', '>')
            else:
                pred_sql = pred_sql.replace('<unk>', '<')  # default fallback

    # 4. CLEANUP
    pred_sql = pred_sql.replace('<pad>', '').replace('</s>', '').replace('<s>', '').strip()
    if not pred_sql.endswith(';'):
        pred_sql += ';'

    return pred_sql

def generate_predictions(model_path, data_path, output_path="/kaggle/working/predictions.json"):
    print("\n🔍 Running inference and post-processing...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = T5Tokenizer.from_pretrained("t5-base")
    tokenizer.add_tokens(["<schema>"])

    model = T5ForConditionalGeneration.from_pretrained("t5-base")
    model.resize_token_embeddings(len(tokenizer))
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    _, dev_data, tables = load_spider_data(data_path)
    schema_dict = get_schema_dict(tables)
    #schema_dict = build_schema_dict(tables)

    results = []

    for ex in tqdm(dev_data):
        question = ex["question"]
        db_id = ex["db_id"]
        schema = schema_dict[db_id]
        schema_str = serialize_schema(db_id, schema_dict)

        input_text = f"translate English to SQL: {question} <schema> {schema_str}"
        input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True).to(device)

        output_ids = model.generate(input_ids, max_length=128)
        pred = tokenizer.decode(output_ids[0], skip_special_tokens=False)
        #fixed_pred = postprocess_prediction(pred, question, schema)
        fixed_pred = apply_postprocessing(pred, question, db_id, schema_dict)

        results.append({"db_id": db_id, "query": fixed_pred})

    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\n✅ Predictions saved to {output_path}")


# ======================= Main ============================
def main():
    data_path = "/kaggle/input/spider-data/spider_data/"
    model_path = "/kaggle/input/spider-data/t5_text2sql_final.pt"


    """
    print("Loading and preparing data...")
    train_data, dev_data, tables = load_spider_data(data_path)
    train_data, val_data = split_train_val(train_data)
    schema_dict = get_schema_dict(tables)

    train_pairs = prepare_input_output(train_data, schema_dict)
    val_pairs = prepare_input_output(val_data, schema_dict)

    tokenizer = T5Tokenizer.from_pretrained("t5-base")
    tokenizer.add_tokens(["<schema>"])

    model = T5ForConditionalGeneration.from_pretrained("t5-base")
    model.resize_token_embeddings(len(tokenizer))

    train_loader, val_loader = prepare_dataloaders(train_pairs, val_pairs, tokenizer)

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("Starting training...")
    train(model, train_loader, val_loader, optimizer, device, tokenizer, save_path=model_path)
    """
    generate_predictions(model_path, data_path)



if __name__ == "__main__":
    main()





2025-05-10 05:59:21.407499: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746856761.694057      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746856761.777648      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered



🔍 Running inference and post-processing...


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

  model.load_state_dict(torch.load(model_path, map_location=device))
  0%|          | 0/1034 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
100%|██████████| 1034/1034 [10:54<00:00,  1.58it/s]


✅ Predictions saved to /kaggle/working/predictions.json





In [None]:
print("ho")

In [2]:
import json
import os

# Clone evaluation repo if not already
if not os.path.exists("test-suite-sql-eval"):
    !git clone https://github.com/taoyds/test-suite-sql-eval.git

# 1. Fix the gold file (dev.json)
with open("/kaggle/input/spider-data/spider_data/dev.json") as f:
    gold_raw = json.load(f)

with open("/kaggle/working/fixed_gold.txt", "w") as f:
    for ex in gold_raw:
        gold_query = ex['query'].strip().lower()  # convert to lowercase
        db_id = ex['db_id']
        if(gold_query[-1]!=";"):
            gold_query = gold_query + ";"
        f.write(f"{gold_query}\t{db_id}\n")

# 2. Fix the predictions file (predictions.json)
with open("/kaggle/working/predictions.json") as f:
    pred_raw = json.load(f)

with open("/kaggle/working/fixed_pred.txt", "w") as f:
    for ex in pred_raw:
        pred_query = ex['query'].strip().lower()  # convert to lowercase
        db_id = ex['db_id']
        f.write(f"{pred_query}\t{db_id}\n")

# 3. Check format of fixed_gold.txt
with open("/kaggle/working/fixed_gold.txt") as f:
    for i, line in enumerate(f, 1):
        try:
            parts = line.strip().split("\t")
            if len(parts) != 2:
                print(f"[GOLD] Line {i} is malformed:", parts)
        except Exception as e:
            print(f"[GOLD] Line {i} error:", e)

# 4. Check format of fixed_pred.txt
with open("/kaggle/working/fixed_pred.txt") as f:
    for i, line in enumerate(f, 1):
        try:
            parts = line.strip().split("\t")
            if len(parts) != 2:
                print(f"[PRED] Line {i} is malformed:", parts)
        except Exception as e:
            print(f"[PRED] Line {i} error:", e)

# 5. Run evaluation
!python test-suite-sql-eval/evaluation.py \
  --gold /kaggle/working/fixed_gold.txt \
  --pred /kaggle/working/fixed_pred.txt \
  --db /kaggle/input/spider-data/spider_data/database/ \
  --table /kaggle/input/spider-data/spider_data/tables.json \
  --etype all


Cloning into 'test-suite-sql-eval'...
remote: Enumerating objects: 61, done.[K
remote: Counting objects: 100% (30/30), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 61 (delta 21), reused 16 (delta 16), pack-reused 31 (from 1)[K
Receiving objects: 100% (61/61), 618.38 KiB | 5.62 MiB/s, done.
Resolving deltas: 100% (25/25), done.
medium pred: select avg(age) , min(age) , max(age) from singer where is_male = 'french';
medium gold: select avg(age) ,  min(age) ,  max(age) from singer where country  =  'france';

medium pred: select max(capacity) , avg(average) from stadium;
medium gold: select max(capacity), average from stadium;

medium pred: select name , capacity from stadium order by avg(capacity) desc limit 1;
medium gold: select name ,  capacity from stadium order by average desc limit 1;

medium pred: select name , capacity from stadium order by avg(capacity) desc limit 1;
medium gold: select name ,  capacity from stadium order by average desc limit 1;


In [29]:
#!git clone https://github.com/taoyds/spider.git

!python spider/evaluation.py \
  --gold /kaggle/input/spider-data/spider_data/dev.json \
  --pred /kaggle/working/predictions.json \
  --db /kaggle/input/spider-data/spider_data/database/ \
  --table /kaggle/input/spider-data/spider_data/tables.json


usage: evaluation.py [-h] [--gold GOLD] [--pred PRED] [--db DB] [--table TABLE]
                     [--etype {all,exec,match}] [--plug_value] [--keep_distinct]
                     [--progress_bar_for_each_datapoint]
evaluation.py: error: unrecognized arguments:   --etype all
