In [32]:
import re
import json
import base64
import csv
import requests
import pandas as pd

### create some util functions

In [33]:
# load_and_prepare_data
def load_and_prepare_csv_data(data_path):
    # read the csv file using csv.reader
    with open(data_path, 'r') as f:
        reader = csv.reader(f)
        data = list(reader)
    return data

# load_and_prepare_data
def display_data(data):
    """Display data - works with both DataFrame and list of lists."""
    if isinstance(data, pd.DataFrame):
        print(data.to_string())
    else:
        for row in data:
            print(row)

# get llm response from ollama hosted llm using api call to chat/completions
def get_llm_response(prompt):
    url = "http://localhost:11434/v1/chat/completions"
    headers = {
        "Content-Type": "application/json"
    }
    data = {
        "model": "mistral:7b-instruct",
        "messages": [
            {"role": "user", "content": prompt}
        ],
        "stream": False
    }

    response = requests.post(url, headers=headers, json=data)
    return response.json()['choices'][0]['message']['content']


def encode_image_b64(chart_path: str) -> tuple[str, str]:
    """
    Encode an image file to a base64 string.
    
    Args: 
        chart_path (str): The path to the image file to encode.

    Returns:
        tuple[str, str]: A tuple containing the media type and the base64 encoded string.
    """
    with open(chart_path, "rb") as image_file:
        return "image/png", base64.b64encode(image_file.read()).decode("utf-8")
        

def ensure_execute_python_tags(refined_code_body: str) -> str:
    """
    Ensure the code is wrapped in <execute_python> tags.
    """
    return f"<execute_python>\n{refined_code_body}\n</execute_python>"


### load the data and display it

In [34]:
# Load data as pandas DataFrame (required for LLM-generated code)
df = pd.read_csv('coffee_sales.csv')
display_data(df)

          date   time cash_type       card  price coffee_name  quarter  month  year
0   01/15/2024  09:00      card  VISA-1234    620    Espresso        1      1  2024
1   01/15/2024  09:05      card  VISA-1234    480   Americano        1      1  2024
2   01/15/2024  09:10      card  VISA-1234    530  Cappuccino        1      1  2024
3   01/15/2024  09:15      card  VISA-1234    700       Latte        1      1  2024
4   01/15/2024  09:20      card  VISA-1234    410       Mocha        1      1  2024
5   01/15/2024  09:25      card  VISA-1234    360  Flat White        1      1  2024
6   01/15/2024  09:30      card  VISA-1234    290   Cold Brew        1      1  2024
7   01/15/2025  09:00      card  VISA-5678    820    Espresso        1      1  2025
8   01/15/2025  09:05      card  VISA-5678    670   Americano        1      1  2025
9   01/15/2025  09:10      card  VISA-5678    760  Cappuccino        1      1  2025
10  01/15/2025  09:15      card  VISA-5678    980       Latte        1      

### create a prompt to use the data and instructions


In [38]:
# method to generate code to create a bar chart using ollama hosted llm
def generate_bar_chart_code(instruction: str, out_path_v1: str):
    prompt = f"""
    You are a data visualization expert.

    Return your answer *strictly* in this format:

    <execute_python>
    # valid python code here
    </execute_python>

    Do not add explanations, only the tags and the code.

    The code should create a visualization from a DataFrame 'df' with these columns:
    - date (M/D/YY)
    - time (HH:MM)
    - cash_type (card or cash)
    - card (string)
    - price (number)
    - coffee_name (string)
    - quarter (1-4)
    - month (1-12)
    - year (YYYY)

    User instruction: {instruction}

    Requirements for the code:
    1. Assume the DataFrame is already loaded as 'df'.
    2. Use matplotlib for plotting.
    3. Add clear title, axis labels, and legend if needed.
    4. Save the figure as '{out_path_v1}' with dpi=300.
    5. Do not call plt.show().
    6. Close all plots with plt.close() without any arguments.
    7. Add all necessary import python statements

    Return ONLY the code wrapped in <execute_python> tags.
    """
    
    response = get_llm_response(prompt)
    return response



### generate the llm reponse for bar chart creation code

In [39]:
# generate code to create a bar chart
instruction="Create a plot comparing Q1 coffee sales in 2024 and 2025 using the data in coffee_sales.csv."
out_path_v1="q1_sales_2024_2025.png"
code_v1 = generate_bar_chart_code(instruction, out_path_v1)
print(code_v1)

 <execute_python>
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Load the DataFrame 'df' assuming it is already loaded

# Filter data for Q1 from 2024 and 2025, group by coffee_name and sum price
q1_sales = df[(df['quarter'] == 1) & (df['year'].isin([2024, 2025]))].groupby('coffee_name')['price'].sum()

# Create bar plot with clear title, axis labels, and legend
fig, ax = plt.subplots(figsize=(10, 6))
q1_sales.plot(kind='bar', rot=0, alpha=0.8)
ax.set_title('Q1 Coffee Sales Comparison, 2024 & 2025')
ax.set_xlabel('Coffee Name')
ax.set_ylabel('Total Sales')
ax.legend()

# Save the figure and close all plots
plt.savefig('q1_sales_2024_2025.png', dpi=300)
plt.close(fig)
</execute_python>


### extract the code from LLM generated content and execute it 

In [40]:
# extract the code from above llm response 
match = re.search(r'<execute_python>(.*?)</execute_python>', code_v1, re.DOTALL)
if match:
    code = match.group(1).strip()
    print(code)
    # df is already a pandas DataFrame, pass it to exec context
    exec_globals = {"df": df, "__builtins__": __builtins__}
    exec(code, exec_globals)
else:
    print("No code found in the response")


import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Load the DataFrame 'df' assuming it is already loaded

# Filter data for Q1 from 2024 and 2025, group by coffee_name and sum price
q1_sales = df[(df['quarter'] == 1) & (df['year'].isin([2024, 2025]))].groupby('coffee_name')['price'].sum()

# Create bar plot with clear title, axis labels, and legend
fig, ax = plt.subplots(figsize=(10, 6))
q1_sales.plot(kind='bar', rot=0, alpha=0.8)
ax.set_title('Q1 Coffee Sales Comparison, 2024 & 2025')
ax.set_xlabel('Coffee Name')
ax.set_ylabel('Total Sales')
ax.legend()

# Save the figure and close all plots
plt.savefig('q1_sales_2024_2025.png', dpi=300)
plt.close(fig)


### now we reflect on the output that we generated 

In [41]:
def reflect_on_image_and_regenerate(
    chart_path: str,
    instruction: str,
    out_path_v2: str,
    code_v1: str,  
) -> tuple[str, str]:
    """
    Critique the chart IMAGE and the original code against the instruction, 
    then return refined matplotlib code.
    Returns (feedback, refined_code_with_tags).
    """
    media_type, b64 = encode_image_b64(chart_path)
    

    prompt = f"""
    You are a data visualization expert.
    Your task: critique the attached chart and the original code against the given instruction,
    then return improved matplotlib code.

    Original code (for context):
    {code_v1}

    OUTPUT FORMAT (STRICT):
    1) First line: a valid JSON object with ONLY the "feedback" field.
    Example: {{"feedback": "The legend is unclear and the axis labels overlap."}}

    2) After a newline, output ONLY the refined Python code wrapped in:
    <execute_python>
    ...
    </execute_python>

    3) Import all necessary libraries in the code. Don't assume any imports from the original code.

    HARD CONSTRAINTS:
    - Do NOT include Markdown, backticks, or any extra prose outside the two parts above.
    - Use pandas/matplotlib only (no seaborn).
    - Assume df already exists; do not read from files.
    - Save to '{out_path_v2}' with dpi=300.
    - Always call plt.close() at the end (no plt.show()).
    - Include all necessary import statements.

    Schema (columns available in df):
    - date (M/D/YY)
    - time (HH:MM)
    - cash_type (card or cash)
    - card (string)
    - price (number)
    - coffee_name (string)
    - quarter (1-4)
    - month (1-12)
    - year (YYYY)

    Instruction:
    {instruction}
    """


    
    content = get_llm_response(prompt)

    try:
        obj = json.loads(content)
    except Exception as e:
        # Fallback: try to capture the first {...} in all the content
        m_json = re.search(r"\{.*?\}", content, flags=re.DOTALL)
        if m_json:
            try:
                obj = json.loads(m_json.group(0))
            except Exception as e2:
                obj = {"feedback": f"Failed to parse JSON: {e2}", "refined_code": ""}
        else:
            obj = {"feedback": f"Failed to find JSON: {e}", "refined_code": ""}

    # --- Extract refined code from <execute_python>...</execute_python> ---
    m_code = re.search(r"<execute_python>([\s\S]*?)</execute_python>", content)
    refined_code_body = m_code.group(1).strip() if m_code else ""
    refined_code = ensure_execute_python_tags(refined_code_body)

    feedback = str(obj.get("feedback", "")).strip()
    return feedback, refined_code



In [42]:
# Generate feedback alongside reflected code
feedback, code_v2 = reflect_on_image_and_regenerate(
    chart_path="q1_sales_2024_2025.png",            
    instruction="Create a plot comparing Q1 coffee sales in 2024 and 2025 using the data in coffee_sales.csv.", 
    out_path_v2="q1_sales_2024_2025_v2.png",
    code_v1=code_v1,   # pass in the original code for context        
)

print("Feedback on V1 Chart\n\n", feedback)
print("Regenerated Code Output (V2)\n\n", code_v2)

Feedback on V1 Chart

 The original code does not filter out unnecessary columns, and it lacks error checking for invalid dates or missing values. Additionally, the x-axis label 'Coffee Name' might overlap with the bars if all coffee names are included due to limited space, leading to readability issues. Lastly, grouping by the 'coffee_name' without sorting may cause bars to appear in an arbitrary order making comparison difficult.
Regenerated Code Output (V2)

 <execute_python>
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.dates import DateFormatter

# Filter data for Q1 from 2024 and 2025, group by coffee_name, sum price, sort and handle missing values
filtered_df = df[(df['quarter'] == 1) & (df['year'].isin([2024, 2025]))]
filtered_q1_sales = filtered_df.pivot(index='coffee_name', columns='date', values='price').sum()
filtered_q1_sales[np.isnan(filtered_q1_sales)] = 0

# Create bar plot with clear title, axis labels, and legend, sorting the c

In [43]:
# Get the code within the <execute_python> tags
match = re.search(r"<execute_python>([\s\S]*?)</execute_python>", code_v2)
if match:
    reflected_code = match.group(1).strip()
    exec_globals = {"df": df}
    exec(reflected_code, exec_globals)

