In [10]:
import json.decoder
from utils.enums import LLM
from langchain_google_vertexai import VertexAI

from google.oauth2 import service_account
from google.cloud import aiplatform
import vertexai
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory
import os

from dotenv import load_dotenv

load_dotenv(override=True)

safety_settings = {
    HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}


GCP_PROJECT = os.getenv("GCP_PROJECT")
GCP_REGION = os.getenv("GCP_REGION")
GCP_CREDENTIALS = os.getenv("GCP_CREDENTIALS")

def init_gemini():
    aiplatform.init(
        project=GCP_PROJECT,
        location=GCP_REGION,
        credentials=service_account.Credentials.from_service_account_file(GCP_CREDENTIALS)
    )
    vertexai.init(project=GCP_PROJECT, location=GCP_REGION, credentials=service_account.Credentials.from_service_account_file(GCP_CREDENTIALS))

init_gemini()

llm = VertexAI(
    model="gemini-1.5-pro",
    temperature=0,
    safety_settings=safety_settings
)

In [11]:
prompt_path = "schema_extraction.txt"

with open(prompt_path, "r") as f:
    prompt_template = f.read()
prompt_template

'You are given an SQL query. Your task is to extract the db_ids (the database identifiers), table_names (the names of the tables being queried), and column_names (the names of the columns referenced in the query). Organize this information in the form of a dictionary where:\n\nThe keys of the dictionary represent the db_ids (database identifiers).\nEach db_id maps to another dictionary where:\nThe keys are the table_names.\nThe values are a list of the column_names used from that table.\nThe final output should be a JSON object that follows this structure.\n\nTask Details:\ndb_ids: Extract the database identifiers from the query (these are usually in the form of project.dataset in systems like BigQuery).\ntable_names: Identify the table names used in the query, including any table wildcards if applicable (e.g., events_*).\ncolumn_names: Extract the column names from the query, including fields nested inside structs (e.g., event_params.key, event_params.value.int_value).\n\nFew-Shot Exa

In [12]:
data_path = "preprocessed_data/spider2-lite/spider2-lite_preprocessed.json"
import json

with open(data_path, "r") as f:
    data = json.load(f)

In [15]:
schemas = {}

for i, entry in enumerate(data):
    query = entry["query"].replace("\n", " ")
    
    message = prompt_template.format(SQL_QUERY=query)
    response = llm.invoke(message)
    try:
        if "```json" in response:
            response = response.split("```json")[1]
            response = response.split("```")[0]
            response = json.loads(response)
            schemas[i] = {
                "query": query,
                "schema": response["schema"]
            }
    except Exception as e:
        schemas[i] = {
            "query": query,
            "schema": None,
            "error": str(e)
        }

SELECT   COUNT(DISTINCT MDaysUsers.user_pseudo_id) AS n_day_inactive_users_count FROM   (     SELECT       user_pseudo_id     FROM       `bigquery-public-data.ga4_obfuscated_sample_ecommerce.events_*` AS T     CROSS JOIN       UNNEST(T.event_params) AS event_params     WHERE       event_params.key = 'engagement_time_msec' AND event_params.value.int_value > 0       /* Has engaged in last M = 7 days */       AND event_timestamp > UNIX_MICROS(TIMESTAMP_SUB(TIMESTAMP('2021-01-07 23:59:59'), INTERVAL 7 DAY))       /* Include only relevant tables based on the fixed timestamp */       AND _TABLE_SUFFIX BETWEEN '20210101' AND '20210107'   ) AS MDaysUsers LEFT JOIN   (     SELECT       user_pseudo_id     FROM       `bigquery-public-data.ga4_obfuscated_sample_ecommerce.events_*` AS T     CROSS JOIN       UNNEST(T.event_params) AS event_params     WHERE       event_params.key = 'engagement_time_msec' AND event_params.value.int_value > 0       /* Has engaged in last N = 2 days */       AND event_t

KeyboardInterrupt: 