In [1]:
import json.decoder
from utils.enums import LLM
from langchain_google_vertexai import VertexAI

from google.oauth2 import service_account
from google.cloud import aiplatform
import vertexai
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory
import os

from dotenv import load_dotenv

load_dotenv(override=True)

safety_settings = {
    HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}


GCP_PROJECT = os.getenv("GCP_PROJECT")
GCP_REGION = os.getenv("GCP_REGION")
GCP_CREDENTIALS = os.getenv("GCP_CREDENTIALS")

def init_gemini():
    aiplatform.init(
        project=GCP_PROJECT,
        location=GCP_REGION,
        credentials=service_account.Credentials.from_service_account_file(GCP_CREDENTIALS)
    )
    vertexai.init(project=GCP_PROJECT, location=GCP_REGION, credentials=service_account.Credentials.from_service_account_file(GCP_CREDENTIALS))

init_gemini()

llm = VertexAI(
    model="gemini-1.5-pro",
    temperature=0.5,
    safety_settings=safety_settings
)

In [2]:
prompt_path = "schema_extraction.txt"

with open(prompt_path, "r") as f:
    prompt_template = f.read()
prompt_template

'You are given an SQL query. Your task is to extract the db_ids (the database identifiers), table_names (the names of the tables being queried), and column_names (the names of the columns referenced in the query). Organize this information in the form of a dictionary where:\n\nThe keys of the dictionary represent the db_ids (database identifiers).\nEach db_id maps to another dictionary where:\nThe keys are the table_names.\nThe values are a list of the column_names used from that table.\nThe final output should be a JSON object that follows this structure.\n\nTask Details:\ndb_ids: Extract the database identifiers from the query (these are usually in the form of project.dataset in systems like BigQuery). If the query does not specify a db_id (for example, in SQLite queries), set the db_id to "public".\ntable_names: Identify the table names used in the query, including any table wildcards if applicable (e.g., events_*).\ncolumn_names: Extract the column names from the query, including f

In [3]:
data_path = "preprocessed_data/spider2-lite/spider2-lite_preprocessed.json"
import json

with open(data_path, "r") as f:
    data = json.load(f)

In [4]:
def is_it_a_three_layer_dict(temp):
    # print(temp)
    try:
        for db, tables_dict in temp.items():
            if isinstance(tables_dict, dict):
                for table, columns_list in tables_dict.items():
                    if isinstance(columns_list, list):
                        continue
                    else:
                        print(f"{db} {table} {columns_list}")
                        return False
            else:
                print(f"{db} {tables_dict}")
                return False
        return True
    
    except Exception as e:
        print(e)
        return False 

In [5]:
from tqdm import tqdm
schemas = {}
# load the schema
if os.path.exists("schemas.json"):
    with open("schemas.json", "r") as f:
        schemas = json.load(f)
    print(f"Loaded {len(schemas)} schemas")

error_ids = {}

for i, entry in enumerate(tqdm(data)):
    instance_id = entry["instance_id"]
    query = entry["query"].replace("\n", " ")
    if instance_id in schemas:
        continue
    
    try:
        for k in range(5):
            message = prompt_template.format(SQL_QUERY=query)
            response = llm.invoke(message)
            if "```json" in response:
                response = response.split("```json")[1]
                response = response.split("```")[0]
            temp_dict = json.loads(response)
            if is_it_a_three_layer_dict(temp_dict):
                schemas[instance_id] = {
                    "query": query,
                    "schema": temp_dict
                }
                break
            else:
                print(f"Retrying {k} at {instance_id}")
                
    except Exception as e:
        # schemas[instance_id] = {
        #     "query": query,
        #     "schema": None,
        #     "error": str(e)
        # }
        error_ids[instance_id] = {
            "query": query,
            "error": str(e)
        }
        print(f"Error at {i}: {str(e)}")
        
    if i % 10 == 0:
        #save the schemas
        with open("schemas.json", "w") as f:
            json.dump(schemas, f, indent=4)

Loaded 271 schemas


 98%|█████████▊| 325/330 [04:12<00:18,  3.75s/it]

In [6]:
# Save the schemas
with open("schemas.json", "w") as f:
    json.dump(schemas, f, indent=4)

In [4]:
import json
# load the schemas
with open("schemas.json", "r") as f:
    schemas = json.load(f)

In [8]:
# good_schemas = {}
for instance_id, entry in schemas.items():
    schema = entry["schema"]
    bad_generation = False
    if len(schema.keys()) == 0:
        print(f"{instance_id} {entry['query']}")
        bad_generation = True
        fixed_schema = {}
    else:
        for db, tables_dict in schema.items():
            if len(db) < 6:
                print(f"{instance_id} {db}")
                bad_generation = True
                fixed_schema = {"public": tables_dict}
                
    if not bad_generation:
        good_schemas[instance_id] = entry
    else:
        good_schemas[instance_id] = {
            "query": entry["query"],
            "schema": fixed_schema
        }
        
# # Save the good schemas
# with open("good_schemas.json", "w") as f:
#     json.dump(good_schemas, f, indent=4)