In [122]:
from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI
from langchain.tools import BaseTool, Tool, tool
from langgraph.graph import StateGraph, END
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage, FunctionMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.prebuilt import ToolInvocation
from langchain.tools import format_tool_to_openai_function
from typing import TypedDict, Sequence




import json

import os

load_dotenv()

api_key = os.getenv("OPENAI_API_KEY")

os.environ["OPENAI_API_KEY"] = api_key

In [123]:
llm = ChatOpenAI(api_key=api_key, model="gpt-4", temperature=0, streaming=True)

# Memory saver for the workflow
memory = MemorySaver()


In [124]:
class AgentState(TypedDict):
    messages: Sequence[BaseMessage]

In [125]:
@tool("previous_cell", return_direct=True)
def previous_cell(input: str) -> str:
    """Fetch the content of the previous cell given its index."""
    try:
        idx = int(input)
        with open("../test_data/test_book.ipynb", "r") as f:
            notebook_data = json.load(f)
            cells = notebook_data.get("cells", [])
            if 0 <= idx < len(cells) and cells[idx]["cell_type"] == "code":
                return "".join(cells[idx]["source"]).strip()
            return "No code found or invalid index."
    except Exception as e:
        return f"Error fetching previous cell: {str(e)}"

In [126]:
@tool("check_for_unwritten_function", return_direct=True)
def check_for_unwritten_function(code):
    """
    Given the code from a cell, determine if there are any function definitions whose body still 
    needs to be written in that cell, return their names if there are any.
    """
    unwritten_functions = []
    error = None

    # Split the code by newlines and analyze each line
    lines = code.split("\n")

    # Regex to match function definitions
    func_pattern = re.compile(r"def (\w+)\s?\(")

    for idx, line in enumerate(lines):
        match = func_pattern.match(line.strip())
        if match:
            func_name = match.group(1)
            # Check the next lines for the body of the function
            body_lines = lines[idx + 1:]  # Body starts from the next line after "def"
            
            # Identify the body of the function
            body = ""
            for body_line in body_lines:
                body += body_line.strip() + "\n"
            
            # Check the function body for unwritten elements: 'pass', comment-only, or empty
            if 'pass' in body.strip() or all(line.lstrip().startswith('#') for line in body_lines):
                unwritten_functions.append(func_name)

    result = {
        "unwritten_functions": unwritten_functions,
        "error": error
    }

    return json.dumps(result)


In [127]:
check_for_unwritten_function("""def convert_to_string(list): 
    pass""")

'{"unwritten_functions": ["convert_to_string"], "error": null}'

In [128]:
def test_check_for_unwritten_function():
    cases = [
        # (code input, expected unwritten functions, expect error)
        ("def convert_to_string(list):\n    pass", ["convert_to_string"], None),
        ("def empty_function():\n    # This is a placeholder", ["empty_function"], None),
        ("def valid_function():\n    # Initial comment\n    print('Hello, world!')", [], None),
        ("def func_with_pass():\n    pass\n\ndef func_with_body():\n    print('I do stuff')\n\ndef empty_func():",
         ["func_with_pass", "empty_func"], None),
        ("nums = [1, 2, 3, 4, 5]", [], None),
        ("def todo_func():\n    # TODO", ["todo_func"], None),
        ("def comment_only_func():\n    #", ["comment_only_func"], None)
    ]

    for i, (code, expected_functions, expected_error) in enumerate(cases):
        print(f"Test Case {i + 1}: Input: {code!r}")
        result = check_for_unwritten_function(code)
        result_dict = json.loads(result)

        unwritten = result_dict.get("unwritten_functions", [])
        error = result_dict.get("error")

        if unwritten == expected_functions and (expected_error is None or expected_error in (error or "")):
            print("✔ Test passed")
        else:
            print(f"✘ Test failed: Expected functions {expected_functions} and error '{expected_error}', "
                  f"but got functions {unwritten} and error '{error}'")


test_check_for_unwritten_function()

Test Case 1: Input: 'def convert_to_string(list):\n    pass'
✔ Test passed
Test Case 2: Input: 'def empty_function():\n    # This is a placeholder'
✔ Test passed
Test Case 3: Input: "def valid_function():\n    # Initial comment\n    print('Hello, world!')"
✔ Test passed
Test Case 4: Input: "def func_with_pass():\n    pass\n\ndef func_with_body():\n    print('I do stuff')\n\ndef empty_func():"
✔ Test passed
Test Case 5: Input: 'nums = [1, 2, 3, 4, 5]'
✔ Test passed
Test Case 6: Input: 'def todo_func():\n    # TODO'
✔ Test passed
Test Case 7: Input: 'def comment_only_func():\n    #'
✔ Test passed


In [129]:
@tool("find_function_calls", return_direct=True)
def find_function_calls(input: str) -> str:
    """Finds all cells in the notebook where the unwritten function is called. Returns a list of cell content."""
    try:
        func_name = input
        with open("../test_data/test_book.ipynb", "r") as f:
            notebook_data = json.load(f)
            cells = notebook_data.get("cells", [])
            return json.dumps(
                [
                    "".join(cell["source"])
                    for cell in cells
                    if cell["cell_type"] == "code" and func_name in "".join(cell["source"])
                ]
            )
    except Exception as e:
        return f"Error searching for function calls: {str(e)}"


In [130]:
tools = [previous_cell, check_for_unwritten_function, find_function_calls]
tool_executer = ToolExecutor(tools)
functions = [format_tool_to_openai_function(t) for t in tools]
model = llm.bind_functions(functions)

  tool_executer = ToolExecutor(tools)


In [131]:
def agent(state):
    messages = state['messages']
    response = model.invoke(messages)
    return {"messages": [response]}

In [132]:
def should_continue(state):
    messages = state["messages"]
    last_message = messages[-1]
    if "function_call" not in last_message.additional_kwargs:
        return "end"
    else: 
        return "continue"

In [133]:
def call_tool(state):
    messages = state["messages"]
    last_message = messages[-1]

    action = ToolInvocation(
        tool = last_message.additional_kwargs["function_call"]["name"],
        tool_input = json.loads(last_message.additional_kwargs["function_call"]["arguments"]))

    response = tool_executer.invoke(action)
    function_message = FunctionMessage(content=str(response), name = action.tool)
    return {"messages": [function_message]}

In [134]:
workflow = StateGraph(AgentState)

workflow.add_node("agent", agent)
workflow.add_node("action", call_tool)

workflow.set_entry_point("agent")

workflow.add_conditional_edges("agent", should_continue, { "continue": "action", "end": END})

workflow.add_edge("action", "agent")

app = workflow.compile(checkpointer = memory)

In [135]:
def fix_notebook_code_cells(input_path, output_path):
    with open(input_path, "r") as f:
        notebook_data = json.load(f)

    cells = notebook_data.get("cells", [])
    processed_cells = []

    for idx, cell in enumerate(cells):
        if cell["cell_type"] == "code":
            code = "".join(cell.get("source", "")).strip()
            if code:
                config = {"configurable": {"thread_id": "thread-1"}}

                system_message = SystemMessage(
                    content="You are an AI assistant for completing Python notebooks."
                )
                human_message = HumanMessage(
                    content=f"Cell {idx}: {code}\n\nCheck for unwritten functions. If there are, search for context to use to fill in the bodies of those functions. If not, return the original code cell."
                )
                
                inputs = {"messages": [system_message, human_message]}
                
                result = app.invoke(inputs, config)
                print("cell: ", idx)
                for output in app.stream(inputs, config):
                    for key, value in output.items():
                        print(f"Output from node: '{key}':")
                        print("---")
                        print(value)
                        print("\n---\n")
                
                fixed_code = result["messages"][-1].content
                cell["source"] = [fixed_code]
            
            processed_cells.append(cell)
        elif cell["cell_type"] == "markdown":
            processed_cells.append(cell)

    notebook_data["cells"] = processed_cells

    with open(output_path, "w") as f:
        json.dump(notebook_data, f, indent=4)

    print(f"Updated notebook saved to {output_path}")


In [136]:
fix_notebook_code_cells("../test_data/test_book.ipynb", "../test_data/fixed_test.ipynb")

  action = ToolInvocation(
  action = ToolInvocation(


cell:  1
Output from node: 'agent':
---
{'messages': [AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\n  "code": "def convert_to_string(list): \\n    pass"\n}', 'name': 'check_for_unwritten_function'}}, response_metadata={'finish_reason': 'function_call'}, id='run-a80c42d4-63a2-464b-92c2-0785d6b47c98-0')]}

---

Output from node: 'action':
---
{'messages': [FunctionMessage(content='{"unwritten_functions": ["convert_to_string"], "error": null}', additional_kwargs={}, response_metadata={}, name='check_for_unwritten_function')]}

---



  action = ToolInvocation(


Output from node: 'agent':
---
{'messages': [AIMessage(content="The function `convert_to_string` is not yet written. Let's find where this function is called in the notebook.", additional_kwargs={'function_call': {'arguments': '{\n  "input": "convert_to_string"\n}', 'name': 'find_function_calls'}}, response_metadata={'finish_reason': 'function_call'}, id='run-2ef3fa12-2321-4eaa-bd2a-d5cba0da268c-0')]}

---

Output from node: 'action':
---
{'messages': [FunctionMessage(content='["def convert_to_string(list): \\n    pass", "nums = [1, 2, 3, 4, 5]\\nstrings = convert_to_string(nums)"]', additional_kwargs={}, response_metadata={}, name='find_function_calls')]}

---



  action = ToolInvocation(


Output from node: 'agent':
---
{'messages': [AIMessage(content='The function `convert_to_string` is called in the following cell:\n\n```python\nnums = [1, 2, 3, 4, 5]\nstrings = convert_to_string(nums)\n```', additional_kwargs={}, response_metadata={'finish_reason': 'stop'}, id='run-2bf1d852-0aa6-4237-a9cd-4b186a02f666-0')]}

---



  action = ToolInvocation(


cell:  2
Output from node: 'agent':
---
{'messages': [AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\n  "code": "nums = [1, 2, 3, 4, 5]\\nstrings = convert_to_string(nums)"\n}', 'name': 'check_for_unwritten_function'}}, response_metadata={'finish_reason': 'function_call'}, id='run-37138a27-df6c-4513-a2e0-2ba62839d4cd-0')]}

---

Output from node: 'action':
---
{'messages': [FunctionMessage(content='{"unwritten_functions": [], "error": null}', additional_kwargs={}, response_metadata={}, name='check_for_unwritten_function')]}

---



  action = ToolInvocation(


Output from node: 'agent':
---
{'messages': [AIMessage(content='It seems like there are no unwritten functions in the code. Can you please provide more details or clarify your request?', additional_kwargs={}, response_metadata={'finish_reason': 'stop'}, id='run-9a06ce94-c7f9-4c6a-aafe-ddfaa54b00f2-0')]}

---

Updated notebook saved to ../test_data/fixed_test.ipynb
