## Required Imports

In [24]:
import pandas as pd
import numpy as np
import json
from transformers import T5Tokenizer
from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import Dataset
import sqlparse
import torch


## Getting data from Github

In [2]:
!git clone https://github.com/salesforce/WikiSQL.git

Cloning into 'WikiSQL'...
remote: Enumerating objects: 389, done.[K
remote: Counting objects: 100% (195/195), done.[K
remote: Compressing objects: 100% (41/41), done.[K
remote: Total 389 (delta 186), reused 154 (delta 154), pack-reused 194 (from 1)[K
Receiving objects: 100% (389/389), 50.72 MiB | 55.66 MiB/s, done.
Resolving deltas: 100% (213/213), done.


In [3]:
!cd WikiSQL && tar -xvf data.tar.bz2



data/
data/train.jsonl
data/test.tables.jsonl
data/test.db
data/dev.tables.jsonl
data/dev.db
data/test.jsonl
data/train.tables.jsonl
data/train.db
data/dev.jsonl


In [4]:
!ls WikiSQL/data

dev.db	   dev.tables.jsonl  test.jsonl		train.db     train.tables.jsonl
dev.jsonl  test.db	     test.tables.jsonl	train.jsonl


In [5]:
with open("WikiSQL/data/train.jsonl") as f:
    sample = json.loads(next(f))

with open("WikiSQL/data/train.tables.jsonl") as f:
    table = json.loads(next(f))

print("QUESTION:\n", sample["question"])
print("\nSQL LABEL:\n", sample["sql"])
print("\nTABLE SCHEMA:\n", table)

QUESTION:
 Tell me what the notes are for South Australia 

SQL LABEL:
 {'sel': 5, 'conds': [[3, 0, 'SOUTH AUSTRALIA']], 'agg': 0}

TABLE SCHEMA:
 {'id': '1-1000181-1', 'header': ['State/territory', 'Text/background colour', 'Format', 'Current slogan', 'Current series', 'Notes'], 'types': ['text', 'text', 'text', 'text', 'text', 'text'], 'rows': [['Australian Capital Territory', 'blue/white', 'Yaa·nna', 'ACT · CELEBRATION OF A CENTURY 2013', 'YIL·00A', 'Slogan screenprinted on plate'], ['New South Wales', 'black/yellow', 'aa·nn·aa', 'NEW SOUTH WALES', 'BX·99·HI', 'No slogan on current series'], ['New South Wales', 'black/white', 'aaa·nna', 'NSW', 'CPX·12A', 'Optional white slimline series'], ['Northern Territory', 'ochre/white', 'Ca·nn·aa', 'NT · OUTBACK AUSTRALIA', 'CB·06·ZZ', 'New series began in June 2011'], ['Queensland', 'maroon/white', 'nnn·aaa', 'QUEENSLAND · SUNSHINE STATE', '999·TLG', 'Slogan embossed on plate'], ['South Australia', 'black/white', 'Snnn·aaa', 'SOUTH AUSTRALIA'

## Converting WikiSQL into SQL String

In [31]:
AGG_OPS = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
COND_OPS = ["=", ">", "<", "OP"]

def convert_sql(sample, table):
  col_names= table['header']

  select_col= col_names[sample['sql']['sel']]

  agg= AGG_OPS[sample['sql']['agg']]

  if agg:
    select_part = f"{agg}({select_col})"
  else:
    select_part = select_col


  where_clause=[]

  for c in sample['sql']['conds']:
    column= col_names[c[0]]
    op= COND_OPS[c[1]]
    value= c[2]
    where_clause.append(f"{column}{op}'{value}'")

  where = " AND ".join(where_clause)

  query = f"SELECT {select_part} FROM table"

  if where:
    query += f" WHERE {where}"

  return query

In [7]:
# schema-aware input
def build_schema(table):
  return ", ".join([
  f"{col}({typ})"
  for col, typ in zip(table['header'], table['types'])
  ])

## Creating Training Examples

In [34]:
training_pairs=[]

with open("WikiSQL/data/train.jsonl") as f_data, \
     open("WikiSQL/data/train.tables.jsonl") as f_tables:

     tables={}

     for t in f_tables:
      table_obj= json.loads(t)
      tables[table_obj['id']]= table_obj

     for line in f_data:
        sample = json.loads(line)
        actual_table = tables[sample['table_id']]

        schema= build_schema(actual_table)

        sql = convert_sql(sample, actual_table)

        inp= f"Schema: {schema}, Question: {sample['question']}"

        training_pairs.append((inp,sql))

     training_pairs= training_pairs[:8000]


In [35]:
len(training_pairs)

8000

## Tokenization

In [36]:
#Tokenize for T5
tokenizer= T5Tokenizer.from_pretrained("t5-small")
def tokenize(example):
  return tokenizer(
      example[0],
      text_target = example[1],
      truncation= True,
      padding= "max_length",
      max_length= 256
  )

In [37]:
# Converting training pairs into tokenized data
class SQLDataset(Dataset):
    def __init__(self, pairs, tokenizer):
        self.pairs = pairs
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        inp, out = self.pairs[idx]
        tokens = self.tokenizer(
            inp,
            text_target=out,
            truncation=True,
            padding="max_length",
            max_length=256
        )
        return {
            "input_ids": tokens["input_ids"],
            "attention_mask": tokens["attention_mask"],
            "labels": tokens["labels"]
        }




In [38]:
tokenized_data= SQLDataset(training_pairs,tokenizer)

## Fine-tuning T5

In [39]:
model = T5ForConditionalGeneration.from_pretrained('t5-small')

args= TrainingArguments(
    output_dir= './sql_model',
    learning_rate= 3e-4,
    per_device_train_batch_size=16,
    num_train_epochs=2,
    save_steps=50
)

trainer= Trainer(
    model=model,
    args=args,
    train_dataset= tokenized_data
)

trainer.train()


Step,Training Loss
500,0.1153
1000,0.027


TrainOutput(global_step=1000, training_loss=0.0711275749206543, metrics={'train_runtime': 868.3336, 'train_samples_per_second': 18.426, 'train_steps_per_second': 1.152, 'total_flos': 1082734411776000.0, 'train_loss': 0.0711275749206543, 'epoch': 2.0})

In [40]:
trainer.save_model("./sql_model")
tokenizer.save_pretrained("./sql_model")

('./sql_model/tokenizer_config.json',
 './sql_model/special_tokens_map.json',
 './sql_model/spiece.model',
 './sql_model/added_tokens.json')

In [41]:
def is_valid_sql(sql):
    return bool(sqlparse.parse(sql))

## Inference

In [42]:
tokenizer = T5Tokenizer.from_pretrained("./sql_model")
model = T5ForConditionalGeneration.from_pretrained("./sql_model")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.eval()

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [43]:
def generate_sql(schema, question, max_length=128):
    input_text = f"Schema: {schema}, Question: {question}"

    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        truncation=True,
        padding=True
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_beams=4,
            no_repeat_ngram_size=3,
            early_stopping=True
        )

    sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return sql


In [63]:
schema = "Name(text), Age(number), Salary(number), Department(text)"
question = "What is the maximum salary?"
sql= generate_sql(schema, question)

if is_valid_sql(sql):
    print("Generated SQL:", sql)
else:
    print("Invalid SQL generated")


Generated SQL: SELECT MAX(Salary) FROM table


In [64]:
schema = "Name(text), Age(number), Salary(number), Department(text)"
question = "What is the salary of Matthew?"

sql= generate_sql(schema, question)

if is_valid_sql(sql):
    print("Generated SQL:", sql)
else:
    print("Invalid SQL generated")


Generated SQL: SELECT Salary FROM table WHERE Name='Matthew'


In [65]:
schema = "Name(text), Age(number), Salary(number), Department(text)"
question = "How many records are there in Sales Department?"

sql= generate_sql(schema, question)

if is_valid_sql(sql):
    print("Generated SQL:", sql)
else:
    print("Invalid SQL generated")


Generated SQL: SELECT COUNT(Name) FROM table WHERE Department='Sales Department'
