In [16]:
# from git import Repo
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import LanguageParser
from langchain.text_splitter import Language
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.output_parsers.json import SimpleJsonOutputParser
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma


In [17]:
import pandas as pd

df = pd.read_csv("/Users/borisyu/Desktop/github/AutoFinViz/example/data/Stock_price_TSLA.csv")

## Question Formulator

In [18]:
summary = summary = [{'column': 'Date',
  'properties': {'dtype': 'date',
   'min': '2022-11-21',
   'max': '2023-11-17',
   'num_unique_values': 250,
   'samples': ['2022-11-21', '2022-11-22', '2022-11-23'],
   'semantic_type': '',
   'description': ''}},
 {'column': 'Open',
  'properties': {'dtype': 'number',
   'min': 103.0,
   'max': 296.040009,
   'num_unique_values': 248,
   'samples': [175.850006, 168.630005, 173.570007],
   'semantic_type': '',
   'description': ''}},
 {'column': 'High',
  'properties': {'dtype': 'number',
   'min': 111.75,
   'max': 299.290009,
   'num_unique_values': 246,
   'samples': [176.770004, 170.919998, 183.619995],
   'semantic_type': '',
   'description': ''}},
 {'column': 'Low',
  'properties': {'dtype': 'number',
   'min': 101.809998,
   'max': 289.519989,
   'num_unique_values': 247,
   'samples': [167.539993, 166.190002, 172.5],
   'semantic_type': '',
   'description': ''}},
 {'column': 'Close',
  'properties': {'dtype': 'number',
   'min': 108.099998,
   'max': 293.339996,
   'num_unique_values': 247,
   'samples': [167.869995, 169.910004, 183.199997],
   'semantic_type': '',
   'description': ''}},
 {'column': 'Adj Close',
  'properties': {'dtype': 'number',
   'min': 108.099998,
   'max': 293.339996,
   'num_unique_values': 247,
   'samples': [167.869995, 169.910004, 183.199997],
   'semantic_type': '',
   'description': ''}},
 {'column': 'Volume',
  'properties': {'dtype': 'number',
   'min': 50672700,
   'max': 306590600,
   'num_unique_values': 250,
   'samples': [92882700, 78452300, 109536700],
   'semantic_type': '',
   'description': ''}},
 {'column': 'Daily Range',
  'properties': {'dtype': 'number',
   'min': 2.910004000000015,
   'max': 20.38000500000001,
   'num_unique_values': 240,
   'samples': [9.23001099999999, 4.729996, 11.119994999999989],
   'semantic_type': '',
   'description': ''}},
 {'column': 'Percentage Change',
  'properties': {'dtype': 'number',
   'min': -0.08753273328663176,
   'max': 0.09766988349514556,
   'num_unique_values': 250,
   'samples': [-0.045379645878431296,
    0.007590576777839597,
    0.055481878271745375],
   'semantic_type': '',
   'description': ''}},
 {'column': 'Volume Weighted Average Price',
  'properties': {'dtype': 'number',
   'min': 138.933071859511,
   'max': 203.68287803146123,
   'num_unique_values': 250,
   'samples': [167.869995, 168.80409134966993, 174.41832731024735],
   'semantic_type': '',
   'description': ''}}]

n = 3

list_of_graph = ["OHLC Chart", "Candlestick Chart", "Moving Average Graph", "RSI Graph", "Moving Average Convergence Divergence",
                 "Waterfall Chart", "Funnel Chart", "Time Series graph with a range slider","Multi-Measurment Time Series Graph", ]

In [19]:

format_instruction = """
    THE OUTPUT MUST BE A CODE SNIPPET OF A VALID LIST OF JSON OBJECTS. IT MUST USE THE FOLLOWING FORMAT:

    ```
    [
        { "index": 0,  "title": "How", "visualization_type": "histogram", "x_axis": ["col_1", "col_2"], "y_axis": ["col_3", "col_4"]},
    ]
    ```
    THE OUTPUT SHOULD ONLY USE THE JSON FORMAT ABOVE.
"""


user_prompt = f"""

    CHOOSE {n} plots, that fit for dataframe df, among {list_of_graph} as the visulization_type 
    SET a QUESTION as the TITLE of the plot.
    CHOOSE columns name for x axis and y axis.

    FOLLOW this output format instruction:

    {format_instruction}

    'visualization_type' can ONLY within {list_of_graph}.

"""

In [20]:

graph_description = [
    {"visualization_type": "Time Series Graph with slider", "parameter": """
        df (pd.DataFrame): DataFrame containing the data for the plot.
        date (str): Column name in `df` representing the date axis.
        measurement (str): Column name in `df` representing the measurement axis.
    """}, 
    {"visualization_type": "Time Series Graph", "parameter": """
        df (pd.DataFrame): DataFrame containing the data for the plot.
        date (str): Column name in `df` representing the date axis.
        measurement (str): Column name in `df` representing the measurement axis.
       
    """},
    {"visualization_type": "Multi-Measurment Time Series Graph", "parameter": """
        df (pd.DataFrame): df (pd.DataFrame): DataFrame containing the data for the plot.
        date (str): Column name in `df` representing the date axis.
        measurements (List[str]): List of column names in `df` representing the multiple measurement axes.
    """},
    {"visualization_type": "OHLC Chart", "parameter": """
        df (pd.DataFrame): DataFrame containing the OHLC data.
        date (str): Column name in `df` representing the dates.
        open (str): Column name in `df` representing the opening prices.
        high (str): Column name in `df` representing the highest prices.
        low (str): Column name in `df` representing the lowest prices.
        close (str): Column name in `df` representing the closing prices.
    """},
    {"visualization_type": "Candlestick Chart", "parameter": """
        df (pd.DataFrame): DataFrame containing the data for the candlestick chart.
        date (str): Column name in `df` representing the dates.
        open (str): Column name in `df` representing the opening prices.
        high (str): Column name in `df` representing the highest prices during the period.
        low (str): Column name in `df` representing the lowest prices during the period.
        close (str): Column name in `df` representing the closing prices.
    """},
    {"visualization_type": "Moving Average Graph", "parameter": """
        df (pd.DataFrame): DataFrame containing the data for the plot. Must include 'Date' and 'Close' columns.
        date_col (str): Column name in `df` representing the date axis.
        close_col (str): Column name in `df` representing the closing price.
    """},
    {"visualization_type": "RSI Graph", "parameter": """
        df (pd.DataFrame): DataFrame containing the data for the calculation. Must include 'Date' and 'Close' columns.
        date_col (str): Column name in `df` representing the date axis.
        close_col (str): Column name in `df` representing the closing price.
    """},
    {"visualization_type": "Moving Average Convergence Divergence", "parameter": """
        df (pd.DataFrame): DataFrame containing the data for the calculation. Must include 'Date' and 'Close' columns.
        date_col (str): Column name in `df` representing the date axis.
        close_col (str): Column name in `df` representing the closing price.
    """},
    {"visualization_type": "Waterfall Chart", "parameter": """
        x_labels (list): A list of strings representing the categories or stages in the waterfall chart.
        y_values (list): A list of numerical values corresponding to each category or stage. Positive values represent increases, and negative values represent decreases.
    """},
    {"visualization_type": "Funnel Chart", "parameter": """
        df (pd.DataFrame): A dictionary with keys representing column names and values as lists of data points. 
                 It should contain data for both 'x' and 'y' axes of the funnel chart.
        x_col (str): The column from data to be used as values (quantitative axis) of the funnel.
        y_col (str): The column from data to be used as stages (qualitative axis) of the funnel.
    """}
]


In [21]:

llm = ChatOpenAI(model_name="gpt-3.5-turbo")

prompt ="""
    Given the summary of dataframe: {{summary}}
    Given description of each graph: {{graph_description}}

    {{user_prompt}}
"""

prompt = PromptTemplate.from_template(prompt, template_format="jinja2")
chain = prompt | llm | SimpleJsonOutputParser()
question_result = chain.invoke({"summary": summary, "graph_description": graph_description, "user_prompt": user_prompt})


In [22]:
print(question_result)

[{'index': 0, 'title': 'Daily Stock Prices', 'visualization_type': 'OHLC Chart', 'x_axis': ['Date'], 'y_axis': ['Open', 'High', 'Low', 'Close']}, {'index': 1, 'title': 'Moving Average', 'visualization_type': 'Moving Average Graph', 'x_axis': ['Date'], 'y_axis': ['Close']}, {'index': 2, 'title': 'Relative Strength Index (RSI)', 'visualization_type': 'RSI Graph', 'x_axis': ['Date'], 'y_axis': ['Close']}]


## Load & Split data

In [28]:
loader = GenericLoader.from_filesystem(
    "/Users/borisyu/Desktop/github/AutoFinViz/notebooks/graph",
    glob="**/*",
    suffixes=[".py"],
    parser=LanguageParser(language=Language.PYTHON, parser_threshold=500),
)
documents = loader.load()
len(documents)

1

In [29]:
python_splitter = RecursiveCharacterTextSplitter.from_language(
    language=Language.PYTHON, chunk_size=2000, chunk_overlap=200
)
texts = python_splitter.split_documents(documents)
len(texts)

10

In [30]:
db = Chroma.from_documents(texts, OpenAIEmbeddings(disallowed_special=()))
retriever = db.as_retriever(
    search_type="mmr",  # Also test "similarity"
    search_kwargs={"k": 5},
)

## Visualizer

In [31]:
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA

# RAG template
prompt_RAG = """
    You are a proficient python developer. Respond with the syntactically correct code for to the question below. Make sure you follow these rules:
    1. Use context to understand the APIs and how to use it & apply.
    2. Do not add license information to the output code.
    3. Do not include colab code in the output.
    4. Ensure all the requirements in the question are met.

    Question:
    {question}

    Context:
    {context}

    Helpful Response :
    """

prompt_RAG_tempate = PromptTemplate(
    template=prompt_RAG, input_variables=["context", "question"]
)

qa_chain = RetrievalQA.from_llm(
    llm=llm, prompt=prompt_RAG_tempate, retriever=retriever, return_source_documents=True
)

question= {'index': 10,
  'title': 'Stock Price Distribution',
  'visualization_type': 'Pie Chart',
  'x_axis': ['Close'],
  'y_axis': ['Volume']}


In [32]:
for viz in question_result:

    print(viz)

    question = f"""Write a python function that takes df dataframe and plot {viz['visualization_type']} in plotly library.
    USE the X_column {viz["x_axis"]}, Y_columns {viz["y_axis"]}
    SET title as "{viz['title']}".
    ONLY Return executable PYTHON code.
    NAME the function as plot(df)

    TEMPLATE:

    ```
    import ...

    def plot(df: pd.dataframe):

        return fig

    # DO NOT MODIFY ANYTHNIG BELOW
    # Dataframe 'df' is already defined so NO NEED TO CONSTRUCT DATAFRAME "df"
    fig.write_image("../example/figures/{viz['title']}.png")
    fig.show()
    ```
    THE OUTPUT SHOULD ONLY USE THE PYTHON FORMAT ABOVE.
    MUST execute the plot(df) after defining the plot().
    """

    results = qa_chain({"query": question})
    code = results["result"]
    print(code)

    try:
        fig = exec(code)
        continue
    except Exception as e:
        print(f"An error occurred: {e}")
        results = qa_chain({"query": question})
        code = results["result"]
        print(code)

    try:
        fig = exec(code)
    except Exception as e:
        print(f"An error occurred: {e}")

{'index': 0, 'title': 'Daily Stock Prices', 'visualization_type': 'OHLC Chart', 'x_axis': ['Date'], 'y_axis': ['Open', 'High', 'Low', 'Close']}
```
import pandas as pd
import plotly.graph_objects as go
from pandas.api.types import is_datetime64_any_dtype as is_datetime

def plot(df: pd.DataFrame):
    """
    Plots an OHLC (Open, High, Low, Close) chart using Plotly.

    Parameters:
    df (pd.DataFrame): DataFrame containing the OHLC data.

    Returns:
    go.Figure: Plotly figure object.

    Note:
    The function assumes that the DataFrame contains the following columns: 'Date', 'Open', 'High', 'Low', 'Close'.
    """

    # Ensure the date column is in datetime format
    if not is_datetime(df['Date']):
        df['Date'] = pd.to_datetime(df['Date'], infer_datetime_format=True, errors='coerce')

    # Create the OHLC chart
    fig = go.Figure(go.Ohlc(x=df['Date'], open=df['Open'], high=df['High'], low=df['Low'], close=df['Close'], name='Price'))

    # Set title
    fig.update_l

{'index': 1, 'title': 'Moving Average', 'visualization_type': 'Moving Average Graph', 'x_axis': ['Date'], 'y_axis': ['Close']}
import pandas as pd
import plotly.graph_objects as go

def plot(df: pd.DataFrame):

    # Ensure the date column is in datetime format
    if not is_datetime(df['Date']):
        df['Date'] = pd.to_datetime(df['Date'], infer_datetime_format=True, errors='coerce')

    # Calculate EMA and SMA
    df['EMA_9'] = df['Close'].ewm(9).mean().shift()
    df['EMA_22'] = df['Close'].ewm(22).mean().shift()
    df['SMA_5'] = df['Close'].rolling(5).mean().shift()
    df['SMA_10'] = df['Close'].rolling(10).mean().shift()
    df['SMA_15'] = df['Close'].rolling(15).mean().shift()
    df['SMA_30'] = df['Close'].rolling(30).mean().shift()

    # Create and display the plot
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=df['Date'], y=df['EMA_9'], name='EMA 9'))
    fig.add_trace(go.Scatter(x=df['Date'], y=df['EMA_22'], name='EMA 22'))
    fig.add_trace(go.Scatter(x=df['Date

{'index': 2, 'title': 'Relative Strength Index (RSI)', 'visualization_type': 'RSI Graph', 'x_axis': ['Date'], 'y_axis': ['Close']}
import pandas as pd
import plotly.graph_objects as go

def plot(df: pd.DataFrame):
    """
    Plots the Relative Strength Index (RSI) using Plotly.

    Parameters:
    df (pd.DataFrame): DataFrame containing the data for the calculation. Must include 'Date' and 'Close' columns.

    Returns:
    fig (go.Figure): The RSI plot as a Plotly figure.
    """

    # Ensure the date column is in datetime format
    if not pd.api.types.is_datetime64_any_dtype(df['Date']):
        df['Date'] = pd.to_datetime(df['Date'], infer_datetime_format=True, errors='coerce')

    # RSI calculation
    delta = df['Close'].diff()
    gain = (delta.where(delta > 0, 0)).rolling(window=14, min_periods=0).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(window=14, min_periods=0).mean()
    rs = gain / loss
    rsi = 100 - (100 / (1 + rs))

    # Create the RSI plot
    fig = 