In [30]:
import os

# Load API key from environment variable
gemini_api_key = os.getenv('AIzaSyBD9MWNN7Ntu32w_E_6EDQygvyJwQS-n6c')


NameError: name 'GeminiLLM' is not defined

In [4]:
from langchain_core.prompts import ChatPromptTemplate

template = """
You are a highly capable AI assistant with deep expertise in SQL and database querying.

Your task is to generate a syntactically correct, efficient, and safe SQL query based on:
- The given database schema
- The user's natural language question

---

Instructions:
1. Use only tables and columns that exist in the provided schema.
2. Use aliases (e.g., a, t, c) when they improve readability.
3. Format the SQL query with proper indentation.
4. Avoid SELECT * — be specific with columns when possible.
5. If aggregation, filtering, or ordering is needed, apply it correctly.
6. The query must be executable as-is on the schema given.
7. Do not explain the query or include commentary — output only the SQL.

---

Database Schema:
{schema}

---

User Question:
{question}

---

SQL Query:
"""
prompt = ChatPromptTemplate.from_template(template)


In [5]:
prompt.format(schema = "my schema", question = "how many users are present there?")

"Human: \nYou are a highly capable AI assistant with deep expertise in SQL and database querying.\n\nYour task is to generate a syntactically correct, efficient, and safe SQL query based on:\n- The given database schema\n- The user's natural language question\n\n---\n\nInstructions:\n1. Use only tables and columns that exist in the provided schema.\n2. Use aliases (e.g., a, t, c) when they improve readability.\n3. Format the SQL query with proper indentation.\n4. Avoid SELECT * — be specific with columns when possible.\n5. If aggregation, filtering, or ordering is needed, apply it correctly.\n6. The query must be executable as-is on the schema given.\n7. Do not explain the query or include commentary — output only the SQL.\n\n---\n\nDatabase Schema:\nmy schema\n\n---\n\nUser Question:\nhow many users are present there?\n\n---\n\nSQL Query:\n"

In [9]:
import langchain_community

from langchain_community.utilities import SQLDatabase
import mysql.connector



In [23]:
from langchain_community.utilities import SQLDatabase

mysql_uri = 'mysql+mysqlconnector://adit:yourpassword@localhost:3306/mychinook'
db = SQLDatabase.from_uri(mysql_uri)


In [25]:
db.run("SHOW TABLES FROM mychinook;")

"[('Album',), ('Artist',), ('Customer',), ('Employee',), ('Genre',), ('Invoice',), ('InvoiceLine',), ('MediaType',), ('Playlist',), ('PlaylistTrack',), ('Track',)]"

In [44]:
def get_schema(_):
    return db.get_table_info()  # Use the global or outer scoped db object

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)


In [52]:
%pip install langchain --upgrade



Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [58]:
from langchain.schema import BaseOutputParser

class SimpleStrParser(BaseOutputParser):
    def parse(self, text: str) -> str:
        return text

# then in your chain:
sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | SimpleStrParser()
)


In [62]:
import requests
from langchain.llms.base import LLM
from langchain_core.output_parsers import StrOutputParser  # Fixed import
from langchain_core.runnables import RunnablePassthrough   # Fixed import
from typing import Optional, List, Any, ClassVar

class GeminiLLM(LLM):
    """Custom LLM implementation for Google Gemini API"""
    
    llm_type: ClassVar[str] = "gemini"  # Fixed: added ClassVar annotation
    
    # Add type annotations for instance attributes
    api_key: str
    api_url: str
    
    def __init__(self, api_key: str):
        super().__init__()
        self.api_key = api_key
        # Updated to use the correct Gemini API endpoint
        self.api_url = (
            "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"
        )
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
        """Make API call to Gemini"""
        headers = {
            "Content-Type": "application/json",
        }
        
        # Updated payload structure for current Gemini API
        data = {
            "contents": [
                {
                    "parts": [
                        {"text": prompt}
                    ]
                }
            ],
            "generationConfig": {
                "maxOutputTokens": 2000,
                "temperature": 0.7,
            }
        }
        
        # Add stop sequences if provided
        if stop:
            data["generationConfig"]["stopSequences"] = stop
        
        # Add API key as query parameter (current Gemini API format)
        params = {"key": self.api_key}
        
        try:
            response = requests.post(
                self.api_url, 
                headers=headers, 
                json=data, 
                params=params
            )
            response.raise_for_status()
            
            resp_json = response.json()
            
            # Extract text from the current Gemini API response format
            if "candidates" in resp_json and len(resp_json["candidates"]) > 0:
                candidate = resp_json["candidates"][0]
                if "content" in candidate and "parts" in candidate["content"]:
                    return candidate["content"]["parts"][0]["text"]
                    
            return "No response generated"
            
        except requests.exceptions.RequestException as e:
            raise Exception(f"API request failed: {str(e)}")
        except (KeyError, IndexError) as e:
            raise Exception(f"Unexpected response format: {str(e)}")
    
    @property
    def _identifying_params(self) -> dict:
        """Return identifying parameters for the LLM"""
        return {"llm_type": self.llm_type}

# Example usage:
if __name__ == "__main__":
    # Initialize the LLM with your API key
    # api_key = "your_gemini_api_key_here"
    # gemini_llm = GeminiLLM(api_key=api_key)
    
    # Example of creating a chain with the fixed imports
    # chain = gemini_llm | StrOutputParser()
    
    # Test the LLM
    # result = gemini_llm("Hello, how are you?")
    # print(result)
    pass

In [70]:
# Complete working solution for Gemini LLM with SQL chain

import requests
from langchain.llms.base import LLM
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.prompts import PromptTemplate
from typing import Optional, List, Any, ClassVar

class GeminiLLM(LLM):
    """Working Gemini LLM implementation"""
    
    llm_type: ClassVar[str] = "gemini"
    api_key: str
    api_url: str
    
    def __init__(self, api_key: str):
        super().__init__()
        self.api_key = api_key
        self.api_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
        """Make API call to Gemini"""
        headers = {"Content-Type": "application/json"}
        
        data = {
            "contents": [{"parts": [{"text": prompt}]}],
            "generationConfig": {
                "maxOutputTokens": 2000,
                "temperature": 0.1,
            }
        }
        
        if stop:
            data["generationConfig"]["stopSequences"] = stop
        
        params = {"key": self.api_key}
        
        try:
            response = requests.post(self.api_url, headers=headers, json=data, params=params, timeout=30)
            response.raise_for_status()
            
            resp_json = response.json()
            
            if "candidates" in resp_json and len(resp_json["candidates"]) > 0:
                candidate = resp_json["candidates"][0]
                if "content" in candidate and "parts" in candidate["content"]:
                    result = candidate["content"]["parts"][0]["text"]
                    
                    # Apply stop sequences manually
                    if stop:
                        for stop_seq in stop:
                            if stop_seq in result:
                                result = result.split(stop_seq)[0]
                                break
                    
                    return result.strip()
                    
            return "No response generated"
            
        except Exception as e:
            raise Exception(f"Gemini API error: {str(e)}")
    
    @property
    def _llm_type(self) -> str:
        return "gemini"
    
    @property
    def _identifying_params(self) -> dict:
        return {"llm_type": self.llm_type}

# Method 1: Direct function approach (RECOMMENDED)
def create_sql_query(question: str, schema: str, llm: GeminiLLM) -> str:
    """Create SQL query using Gemini LLM - most reliable approach"""
    
    prompt = f"""Based on the table schema below, write a SQL query to answer the user's question.

Schema:
{schema}

Question: {question}

SQL Query:"""
    
    # Call LLM with stop sequence to prevent extra text
    result = llm._call(prompt, stop=["\n\n", "SQLResult:", "Result:"])
    
    # Clean up the result
    result = result.strip()
    if result.startswith("```sql"):
        result = result.replace("```sql", "").replace("```", "").strip()
    
    return result

# Method 2: LangChain chain without .bind() (if you prefer chains)
def create_working_sql_chain(llm, get_schema_func):
    """Create a working SQL chain without using .bind()"""
    
    def run_sql_chain(inputs):
        question = inputs["question"]
        schema = get_schema_func()
        
        prompt = f"""Based on the table schema below, write a SQL query to answer the user's question.

Schema:
{schema}

Question: {question}

SQL Query:"""
        
        # Call LLM directly
        result = llm._call(prompt, stop=["\nSQLResult:", "\n\n"])
        return result.strip()
    
    return run_sql_chain

# Usage examples:
if __name__ == "__main__":
    # Initialize your LLM
    api_key = "AIzaSyBD9MWNN7Ntu32w_E_6EDQygvyJwQS-n6c"
    llm = GeminiLLM(api_key=api_key)
    
    # Test the LLM first
    try:
        test_result = llm._call("What is 2+2?")
        print(f"LLM Test: {test_result}")
        
        # Method 1: Direct approach (RECOMMENDED)
        def get_schema():
            # Replace this with your actual schema function
            return """
            Table: Album
            Columns: AlbumId (INTEGER), Title (TEXT), ArtistId (INTEGER)
            
            Table: Artist  
            Columns: ArtistId (INTEGER), Name (TEXT)
            """
        
        user_question = 'how many albums are there in the database?'
        
        # Use the direct function approach
        sql_result = create_sql_query(user_question, get_schema(), llm)
        print(f"Generated SQL: {sql_result}")
        
        # Method 2: Chain approach (if you prefer)
        sql_chain_func = create_working_sql_chain(llm, get_schema)
        chain_result = sql_chain_func({"question": user_question})
        print(f"Chain result: {chain_result}")
        
    except Exception as e:
        print(f"Error: {e}")

# If you want to use LangChain syntax, use this instead of .bind():
def create_langchain_compatible_chain(llm, prompt_template, get_schema_func):
    """Create a LangChain-style chain that works with custom LLM"""
    
    class CustomSQLChain:
        def __init__(self, llm, prompt, schema_func):
            self.llm = llm
            self.prompt = prompt
            self.schema_func = schema_func
        
        def invoke(self, inputs):
            schema = self.schema_func()
            formatted_prompt = self.prompt.format(schema=schema, question=inputs["question"])
            result = self.llm._call(formatted_prompt, stop=["\nSQLResult:"])
            return result.strip()
    
    return CustomSQLChain(llm, prompt_template, get_schema_func)

ValidationError: 2 validation errors for GeminiLLM
api_key
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.11/v/missing
api_url
  Field required [type=missing, input_value={}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.11/v/missing

In [66]:
user_question = 'how many albums are there in the database?'
sql_chain.invoke({"question": user_question})

# 'SELECT COUNT(*) AS TotalAlbums\nFROM Album;'


ConnectionError: HTTPSConnectionPool(host='api.gemini.api', port=443): Max retries exceeded with url: /v1/chat/completions (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x7922cdd0d190>: Failed to resolve 'api.gemini.api' ([Errno -2] Name or service not known)"))