In [8]:
from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI
from langchain.tools import BaseTool, Tool, tool
from langgraph.graph import StateGraph, END
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage, FunctionMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.prebuilt import ToolInvocation
from langchain.tools import format_tool_to_openai_function
from typing import TypedDict, Sequence
import re




import json

import os

load_dotenv()

api_key = os.getenv("OPENAI_API_KEY")

os.environ["OPENAI_API_KEY"] = api_key

In [58]:
llm = ChatOpenAI(api_key=api_key, model="gpt-4", temperature=0, streaming=True)

# Memory saver for the workflow
memory = MemorySaver()


In [59]:
class AgentState(TypedDict):
    messages: Sequence[BaseMessage]

In [60]:
@tool("previous_cell", return_direct=True)
def previous_cell(input: str) -> str:
    """Fetch the content of the previous cell given its index. 
    This can potentially help provide more context to write the body of the function"""
    try:
        idx = int(input)
        with open("../test_data/test_book.ipynb", "r") as f:
            notebook_data = json.load(f)
            cells = notebook_data.get("cells", [])
            if 0 <= idx < len(cells) and cells[idx]["cell_type"] == "code":
                return "".join(cells[idx]["source"]).strip()
            return "No code found or invalid index."
    except Exception as e:
        return f"Error fetching previous cell: {str(e)}"

In [61]:
@tool("check_for_unwritten_function", return_direct=True)
def check_for_unwritten_function(code):
    """
    Given the code from a cell, determine if there are any function definitions whose body still 
    needs to be written in that cell, return their names if there are any.
    """
    unwritten_functions = []
    error = None

    # Split the code by newlines and analyze each line
    lines = code.split("\n")

    # Regex to match function definitions
    func_pattern = re.compile(r"def (\w+)\s?\(")

    for idx, line in enumerate(lines):
        match = func_pattern.match(line.strip())
        if match:
            func_name = match.group(1)
            # Check the next lines for the body of the function
            body_lines = lines[idx + 1:]  # Body starts from the next line after "def"
            
            # Identify the body of the function
            body = ""
            for body_line in body_lines:
                body += body_line.strip() + "\n"
            
            # Check the function body for unwritten elements: 'pass', comment-only, or empty
            if 'pass' in body.strip() or all(line.lstrip().startswith('#') for line in body_lines):
                unwritten_functions.append(func_name)

    result = {
        "unwritten_functions": unwritten_functions,
        "error": error
    }

    return json.dumps(result)


In [62]:
check_for_unwritten_function("""def convert_to_string(list): 
    pass""")

'{"unwritten_functions": ["convert_to_string"], "error": null}'

In [63]:
def test_check_for_unwritten_function():
    cases = [
        # (code input, expected unwritten functions, expect error)
        ("def convert_to_string(list):\n    pass", ["convert_to_string"], None),
        ("def empty_function():\n    # This is a placeholder", ["empty_function"], None),
        ("def valid_function():\n    # Initial comment\n    print('Hello, world!')", [], None),
        ("def func_with_pass():\n    pass\n\ndef func_with_body():\n    print('I do stuff')\n\ndef empty_func():",
         ["func_with_pass", "empty_func"], None),
        ("nums = [1, 2, 3, 4, 5]", [], None),
        ("def todo_func():\n    # TODO", ["todo_func"], None),
        ("def comment_only_func():\n    #", ["comment_only_func"], None)
    ]

    for i, (code, expected_functions, expected_error) in enumerate(cases):
        print(f"Test Case {i + 1}: Input: {code!r}")
        result = check_for_unwritten_function(code)
        result_dict = json.loads(result)

        unwritten = result_dict.get("unwritten_functions", [])
        error = result_dict.get("error")

        if unwritten == expected_functions and (expected_error is None or expected_error in (error or "")):
            print("✔ Test passed")
        else:
            print(f"✘ Test failed: Expected functions {expected_functions} and error '{expected_error}', "
                  f"but got functions {unwritten} and error '{error}'")


test_check_for_unwritten_function()

Test Case 1: Input: 'def convert_to_string(list):\n    pass'
✔ Test passed
Test Case 2: Input: 'def empty_function():\n    # This is a placeholder'
✔ Test passed
Test Case 3: Input: "def valid_function():\n    # Initial comment\n    print('Hello, world!')"
✔ Test passed
Test Case 4: Input: "def func_with_pass():\n    pass\n\ndef func_with_body():\n    print('I do stuff')\n\ndef empty_func():"
✔ Test passed
Test Case 5: Input: 'nums = [1, 2, 3, 4, 5]'
✔ Test passed
Test Case 6: Input: 'def todo_func():\n    # TODO'
✔ Test passed
Test Case 7: Input: 'def comment_only_func():\n    #'
✔ Test passed


In [64]:
@tool("find_function_calls", return_direct=True)
def find_function_calls(input: str) -> str:
    """Finds all cells in the notebook where the unwritten function is called. Returns a list of cell content."""
    try:
        func_name = input
        with open("../test_data/test_book.ipynb", "r") as f:
            notebook_data = json.load(f)
            cells = notebook_data.get("cells", [])
            return json.dumps(
                [
                    "".join(cell["source"])
                    for cell in cells
                    if cell["cell_type"] == "code" and func_name in "".join(cell["source"])
                ]
            )
    except Exception as e:
        return f"Error searching for function calls: {str(e)}"


In [65]:
tools = [previous_cell, check_for_unwritten_function, find_function_calls]
tool_executer = ToolExecutor(tools)
functions = [format_tool_to_openai_function(t) for t in tools]
model = llm.bind_functions(functions)

  tool_executer = ToolExecutor(tools)


In [66]:
def agent(state):
    messages = state['messages']
    # print("MESSAGES: ", messages)
    response = model.invoke(messages)
    return {"messages": messages + [response]}

In [67]:
def should_continue(state):
    messages = state["messages"]
    last_message = messages[-1]
    if "function_call" not in last_message.additional_kwargs:
        return "end"
    else: 
        return "continue"

In [68]:
def call_tool(state):
    messages = state["messages"]
    last_message = messages[-1]

    action = ToolInvocation(
        tool = last_message.additional_kwargs["function_call"]["name"],
        tool_input = json.loads(last_message.additional_kwargs["function_call"]["arguments"]))

    response = tool_executer.invoke(action)
    function_message = FunctionMessage(content=str(response), name = action.tool)
    return {"messages": messages + [function_message]}

In [69]:
workflow = StateGraph(AgentState)

workflow.add_node("agent", agent)
workflow.add_node("action", call_tool)

workflow.set_entry_point("agent")

workflow.add_conditional_edges("agent", should_continue, { "continue": "action", "end": END})

workflow.add_edge("action", "agent")

app = workflow.compile(checkpointer = memory)

In [72]:
def extract_code_blocks(content):
    """
    Extracts code blocks from the response content.
    """
    code_blocks = re.findall(r'```python\n(.*?)```', content, re.DOTALL)
    return "\n".join(code_blocks) if code_blocks else content


In [73]:
extract_code_blocks("""The function `convert_to_string` is called with a list of numbers as an argument. It seems like the function is expected to convert each number in the list to a string. Here is a possible implementation:

```python
def convert_to_string(lst):
    return [str(item) for item in lst]
```

This function uses a list comprehension to iterate over each item in the list, convert it to a string using the `str` function, and return a new list with the converted items.""")

'def convert_to_string(lst):\n    return [str(item) for item in lst]\n'

In [74]:
def fix_notebook_code_cells(input_path, output_path):
    with open(input_path, "r") as f:
        notebook_data = json.load(f)

    cells = notebook_data.get("cells", [])
    processed_cells = []

    for idx, cell in enumerate(cells):
        if cell["cell_type"] == "code":
            code = "".join(cell.get("source", "")).strip()
            if code:
                config = {"configurable": {"thread_id": "thread-1"}}

                system_message = SystemMessage(
                    content="You are an AI assistant for completing Python notebooks."
                )
                human_message = HumanMessage(
                    content=f"Cell {idx}: {code}\n\nCheck for unwritten functions. If there are, search for context to use to fill in the bodies of those functions. If not, return the original code cell."
                )
                
                inputs = {"messages": [system_message, human_message]}
                
                result = app.invoke(inputs, config)
                print("cell: ", idx)
                for output in app.stream(inputs, config):
                    for key, value in output.items():
                        print(f"Output from node: '{key}':")
                        print("---")
                        print(value)
                        print("\n---\n")
                
                fixed_code = result["messages"][-1].content
                fixed_code = extract_code_blocks(fixed_code)
                cell["source"] = [fixed_code]
            
            processed_cells.append(cell)
        elif cell["cell_type"] == "markdown":
            processed_cells.append(cell)

    notebook_data["cells"] = processed_cells

    with open(output_path, "w") as f:
        json.dump(notebook_data, f, indent=4)

    print(f"Updated notebook saved to {output_path}")


In [75]:
fix_notebook_code_cells("../test_data/test_book.ipynb", "../test_data/fixed_test.ipynb")

  action = ToolInvocation(
  action = ToolInvocation(


cell:  1
Output from node: 'agent':
---
{'messages': [SystemMessage(content='You are an AI assistant for completing Python notebooks.', additional_kwargs={}, response_metadata={}), HumanMessage(content='Cell 1: def convert_to_string(list): \n    pass\n\nCheck for unwritten functions. If there are, search for context to use to fill in the bodies of those functions. If not, return the original code cell.', additional_kwargs={}, response_metadata={}), AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\n  "code": "def convert_to_string(list): \\n    pass"\n}', 'name': 'check_for_unwritten_function'}}, response_metadata={'finish_reason': 'function_call'}, id='run-d8a61da0-84f0-4217-ad0c-ba2ad6b1fdf5-0')]}

---

Output from node: 'action':
---
{'messages': [SystemMessage(content='You are an AI assistant for completing Python notebooks.', additional_kwargs={}, response_metadata={}), HumanMessage(content='Cell 1: def convert_to_string(list): \n    pass\n\nCheck for unwr

  action = ToolInvocation(


Output from node: 'agent':
---
{'messages': [SystemMessage(content='You are an AI assistant for completing Python notebooks.', additional_kwargs={}, response_metadata={}), HumanMessage(content='Cell 1: def convert_to_string(list): \n    pass\n\nCheck for unwritten functions. If there are, search for context to use to fill in the bodies of those functions. If not, return the original code cell.', additional_kwargs={}, response_metadata={}), AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\n  "code": "def convert_to_string(list): \\n    pass"\n}', 'name': 'check_for_unwritten_function'}}, response_metadata={'finish_reason': 'function_call'}, id='run-d8a61da0-84f0-4217-ad0c-ba2ad6b1fdf5-0'), FunctionMessage(content='{"unwritten_functions": ["convert_to_string"], "error": null}', additional_kwargs={}, response_metadata={}, name='check_for_unwritten_function'), AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\n  "input": "convert_to_string

  action = ToolInvocation(


Output from node: 'agent':
---
{'messages': [SystemMessage(content='You are an AI assistant for completing Python notebooks.', additional_kwargs={}, response_metadata={}), HumanMessage(content='Cell 1: def convert_to_string(list): \n    pass\n\nCheck for unwritten functions. If there are, search for context to use to fill in the bodies of those functions. If not, return the original code cell.', additional_kwargs={}, response_metadata={}), AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\n  "code": "def convert_to_string(list): \\n    pass"\n}', 'name': 'check_for_unwritten_function'}}, response_metadata={'finish_reason': 'function_call'}, id='run-d8a61da0-84f0-4217-ad0c-ba2ad6b1fdf5-0'), FunctionMessage(content='{"unwritten_functions": ["convert_to_string"], "error": null}', additional_kwargs={}, response_metadata={}, name='check_for_unwritten_function'), AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\n  "input": "convert_to_string

  action = ToolInvocation(


cell:  2
Output from node: 'agent':
---
{'messages': [SystemMessage(content='You are an AI assistant for completing Python notebooks.', additional_kwargs={}, response_metadata={}), HumanMessage(content='Cell 2: nums = [1, 2, 3, 4, 5]\nstrings = convert_to_string(nums)\n\nCheck for unwritten functions. If there are, search for context to use to fill in the bodies of those functions. If not, return the original code cell.', additional_kwargs={}, response_metadata={}), AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\n  "code": "nums = [1, 2, 3, 4, 5]\\nstrings = convert_to_string(nums)"\n}', 'name': 'check_for_unwritten_function'}}, response_metadata={'finish_reason': 'function_call'}, id='run-09a61ffb-50e3-4af4-a15e-c71035a6d24b-0')]}

---

Output from node: 'action':
---
{'messages': [SystemMessage(content='You are an AI assistant for completing Python notebooks.', additional_kwargs={}, response_metadata={}), HumanMessage(content='Cell 2: nums = [1, 2, 3, 4, 5

  action = ToolInvocation(


Output from node: 'agent':
---
{'messages': [SystemMessage(content='You are an AI assistant for completing Python notebooks.', additional_kwargs={}, response_metadata={}), HumanMessage(content='Cell 2: nums = [1, 2, 3, 4, 5]\nstrings = convert_to_string(nums)\n\nCheck for unwritten functions. If there are, search for context to use to fill in the bodies of those functions. If not, return the original code cell.', additional_kwargs={}, response_metadata={}), AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\n  "code": "nums = [1, 2, 3, 4, 5]\\nstrings = convert_to_string(nums)"\n}', 'name': 'check_for_unwritten_function'}}, response_metadata={'finish_reason': 'function_call'}, id='run-09a61ffb-50e3-4af4-a15e-c71035a6d24b-0'), FunctionMessage(content='{"unwritten_functions": [], "error": null}', additional_kwargs={}, response_metadata={}, name='check_for_unwritten_function'), AIMessage(content='There are no unwritten functions in the code cell. The original code 

# Testing and Evaluation - "leave k out cross fold validation" 
### Objective: ###
Evaluate the agent's performance by systematically modifying a pre-written notebook and measuring how well the agent restores missing function bodies.

### Process: ###

1. Input: A Jupyter notebook containing multiple fully implemented functions with a definitive expected output.
2. Loop Through Function Removals:
    * Start by systematically removing function bodies one by one and testing the agent’s ability to reconstruct them.
    * Extend this by removing combinations of multiple function bodies, increasing the difficulty progressively (remove 1 function, then all possible pairs, then groups of 3, ..., up to removing all functions).
3. For Each Modified Notebook:
    * Run the agent to attempt function reconstruction.
    * Execute the notebook to see if it still produces the expected output.
    * If the output is correct, increment the agent’s success score.
4. Performance Evaluation:
    * At each level (1 missing function, 2 missing functions, etc.), compute the average success rate (percentage of notebooks where the agent successfully restored functionality).
    * Print and analyze the aggregated success rate at each level to observe performance trends.

In [6]:
from openai import OpenAI
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
    model="gpt-4",
    messages=[{"role": "system", "content": "Say hello!"}]
)
print(response)

NameError: name 'api_key' is not defined