In [3]:
# Install and Import Dependencies

!pip install -q transformers datasets gradio torchtext sqlalchemy nltk accelerate

import os
import json
import sqlite3
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from transformers import AutoTokenizer, AutoModel
from typing import List, Dict, Tuple
import gradio as gr
import re
from collections import defaultdict
import nltk
nltk.download('punkt')

scaler = GradScaler()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Load Spider Dataset

import os
import json
import zipfile
from google.colab import drive

drive.mount('/content/drive', force_remount=True)

zip_path = "/content/drive/MyDrive/spider_data.zip"
extract_dir = "/content/spider"

if not os.path.exists(extract_dir):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        for member in zip_ref.namelist():
            if member.startswith("spider_data/") and not member.endswith("/"):
                rel_path = os.path.relpath(member, "spider_data")
                out_path = os.path.join(extract_dir, rel_path)
                os.makedirs(os.path.dirname(out_path), exist_ok=True)
                with open(out_path, "wb") as f:
                    f.write(zip_ref.read(member))
    print("Extracted spider_data →", extract_dir)
else:
    print("Already extracted.")

# Load helper
def load_json(filename):
    with open(filename, "r", encoding="utf-8") as f:
        return json.load(f)

spider_dir = "/content/spider"
db_dir = os.path.join(spider_dir, "database")

train_data = load_json(os.path.join(spider_dir, "train_spider.json"))
dev_data = load_json(os.path.join(spider_dir, "dev.json"))
table_schemas = load_json(os.path.join(spider_dir, "tables.json"))
schema_dict = {schema['db_id']: schema for schema in table_schemas}

print(f"Loaded {len(train_data)} training examples")
print(f"Loaded {len(dev_data)} dev examples")
print(f"Loaded {len(schema_dict)} schema definitions")

In [5]:
# Build Schema Graph + Link Question to Schema

import nltk
from typing import List, Dict, Tuple
from collections import defaultdict
nltk.download('punkt')

def tokenize(text: str) -> List[str]:
    return nltk.word_tokenize(text.lower())

def extract_schema_elements(schema: Dict) -> Tuple[List[str], List[str]]:
    table_names = [t.lower() for t in schema['table_names_original']]
    column_names = []
    for table_id, col_name in schema['column_names_original']:
        if col_name == "*":
            continue
        full_col = f"{table_names[table_id]}.{col_name.lower()}" if table_id >= 0 else col_name.lower()
        column_names.append(full_col)
    return column_names, table_names

def get_values_from_db(db_id: str, schema: Dict, db_dir: str) -> Dict[str, List[str]]:
    value_dict = defaultdict(list)
    try:
        conn = sqlite3.connect(os.path.join(db_dir, db_id, f"{db_id}.sqlite"))
        cursor = conn.cursor()
        table_names = [t.lower() for t in schema['table_names_original']]
        for table_id, col_name in schema['column_names_original']:
            if col_name == "*" or table_id < 0:
                continue
            table = table_names[table_id]
            col = col_name.lower()
            try:
                cursor.execute(f"SELECT DISTINCT {col} FROM {table} LIMIT ")
                values = [str(row[0]).lower() for row in cursor.fetchall() if row[0] is not None]
                value_dict[f"{table}.{col}"] = values
            except Exception:
                continue
        conn.close()
    except Exception as e:
        print(f"[DB Error] Failed to connect for {db_id}: {e}")
    return value_dict

def get_relations(question: str, schema: Dict, db_dir: str) -> Dict:
    q_tokens = tokenize(question)
    col_names, tab_names = extract_schema_elements(schema)
    schema_tokens = col_names + tab_names
    all_nodes = q_tokens + schema_tokens
    edge_types = {}

    try:
        value_dict = get_values_from_db(schema['db_id'], schema, db_dir)
        for i, q_tok in enumerate(q_tokens):
            for full_col, values in value_dict.items():
                if q_tok in values and full_col in schema_tokens:
                    j = schema_tokens.index(full_col)
                    edge_types[(i, len(q_tokens) + j)] = "value_match"
                    edge_types[(len(q_tokens) + j, i)] = "value_match"
    except Exception as e:
        print(f"[Linking] Value linking failed: {e}")

    for i, q_tok in enumerate(q_tokens):
        for j, s_tok in enumerate(schema_tokens):
            if q_tok in s_tok.split("_") or s_tok in q_tok:
                edge_types[(i, len(q_tokens) + j)] = "match"
                edge_types[(len(q_tokens) + j, i)] = "match"

    for i, col in enumerate(col_names):
        for j, tab in enumerate(tab_names):
            if col.startswith(tab + "."):
                ci = len(q_tokens) + i
                tj = len(q_tokens) + len(col_names) + j
                edge_types[(ci, tj)] = "belongs_to"
                edge_types[(tj, ci)] = "belongs_to"

    for i, col_i in enumerate(col_names):
        for j, col_j in enumerate(col_names):
            if i != j and col_i.split('.')[0] == col_j.split('.')[0]:
                ci, cj = len(q_tokens) + i, len(q_tokens) + j
                edge_types[(ci, cj)] = "same_table"

    for fk_pair in schema.get('foreign_keys', []):
        col1_idx, col2_idx = fk_pair
        if col1_idx < len(col_names) and col2_idx < len(col_names):
            c1 = len(q_tokens) + col1_idx
            c2 = len(q_tokens) + col2_idx
            edge_types[(c1, c2)] = "foreign_key_forward"
            edge_types[(c2, c1)] = "foreign_key_backward"

    for pk_idx in schema.get('primary_keys', []):
        if pk_idx < len(col_names):
            col = len(q_tokens) + pk_idx
            table_name = col_names[pk_idx].split('.')[0]
            if table_name in tab_names:
                tab = len(q_tokens) + len(col_names) + tab_names.index(table_name)
                edge_types[(col, tab)] = "primary_key"
                edge_types[(tab, col)] = "primary_key"

    return {
        "tokens": all_nodes,
        "edges": edge_types,
        "q_len": len(q_tokens),
        "schema_len": len(schema_tokens),
        "column_names": col_names,
        "table_names": tab_names
    }

In [53]:
# Encoder(Transformer + Relation-Aware Attention)

from collections import defaultdict

class EmbeddingEncoder(nn.Module):
    def __init__(self, model_name='microsoft/deberta-v3-base'):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.transformer = AutoModel.from_pretrained(model_name)
        self.output_dim = self.transformer.config.hidden_size

    def forward(self, word_lists: List[List[str]]) -> Tuple[torch.Tensor, None]:
        tokenized = self.tokenizer(
            word_lists,
            padding=True,
            is_split_into_words=True,
            truncation=True,
            return_tensors='pt',
            return_attention_mask=True
        )

        input_ids = tokenized['input_ids'].to(device)
        attention_mask = tokenized['attention_mask'].to(device)

        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state

        word_embeddings = []
        for i in range(len(word_lists)):
            word_ids = tokenized.word_ids(batch_index=i)
            grouped = defaultdict(list)
            for j, wid in enumerate(word_ids):
                if wid is not None:
                    grouped[wid].append(last_hidden[i, j])
            avg_embeds = [torch.stack(group).mean(0) for _, group in sorted(grouped.items())]
            word_embeddings.append(torch.stack(avg_embeds))

        max_len = max(x.size(0) for x in word_embeddings)
        padded = torch.stack([
            F.pad(x, (0, 0, 0, max_len - x.size(0))) for x in word_embeddings
        ])
        return padded.to(device), None


class RATEncoderLayer(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.1, num_relations=10):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.rel_k = nn.Embedding(num_relations, dim)
        self.rel_v = nn.Embedding(num_relations, dim)
        self.dropout = nn.Dropout(dropout)
        self.ffn = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.ReLU(),
            nn.Linear(4 * dim, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x: torch.Tensor, rel_mat: torch.Tensor) -> torch.Tensor:
        B, L, D = x.size()
        rel_kv = self.rel_k(rel_mat)  # [B, L, L, D]
        rel_vv = self.rel_v(rel_mat)

        q = x  # [B, L, D]
        k = x.unsqueeze(2) + rel_kv  # [B, L, L, D]
        v = x.unsqueeze(2) + rel_vv

        k = k.view(B * L, L, D)
        v = v.view(B * L, L, D)
        q = q.view(B * L, 1, D)

        out, _ = self.attn(q, k, v)
        out = out.view(B, L, D)

        x = self.norm1(x + self.dropout(out))
        x = self.norm2(x + self.dropout(self.ffn(x)))
        return x


class RATEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, num_layers=4, num_relations=10):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.layers = nn.ModuleList([
            RATEncoderLayer(hidden_dim, num_heads=8, dropout=0.1, num_relations=num_relations)
            for _ in range(num_layers)
        ])

    def forward(self, x: torch.Tensor, rel_mat: torch.Tensor) -> torch.Tensor:
        x = self.input_proj(x)
        for layer in self.layers:
            x = layer(x, rel_mat)
        return x


In [48]:
# Decoder(Not AST)

import torch
import torch.nn as nn
import torch.nn.functional as F
import sqlparse
import re

RULES = {
    'ROOT': ['sql'],
    'sql': ['select', 'from', 'where', 'group_by', 'having', 'order_by', 'limit', 'set_op'],
    'select': ['Select'],
    'from': ['From'],
    'where': ['None', 'Where'],
    'group_by': ['None', 'GroupBy'],
    'having': ['None', 'Having'],
    'order_by': ['None', 'OrderBy'],
    'limit': ['None', 'Limit'],
    'set_op': ['None', 'Union', 'Intersect', 'Except'],
    'cond': ['Eq', 'Gt', 'Lt', 'Ge', 'Le', 'Ne', 'And', 'Or', 'In', 'Like', 'Between', 'Exists', 'Not'],
    'val_unit': ['Column', 'Minus', 'Plus', 'Times', 'Divide'],
    'table_unit': ['Table', 'TableUnitSql'],
    'agg': ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
}

class SQLTreeDecoder(nn.Module):
    def __init__(self, encoder_dim, hidden_dim, num_values=100):
        super().__init__()
        self.lstm = nn.LSTMCell(encoder_dim, hidden_dim)
        self.rule_classifiers = nn.ModuleDict({
            rule: nn.Linear(hidden_dim, len(constructors)) for rule, constructors in RULES.items()
        })
        self.column_pointer = nn.Linear(hidden_dim, encoder_dim)
        self.table_pointer = nn.Linear(hidden_dim, encoder_dim)
        self.value_generator = nn.Linear(hidden_dim, num_values)
        self.order_direction_classifier = nn.Linear(hidden_dim, 2)
        self.encoder_dim = encoder_dim
        self.hidden_dim = hidden_dim

    def forward(self, encoder_output, graph, beam_size=5):
        h = encoder_output.mean(dim=1).squeeze(0)
        c = torch.zeros_like(h)
        h, c = self.lstm(h, (h, c))
        beam = self.decode_node_beam("ROOT", h, encoder_output, graph, beam_size=beam_size)
        return beam[0]["sql"]

    def forward_supervised(self, encoder_output, graph, labels):
        loss = 0.0
        h = encoder_output.mean(dim=1).squeeze(0)
        c = torch.zeros_like(h)
        h, c = self.lstm(h, (h, c))

        rule_logits = self.rule_classifiers['select'](h)
        loss += F.cross_entropy(rule_logits.unsqueeze(0), torch.tensor([0], device=h.device))

        for col_idx, agg_id in zip(labels['select_cols'], labels['select_aggs']):
            agg_logits = self.rule_classifiers['agg'](h)
            loss += F.cross_entropy(agg_logits.unsqueeze(0), torch.tensor([agg_id], device=h.device))

            col_query = self.column_pointer(h)
            col_scores = torch.matmul(col_query, encoder_output.squeeze(0).T)
            col_idx = min(col_idx, encoder_output.size(1) - 1)
            loss += F.cross_entropy(col_scores.unsqueeze(0), torch.tensor([col_idx], device=h.device))

            h, c = self.lstm(h, (h, c))

        rule_logits = self.rule_classifiers['from'](h)
        loss += F.cross_entropy(rule_logits.unsqueeze(0), torch.tensor([0], device=h.device))

        tab_query = self.table_pointer(h)
        tab_scores = torch.matmul(tab_query, encoder_output.squeeze(0).T)
        tab_idx = min(labels['tab_idx'], encoder_output.size(1) - 1)
        loss += F.cross_entropy(tab_scores.unsqueeze(0), torch.tensor([tab_idx], device=h.device))

        h, c = self.lstm(h, (h, c))

        if labels['conds']:
            rule_logits = self.rule_classifiers['where'](h)
            loss += F.cross_entropy(rule_logits.unsqueeze(0), torch.tensor([1], device=h.device))

            for cond in labels['conds']:
                op_logits = self.rule_classifiers['cond'](h)
                op_id = RULES['cond'].index(cond['op'].capitalize()) if cond['op'].capitalize() in RULES['cond'] else 0
                loss += F.cross_entropy(op_logits.unsqueeze(0), torch.tensor([op_id], device=h.device))

                col_query = self.column_pointer(h)
                col_scores = torch.matmul(col_query, encoder_output.squeeze(0).T)
                col_idx = min(cond['col'], encoder_output.size(1) - 1)
                loss += F.cross_entropy(col_scores.unsqueeze(0), torch.tensor([col_idx], device=h.device))

                val_logits = self.value_generator(h)
                loss += F.cross_entropy(val_logits.unsqueeze(0), torch.tensor([cond['val_id']], device=h.device))

                h, c = self.lstm(h, (h, c))

        return loss


    def decode_node_beam(self, rule, state, enc, graph, beam_size=5):
        if rule not in self.rule_classifiers:
            return [{"sql": f"--{rule}--", "score": 0.0}]

        logits = self.rule_classifiers[rule](state)
        probs = F.log_softmax(logits, dim=-1)
        k = min(beam_size, len(RULES[rule]), probs.size(-1))
        topk_probs, topk_indices = probs.view(-1).topk(k)

        candidates = []
        for i in range(k):
            pred_idx = topk_indices[i].item()
            score = topk_probs[i].item()
            if pred_idx >= len(RULES[rule]):
                continue
            constructor = RULES[rule][pred_idx]

            if constructor == 'Select':
                result = self.decode_select(state, enc, graph)
            elif constructor == 'From':
                result = self.decode_from(state, enc, graph)
            elif constructor == 'Where':
                result = self.decode_where(state, enc, graph)
            elif constructor == 'GroupBy':
                result = self.decode_group_by(state, enc, graph)
            elif constructor == 'Having':
                result = self.decode_having(state, enc, graph)
            elif constructor == 'OrderBy':
                result = self.decode_order_by(state, enc, graph)
            elif constructor == 'Limit':
                result = self.decode_limit(state)
            elif constructor in {'Union', 'Intersect', 'Except'}:
                result = self.decode_set_op(constructor, state)
            elif rule == 'ROOT':
                result = self.decode_node_beam('sql', state, enc, graph)[0]["sql"]
            elif rule == 'sql':
                clauses = []
                total_score = score
                for clause_rule in RULES['sql']:
                    if clause_rule not in self.rule_classifiers:
                        continue
                    clause_logits = self.rule_classifiers[clause_rule](state)
                    clause_probs = F.log_softmax(clause_logits, dim=-1)
                    clause_probs = clause_probs.view(-1)
                    clause_idx = torch.argmax(clause_probs).item()
                    clause_idx = min(clause_idx, len(RULES[clause_rule]) - 1)
                    clause_score = clause_probs[clause_idx].item()
                    clause_constructor = RULES[clause_rule][clause_idx]

                    total_score += clause_score
                    if clause_constructor != 'None':
                        sub_beam = self.decode_node_beam(clause_rule, state, enc, graph, beam_size=1)
                        if sub_beam and "sql" in sub_beam[0]:
                            clauses.append(sub_beam[0]["sql"])
                        else:
                            print(f"[Warning] Empty beam for clause '{clause_rule}' — skipping.")

                has_group_by = any("GROUP BY" in c for c in clauses)
                clauses = [c for c in clauses if not c.startswith("HAVING") or has_group_by]

                result = ' '.join(clauses)
                candidates.append({"sql": result, "score": total_score})
                return candidates
            else:
                result = f"--Unhandled {constructor}--"


            candidates.append({"sql": result, "score": score})

        return sorted(candidates, key=lambda x: -x["score"])[:beam_size]

    def decode_select(self, state, enc, graph, threshold=0.2):
        agg_logits = self.rule_classifiers['agg'](state)
        agg_probs = F.softmax(agg_logits, dim=-1)
        col_query = self.column_pointer(state).unsqueeze(0)
        enc_ = enc[0]
        col_scores = torch.matmul(col_query, enc_.T).squeeze(0)
        col_probs = F.softmax(col_scores, dim=-1).view(-1)
        selected = [i for i, p in enumerate(col_probs.detach().cpu().tolist()) if p > threshold]
        if not selected:
            selected = [torch.argmax(col_scores).item()]
        parts = []
        for i in selected:
            i = min(i, len(graph['column_names']) - 1)
            col = graph['column_names'][i]
            agg_idx = torch.argmax(agg_probs).item()
            agg_idx = min(agg_idx, len(RULES['agg']) - 1)
            agg = RULES['agg'][agg_idx]

            parts.append(f"{agg}({col})" if agg else col)
        return "SELECT " + ", ".join(parts)

    def decode_from(self, state, enc, graph):
        tab_query = self.table_pointer(state).unsqueeze(0)
        enc_ = enc[0]
        tab_scores = torch.matmul(tab_query, enc_.T).squeeze(0)
        tab_idx = torch.argmax(tab_scores).item()
        tab_idx = min(tab_idx, len(graph['table_names']) - 1)
        return f"FROM {graph['table_names'][tab_idx]}"

    def decode_where(self, state, enc, graph):
        col_query = self.column_pointer(state).unsqueeze(0)
        enc_ = enc[0]
        col_scores = torch.matmul(col_query, enc_.T).squeeze(0)
        col_idx = torch.argmax(col_scores).item()
        col_idx = min(col_idx, len(graph['column_names']) - 1)
        col = graph['column_names'][col_idx]
        op_logits = self.rule_classifiers['cond'](state)
        op_idx = torch.argmax(op_logits).item()
        op_idx = min(op_idx, len(RULES['cond']) - 1)
        op = RULES['cond'][op_idx]

        val_logits = self.value_generator(state)
        val_id = torch.argmax(val_logits).item()
        val = graph.get("value_vocab_inv", {}).get(val_id, f"val{val_id}")
        return f"WHERE {col} {op} '{val}'"


    def decode_group_by(self, state, enc, graph, threshold=0.2):
        col_query = self.column_pointer(state).unsqueeze(0)
        enc_ = enc[0]
        col_scores = torch.matmul(col_query, enc_.T).squeeze(0)
        col_probs = F.softmax(col_scores, dim=-1).view(-1)
        selected = [i for i, p in enumerate(col_probs.detach().cpu().tolist()) if p > threshold]
        if not selected:
            selected = [torch.argmax(col_scores).item()]
        cols = [graph['column_names'][i] if i < len(graph['column_names']) else f"col{i}" for i in selected]
        return "GROUP BY " + ", ".join(cols)

    def decode_having(self, state, enc, graph):
        col_query = self.column_pointer(state).unsqueeze(0)
        enc_ = enc[0]
        col_scores = torch.matmul(col_query, enc_.T).squeeze(0)
        col_idx = torch.argmax(col_scores).item()
        col_idx = min(col_idx, len(graph['column_names']) - 1)
        col = graph['column_names'][col_idx]
        val_logits = self.value_generator(state)
        val_id = torch.argmax(val_logits).item()
        val = graph.get("value_vocab_inv", {}).get(val_id, f"val{val_id}")
        return f"HAVING COUNT({col}) > '{val}'"

    def decode_order_by(self, state, enc, graph, threshold=0.2):
        col_query = self.column_pointer(state).unsqueeze(0)
        enc_ = enc[0]
        col_scores = torch.matmul(col_query, enc_.T).squeeze(0)
        col_probs = F.softmax(col_scores, dim=-1).view(-1)
        selected = [i for i, p in enumerate(col_probs.detach().cpu().tolist()) if p > threshold]
        if not selected:
            selected = [torch.argmax(col_scores).item()]

        dir_logits = self.order_direction_classifier(state)
        dir_probs = F.softmax(dir_logits, dim=-1)

        dir_idx = torch.argmax(dir_probs).item()
        dir_idx = min(dir_idx, len(['ASC', 'DESC']) - 1)
        direction = ['ASC', 'DESC'][dir_idx]

        return "ORDER BY " + ", ".join([
            f"{graph['column_names'][i]} {direction}" if i < len(graph['column_names']) else f"col{i} {direction}"
            for i in selected
        ])

    def decode_limit(self, state):
        return "LIMIT 10"

    def decode_set_op(self, constructor, state):
        return f"{constructor.upper()} SELECT ..."

In [6]:
# Training Loop

from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
import torch
import gc

gc.collect()
torch.cuda.empty_cache()

# Initialize models
encoder_model = EmbeddingEncoder().to(device)
rat_encoder = RATEncoder(input_dim=encoder_model.output_dim).to(device)
decoder = SQLTreeDecoder(
    encoder_dim=256,
    hidden_dim=256,
    num_values=100
).to(device)
decoder.num_values = 100

# Load dataset
train_dataset = SpiderMiniDataset(train_data, schema_dict, db_dir)
dev_dataset = SpiderMiniDataset(dev_data, schema_dict, db_dir)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

if 'value_vocab' not in globals():
    def build_value_vocab(schema_dict, db_dir, max_size=100):
        vocab = {}
        idx = 1
        for db_id, schema in schema_dict.items():
            values = get_values_from_db(db_id, schema, db_dir)
            for val_list in values.values():
                for val in val_list:
                    if val not in vocab:
                        vocab[val] = idx
                        idx += 1
                    if len(vocab) >= max_size:
                        return vocab
        return vocab

    value_vocab = build_value_vocab(schema_dict, db_dir, max_size=100)
    value_vocab_inv = {v: k for k, v in value_vocab.items()}

optimizer = torch.optim.AdamW(list(encoder_model.parameters()) +
                              list(rat_encoder.parameters()) +
                              list(decoder.parameters()), lr=2e-5)
scaler = GradScaler()

best_acc = 0.0
patience = 3
epochs_no_improve = 0
max_epochs = 20

acc = evaluate_execution_accuracy(decoder, encoder_model, rat_encoder, dev_dataset, db_dir)
em_score = evaluate_exact_match(decoder, encoder_model, rat_encoder, dev_dataset, db_dir)
print(f"Initial Dev Acc: {acc:.2%}, Exact Match: {em_score:.2%}")

for epoch in range(max_epochs):
    encoder_model.train(); rat_encoder.train(); decoder.train()
    print(f"\nEpoch {epoch+1}/{max_epochs}")

    for step, batch in enumerate(train_loader):
        optimizer.zero_grad()
        with autocast(device_type="cuda"):
            batch['x_embed'] = batch['x_embed'].float()
            x_all = rat_encoder(batch['x_embed'], batch['rel_mat'])
            loss_sum = 0.0

            for i in range(x_all.size(0)):
                try:
                    graph = batch['graphs'][i]
                    db_id = batch['db_ids'][i]
                    schema = schema_dict[db_id]
                    graph['value_vocab_inv'] = value_vocab_inv

                    labels = extract_labels_from_sql(batch['queries'][i], schema, value_vocab)
                    x = x_all[i:i+1]

                    loss = decoder.forward_supervised(x, graph, labels)
                    if loss is not None and torch.isfinite(loss):
                        loss_sum += loss
                except Exception as e:
                    print(f"[Skip] Step {step}, Example {i}: {e}")
                    continue

        if loss_sum == 0.0:
            continue

        scaler.scale(loss_sum).backward()
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), 5.0)
        scaler.step(optimizer)
        scaler.update()
        torch.cuda.empty_cache()

        if step % 10 == 0:
            print(f"🔹 Step {step} - Loss: {loss_sum.item():.4f}")


    acc = evaluate_execution_accuracy(decoder, encoder_model, rat_encoder, dev_dataset, db_dir)
    em_score = evaluate_exact_match(decoder, encoder_model, rat_encoder, dev_dataset, db_dir)
    print(f"📊 Epoch {epoch} — Dev Acc: {acc:.2%}, Exact Match: {em_score:.2%}")

    if acc > best_acc:
        best_acc = acc
        epochs_no_improve = 0
        torch.save({
            'encoder': encoder_model.state_dict(),
            'rat': rat_encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'optimizer': optimizer.state_dict()
        }, "/content/drive/MyDrive/rat_sql_best.pt")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping in", patience, "epochs.")
            break

In [2]:
# Gradio UI
import gradio as gr
import sqlite3

def extract_schema_from_sqlite(sqlite_path):
    conn = sqlite3.connect(sqlite_path)
    cursor = conn.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [row[0] for row in cursor.fetchall()]

    schema = {
        'table_names_original': tables,
        'column_names_original': [],
        'db_id': 'custom'
    }

    for i, table in enumerate(tables):
        cursor.execute(f"PRAGMA table_info({table});")
        for col in cursor.fetchall():
            col_name = col[1]
            schema['column_names_original'].append((i, col_name))

    conn.close()
    return schema

def predict_sql_from_uploaded_db(question, db_file):
    schema = extract_schema_from_sqlite(db_file.name)
    graph = get_relations(question, schema, db_dir=os.path.dirname(db_file.name))
    graph['value_vocab_inv'] = value_vocab_inv
    tokens = graph['tokens']

    with torch.no_grad():
        x_embed, _ = encoder_model([tokens])
        rel_mat = build_relation_matrix(graph).unsqueeze(0).to(device)
        x = rat_encoder(x_embed, rel_mat)
        pred_sql = decoder(x, graph)

    return pred_sql

gr.Interface(
    fn=predict_sql_from_uploaded_db,
    inputs=[
        gr.Textbox(label="Enter your natural language question"),
        gr.File(label="Upload your SQLite (.sqlite) database")
    ],
    outputs=gr.Textbox(label="Generated SQL Query"),
    title="RAT-SQL Decoder (2025)",
    description="Upload a SQLite DB and ask a question. It will generate a SQL query using the RAT-SQL decoder."
).launch()

It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://dcdbced473b64a6c08.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


