In [5]:
from nl2sql import *

In [34]:
class NL2SQL_resp:
    def __init__(self, nl_output, generated_sql, sql_output) -> None:
        self.nl_output = nl_output
        self.generated_sql = generated_sql
        self.sql_output = sql_output
    
    def __str__(self) -> str:
        return f"{self.nl_output}\n\n{self.generated_sql}\n\n{self.sql_output}"
        
class DBAI_nl2sql(DBAI):
    def __init__(
            self,
            proj_id="proj-kous",
            dataset_id="Albertsons",
            tables_list=['camain_oracle_hcm', 'camain_ps']
            ):
        super().__init__(proj_id, dataset_id, tables_list)

        self.nl2sql_tool = Tool(
            function_declarations=[
                list_tables_func,
                get_table_metadata_func,
                sql_query_func,
            ],
        )

        self.agent = GenerativeModel("gemini-1.5-pro-001",
                            generation_config={"temperature": 0.05},
                            safety_settings=safety_settings,
                            tools=[self.nl2sql_tool],
                            )
        

    
    def get_sql(self, question):
        chat = self.agent.start_chat()
        prompt = question + f"\n The dataset_id is {self.dataset_id}" + self.SYSTEM_PROMPT

        response = chat.send_message(prompt)
        response = response.candidates[0].content.parts[0]
        intermediate_steps = []

        function_calling_in_process = True
        while function_calling_in_process:
            try:
                function_name, params = response.function_call.name, {}
                for key, value in response.function_call.args.items():
                    params[key] = value

                if function_name == "list_tables":
                    api_response = self.api_list_tables()
                    
                if function_name == "get_table_metadata":
                    api_response = self.api_get_table_metadata(params["table_id"])

                if function_name == "sql_query":
                    api_response = self.execute_sql_query(params["query"])

                response = chat.send_message(
                    Part.from_function_response(
                        name=function_name,
                        response={
                            "content": api_response,
                        },
                    ),
                )
                response = response.candidates[0].content.parts[0]
                intermediate_steps.append({
                    'function_name': function_name,
                    'function_params': params,
                    'API_response': api_response,
                    'response': response
                })

            except AttributeError:
                function_calling_in_process = False

        for i in intermediate_steps[::-1]:
            if i['function_name'] == 'sql_query':
                generated_sql = i['function_params']['query']
                sql_output = i['API_response']

        return NL2SQL_resp(response.text, generated_sql, sql_output)

In [35]:
ai = DBAI_nl2sql(dataset_id="Albertsons", tables_list=[])

In [36]:
resp = ai.get_sql('how many contracts expired last year')