In [69]:
from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI
from langchain.tools import BaseTool, Tool, tool
from langgraph.graph import StateGraph, END
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage, FunctionMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.prebuilt import ToolInvocation
from langchain.tools import format_tool_to_openai_function
from typing import TypedDict, Sequence




import json

import os

load_dotenv()

api_key = os.getenv("API_KEY")

In [70]:
llm = ChatOpenAI(model="gpt-4", temperature=0)

# Memory saver for the workflow
memory = MemorySaver()
config = {"configurable": {"thread_id": "thread-1"}}


In [71]:
class AgentState(TypedDict):
    messages: Sequence[BaseMessage]

In [72]:
@tool("previous_cell", return_direct=True)
def previous_cell(input: str) -> str:
    """Fetch the content of the previous cell given its index."""
    try:
        idx = int(input)
        with open("../test_data/test_book.ipynb", "r") as f:
            notebook_data = json.load(f)
            cells = notebook_data.get("cells", [])
            if 0 <= idx < len(cells) and cells[idx]["cell_type"] == "code":
                return "".join(cells[idx]["source"]).strip()
            return "No code found or invalid index."
    except Exception as e:
        return f"Error fetching previous cell: {str(e)}"

In [73]:
@tool("check_for_unwritten_function", return_direct=True)
def check_for_unwritten_function(input: str) -> str:
    """
    Given code from a cell, inspect for any function without a body (e.g., 'pass' or empty).
    Return the name of the function, or "" if all functions are complete.
    """
    import ast
    try:
        tree = ast.parse(input)
        for node in tree.body:
            if isinstance(node, ast.FunctionDef) and isinstance(node.body[0], ast.Pass):
                return node.name  # Function name
        return ""  # No unwritten functions
    except Exception as e:
        return f"Error analyzing code: {str(e)}"


In [74]:
@tool("find_function_calls", return_direct=True)
def find_function_calls(input: str) -> str:
    """Find all cells in the notebook where the given function is called. Returns a list of cell content."""
    try:
        func_name, idx = input.split("|")
        idx = int(idx)
        with open("../test_data/test_book.ipynb", "r") as f:
            notebook_data = json.load(f)
            cells = notebook_data.get("cells", [])
            return json.dumps(
                [
                    "".join(cell["source"])
                    for cell in cells
                    if cell["cell_type"] == "code" and func_name in "".join(cell["source"])
                ]
            )
    except Exception as e:
        return f"Error searching for function calls: {str(e)}"


In [75]:
tools = [previous_cell, check_for_unwritten_function, find_function_calls]
tool_executer = ToolExecutor(tools)
functions = [format_tool_to_openai_function(t) for t in tools]
model = llm.bind_functions(functions)

  tool_executer = ToolExecutor(tools)


In [76]:
def agent(state):
    messages = state['messages']
    response = model.invoke(messages)
    return {"messages": [response]}

In [77]:
def should_continue(state):
    messages = state["messages"]
    last_message = messages[-1]
    if "function_call" not in last_message.additional_kwargs:
        return "end"
    else: 
        return "continue"

In [78]:
def call_tool(state):
    messages = state["messages"]
    last_message = messages[-1]

    action = ToolInvocation(
        tool = last_message.additional_kwargs["function_call"]["name"],
        tool_input = json.loads(last_message.additional_kwargs["function_call"]["arguments"]))

    response = tool_executer.invoke(action)
    function_message = FunctionMessage(content=str(response), name = action.tool)
    return {"messages": [function_message]}

In [79]:
workflow = StateGraph(AgentState)

workflow.add_node("agent", agent)
workflow.add_node("action", call_tool)

workflow.set_entry_point("agent")

workflow.add_conditional_edges("agent", should_continue, { "continue": "action", "end": END})

workflow.add_edge("action", "agent")

app = workflow.compile(checkpointer = memory)

In [82]:
def fix_notebook_code_cells(input_path, output_path):
    with open(input_path, "r") as f:
        notebook_data = json.load(f)

    cells = notebook_data.get("cells", [])
    processed_cells = []

    for idx, cell in enumerate(cells):
        if cell["cell_type"] == "code":
            code = "".join(cell.get("source", "")).strip()
            if code:
                system_message = SystemMessage(
                    content="You are an AI assistant for completing Python notebooks."
                )
                human_message = HumanMessage(
                    content=f"Cell {idx}: {code}\n\nCheck for unwritten functions and fix them."
                )
                
                inputs = {"messages": [system_message, human_message]}
                
                result = app.invoke(inputs)
                for output in app.stream(inputs, config):
                    for key, value in output.items():
                        print(f"Output from node: '{key}':")
                        print("---")
                        print(value)
                        print("\n---\n")
                
                fixed_code = result["messages"][-1].content
                cell["source"] = [fixed_code]
            
            processed_cells.append(cell)
        elif cell["cell_type"] == "markdown":
            processed_cells.append(cell)

    notebook_data["cells"] = processed_cells

    with open(output_path, "w") as f:
        json.dump(notebook_data, f, indent=4)

    print(f"Updated notebook saved to {output_path}")


In [83]:
fix_notebook_code_cells("../test_data/test_book.ipynb", "../test_data/fixed_test.ipynb")

ValueError: Checkpointer requires one or more of the following 'configurable' keys: ['thread_id', 'checkpoint_ns', 'checkpoint_id']