In [1]:
from dotenv import load_dotenv
load_dotenv()

%load_ext autoreload
%autoreload 2

In [14]:
from langchain_core.messages import BaseMessage
from langgraph.graph import MessagesState
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from typing_extensions import Optional, Annotated, List, Sequence, Dict
import operator

class AgentInputState(MessagesState):
    pass

class AgentState(MessagesState):
    """
    Main state for the full multi-agent research system.
    
    Extends MessagesState with additional fields for research coordination.
    Note: Some fields are duplicated across different state classes for proper
    state management between subgraphs and the main workflow.
    """

    # Input message generated from user conversation history
    input_message: List[Dict[str, str]]
    diagram_type: str
    mermaid_code: str
    n_entities: str
    entity_names: list[str] = []
    entity_details: list[str] = []
    n_edges: str
    edges: list[str] = []
    bit_ranges: list[str] = []

    # # Messages exchanged with the supervisor agent for coordination
    # supervisor_messages: Annotated[Sequence[BaseMessage], add_messages]
    # # Raw unprocessed research notes collected during the research phase
    # raw_notes: Annotated[list[str], operator.add] = []
    # # Processed and structured notes ready for report generation
    # notes: Annotated[list[str], operator.add] = []
    # # Final formatted research report
    # final_report: str

# ===== STRUCTURED OUTPUT SCHEMAS =====
class StartMessageState(BaseModel):
    """Schema for invoking the supervisor agent."""
    diagram_type: str = Field(
        description="Type of diagram provided by the user"
    )
class OCRStateEntities(BaseModel):
    """Schema for invoking the OCR agent."""
    n_entities: str = Field(
        description="Number of entities detected in the image"
    )
    entity_names: List[str] = Field(
        description="Names of entities detected in the image"
    )
    entity_details: List[str] = Field(
        description="Details of C4 entities detected in the image"
    )
class OCRStateEdges(BaseModel):
    """Schema for invoking the OCR agent."""
    n_edges: str = Field(
        description="Number of edges detected in the image"
    )
    edges: List[str] = Field(
        description="Labels of the edges detected in the image"
    )
    bit_ranges: List[str] = Field(
        description="Bit ranges of the headers detected in the image"
    )
class CodeAgentState(BaseModel):
    """Schema for invoking the Code agent."""
    mermaid_code: str = Field(
        description="Mermaid code generated by the agent"
    )

In [15]:
from datetime import datetime
from typing_extensions import Literal
import base64
from pathlib import Path

from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, AIMessage, get_buffer_string
from langgraph.graph import StateGraph, START, END
from langgraph.types import Command

from prompts.get_diag_prompt import get_diag_prompt
from prompts.extract_entity_prompt import ocr_extract_entity
from prompts.extract_edges_prompt import ocr_extract_edges
from prompts.write_mermaid_prompt import write_mermaid
from prompts.mermaid_examples import mermaid_example

# Initialize model
# model = init_chat_model(model="openai:gpt-4.1", temperature=0.0)
model = init_chat_model(model="openai:gpt-4o-mini", temperature=0.0)


# ===== WORKFLOW NODES =====

def supervisor_node(state: AgentState) -> AgentState:
    """
    Supervisor agent node that reviews research progress and provides feedback.
    """
    # Combine all messages for context
    structured_output_model = model.with_structured_output(StartMessageState)
    response = structured_output_model.invoke([*state["messages"], 
        HumanMessage(content=get_diag_prompt)
    ])
    print("Determined Diagram Type:", response.diagram_type)
    
    return {
        "diagram_type": response.diagram_type
    }

def ocr_agent_entity(state: AgentState) -> AgentState:
    """
    OCR agent node that extracts text from an image.
    """
    structured_output_model = model.with_structured_output(OCRStateEntities)
    response = structured_output_model.invoke([*state["messages"], 
        HumanMessage(content=ocr_extract_entity[state["diagram_type"]])
    ])

    print("Extracted Number of Entity Names:", response.n_entities)
    response.entity_names = [x.strip() for x in response.entity_names if x.strip()]
    print("Extracted Entity Names:", response.entity_names)    
    return {
        "n_entities": response.n_entities,
        "entity_names": response.entity_names,
        "entity_details": response.entity_details
    }

def ocr_agent_edge(state: AgentState) -> AgentState:
    """
    OCR agent node that extracts edges and bit ranges.
    """
    structured_output_model = model.with_structured_output(OCRStateEdges)
    response = structured_output_model.invoke([*state["messages"], 
        HumanMessage(content=ocr_extract_edges[state["diagram_type"]].format(
            entity_names=(state["entity_names"])
        ))
    ])
    response.bit_ranges = [x.strip() for x in response.bit_ranges if x.strip()]
    response.edges = [x.strip() for x in response.edges if x.strip()]
    # print("Extracted Edge Labels:", response.edges)
    # print("Extracted Bit Ranges:", response.bit_ranges)    
    return {
        "edges": response.edges,
        "bit_ranges": response.bit_ranges
    }
def code_agent(state: AgentState) -> AgentState:
    """
    Code generation agent node that creates mermaid code from user request.
    """
    structured_output_model = model.with_structured_output(CodeAgentState)
    response = structured_output_model.invoke([*state["messages"][1:], 
        HumanMessage(content=write_mermaid.format(
            diagram_type=state["diagram_type"],
            examples=mermaid_example[state["diagram_type"]],
            entity_names=state["entity_names"],
            entity_details=state["entity_details"],
            edges=state["edges"],
            bit_ranges=state["bit_ranges"]  
        ))
    ])
    return {
        "mermaid_code": response.mermaid_code
    }


# Build the scoping workflow
teching_graph = StateGraph(AgentState, input_schema=AgentInputState)
teching_graph.add_node("supervisor_node", supervisor_node)
teching_graph.add_node("ocr_agent_entity", ocr_agent_entity)
teching_graph.add_node("ocr_agent_edge", ocr_agent_edge)
teching_graph.add_node("code_agent", code_agent)

teching_graph.add_edge(START, "supervisor_node")
teching_graph.add_edge("supervisor_node", "ocr_agent_entity")
teching_graph.add_edge("ocr_agent_entity", "ocr_agent_edge")
teching_graph.add_edge("ocr_agent_edge", "code_agent")
teching_graph.add_edge("code_agent", END)
teching_graph_workflow = teching_graph.compile()



In [16]:
img_path = Path("42.png")
if not img_path.exists():
    raise FileNotFoundError("42.jpg not found in the current working directory")

b64 = base64.b64encode(img_path.read_bytes()).decode("ascii")
data_uri = f"data:image/jpeg;base64,{b64}"

first_message = HumanMessage(content=[
    {"type": "text", "text": "Give mermaid code of the given diagram."},
    {"type": "image_url", "image_url": {"url": data_uri}}
])

# result = teching_graph_workflow.invoke({"messages": [HumanMessage(content="I want to research the best coffee shops in")]}, config=thread)
result = teching_graph_workflow.invoke({"messages": first_message})
print(result["mermaid_code"])

Determined Diagram Type: C4
Extracted Number of Entity Names: 4
Extracted Entity Names: ['Customer', 'Event Manager', 'Event Booking Database', 'Registration System']
```mermaid
C4Context

        Enterprise_Boundary(b1, "Event Management Boundary") {
            Person(Customer, "Customer", "Customer wants to book an event")
            Person(EventManager, "Event Manager", "Organizes the event")
            SystemDb(EventBookingDatabase, "Event Booking Database", "Stores customer booking info")
            System(RegistrationSystem, "Registration System", "A website takes info from customer")
    }
    Rel(Customer, RegistrationSystem, "Uses")
    Rel(Customer, EventBookingDatabase, "Uses")
    Rel(EventManager, EventBookingDatabase, "Uses")

 UpdateLayoutConfig($c4ShapeInRow="3", $c4BoundaryInRow="1")
```
