In [117]:
import os
from dotenv import load_dotenv
from langchain.chains import create_sql_query_chain
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from sqlalchemy import create_engine
from langchain.sql_database import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool


from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

load_dotenv()

# Load OpenAI API key
OPENAI_KEY = os.getenv('OPENAI_KEY')

# Setup the OpenAI LLM
llm = ChatOpenAI(api_key=OPENAI_KEY, temperature=0.5)

# Create a SQLAlchemy engine
engine = create_engine('mysql+pymysql://root:root@127.0.0.1:8889/mydatabase')

# Wrap the engine with SQLDatabase
db = SQLDatabase(engine)


# Create SQL query chain using the SQLDatabase object
write_query = create_sql_query_chain(llm, db)

execute_query = QuerySQLDataBaseTool(db=db)

chain = write_query | execute_query


answer_prompt = PromptTemplate.from_template(
    """
    Based on the user's question and the SQL result, answer the question either by providing a direct text response or suggesting an appropriate graph type.

    Question: {question}
    SQL Query: {query}
    SQL Result: {result}

    Please decide if the data should be visualized using one of the following graph types: 'line chart', 'stack bar chart', 'bar chart', 'sankey chart'. 
    If a graph is required, provide the data in the following formats:

    - **Line Chart**: Use a list of dictionaries with x and y values:
      ```python
      [
          {{x-axis name}}: date, {{y-axis nam}}e: value,
          ...
      ]
      ```
    - **Stack Bar Chart**: Use a list of dictionaries with categories and stacked values:
      ```python
      [
          {{category}}: "Category", {{value1}}: value1, {{value2}}: value2,
          ...
      ]
      ```
    - **Bar Chart**: Use a list of dictionaries with categories and values:
      ```python
      [
          {{category}}: "Category", {{vlaue}}: value,
          ...
      ]
      ```
    - **Sankey Chart**: Use a list of dictionaries with source, target, and value:
      ```python
      [
          {{source}}: "Source", {{target}}: "Target", {{value}}: value,
          ...
      ]
      ```

    If the answer is a single value or string, provide a direct text answer.

    Answer format:
    - graph_needed: "yes" or "no"
    - graph_type: one of ['line chart', 'stack bar chart', 'bar chart', 'sankey chart'] (if graph_needed is "yes")
    - data_array: python data list (if graph_needed is "yes")
    - text_answer: The direct answer (if graph_needed is "no")
    """
)



answer = answer_prompt | llm |StrOutputParser()

chain = (RunnablePassthrough.assign(query = write_query).assign(result = itemgetter("query") | execute_query) | answer)

print(execute_query)
# Run a query
#query = "What is the overall trend of global investments in plastic circularity?"

query = "How Deal Value is divided according to the region and Archetype? "
result = chain.invoke({"question": query})

print(result)



db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x158955d50>
- graph_needed: "yes"
- graph_type: "stack bar chart"
- data_array: 
```python
[
    {'Region': 'North America', 'Recovery': 53876821478.0, 'Recycling': 27498521727.0, 'Materials': 7816251495.0},
    {'Region': 'Europe', 'Recovery': 40254387086.0, 'Recycling': 17211904704.0, 'Materials': 0}
]
```


In [119]:
import re
import json
import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go

def extract_fields(result):
    # Updated regex patterns
    graph_needed_pattern = r'graph_needed:\s*"(\w+)"'
    graph_type_pattern = r'graph_type:\s*"([\w\s]+)"'
    data_array_pattern = r'data_array:\s*```python\s*(\[\{.*?\}\])\s*```'

    # Extract fields
    graph_needed = re.search(graph_needed_pattern, result)
    graph_type = re.search(graph_type_pattern, result)
    data_array = re.search(data_array_pattern, result, re.DOTALL)

    # Extract and clean values
    graph_needed_value = graph_needed.group(1) if graph_needed else None
    graph_type_value = graph_type.group(1).strip() if graph_type else None
    data_array_str = data_array.group(1).strip() if data_array else None

    if data_array_str:
        # Clean the data array string and convert it to a Python list
        data_array_str = data_array_str.replace("'", '"')  # Replace single quotes with double quotes
        try:
            data_array_value = json.loads(data_array_str)  # Convert string to Python list
        except json.JSONDecodeError:
            print("Error decoding JSON from data_array.")
            data_array_value = None
    else:
        data_array_value = None

    return graph_needed_value, graph_type_value, data_array_value

def plot_chart(graph_needed, graph_type, data_array):
    if graph_needed == "no":
        print("No graph needed.")
        return

    if graph_type == "line chart":
        plot_line_chart(data_array)
    elif graph_type == "stack bar chart":
        plot_stack_bar_chart(data_array)
    elif graph_type == "bar chart":
        plot_bar_chart(data_array)
    elif graph_type == "sankey chart":
        plot_sankey_chart(data_array)
    else:
        print("Unknown graph type.")

def plot_line_chart(data):
    df = pd.DataFrame(data)
    for column in df.columns[1:]:
        plt.plot(df[df.columns[0]], df[column], marker='o', label=column)
    plt.title('Line Chart')
    plt.xlabel(df.columns[0])
    plt.ylabel('Values')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_stack_bar_chart(data):
    df = pd.DataFrame(data)
    df.set_index('Region', inplace=True)
    df.plot(kind='bar', stacked=True, figsize=(12, 8))
    plt.title('Stack Bar Chart')
    plt.xlabel('Region')
    plt.ylabel('Values')
    plt.legend(title='Categories', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

def plot_bar_chart(data):
    df = pd.DataFrame(data)
    df.plot(kind='bar', x='Region', y=df.columns[1:], figsize=(12, 8))
    plt.title('Bar Chart')
    plt.xlabel('Region')
    plt.ylabel('Values')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

def plot_sankey_chart(data):
    sources = [d.get('source') for d in data]
    targets = [d.get('target') for d in data]
    values = [d.get('value') for d in data]

    unique_nodes = list(set(sources + targets))
    node_indices = {node: idx for idx, node in enumerate(unique_nodes)}

    fig = go.Figure(go.Sankey(
        node=dict(
            pad=15,
            thickness=20,
            line=dict(color='black', width=0.5),
            label=unique_nodes
        ),
        link=dict(
            source=[node_indices.get(src) for src in sources],
            target=[node_indices.get(tgt) for tgt in targets],
            value=values
        )
    ))

    fig.update_layout(title_text='Sankey Diagram', font_size=10)
    fig.show()



In [120]:
graph_type

''

In [121]:
graph_needed