# Modify Schemalink

## JSON Structure Description

The JSON embedded_data object has the following structure:
```json
{
  "name": "domain name",
  "tables": {
    "table_1" : {
      "description" : {
        "text" : "this is a description of table_1",
        "vector" : [Vector]
      },
      "datatypes" : {
        "JOIN_KEY" : {
          "PK" : ["column_1", "column_3"],
          "FK" : {
            "column_3" : "table_2"
            }, 
          },
        "COLUMNS" : {
          "column_1" : "number",
          "column_2" : "text",
          "column_3" : "text"
        }
      },
      "class_labels" : {
        "class_label_1" : {
          "text" : "This is a class description for identify column type",
          "vector" : [Vector],
        },
        "class_label_2" : {
          "text" : "This is a class description for identify column type",
          "vector" : [Vector],
        }
      },
      "columns" : {
        "column_1" : {
          "text" : "This is a description of column_1",
          "vector" : [Vector],
          "column_classes" : ["class_label_1"]
        },
        "column_2" : {
          "text" : "This is a description of column_2",
          "vector" : [Vector],
          "column_classes" : ["class_label_2"]
        },
        "column_3" : {
          "text" : "This is a description of column_3",
          "vector" : [Vector],
          "column_classes" : ["class_label_1", "class_label_2"]
        }
      }
    }
  },
  "table_2" : {...}
}

In [37]:
import re, math, torch, os
from torch import Tensor
import numpy as np
import pandas as pd
from openai import OpenAI           # openai version 1.11.1

def cos_sim(a: Tensor, b: Tensor):
    """
    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))

In [38]:
from sentence_transformers import SentenceTransformer
import json
# from api import encode

from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings

tokenizer = AutoTokenizer.from_pretrained("models/nsql-350M")
model = AutoModelForCausalLM.from_pretrained("models/nsql-350M")
sen_emb = SentenceTransformer("models/all-MiniLM-L6-v2")

def encode(text):
    return sen_emb.encode(text).tolist()


In [39]:
def most_relate_topic(text:str, table_classes:dict,topic_threshold_score:float=0.4) -> list:
    text_vec = encode(text)
    topic_scores = [float(cos_sim(info['vector'], text_vec)) for info in table_classes.values()]
    related_topic_indices = np.where(np.array(topic_scores) >= topic_threshold_score)[0]
    
    # If no topics meet the threshold, return the topic with maximum score
    if not len(related_topic_indices) : 
        related_topic_indices = [np.argmax(topic_scores)]

    related_topics = [list(table_classes.keys())[i] for i in related_topic_indices]
    return related_topics

In [40]:
def new_embed_domain(domain_name):
  schema_descriptions_files_path = f"{domain_name}/descriptions"
  schema_datatypes_files_path = f"{domain_name}/datatypes"
  
  domain = {}
  domain['name'] = domain_name
  domain['tables'] = {}


  #dev
  schema_classes_file_path = [file for file in os.listdir(domain_name) if file.endswith("_classes.json")][0]
  with open(os.path.join(domain_name, schema_classes_file_path), "r") as file:
    schema_classes = json.load(file)

  schema_classes_vector = dict()
  for table_name, table_classes in schema_classes.items():
    schema_classes_vector[table_name] = dict()
    for class_label, class_description in table_classes.items():
       schema_classes_vector[table_name][class_label] = { "text" : class_description,
                                                          "vector" : encode(class_description) }

  description_files = sorted(os.listdir(schema_descriptions_files_path))
  datatype_files = sorted(os.listdir(schema_datatypes_files_path))
  for description, datatype in zip(description_files, 
                                   datatype_files):

    with open(os.path.join(schema_descriptions_files_path, description), 'r') as file:
      description_body = json.load(file)
    
    with open(os.path.join(schema_datatypes_files_path, datatype), 'r') as file:
      datatype_body = json.load(file)

    # dev
      
    table_name = description_body['table']
    column_classes = schema_classes_vector[table_name]
    table_description = {}

    table_description['text'] = description_body['description']
    table_description['vector']= encode(description_body['description'])

    columns = {}
    for col, desc in description_body['columns'].items():
      column = {}
      column['text'] = desc
      column['vector'] = encode(desc)
      column['column_classes'] = most_relate_topic(desc, column_classes)
      columns[col] = column

    table = {}
    table['description'] = table_description
    table['datatypes'] = datatype_body
    table['class_labels'] = column_classes
    table['columns'] = columns

    domain['tables'][table_name] = table

  return domain

In [41]:
domain = 'coffee_shop'

with open(f'src/src_dev/{domain}/embedded_data.json', 'w') as file:
    json.dump(new_embed_domain(f'src/src_dev/{domain}'), file, indent=2)

# Preprocess to dataframe

In [42]:
domain = 'coffee_shop'

with open(f"src/src_dev/{domain}/embedded_data.json") as file:
    embedded_file = json.load(file)

embedded_file.keys()

dict_keys(['name', 'tables'])

In [43]:
def print_json_structure(data, indent=0):
    if isinstance(data, dict):
        for key, value in data.items():
            print('  ' * indent + str(key))
            print_json_structure(value, indent + 2)
    elif isinstance(data, list):
        for item in data:
            print_json_structure(item, indent)


print_json_structure(embedded_file)

name
tables
    happy_hour
        description
            text
            vector
        datatypes
            JOIN_KEY
                PK
                FK
                    Shop_ID
                        shop
            COLUMNS
                HH_ID
                Shop_ID
                Month
                Num_of_staff_in_charge
        class_labels
            EventDetails
                text
                vector
            StaffMetrics
                text
                vector
        columns
            HH_ID
                text
                vector
                column_classes
            Shop_ID
                text
                vector
                column_classes
            Month
                text
                vector
                column_classes
            Num_of_staff_in_charge
                text
                vector
                column_classes
    happy_hour_member
        description
            text
            vector
        data

In [44]:

def selected_columns(question, specific_tables:list, column_info_df, schema_classes, max_n=20) -> dict:

    _df = column_info_df[['Table', 'Column', 'Vector', 'Class_name']]
    _df = _df[_df['Table'].isin(specific_tables)]
    question_vector = sen_emb.encode(question)
    _df['Score'] = _df['Vector'].apply(lambda x: float(cos_sim(x, question_vector)))

    # {'event_detail': 6, 'member_information': 2, 'shop_information': 4}
    
    topic_selected = dict()
    for table in specific_tables:
        table_select = (max_n // len(specific_tables))
        topic_selected.update(most_relate_topic(question, table, schema_classes, top_n=table_select, base_n=1))
        
    used_schema = {table : dict() for table in specific_tables}
    used_cols = []

    for topic, num in topic_selected.items():
        selected_col_index = _df[_df['Class_name'] == topic]['Score'].sort_values(ascending=False).head(num).index
        used_cols.extend(_df.loc[selected_col_index, 'Column'].to_list())

    used_cols = list(set(used_cols))

    for i, row in _df[_df['Column'].isin(used_cols)].iterrows():
        used_schema[row['Table']][row['Column']] = round(row['Score'],3)

    return used_schema



In [45]:
def most_relate_topic(text:str, table:str, schema_classes,
                      top_n:int=10, base_n:int=1):
    text_vec = encode(text)
    topic_scores = [float(cos_sim(info['vector'], text_vec)) for info in schema_classes[table].values()]

    probs = (topic_scores / np.sum(topic_scores)) * top_n
    topic_selected = { key: max(base_n, math.ceil(score)) for key, score in zip(schema_classes[table].keys(), probs)}
    # {'event_detail': 6, 'member_information': 2, 'shop_information': 4}
    return topic_selected

In [46]:
df_data = { 'Table' : [],
            'Column' : [],
            'Description' : [],
            'Vector' : [],
            'Class_name' : []}

schema_classes = dict()

for table, table_info in embedded_file['tables'].items():
    schema_classes[table] = table_info['class_labels']
    for col in table_info['columns']:
        df_data['Table'].append(table)
        df_data['Column'].append(col)
        df_data['Description'].append(table_info['columns'][col]['text'])
        df_data['Vector'].append(table_info['columns'][col]['vector'])
        df_data['Class_name'].append(table_info['columns'][col]['column_classes'])

df = pd.DataFrame(df_data).explode('Class_name')
df.head()

Unnamed: 0,Table,Column,Description,Vector,Class_name
0,happy_hour,HH_ID,Unique identifier for the happy hour event,"[-0.030357593670487404, 0.11128713935613632, 0...",EventDetails
0,happy_hour,HH_ID,Unique identifier for the happy hour event,"[-0.030357593670487404, 0.11128713935613632, 0...",StaffMetrics
1,happy_hour,Shop_ID,Identifier of the shop hosting the happy hour,"[-0.020781254395842552, 0.10354368388652802, -...",EventDetails
1,happy_hour,Shop_ID,Identifier of the shop hosting the happy hour,"[-0.020781254395842552, 0.10354368388652802, -...",StaffMetrics
2,happy_hour,Month,Month in which the happy hour takes place,"[0.053571637719869614, 0.0968250185251236, -0....",EventDetails


In [47]:
print_json_structure(schema_classes)

happy_hour
    EventDetails
        text
        vector
    StaffMetrics
        text
        vector
happy_hour_member
    ParticipationDetails
        text
        vector
    SpendingMetrics
        text
        vector
member
    MemberDetails
        text
        vector
shop
    ShopDetails
        text
        vector


In [48]:
# coffee shop test

question = "How many shop has happy hour more than 2 hours and has maximum customer member"
tables = ['happy_hour', 'happy_hour_member', 'member', 'shop']
selected_columns(question, tables, df, schema_classes)

{'happy_hour': {'HH_ID': 0.494,
  'Shop_ID': 0.734,
  'Month': 0.547,
  'Num_of_staff_in_charge': 0.704},
 'happy_hour_member': {'HH_ID': 0.492,
  'Member_ID': 0.544,
  'Total_amount': 0.61},
 'member': {'Member_ID': 0.137,
  'Membership_card': 0.188,
  'Age': 0.168,
  'Time_of_purchase': 0.331,
  'Level_of_membership': 0.18,
  'Address': 0.093},
 'shop': {'Shop_ID': 0.361,
  'Address': 0.346,
  'Num_of_staff': 0.532,
  'Score': 0.42,
  'Open_Year': 0.35}}

# Test after modify schema link

In [33]:
from sentence_transformers import SentenceTransformer
import json
# from api import encode

from transformers import AutoTokenizer, AutoModelForCausalLM
from openai import OpenAI
import google.generativeai as genai

tokenizer = AutoTokenizer.from_pretrained("models/nsql-350M")
model = AutoModelForCausalLM.from_pretrained("models/nsql-350M")
sen_emb = SentenceTransformer("models/all-MiniLM-L6-v2")

def encode(text):
    return sen_emb.encode(text).tolist()


In [61]:
import re, math, torch
from torch import Tensor
import numpy as np
import pandas as pd

def cos_sim(a: Tensor, b: Tensor):
    """
    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))

class SchemaLinking():

    def __init__(self, domain):
        
        self.domain = domain
        self.split_pattern = r'[\s\n;().]'
        self.verbose = False
        
        df_data = { 'Table' : [],
                    'Column' : [],
                    'Description' : [],
                    'Vector' : [],
                    'Class_name' : []}

        self.schema_classes = dict()
        self.schema_datatypes = {}      # { table1: { column1: datatype, ...}}
        self.table_descriptions = {}    # { table1: description, ...}
        self.sql_condition = {'=', '>', '<', '>=', '<=', '<>', '!='}

        # preparing object variable
        for table, table_info in domain['tables'].items():
            self.schema_classes[table] = table_info['class_labels']
            self.schema_datatypes[table] = table_info['datatypes']
            self.table_descriptions[table] = table_info['description']
            for col in table_info['columns']:
                df_data['Table'].append(table)
                df_data['Column'].append(col)
                df_data['Description'].append(table_info['columns'][col]['text'])
                df_data['Vector'].append(table_info['columns'][col]['vector'])
                df_data['Class_name'].append(table_info['columns'][col]['column_classes'])

        self.column_info_df = pd.DataFrame(df_data).explode('Class_name')
        self.schema_columns_lower = set(self.column_info_df['Column'].str.lower().values)
        self.schema_tables_lower = set(self.column_info_df['Table'].str.lower().values)

        
    
    
    def most_relate_topic(self, text:str, table:str, top_n:int=10, base_n:int=1) -> dict:
        """
        Determine the most related topics to the given text based on schema classes.

        Parameters:
        text (str): The sentence for which related topics need to be determined.
        table (str): The table representing the schema classes against which the sentence will be compared.
        top_n (int): The maximum number of selects for all topics to be selected based on relevance score. Defaults to 10.
        base_n (int): The minimum number of each topic must be selected. Defaults to 1.

        Returns:
        dict: A dictionary where keys are topic names and values are the count of how many times each topic should be selected.

        Example:
        SchemaLinking.most_relate_topic("example text", "example_table")
        {'topic1': 3, 'topic2': 2, 'topic3': 1}
        """

        text_vec = encode(text)
        # apply cosin similarity score for each column followed by question
        topic_scores = [float(cos_sim(info['vector'], text_vec)) for info in self.schema_classes[table].values()]
        # select number of topics based on score probability
        probs = (topic_scores / np.sum(topic_scores)) * top_n
        topic_selected = { key: max(base_n, math.ceil(score)) for key, score in zip(self.schema_classes[table].keys(), probs)}
        
        return topic_selected
    

    def filter_schema(self, question:str, specific_tables:list, max_n:int=10) -> dict:
        """
        Filter the schema to obtain only the columns of each specified table to be used for generating SQL based on a question.

        Parameters:
        question (str): The question for which the SQL schema needs to be filtered.
        specific_tables (list): List of specific tables for which columns should be selected.
        max_n (int): The maximum number of columns to select per table.

        Returns:
        dict: A dictionary containing selected columns for each table along with their relevance scores.

        Example:
        SchemaLinking.filter_schema("example question", ["table1", "table2"])
        {'table1': {'column1': 0.845, 'column2': 0.723}, 'table2': {'column3': 0.912, 'column4': 0.654}}
        """

        # filtered ued column
        _df = self.column_info_df[['Table', 'Column', 'Vector', 'Class_name']]
        _df = _df[_df['Table'].isin(specific_tables)]

        question_vector = encode(question)
        # apply similarity score of each column followed by question
        _df['Score'] = _df['Vector'].apply(lambda x: float(cos_sim(x, question_vector)))

        # check string matching conditions
        columns_match = []
        for word in question.split():
            if word.lower() in self.schema_columns_lower:
                columns_match.append(word)
            if word.lower() in self.schema_tables_lower and word not in specific_tables:
                specific_tables.append(word)

        # if columns match
        if columns_match:
            print("String matching", columns_match)
            _df.loc[_df['Column'].isin(columns_match), 'Score'] = 1.0
        
        # get the number for selecting each topic
        topic_selected = dict()
        for table in specific_tables:
            table_select = (max_n // len(specific_tables))
            topic_selected.update(self.most_relate_topic(question, table, top_n=table_select, base_n=1))
        
        # prepare used schema each table
        used_schema = {table : dict() for table in specific_tables}
        used_cols = []

        # select the top column n number followed by the highest score.
        for topic, num in topic_selected.items():
            selected_col_index = _df[_df['Class_name'] == topic]['Score'].sort_values(ascending=False).head(num).index
            used_cols.extend(_df.loc[selected_col_index, 'Column'].to_list())

        used_cols = list(set(used_cols))

        for i, row in _df[_df['Column'].isin(used_cols)].iterrows():
            used_schema[row['Table']][row['Column']] = round(row['Score'],3)

        # Primary keys and foreign keys are always selected when using more than one table
        if len(specific_tables) > 1:
            for table in specific_tables:
                table_pk = self.schema_datatypes[table]["JOIN_KEY"]["PK"]
                table_fk = self.schema_datatypes[table]["JOIN_KEY"]["FK"]
                for fk, ref_table_column in table_fk.items():
                    if list(ref_table_column.keys())[0] not in specific_tables: del table_fk[fk]
                column_keys = table_pk + list(table_fk.keys())
                for col in column_keys:
                    if col not in used_schema[table].keys():
                        used_schema[table][col] = 0.5

        return used_schema

    def table_col_of_sql(self, sql_query:str) -> dict:
        """
        Extract tables and their corresponding columns from the given SQL query.

        Parameters:
        sql_query (str): The SQL query from which tables and columns need to be extracted.

        Returns:
        dict: A dictionary containing tables as keys and lists of columns as values.

        Example:
        SchemaLinking.table_col_of_sql("SELECT column1, column2 FROM table1 WHERE column3 = 'value'")
        {'table1': ['column1', 'column2', 'column3']}
        """
        
        selected_schema = {}
        query_split = re.split(self.split_pattern, sql_query)
        for table in self.schema_datatypes.keys():
            if table in query_split:
                selected_col = []
                for col in self.schema_datatypes[table]['COLUMNS'].keys():
                    if col in query_split: selected_col.append(col)
                selected_schema[table] = selected_col

        return selected_schema

    def masking_query(self, sql_query:str, condition_value_mask:bool=True) -> str:
        """
        Mask specified columns and optionally condition values in the given SQL query.

        Parameters:
        sql_query (str): The SQL query to be masked.
        condition_value_mask (bool): Whether to mask condition values. Defaults to True.

        Returns:
        str: The masked SQL query.

        Example:
        SchemaLinking.masking_query("SELECT column1, column2 FROM table1 WHERE column3 = 'value'")
        SELECT [MASK], [MASK] FROM [MASK] WHERE [MASK] = [MASK]
        """

        if '*' in sql_query: sql_query = sql_query.replace('*', "[MASK]")
        query_split = re.split(r'(?<=[() .,;])|(?=[() .,;])', sql_query)
        mask_next = False

        for i in range(len(query_split)):
            token = query_split[i].lower()
            # prepare mask condition value
            if token.lower() == 'where': mask_next = True
            if condition_value_mask and mask_next and (token in self.sql_condition and i + 1 < len(query_split)):
                step_mask_next = 1
                # find the condition value
                while query_split[i + step_mask_next] == ' ': step_mask_next += 1
                query_split[i + step_mask_next] = "[MASK]"
            
            if token in self.schema_columns_lower or token in self.schema_tables_lower:
                query_split[i] = "[MASK]"

        return "".join(query_split)
    

In [10]:
import warnings

def generate_sql(prompt):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        generated_ids = model.generate(input_ids, max_length=1000)
        sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True).split('\n')[-1]
    return sql

In [34]:
OPENAI_API_KEY = "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"
GOOGLE_API_KEY = "AIzaSyCL1lMVUqwf0nQKtLPk30tv7VUXTiKE-fE"

llm_model_name = "gpt-3.5-turbo"

llm_stop = ['\n\n']
temperature = 0
llm_prompt = """You are a SQL query assistant.
I have some SQL where the [MASK] columns, condition values and tables is syntaxed and I want you to respond to output that populates the [MASK] column of the SQL input followed by the question and schema description (name - description - data type).
If you don't know which column to fill in Do not include columns that you have created yourself. And only columns defined from the schema must be used. 
Do not use columns from other tables or schema. must also be used from the same table defined in the input.
If you must enter conditional values Please decide the format or value based on the sample values of that column.
If that column has many too long category value please decide base on column description.

For example:
table :     cat - this table contain cat information 
columns :    id - number for identify cat | number
            name - name of cat | text
            age - age of cat | number
            birth_date - pet birthday in format 'YYYY-MM-DD' | datetime
            gender - gender of cat (male, female) | text

question : Show me number of cat for each gender which born before March 23, 2011.
input : SELECT [MASK], COUNT([MASK]) FROM [MASK] WHERE [MASK] < [MASK] GROUP BY [MASK] ;
output : SELECT gender, COUNT(*) FROM cat WHERE birth_date < '2011-03-23' GROUP BY gender;

"""

In [22]:
def create_prompt(schema_link:object, question:str, used_schema:dict) -> str:
    """
    Generate a prompt for applying into SQL generation model based on the question and schema.

    Parameters:
    schema_link (object): The instance of the class containing schema information.
    question (str): The question for which the prompt is generated.
    used_schema (dict): A dictionary containing tables as keys and lists of columns as values after filtering the schema.

    Returns:
    str: A prompt for applying into SQL generation model.

    Example:
    prompt = create_prompt(schema_instance, "What are the total sales?", 
                          { 'sales': {'date' : 0.3, 'amount' : 0.61}, 
                            'products': {'name' : 0.23, 'price' : 0.57}})
    print(prompt)

    CREATE TABLE sales ( date DATE, amount INT,PRIMARY KEY ("date") )
    -- Using valid SQLite, answer the following questions for the tables provided above.
    -- What are the total sales?
    SELECT
    """
    full_sql = ""
    for table, columns in used_schema.items():
        if not len(columns): continue       # pass this table when no column
        primary_keys = schema_link.schema_datatypes[table]["JOIN_KEY"]["PK"]
        foreign_keys = list(schema_link.schema_datatypes[table]["JOIN_KEY"]["FK"].keys())
        join_table_key = primary_keys + foreign_keys
        
        sql = f"CREATE TABLE {table} ("
        for column in columns:
            if column in join_table_key and len(join_table_key): join_table_key.remove(column)
            try:
                sql += f' {column} {schema_link.schema_datatypes[table]["COLUMNS"][column]},'
            except KeyError: 
                print(f"KeyError :{column}")
                
        if len(join_table_key): # key for join of table are remaining
            for column in join_table_key:
                sql += f' {column} {schema_link.schema_datatypes[table]["COLUMNS"][column]},'

        # All table contain PK (maybe)
        if len(primary_keys):
            sql += 'PRIMARY KEY ('
            for pk_type in primary_keys: sql += f'"{pk_type}" ,'
            sql = sql[:-1] + "),"

        if len(foreign_keys):
            for fk, ref_table_column in schema_link.schema_datatypes[table]["JOIN_KEY"]["FK"].items():
                sql += f' FOREIGN KEY ("{fk}") REFERENCES "{list(ref_table_column.keys())[0]}" ("{list(ref_table_column.values())[0]}"),'

        sql = sql[:-1] + " )\n\n"
        full_sql += sql
    prompt = full_sql + "-- Using valid SQLite, answer the following questions for the tables provided above."
    prompt = prompt + '\n' + '-- ' + question
    prompt = prompt + '\n' + "SELECT"

    return prompt

In [35]:
def LLM_fill_column(schema_link, question:str , used_schema:dict, masked_query:str, llm_model:str) -> str:
        """
        Fill the [MASK] query to complete the SQL query using the Language Model.

        Parameters:
        question (str): The question related to the SQL query.
        used_schema (dict): Dictionary containing the schema information.
        masked_query (str): The SQL query with masked columns.

        Returns:
        str: The complete SQL query.

        Example:
        SchemaLinking.LLM_fill_column(schema_instance, "Show all employees", {"employees": {"employee_id" : 0.45, "employee_name" : 0.37}}, "SELECT [MASK] FROM employees;")
        SELECT employee_id, employee_name FROM employees;
        """
        
        full_prompt = ""
        for table_name, column_score in used_schema.items():
            _df = schema_link.column_info_df[schema_link.column_info_df['Table'] == table_name][['Column', 'Description']].drop_duplicates()
            full_prompt += f"\ntable : {table_name} - {schema_link.table_descriptions[table_name]['text']}\ncolumns:"
            for column_name in column_score:
                full_prompt += f"\t{column_name} - {_df[_df['Column'] == column_name]['Description'].values[0]}"
                full_prompt += f" | {schema_link.schema_datatypes[table_name]['COLUMNS'][column_name]}\n"

        full_prompt += f"question : {question}\n"
        full_prompt += f"input : {masked_query}\noutput :"
        full_prompt = llm_prompt + full_prompt

        if llm_model in ['openai', 'gpt-3.5-turbo', 'gpt-4-turbo']: 
            try:
                client = OpenAI(api_key=OPENAI_API_KEY)
                response = client.chat.completions.create(
                    model=llm_model,
                    messages=[
                            {"role": "system",
                                "content": "I will give you some x-y examples followed by a x, you need to give me the y, and no other content."},
                            {"role": "user", 
                                "content": full_prompt},
                            ],
                    stop=llm_stop,
                    temperature=temperature
                )
                return response.choices[0].message.content
            
            except Exception as e:
                return f"OpenAI Error :{e}"
        elif llm_model in ['gemini-pro']:
            try:
                genai.configure(api_key=GOOGLE_API_KEY)
                gemini_model = genai.GenerativeModel(llm_model)
                response = gemini_model.generate_content(full_prompt)
                return response.text
            
            except Exception as e:
                 return f"Google AI Error : {e}"

In [36]:
def get_reason(schema_link:object, sql_result:str) -> str:
    """
    Get the reason message related to the selected columns and tables from the schema based on the SQL query.

    Parameters:
    schema_link (object): The instance of the class containing schema information.
    sql_result (str): The SQL query result for which the reason message is generated.

    Returns:
    str: The reason message explaining the selection of columns and tables from the schema.

    Example:
    get_reason(schema_instance, "SELECT column1, column2 FROM table1 WHERE column3 = 'value'")

    Table - table1 : Description of table1
        Column - column1 : Description of column1
        Column - column2 : Description of column2
        Column - column3 : Description of column3
    """

    table_col_sql = schema_link.table_col_of_sql(sql_result)
    reason = ""

    for table, cols in table_col_sql.items():
        _df = schema_link.column_info_df[schema_link.column_info_df['Table'] == table][['Column', 'Description']].drop_duplicates()
        table_reason = f"Table - {table}\t: {schema_link.table_descriptions[table]['text']}\n"
        if len(cols):       # have columns of table
            col_reason = "\n".join([f"\tColumn - {c}\t: {_df.loc[_df['Column'] == c, 'Description'].values[0]}" for c in cols])
        else: col_reason = ""
        reason += str(table_reason + col_reason + "\n\n")

    return reason

In [57]:
with open("src/src_dev/coffee_shop/embedded_data.json", "r") as f:
    domain = json.load(f)
    domain_tables = list(domain['tables'].keys())

schema_link = SchemaLinking(domain)
question = "Show me the shop ID and address of shops with sum of member's total amout more than 300 "
used_schema = schema_link.filter_schema(question, domain_tables, max_n=10)
used_schema

String matching ['address']


{'happy_hour': {'Shop_ID': 0.368,
  'Num_of_staff_in_charge': 0.291,
  'HH_ID': 0.5,
  'Month': 0.5},
 'happy_hour_member': {'Member_ID': 0.239,
  'Total_amount': 0.334,
  'HH_ID': 0.5},
 'member': {'Member_ID': 0.269,
  'Membership_card': 0.322,
  'Time_of_purchase': 0.328,
  'Address': 0.316},
 'shop': {'Shop_ID': 0.514, 'Address': 0.562, 'Num_of_staff': 0.592}}

In [68]:
with open("src/src_dev/pointx/embedded_data.json", "r") as f:
    domain = json.load(f)
    domain_tables = list(domain['tables'].keys())
    
schema_link = SchemaLinking(domain)
question = "Display the user id and revenue of user who has the highest total transactions id"
used_schema = schema_link.filter_schema(question, ['pointx_fbs_rpt_dly',], max_n=10)
used_schema

String matching ['id', 'id']


{'pointx_fbs_rpt_dly': {'user_id': 0.301,
  'user_ltv_revenue': 0.392,
  'user_ltv_currency': 0.323,
  'traffic_source_medium': 0.303,
  'ecommerce': 0.307,
  'ecoupon_rank': 0.293,
  'id': 1.0,
  'merchant_id': 0.346,
  'order_id': 0.309,
  'place_id': 0.147,
  'product_id': 0.249,
  'total_amount': 0.337,
  'transaction_id': 0.441,
  'transaction_status': 0.331}}

In [69]:
prompt = create_prompt(schema_link, question, used_schema)
print("========= PROMPT =========")
print(prompt)
print()
sql_result = generate_sql(prompt)
print("========= SQL =========")
print(sql_result)
print()
reason = get_reason(schema_link, sql_result)
print("========= REASON =========")
print(reason)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


CREATE TABLE pointx_fbs_rpt_dly ( user_id number, user_ltv_revenue number, user_ltv_currency number, traffic_source_medium text, ecommerce number, ecoupon_rank number, id number, merchant_id number, order_id number, place_id number, product_id number, total_amount number, transaction_id number, transaction_status number, event_id number, _date text,PRIMARY KEY ("event_id" ), FOREIGN KEY ("_date") REFERENCES "pointx_keymatrix_dly" ("_date") )

-- Using valid SQLite, answer the following questions for the tables provided above.
-- Display the user id and revenue of user who has the highest total transactions id
SELECT

SELECT user_id, MAX(total_amount) FROM pointx_fbs_rpt_dly GROUP BY user_id ORDER BY SUM(total_amount) DESC LIMIT 1;

Table - pointx_fbs_rpt_dly	: Table records user interactions with the PointX app daily, capturing events such as app opens and deletions, 
providing key insights into user behavior, app version usage, and device characteristics 
	Column - user_id	: The user 

In [70]:
masked_query = schema_link.masking_query(sql_result)
masked_query

'SELECT [MASK], MAX([MASK]) FROM [MASK] GROUP BY [MASK] ORDER BY SUM([MASK]) DESC LIMIT 1;'

In [73]:
print("========= QUESTION =========")
print(question)
print()
final_result = LLM_fill_column(schema_link, question, used_schema, masked_query, 'gpt-3.5-turbo')
print("========= SQL =========")
print(final_result)
print()
reason = get_reason(schema_link, final_result)
print("========= REASON =========")
print(reason)

Display the user id and revenue of user who has the highest total transactions id

SELECT user_id, user_ltv_revenue FROM pointx_fbs_rpt_dly GROUP BY user_id ORDER BY COUNT(transaction_id) DESC LIMIT 1;

Table - pointx_fbs_rpt_dly	: Table records user interactions with the PointX app daily, capturing events such as app opens and deletions, 
providing key insights into user behavior, app version usage, and device characteristics 
	Column - user_id	: The user ID set via the setUserId API.
	Column - user_ltv_revenue	: The Lifetime Value (revenue) of the user. This field is not populated in intraday tables.
	Column - transaction_id	: Transaction id




In [94]:
def table_selected(question:str,n_select:int , table_descriptions_vector:dict):
    question_vector = encode(question)
    table_scores = {table : round(float(cos_sim(table_vector, question_vector)),3) for table, table_vector in table_descriptions_vector.items()}
    # sum_score = sum(table_scores.values())
    # tab_select = {table: math.floor(n_select * score / sum_score) for table, score in table_scores.items()}
    return table_scores

In [95]:
table_description_vec = dict()

for table_name, table_info in domain['tables'].items():
    table_description_vec[table_name] = table_info['description']['vector']

print(question)
table_selected(question, 10, table_description_vec)

Show me the user id and total revenue of that id


{'pointx_cust_mly': 0.244,
 'pointx_fbs_rpt_dly': 0.138,
 'pointx_keymatrix_dly': 0.114}

# Convert SQL schema to Data source schema format

In [66]:
import sqlite3, os, json

In [5]:
folder_path = "src/spider/database"
select_db = ['musical',
             'farm', 
             'hospital_1', 
             'tvshow', 
             'cinema', 
             'restaurants', 
             'company_employee', 
             'company_offic', 
             'singer', 
             'coffee_shop']

db = []

if os.path.exists(folder_path) and os.path.isdir(folder_path):
    files = os.listdir(folder_path)
    for file in files:
        # if file in select_db:
        db_path = os.path.join(folder_path, file)
        sqlite_db = [os.path.join(db_path, sql) for sql in os.listdir(db_path) if ".sqlite" in sql]
        db.append(*sqlite_db)

db[:5]

['src/spider/database/browser_web/browser_web.sqlite',
 'src/spider/database/musical/musical.sqlite',
 'src/spider/database/farm/farm.sqlite',
 'src/spider/database/voter_1/voter_1.sqlite',
 'src/spider/database/game_injury/game_injury.sqlite']

In [49]:
def get_schema(sqlite_db):
    connection = sqlite3.connect(sqlite_db)
    cursor = connection.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    table_names = []
    for table in tables:
        table_name = table[0]
        table_names.append(table_name)
        print(f"Table: {table_name}")
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()

        for column in columns:
            _, cname, ctype, _, _, pk_sq = column
            print(f"\tColumn: {cname} {ctype} {pk_sq}")
            # Get foreign keys for the table
        cursor.execute(f"PRAGMA foreign_key_list({table_name});")
        foreign_keys = cursor.fetchall()
        print()

        if foreign_keys:
            print("Foreign Keys:")
            for fk in foreign_keys:
                _, _, to_table, fk_column, to_column, _, _, _ = fk
                print(f"\t{fk_column} REFERENCES {to_table}({to_column})")
            print()

    
    cursor.close()
    connection.close()
    return table_names

In [50]:
# db_map_tables = dict({})
for database_path in db[:5]:
    for db_name in select_db:
        if database_path.split('/')[-2] == db_name:
            # if table in exists_table : continue
            print(database_path)
            get_schema(database_path)
            # exists_table.append(table)
            print('---------------------------------')

src/spider/database/musical/musical.sqlite
Table: musical
	Column: Musical_ID INT 1
	Column: Name TEXT 0
	Column: Year INT 0
	Column: Award TEXT 0
	Column: Category TEXT 0
	Column: Nominee TEXT 0
	Column: Result TEXT 0

Table: actor
	Column: Actor_ID INT 1
	Column: Name TEXT 0
	Column: Musical_ID INT 0
	Column: Character TEXT 0
	Column: Duration TEXT 0
	Column: age INT 0

Foreign Keys:
	Musical_ID REFERENCES actor(Actor_ID)

---------------------------------
src/spider/database/farm/farm.sqlite
Table: city
	Column: City_ID INT 1
	Column: Official_Name TEXT 0
	Column: Status TEXT 0
	Column: Area_km_2 REAL 0
	Column: Population REAL 0
	Column: Census_Ranking TEXT 0

Table: farm
	Column: Farm_ID INT 1
	Column: Year INT 0
	Column: Total_Horses REAL 0
	Column: Working_Horses REAL 0
	Column: Total_Cattle REAL 0
	Column: Oxen REAL 0
	Column: Bulls REAL 0
	Column: Cows REAL 0
	Column: Pigs REAL 0
	Column: Sheep_and_Goats REAL 0

Table: farm_competition
	Column: Competition_ID INT 1
	Column: Ye

In [67]:
def transform_schema(schema_db_path, output_path):

    connection = sqlite3.connect(schema_db_path)
    cursor = connection.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    for table in tables:
        
        schema = {
                    "JOIN_KEY" : {
                        "PK" : list(),
                        "FK" : dict()
                    },
                    "COLUMNS" : dict()
        }
        table_name = table[0]
        print(f"Table: {table_name}")
        schema_file_name = f"{table_name}_datatype.json"
        output_file_path = os.path.join(output_path, schema_file_name)
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()

        for column in columns:
            _, cname, ctype, _, _, pk_sq = column
            schema['COLUMNS'][cname] = ctype

            # if this column is primary key
            if pk_sq: schema['JOIN_KEY']['PK'].append(cname)

            print(f"\tColumn: {cname} {ctype} {pk_sq}")

            # Get foreign keys for the table
        cursor.execute(f"PRAGMA foreign_key_list({table_name});")
        foreign_keys = cursor.fetchall()
        print()

        if foreign_keys:
            print("Foreign Keys:")
            for fk in foreign_keys:
                _, _, to_table, fk_column, to_column, _, _, _ = fk
                schema['JOIN_KEY']['FK'][fk_column] = {to_table : to_column}
                print(f"\t{fk_column} REFERENCES {to_table}({to_column})")
            print()

        with open(output_file_path, "w") as file:
            json.dump(schema, file, indent=2)
            print(f"Dump {output_file_path} sucess")
    cursor.close()
    connection.close()

In [68]:
transform_schema("src/src_dev/coffee_shop/coffee_shop.db", 'test')

Table: shop
	Column: Shop_ID INT 1
	Column: Address TEXT 0
	Column: Num_of_staff TEXT 0
	Column: Score REAL 0
	Column: Open_Year TEXT 0

Dump test/shop_datatype.json sucess
Table: member
	Column: Member_ID INT 1
	Column: Name TEXT 0
	Column: Membership_card TEXT 0
	Column: Age INT 0
	Column: Time_of_purchase INT 0
	Column: Level_of_membership INT 0
	Column: Address TEXT 0

Dump test/member_datatype.json sucess
Table: happy_hour
	Column: HH_ID INT 1
	Column: Shop_ID INT 2
	Column: Month TEXT 3
	Column: Num_of_shaff_in_charge INT 0

Foreign Keys:
	Shop_ID REFERENCES shop(Shop_ID)

Dump test/happy_hour_datatype.json sucess
Table: happy_hour_member
	Column: HH_ID INT 1
	Column: Member_ID INT 2
	Column: Total_amount REAL 0

Foreign Keys:
	Member_ID REFERENCES member(Member_ID)

Dump test/happy_hour_member_datatype.json sucess
