In [12]:
import langchain
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain import PromptTemplate
from langchain_community.llms import Ollama

import os
import sqlite3

from langchain_google_genai import GoogleGenerativeAI

import sqlalchemy

import os
from dotenv import load_dotenv
load_dotenv()

GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')

## Connect Database

In [4]:
pokemon_db = SQLDatabase.from_uri('sqlite:///pokedex.sqlite')

In [5]:
print(pokemon_db.dialect)

sqlite


## Connect to LLM

In [14]:
llm = GoogleGenerativeAI(model="gemini-1.5-pro-latest", google_api_key=GEMINI_API_KEY)

## Use SQL Chain

In [15]:
chain = create_sql_query_chain(llm, pokemon_db)

In [19]:
chain = create_sql_query_chain(llm, pokemon_db)

In [17]:
response = chain.invoke({"question": "Who is the most powerful pokemon? Generate the SQL query. Only give me the SQL query nothing else"}) #, "context": context})
response

'SELECT \n    "p"."identifier"\nFROM \n    "pokemon" AS "p"\nORDER BY \n    "p"."base_experience" DESC\nLIMIT 1'

In [18]:
pokemon_db.run(response)

OperationalError: (sqlite3.OperationalError) no such column: p.identifier
[SQL: SELECT 
    "p"."identifier"
FROM 
    "pokemon" AS "p"
ORDER BY 
    "p"."base_experience" DESC
LIMIT 1]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

## Use SQL Chain with DB Context

In [28]:
context = pokemon_db.get_context()

response = chain.invoke({"question": "How many pokemons are there?", "context": context})
response

'SELECT\n  COUNT(*)\nFROM "pokemon";'

In [30]:
pokemon_db.run(response)

'[(673,)]'

In [31]:
response = chain.invoke({"question": "How many pokemons have blue color?", "context": context})
response

'SELECT\n  COUNT(*)\nFROM "pokemon_species" AS t1\nINNER JOIN "pokemon_colors" AS t2\n  ON t1.color_id = t2.id\nWHERE\n  t2.identifier = \'blue\';'

In [32]:
pokemon_db.run(response)

'[(122,)]'

In [29]:
print(pokemon_db.get_usable_table_names())

['abilities', 'ability_changelog', 'ability_changelog_prose', 'ability_flavor_text', 'ability_names', 'ability_prose', 'berries', 'berry_firmness', 'berry_firmness_names', 'berry_flavors', 'conquest_episode_names', 'conquest_episode_warriors', 'conquest_episodes', 'conquest_kingdom_names', 'conquest_kingdoms', 'conquest_max_links', 'conquest_move_data', 'conquest_move_displacement_prose', 'conquest_move_displacements', 'conquest_move_effect_prose', 'conquest_move_effects', 'conquest_move_range_prose', 'conquest_move_ranges', 'conquest_pokemon_abilities', 'conquest_pokemon_evolution', 'conquest_pokemon_moves', 'conquest_pokemon_stats', 'conquest_stat_names', 'conquest_stats', 'conquest_transformation_pokemon', 'conquest_transformation_warriors', 'conquest_warrior_archetypes', 'conquest_warrior_names', 'conquest_warrior_rank_stat_map', 'conquest_warrior_ranks', 'conquest_warrior_skill_names', 'conquest_warrior_skills', 'conquest_warrior_specialties', 'conquest_warrior_stat_names', 'conqu

## ChatPromptTemplate with Validation

In [21]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

system = """Double check the user's {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
Remove all markdown related content and special characters.
Output the final SQL query only, nothing else should be there in the query."""

prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=pokemon_db.dialect)
validation_chain = prompt | llm | StrOutputParser()

full_chain = {"query": chain} | validation_chain

In [25]:
query = full_chain.invoke(
    {
        "question": 'Show me all most powerful pokemons with their names.',
        "context": context
    }
)
query

'SELECT "pokemon_species"."identifier", "pokemon_stats"."base_stat" FROM "pokemon_species" JOIN "pokemon_stats" ON "pokemon_species"."id" = "pokemon_stats"."pokemon_id" ORDER BY "pokemon_stats"."base_stat" DESC LIMIT 5 \n'

In [26]:
pokemon_db.run(query)

"[('blissey', 255), ('chansey', 250), ('shuckle', 230), ('shuckle', 230), ('steelix', 200)]"

In [33]:

query = full_chain.invoke(
    {
        "question": 'How many pokemons have blue color? Give me their details with their name.',
        "context": context
    }
)
query

'SELECT COUNT(*) FROM "pokemon_colors" WHERE "identifier" = \'blue\' \n'

In [34]:
pokemon_db.run(query)

'[(1,)]'