# Reflection Pattern

### Multi-modal LLM  approach

1. **Generate an initial version (V1):**
Use a Large Language Model (LLM) to create the first version of the plotting code.

2. **Execute code and create chart:** 
Run the generated code and display the resulting chart. ** (check everywhere)

3. **Reflect on the output:**
Evaluate both the code and the chart using an LLM to detect areas for improvement (e.g., clarity, accuracy, design).

4. **Generate and execute improved version (V2):**
Produce a refined version of the plotting code based on reflection insights and render the enhanced chart.

In [1]:
import re
import json
import pandas as pd

import os
from openai import OpenAI
import openai
from dotenv import load_dotenv

load_dotenv(".env", override = True)

True

In [2]:
openai.api_key = os.getenv("OPENAI_API_KEY")

client = OpenAI()

In [3]:
df = pd.read_csv("coffee_sales.csv")

In [4]:
df.head()

Unnamed: 0,date,time,cash_type,card,price,coffee_name
0,2024-03-01,06:14,card,ANON-0000-0000-0001,3.87,Latte
1,2024-03-01,11:10,card,ANON-0000-0000-0002,3.87,Hot Chocolate
2,2024-03-01,11:19,card,ANON-0000-0000-0002,3.87,Hot Chocolate
3,2024-03-01,11:37,card,ANON-0000-0000-0003,2.89,Americano
4,2024-03-01,12:56,card,ANON-0000-0000-0004,3.87,Latte


In [5]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3636 entries, 0 to 3635
Data columns (total 6 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   date         3636 non-null   object 
 1   time         3636 non-null   object 
 2   cash_type    3636 non-null   object 
 3   card         3547 non-null   object 
 4   price        3636 non-null   float64
 5   coffee_name  3636 non-null   object 
dtypes: float64(1), object(5)
memory usage: 170.6+ KB


In [7]:
df["date"] = pd.to_datetime(df["date"], errors="coerce")
df["quarter"] = df["date"].dt.quarter
df["month"] = df["date"].dt.month
df["year"] = df["date"].dt.year

In [8]:
df.head()

Unnamed: 0,date,time,cash_type,card,price,coffee_name,quarter,month,year
0,2024-03-01,06:14,card,ANON-0000-0000-0001,3.87,Latte,1,3,2024
1,2024-03-01,11:10,card,ANON-0000-0000-0002,3.87,Hot Chocolate,1,3,2024
2,2024-03-01,11:19,card,ANON-0000-0000-0002,3.87,Hot Chocolate,1,3,2024
3,2024-03-01,11:37,card,ANON-0000-0000-0003,2.89,Americano,1,3,2024
4,2024-03-01,12:56,card,ANON-0000-0000-0004,3.87,Latte,1,3,2024


In [9]:
def generate_chart_code(instruction: str, model: str, out_path_v1: str) -> str:
    """Generate Python code to make a plot with matplotlib using tag-based wrapping."""

    prompt = f"""
    You are a data visualization expert.

    Return your answer *strictly* in this format:

    <execute_python>
    # valid python code here
    </execute_python>

    Do not add explanations, only the tags and the code.

    The code should create a visualization from a DataFrame 'df' with these columns:
    - date (M/D/YY)
    - time (HH:MM)
    - cash_type (card or cash)
    - card (string)
    - price (number)
    - coffee_name (string)
    - quarter (1-4)
    - month (1-12)
    - year (YYYY)

    User instruction: {instruction}

    
    Requirements for the code:
    1. Assume the DataFrame is already loaded as 'df'.
    2. Use matplotlib for plotting.
    3. Add clear title, axis labels, and legend if needed.
    4. Save the figure as '{out_path_v1}' with dpi=300.
    5. Do not call plt.show().
    6. Close all plots with plt.close().
    7. Add all necessary import python statements

    Return ONLY the code wrapped in <execute_python> tags.
    """
    response = client.responses.create(
            model=model,
            input=prompt,
        )
    return response.output_text

In [10]:
code_v1 = generate_chart_code(
    instruction="Create a plot comparing Q1 coffee sales in 2024 and 2025 using the data in coffee_sales.csv.", 
    model="gpt-4o-mini", 
    out_path_v1="chart_v1.png"
)


In [11]:
print(code_v1)

<execute_python>
import pandas as pd
import matplotlib.pyplot as plt

# Filter data for Q1 coffee sales in 2024 and 2025
q1_sales = df[(df['quarter'] == 1) & (df['year'].isin([2024, 2025]))]

# Group by year and coffee_name, summing the prices
grouped_sales = q1_sales.groupby(['year', 'coffee_name'])['price'].sum().unstack()

# Plotting
plt.figure(figsize=(10, 6))
grouped_sales.plot(kind='bar', edgecolor='black')

plt.title('Q1 Coffee Sales Comparison (2024 vs 2025)')
plt.xlabel('Coffee Type')
plt.ylabel('Total Sales ($)')
plt.xticks(rotation=45)
plt.legend(title='Year')
plt.tight_layout()

# Save the figure
plt.savefig('chart_v1.png', dpi=300)
plt.close()
</execute_python>


In [14]:
pattern = r"<execute_python>([\s\S]*?)</execute_python>"

match = re.search(pattern, code_v1)

In [22]:
if match:
    initial_code = match.group(1).strip()

In [23]:
print(initial_code)

import pandas as pd
import matplotlib.pyplot as plt

# Filter data for Q1 coffee sales in 2024 and 2025
q1_sales = df[(df['quarter'] == 1) & (df['year'].isin([2024, 2025]))]

# Group by year and coffee_name, summing the prices
grouped_sales = q1_sales.groupby(['year', 'coffee_name'])['price'].sum().unstack()

# Plotting
plt.figure(figsize=(10, 6))
grouped_sales.plot(kind='bar', edgecolor='black')

plt.title('Q1 Coffee Sales Comparison (2024 vs 2025)')
plt.xlabel('Coffee Type')
plt.ylabel('Total Sales ($)')
plt.xticks(rotation=45)
plt.legend(title='Year')
plt.tight_layout()

# Save the figure
plt.savefig('chart_v1.png', dpi=300)
plt.close()


In [24]:
exec_globals = {"df": df}

In [28]:
import matplotlib.pyplot as plt

In [29]:
exec(initial_code, exec_globals)

<Figure size 1000x600 with 0 Axes>

In [35]:
import base64
import mimetypes
from pathlib import Path

In [34]:
def encode_image_b64(path: str) -> tuple[str, str]:
    """Return (media_type, base64_str) for an image file path."""
    mime, _ = mimetypes.guess_type(path)
    media_type = mime or "image/png"
    with open(path, "rb") as f:
        b64 = base64.b64encode(f.read()).decode("utf-8")
    return media_type, b64

In [55]:
media_type, b64 = encode_image_b64("chart_v1.png")

In [61]:
def image_openai_call(model_name: str, prompt: str, media_type: str, b64: str) -> str:
    data_url = f"data:{media_type};base64,{b64}"
    resp = client.responses.create(
        model=model_name,
        input=[
            {
                "role": "user",
                "content": [
                    {"type": "input_text", "text": prompt},
                    {"type": "input_image", "image_url": data_url},
                ],
            }
        ],
    )
    content = (resp.output_text or "").strip()
    return content

In [70]:
def image_anthropic_call(model_name: str, prompt: str, media_type: str, b64: str) -> str:
    """
    Call Anthropic Claude (messages.create) with text+image and return *all* text blocks concatenated.
    Adds a system message to enforce strict JSON output.
    """
    msg = anthropic_client.messages.create(
        model=model_name,
        max_tokens=2000,
        temperature=0,
        system=(
            "You are a careful assistant. Respond with a single valid JSON object only. "
            "Do not include markdown, code fences, or commentary outside JSON."
        ),
        messages=[{
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image", "source": {"type": "base64", "media_type": media_type, "data": b64}},
            ],
        }],
    )


In [64]:
def ensure_execute_python_tags(text: str) -> str:
    """Normalize code to be wrapped in <execute_python>...</execute_python>."""
    text = text.strip()
    # Strip ```python fences if present
    text = re.sub(r"^```(?:python)?\s*|\s*```$", "", text).strip()
    if "<execute_python>" not in text:
        text = f"<execute_python>\n{text}\n</execute_python>"
    return text

In [65]:
def reflect_on_image_and_regenerate(chart_path: str,instruction: str,
                                    model_name: str,out_path_v2: str,code_v1: str) -> tuple[str, str]:
    """
    Critique the chart IMAGE and the original code against the instruction, 
    then return refined matplotlib code.
    Returns (feedback, refined_code_with_tags).
    Supports OpenAI and Anthropic (Claude).
    """
    media_type, b64 = encode_image_b64(chart_path)
    

    prompt = f"""
    You are a data visualization expert.
    Your task: critique the attached chart and the original code against the given instruction,
    then return improved matplotlib code.

    Original code (for context):
    {code_v1}

    OUTPUT FORMAT (STRICT!):
    1) First line: a valid JSON object with ONLY the "feedback" field.
    Example: {{"feedback": "The legend is unclear and the axis labels overlap."}}

    2) After a newline, output ONLY the refined Python code wrapped in:
    <execute_python>
    ...
    </execute_python>

    3) Import all necessary libraries in the code. Don't assume any imports from the original code.

    HARD CONSTRAINTS:
    - Do NOT include Markdown, backticks, or any extra prose outside the two parts above.
    - Use pandas/matplotlib only (no seaborn).
    - Assume df already exists; do not read from files.
    - Save to '{out_path_v2}' with dpi=300.
    - Always call plt.close() at the end (no plt.show()).
    - Include all necessary import statements.

    Schema (columns available in df):
    - date (M/D/YY)
    - time (HH:MM)
    - cash_type (card or cash)
    - card (string)
    - price (number)
    - coffee_name (string)
    - quarter (1-4)
    - month (1-12)
    - year (YYYY)

    Instruction:
    {instruction}
    """


    # In case the name is "Claude" or "Anthropic", use the safe helper
    lower = model_name.lower()
    if "claude" in lower or "anthropic" in lower:
        # ‚úÖ Use the safe helper that joins all text blocks and adds a system prompt
        content = image_anthropic_call(model_name, prompt, media_type, b64)
    else:
        content = image_openai_call(model_name, prompt, media_type, b64)

    # --- Parse ONLY the first JSON line (feedback) ---
    lines = content.strip().splitlines()
    json_line = lines[0].strip() if lines else ""

    try:
        obj = json.loads(json_line)
    except Exception as e:
        # Fallback: try to capture the first {...} in all the content
        m_json = re.search(r"\{.*?\}", content, flags=re.DOTALL)
        if m_json:
            try:
                obj = json.loads(m_json.group(0))
            except Exception as e2:
                obj = {"feedback": f"Failed to parse JSON: {e2}", "refined_code": ""}
        else:
            obj = {"feedback": f"Failed to find JSON: {e}", "refined_code": ""}

    # --- Extract refined code from <execute_python>...</execute_python> ---
    m_code = re.search(r"<execute_python>([\s\S]*?)</execute_python>", content)
    refined_code_body = m_code.group(1).strip() if m_code else ""
    refined_code = ensure_execute_python_tags(refined_code_body)

    feedback = str(obj.get("feedback", "")).strip()
    return feedback, refined_code

In [66]:
# Generate feedback alongside reflected code
feedback, code_v2 = reflect_on_image_and_regenerate(
    chart_path="chart_v1.png",            
    instruction="Create a plot comparing Q1 coffee sales in 2024 and 2025 using the data in coffee_sales.csv.", 
    model_name="o4-mini",
    out_path_v2="chart_v2.png",
    code_v1=code_v1,   # pass in the original code for context        
)

# utils.print_html(feedback, title="Feedback on V1 Chart")
# utils.print_html(code_v2, title="Regenerated Code Output (V2)")

In [67]:
feedback

'The legend title is misleading and the grouping is inverted‚Äîthe bars are grouped by year rather than by coffee type, the x-axis label is incorrect, and the y-axis lacks currency formatting. Pivot the data so coffee types are on the x-axis, use a correct legend title, add a dollar formatter, and improve layout and grid.'

In [69]:
print(code_v2)

<execute_python>
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

# Filter data for Q1 coffee sales in 2024 and 2025
q1_sales = df[(df['quarter'] == 1) & (df['year'].isin([2024, 2025]))]

# Pivot so coffee types are the index and years are columns
sales_pivot = (
    q1_sales
    .groupby(['coffee_name', 'year'])['price']
    .sum()
    .unstack(fill_value=0)
    .sort_index()
)

# Plotting
fig, ax = plt.subplots(figsize=(12, 6))
sales_pivot.plot(kind='bar', ax=ax, edgecolor='black')

# Formatting
ax.set_title('Q1 Coffee Sales Comparison (2024 vs 2025)', fontsize=14)
ax.set_xlabel('Coffee Type', fontsize=12)
ax.set_ylabel('Total Sales ($)', fontsize=12)
ax.legend(title='Year')
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: f'${x:,.0f}'))
plt.xticks(rotation=45, ha='right')
ax.grid(axis='y', linestyle='--', alpha=0.7)

plt.tight_layout()
plt.savefig('chart_v2.png', dpi=300)
plt.close()
</execute_python>


In [71]:
# Get the code within the <execute_python> tags
match = re.search(r"<execute_python>([\s\S]*?)</execute_python>", code_v2)
if match:
    reflected_code = match.group(1).strip()
    exec_globals = {"df": df}
    exec(reflected_code, exec_globals)


In [72]:
def run_workflow(
    dataset_path: str,
    user_instructions: str,
    generation_model: str,
    reflection_model: str,   
    image_basename: str = "chart",
):
    """
    End-to-end pipeline:
      1) load dataset
      2) generate V1 code
      3) execute V1 ‚Üí produce chart_v1.png
      4) reflect on V1 (image + original code) ‚Üí feedback + refined code
      5) execute V2 ‚Üí produce chart_v2.png

    Returns a dict with all artifacts (codes, feedback, image paths).
    """
    # 0) Load dataset; utils handles parsing and feature derivations (e.g., year/quarter)
    df = pd.read_csv(dataset_path)
    df["date"] = pd.to_datetime(df["date"], errors="coerce")
    df["quarter"] = df["date"].dt.quarter
    df["month"] = df["date"].dt.month
    df["year"] = df["date"].dt.year
    

    # Paths to store charts
    out_v1 = f"{image_basename}_v1.png"
    out_v2 = f"{image_basename}_v2.png"

    # 1) Generate code (V1)
    print("Step 1: Generating chart code (V1)‚Ä¶ üìà")
    code_v1 = generate_chart_code(
        instruction=user_instructions,
        model=generation_model,
        out_path_v1=out_v1,
    )

    print("LLM output with first draft code (V1)")
    print()
    print(code_v1)
    print()

    # 2) Execute V1 (hard-coded: extract <execute_python> block and run immediately)
    print("Step 2: Executing chart code (V1)‚Ä¶ üíª")
    match = re.search(r"<execute_python>([\s\S]*?)</execute_python>", code_v1)
    if match:
        initial_code = match.group(1).strip()
        exec_globals = {"df": df}
        exec(initial_code, exec_globals)
    

    # 3) Reflect on V1 (image + original code) to get feedback and refined code (V2)
    print("Step 3: Reflecting on V1 (image + code) and generating improvements‚Ä¶ üîÅ")
    
    feedback, code_v2 = reflect_on_image_and_regenerate(
        chart_path=out_v1,
        instruction=user_instructions,
        model_name=reflection_model,
        out_path_v2=out_v2,
        code_v1=code_v1,  # pass original code for context
    )
    print("Reflection feedback on V1")
    print(feedback)
    print()
    print("LLM output with revised code (V2)")
    print(code_v2)
    print()

    # 4) Execute V2 (hard-coded: extract <execute_python> block and run immediately)
    print("Step 4: Executing refined chart code (V2)‚Ä¶ üñºÔ∏è")
    match = re.search(r"<execute_python>([\s\S]*?)</execute_python>", code_v2)
    if match:
        reflected_code = match.group(1).strip()
        exec_globals = {"df": df}
        exec(reflected_code, exec_globals)

    return {
        "code_v1": code_v1,
        "chart_v1": out_v1,
        "feedback": feedback,
        "code_v2": code_v2,
        "chart_v2": out_v2,
    }

In [73]:
user_instructions="Create a plot comparing Q1 coffee sales in 2024 and 2025 using the data in coffee_sales.csv." # write your instruction here
generation_model="gpt-4.1-mini"
reflection_model="o4-mini"
image_basename="drink_sales"

# Run the complete agentic workflow
_ = run_workflow(
    dataset_path="coffee_sales.csv",
    user_instructions=user_instructions,
    generation_model=generation_model,
    reflection_model=reflection_model,
    image_basename=image_basename
)

Step 1: Generating chart code (V1)‚Ä¶ üìà
LLM output with first draft code (V1)

<execute_python>
import matplotlib.pyplot as plt
import pandas as pd

# Filter for Q1 sales in 2024 and 2025
df_q1 = df[(df['quarter'] == 1) & (df['year'].isin([2024, 2025]))]

# Aggregate sales by coffee_name and year
sales_summary = df_q1.groupby(['year', 'coffee_name'])['price'].sum().unstack()

# Plotting
fig, ax = plt.subplots(figsize=(10, 6))
sales_summary.plot(kind='bar', ax=ax)

ax.set_title('Q1 Coffee Sales Comparison: 2024 vs 2025')
ax.set_xlabel('Coffee Name')
ax.set_ylabel('Total Sales ($)')
ax.legend(title='Year')

plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig('drink_sales_v1.png', dpi=300)
plt.close()
</execute_python>

Step 2: Executing chart code (V1)‚Ä¶ üíª
Step 3: Reflecting on V1 (image + code) and generating improvements‚Ä¶ üîÅ
Reflection feedback on V1
The current chart uses years as the x-axis but labels it ‚ÄòCoffee Name‚Äô and shows coffee types in a legend t