### Language to SQL transformers 
#### with graph relationships Graphix T5

In [1]:
#let's exlore the data
import datasets


from datasets import load_dataset

ds = load_dataset("xlangai/spider")

  from .autonotebook import tqdm as notebook_tqdm
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Generating train split: 100%|██████████| 7000/7000 [00:00<00:00, 156800.60 examples/s]
Generating validation split: 100%|██████████| 1034/1034 [00:00<00:00, 117166.29 examples/s]


In [None]:
print(ds['train'][0])


{'db_id': 'department_management', 'query': 'SELECT count(*) FROM head WHERE age  >  56', 'question': 'How many heads of the departments are older than 56 ?', 'query_toks': ['SELECT', 'count', '(', '*', ')', 'FROM', 'head', 'WHERE', 'age', '>', '56'], 'query_toks_no_value': ['select', 'count', '(', '*', ')', 'from', 'head', 'where', 'age', '>', 'value'], 'question_toks': ['How', 'many', 'heads', 'of', 'the', 'departments', 'are', 'older', 'than', '56', '?']}


In [None]:
import pandas as pd

df = pd.DataFrame(ds['train'])

df_test = pd.DataFrame(ds['validation'])

In [7]:
df.head(2)

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks
0,department_management,SELECT count(*) FROM head WHERE age > 56,How many heads of the departments are older th...,"[SELECT, count, (, *, ), FROM, head, WHERE, ag...","[select, count, (, *, ), from, head, where, ag...","[How, many, heads, of, the, departments, are, ..."
1,department_management,"SELECT name , born_state , age FROM head ORD...","List the name, born state and age of the heads...","[SELECT, name, ,, born_state, ,, age, FROM, he...","[select, name, ,, born_state, ,, age, from, he...","[List, the, name, ,, born, state, and, age, of..."


In [None]:
import pandas as pd

# Set display options to see the full output
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.width', None)

tables = pd.read_json('tables.json')

# Now print the DataFrame
print(tables[tables['db_id'] == 'department_management']['column_names'])

Unnamed: 0,column_names,column_names_original,column_types,db_id,foreign_keys,primary_keys,table_names,table_names_original
159,"[[-1, *], [0, department id], [0, name], [0, creation], [0, ranking], [0, budget in billions], [0, num employees], [1, head id], [1, name], [1, born state], [1, age], [2, department id], [2, head id], [2, temporary acting]]","[[-1, *], [0, Department_ID], [0, Name], [0, Creation], [0, Ranking], [0, Budget_in_Billions], [0, Num_Employees], [1, head_ID], [1, name], [1, born_state], [1, age], [2, department_ID], [2, head_ID], [2, temporary_acting]]","[text, number, text, text, number, number, number, number, text, text, number, number, number, text]",department_management,"[[12, 7], [11, 1]]","[1, 7, 11]","[department, head, management]","[department, head, management]"


In [20]:
def extract_schema_for_db(db_id, tables_df):
    # Find the schema for the specific database ID
    
    tables_data = tables_df.to_dict(orient='records')
    schema = next((db for db in tables_data if db["db_id"] == db_id), None)
    if not schema:
        raise ValueError(f"Schema for db_id {db_id} not found.")

    # Extract tables and columns
    tables = schema["table_names"]
    columns = schema["column_names"]
    primary_keys = schema["primary_keys"]
    foreign_keys = schema["foreign_keys"]

    # Prepare nodes and edges for the graph
    nodes = tables + [f"{col[1]}" for col in columns if col[0] != -1]  # Exclude global columns
    edges = []

    # Table-Column edges
    for col in columns:
        if col[0] != -1:  # Skip global columns
            edges.append((tables[col[0]], col[1]))

    # Primary Key edges
    for pk in primary_keys:
        edges.append((f"{columns[pk][1]}", f"{columns[pk][1]} (PK)"))

    # Foreign Key edges
    for fk in foreign_keys:
        from_col = f"{columns[fk[0]][1]}"
        to_col = f"{columns[fk[1]][1]}"
        edges.append((from_col, to_col))

    return {"nodes": nodes, "edges": edges}




In [21]:
# Example usage
db_id = "department_management"

schema_graph = extract_schema_for_db(db_id, tables)
schema_graph

{'nodes': ['department',
  'head',
  'management',
  'department id',
  'name',
  'creation',
  'ranking',
  'budget in billions',
  'num employees',
  'head id',
  'name',
  'born state',
  'age',
  'department id',
  'head id',
  'temporary acting'],
 'edges': [('department', 'department id'),
  ('department', 'name'),
  ('department', 'creation'),
  ('department', 'ranking'),
  ('department', 'budget in billions'),
  ('department', 'num employees'),
  ('head', 'head id'),
  ('head', 'name'),
  ('head', 'born state'),
  ('head', 'age'),
  ('management', 'department id'),
  ('management', 'head id'),
  ('management', 'temporary acting'),
  ('department id', 'department id (PK)'),
  ('head id', 'head id (PK)'),
  ('department id', 'department id (PK)'),
  ('head id', 'head id'),
  ('department id', 'department id')]}

In [None]:
def prepare_t5_inputs_targets(data, tables_data):
    inputs = []
    targets = []

    for record in data:
        db_id = record["db_id"]
        question = record["question"]
        query = record["query"]

        # Extract schema for the db_id
        schema = extract_schema_for_db(db_id, tables_data)

        # Serialize the schema
        tables = ", ".join(schema["nodes"][: len(schema["nodes"]) // 2])
        edges = schema["edges"]  # Log table, refs
