In [10]:
import json.decoder
from utils.enums import LLM
from langchain_google_vertexai import VertexAI

from google.oauth2 import service_account
from google.cloud import aiplatform
import vertexai
from langchain_google_vertexai import HarmBlockThreshold, HarmCategory
import os

from dotenv import load_dotenv

load_dotenv(override=True)

safety_settings = {
    HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
}


GCP_PROJECT = os.getenv("GCP_PROJECT")
GCP_REGION = os.getenv("GCP_REGION")
GCP_CREDENTIALS = os.getenv("GCP_CREDENTIALS")

def init_gemini():
    aiplatform.init(
        project=GCP_PROJECT,
        location=GCP_REGION,
        credentials=service_account.Credentials.from_service_account_file(GCP_CREDENTIALS)
    )
    vertexai.init(project=GCP_PROJECT, location=GCP_REGION, credentials=service_account.Credentials.from_service_account_file(GCP_CREDENTIALS))

init_gemini()

llm = VertexAI(
    model="gemini-1.5-pro",
    temperature=0,
    safety_settings=safety_settings
)

In [None]:
prompt_path = "schema_extraction.txt"

with open(prompt_path, "r") as f:
    prompt_template = f.read()
prompt_template

In [12]:
data_path = "preprocessed_data/spider2-lite/spider2-lite_preprocessed.json"
import json

with open(data_path, "r") as f:
    data = json.load(f)

In [None]:
schemas = {}

for i, entry in enumerate(data):
    query = entry["query"].replace("\n", " ")
    
    message = prompt_template.format(SQL_QUERY=query)
    response = llm.invoke(message)
    try:
        if "```json" in response:
            response = response.split("```json")[1]
            response = response.split("```")[0]
            response = json.loads(response)
            schemas[i] = {
                "query": query,
                "schema": response["schema"]
            }
    except Exception as e:
        schemas[i] = {
            "query": query,
            "schema": None,
            "error": str(e)
        }