In [3]:
from langchain_google_genai import GoogleGenerativeAI
from langchain_core.output_parsers import JsonOutputParser
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
import os
from dotenv import load_dotenv

In [4]:
load_dotenv()

True

In [5]:
llm = GoogleGenerativeAI(model="gemini-pro",google_api_key=os.environ['GOOGLE_API_KEY'], temperature=0.2)

In [6]:
examples = [
    {
        "userInput": "i need bar chart who are the clients we worked in Germeny in last 6 months with no of shipments?",
        "sql_query": """select BillingClientName, count(ShipmentNumber) as ShipmentCount
                    from RevandVolume_ShipmentDate
                    where ShipmentDate >= DATEADD(MONTH, -6, GETDATE()) AND ShipmentDate <= GETDATE() 
                    and OriginCountry = 'Germeny' group by BillingClientName order by ShipmentCount desc
                    """,
        "chart_type": "bar chart",
        "chart_columns": "BillingClientName, ShipmentCount",
    },
    {
        "userInput": "i want Bar chart on agent list and their shipment count who worked with us in Sri Lanka in last 2 months.",
        "sql_query": """select distinct AgentName, count(*) as shipment_count 
                    from RevandVolume_ShipmentDate WHERE AgentCountryName like 'United kingdom' and INCO = 'EXW'  
                    and ShipmentDate >= DATEADD(MONTH, -2, GETDATE()) AND ShipmentDate <= GETDATE()
                    group by AgentName\n""",
        "chart_type": "bar chart",
        "chart_columns": "AgentName, shipment_count",
    }
]

In [7]:
exaple_prompt = PromptTemplate(
    input_variables=["userInput", "sql_query", "chart_type", "chart_columns"],
    template="Question: {userInput}\nSQL Query: {sql_query}\nChart Type: {chart_type}\nColumns: {chart_columns}",
)

In [8]:
print(exaple_prompt.format(**examples[0]))

Question: i need bar chart who are the clients we worked in Germeny in last 6 months with no of shipments?
SQL Query: select BillingClientName, count(ShipmentNumber) as ShipmentCount
                    from RevandVolume_ShipmentDate
                    where ShipmentDate >= DATEADD(MONTH, -6, GETDATE()) AND ShipmentDate <= GETDATE() 
                    and OriginCountry = 'Germeny' group by BillingClientName order by ShipmentCount desc
                    
Chart Type: bar chart
Columns: BillingClientName, ShipmentCount


In [144]:
chart_prompt = """You are now generating visualizations based on the User Input and the SQL query provided.
First Identify the visualization type from the following list: bar chart, area chart, scatter chart, line chart, tabulate.
Then Identify the columns to be displayed in the visualization. Pay attention to use only Question & SQL Query as input to generate the visualization.
If you cannot find the visualization type or columns, then just use Chart Type: None and Columns: None

Use the following format:

Question: Question here
SQLQuery: SQL Query here
Chart Type: Identified chart type which user needed to generate
Columns: Identified columns which needs to generate charts

No pre-amble."""

In [10]:
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX
print(PROMPT_SUFFIX)

Only use the following tables:
{table_info}

Question: {input}


In [145]:
prompt = FewShotPromptTemplate(
    examples=examples,
    example_prompt=exaple_prompt,
    prefix=chart_prompt,
    suffix="""Question: {input}
    SQLQuery:{sql_query}""",
    input_variables=["input", "sql_query"],
)

In [146]:
prompt

FewShotPromptTemplate(input_variables=['input', 'sql_query'], examples=[{'userInput': 'i need bar chart who are the clients we worked in Germeny in last 6 months with no of shipments?', 'sql_query': "select BillingClientName, count(ShipmentNumber) as ShipmentCount\n                    from RevandVolume_ShipmentDate\n                    where ShipmentDate >= DATEADD(MONTH, -6, GETDATE()) AND ShipmentDate <= GETDATE() \n                    and OriginCountry = 'Germeny' group by BillingClientName order by ShipmentCount desc\n                    ", 'chart_type': 'bar chart', 'chart_columns': 'BillingClientName, ShipmentCount'}, {'userInput': 'i want Bar chart on agent list and their shipment count who worked with us in Sri Lanka in last 2 months.', 'sql_query': "select distinct AgentName, count(*) as shipment_count \n                    from RevandVolume_ShipmentDate WHERE AgentCountryName like 'United kingdom' and INCO = 'EXW'  \n                    and ShipmentDate >= DATEADD(MONTH, -2, 

In [147]:
print(prompt.format(input="i need bar chart who are the clients we worked in Germeny in last 6 months with no of shipments?", sql_query="select ClientName, count(ShipmentNumber) as NumberofShipments\nfrom RevandVolume_ShipmentDate\nwhere ShipmentDate >= '2023-12-31' AND ShipmentDate <= GETDATE()\nand DestCity = 'Hamburg' and OriginCity = 'Colombo' \nand MainDepCode = 'FES' and ClientName is not null\ngroup by ClientName \norder by NumberofShipments Desc"))

You are now generating visualizations based on the User Input and the SQL query provided.
First Identify the visualization type from the following list: bar chart, area chart, scatter chart, line chart, tabulate.
Then Identify the columns to be displayed in the visualization. Pay attention to use only Question & SQL Query as input to generate the visualization.
If you cannot find the visualization type or columns, then just use Chart Type: None and Columns: None

Use the following format:

Question: Question here
SQLQuery: SQL Query here
Chart Type: Identified chart type which user needed to generate
Columns: Identified columns which needs to generate charts

No pre-amble.

Question: i need bar chart who are the clients we worked in Germeny in last 6 months with no of shipments?
SQL Query: select BillingClientName, count(ShipmentNumber) as ShipmentCount
                    from RevandVolume_ShipmentDate
                    where ShipmentDate >= DATEADD(MONTH, -6, GETDATE()) AND Shipmen

In [151]:
Q1 = prompt.format(input="i need who are the clients we worked in Germeny in last 6 months with no of shipments?", sql_query="select ClientName, count(ShipmentNumber) as NumberofShipments\nfrom RevandVolume_ShipmentDate\nwhere ShipmentDate >= '2023-12-31' AND ShipmentDate <= GETDATE()\nand DestCity = 'Hamburg' and OriginCity = 'Colombo' \nand MainDepCode = 'FES' and ClientName is not null\ngroup by ClientName \norder by NumberofShipments Desc")

: 

In [149]:
Q1

"You are now generating visualizations based on the User Input and the SQL query provided.\nFirst Identify the visualization type from the following list: bar chart, area chart, scatter chart, line chart, tabulate.\nThen Identify the columns to be displayed in the visualization. Pay attention to use only Question & SQL Query as input to generate the visualization.\nIf you cannot find the visualization type or columns, then just use Chart Type: None and Columns: None\n\nUse the following format:\n\nQuestion: Question here\nSQLQuery: SQL Query here\nChart Type: Identified chart type which user needed to generate\nColumns: Identified columns which needs to generate charts\n\nNo pre-amble.\n\nQuestion: i need bar chart who are the clients we worked in Germeny in last 6 months with no of shipments?\nSQL Query: select BillingClientName, count(ShipmentNumber) as ShipmentCount\n                    from RevandVolume_ShipmentDate\n                    where ShipmentDate >= DATEADD(MONTH, -6, GETD

In [51]:
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List, Set

In [84]:
class Output(BaseModel):
    ChartType: str = Field(description="The type of chart")
    Columns: list[str] = Field(description="The columns to be used for the chart")

In [85]:
parser = JsonOutputParser(pydantic_object=Output)

In [136]:
from typing import Any
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
class TexttoJson(BaseGenerationOutputParser[str]):
    "Pass the output from the LLM to the JSON"

    def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
        generation = result[0]
        json_output = convert_to_json(generation.text)
        return json_output

In [137]:
# chain = llm | parser
chain  = llm | TexttoJson()

In [138]:
import json
def convert_to_json(text):
    
    # # Split the text by newline character
    # lines = text.split("\n")

    # # Initialize an empty dictionary to store key-value pairs
    # output = {}

    # # Iterate over each line and extract key-value pairs
    # for line in lines:
    #     key, value = line.split(": ", 1)  # Split each line by ": " and keep only the first split
    #     if key == "Columns":
    #         value = [col.strip() for col in value.split(",")]
    #     output[key] = value

    # # Convert the dictionary to JSON format
    # json_output = json.dumps(output, indent=4)
    # json_output = json.loads(text)
    
# Parse the response string
    chart_type = text.split('\n')[0].split(':')[1].strip()
    columns = [column.strip() for column in text.split('\n')[1].split(':')[1].split(',')]

    # Construct a dictionary
    json_output = {
    "Chart Type": chart_type,
    "Columns": columns
        }

    return json_output


In [150]:
output = chain.invoke(Q1)
output # output

{'Chart Type': 'bar chart', 'Columns': ['ClientName', 'NumberofShipments']}