In [1]:
import streamlit as st
# handle GPT API
from langchain import LLMChain

# formats the prompt history in a particular way
from langchain.prompts.prompt import PromptTemplate
from langchain.llms import OpenAI

import pandas as pd

from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationChain
from vega import VegaLite

ModuleNotFoundError: No module named 'vega'

In [None]:
alt.renderers.enable('notebook')

In [20]:
session_state = {}

In [21]:
df = pd.read_csv("data/covid_worldwide.csv")
column_names = df.columns

In [47]:
def model_initialisation(TEMPERATURE, MODEL, K, column_names):
    # custom query template --> possible to add few shot examples in the future
    # add dynamic variables columns and data types to the prompt
    template = (
        """
    You are a great assistant at vega-lite visualization creation. No matter what
    the user ask, you should always response with a valid vega-lite specification
    in JSON.

    You should create the vega-lite specification based on user's query.

    Besides, Here are some requirements:
    1. Do not contain the key called 'data' in vega-lite specification.
    2. If the user ask many times, you should generate the specification based on the previous context.
    3. You should consider to aggregate the field if it is quantitative and the chart has a mark type of react, bar, line, area or arc.
    4. The available fields in the dataset are:
    %s
    5. Always respond with exactly one vega-lite specfication. Not more, not less.
    6. If you use a color attribute, it must be inside the encoding block attribute of the specification.
    7. When the user tells you to give him a sample graph, then you give him a vega-lite specification that you think,
    will look good.
    8. remember to only respond with vega-lite specifications without additional explanations

    Current conversation:
    {history}
    Human: {input}
    AI Assistant:"""
        % column_names
    )

    PROMPT = PromptTemplate(
        input_variables=["history", "input"], template=template
    )

    # Create an OpenAI instance
    llm = OpenAI(
        temperature=TEMPERATURE,
        openai_api_key="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",
        model_name=MODEL,
        verbose=False,
        streaming=True,
    )

    # Create a ConversationEntityMemory object if not already created
    if "entity_memory" not in session_state:
        session_state["entity_memory"] = ConversationBufferWindowMemory(k=K)

    # Create the ConversationChain object with the specified configuration
    Conversation = ConversationChain(
        llm=llm,
        prompt=PROMPT,
        memory=session_state["entity_memory"],
    )

    return Conversation

In [49]:
print(session_state)

{}


In [50]:
Conversation = model_initialisation(
MODEL="gpt-3.5-turbo",
TEMPERATURE=1,
K=3,
column_names=column_names
)


In [51]:
query = """"fig_gpt_1_description":{
"mark":"arc"
"encoding":{
"theta":{
"field":"Total_Deaths"
"type":"quantitative"
"aggregate":"average"
}
"color":{
"field":"Country"
"type":"nominal"
}
"order":{
"field":"Total_Deaths"
"type":"quantitative"
"aggregate":"average"
}
}
"projection":{
"type":"identity"
}
"data":{
"values":[
0:{
"Total_Deaths":21689
"Country":"Austria"
}
1:{
"Total_Deaths":68399
"Country":"Japan"
}
2:{
"Total_Deaths":697074
"Country":"Brazil"
}
3:{
"Total_Deaths":1132935
"Country":"USA"
}
4:{
"Total_Deaths":74
"Country":"DPRK"
}
5:{
"Total_Deaths":16356
"Country":"Taiwan"
}
6:{
"Total_Deaths":111020
"Country":"Ukraine"
}
7:{
"Total_Deaths":35630
"Country":"Greece"
}
8:{
"Total_Deaths":165711
"Country":"Germany"
}
9:{
"Total_Deaths":142486
"Country":"Colombia"
}
]
}
}"""

In [52]:
Conversation.run(input=query)

'{\n  "mark": "arc",\n  "encoding": {\n    "theta": {\n      "field": "Total_Deaths",\n      "type": "quantitative",\n      "aggregate": "average"\n    },\n    "color": {\n      "field": "Country",\n      "type": "nominal"\n    },\n    "order": {\n      "field": "Total_Deaths",\n      "type": "quantitative",\n      "aggregate": "average"\n    }\n  },\n  "projection": {\n    "type": "identity"\n  }\n}'

In [51]:
def get_low_level_values(nested_dict):
    values = []
    for value in nested_dict.values():
        if isinstance(value, dict):
            values.extend(get_low_level_values(value))
        else:
            values.append(value)
    return values

# Example nested dictionary
nested_dict = {
    'a': 1,
    'b': {
        'c': 2,
        'd': {
            'e': 3,
            'f': 4
        }
    },
    'g': 5
}

low_level_values = get_low_level_values(nested_dict)
print(low_level_values)


[1, 2, 3, 4, 5]
